ducky_learn/
util.rs

1extern crate ndarray;
2
3use ndarray::Array2;
4use std::error::Error;
5
6/// Marker struct indicating a model that has not been fit.
7pub struct Unfit;
8
9/// Marker struct indicating a model that has been fit.
10pub struct Fit;
11
12/// Generates a one-hot encoding for a vector of integers.
13///
14/// # Arguments
15///
16/// * `input_array`: List of integers to be encoded. Each integer should be less than or equal to the maximum integer in the array.
17///
18/// Returns: `Array2<f64>` where each row represents the one-hot encoding of the corresponding integer from the input array. The columns represent the range of integers from the input array.
19///
20/// # Errors
21///
22/// Returns an error if the `input_array` is empty.
23///
24/// # Examples
25///
26/// ```
27/// use ducky_learn::util::one_hot_encoding_vec;
28/// use ndarray::prelude::*;
29///
30/// let input_array = vec![2, 0, 1];
31/// let output_array = array![
32///     [0., 0., 1.],
33///     [1., 0., 0.],
34///     [0., 1., 0.]
35/// ];
36///
37/// assert_eq!(one_hot_encoding_vec(input_array).unwrap(), output_array);
38/// ```
39pub fn one_hot_encoding_vec<T: AsRef<[usize]>>(
40    input_array: T,
41) -> Result<Array2<f64>, Box<dyn Error>> {
42    let input_array = input_array.as_ref();
43    let max_val = match input_array.iter().max() {
44        Some(&max) => max + 1,
45        None => return Err("Empty input array".into()),
46    };
47
48    let mut encoding_array: Vec<Vec<f64>> = Vec::with_capacity(input_array.len());
49
50    for &input_value in input_array {
51        let mut row = vec![0.; max_val];
52        row[input_value] = 1.;
53        encoding_array.push(row);
54    }
55
56    let data: Vec<f64> = encoding_array.into_iter().flatten().collect();
57    let n_row = input_array.len();
58    let n_col = max_val;
59
60    Array2::from_shape_vec((n_row, n_col), data).map_err(|err| err.into())
61}
62
63#[cfg(test)]
64mod util_tests {
65    use super::*;
66    use ndarray::array;
67    use ndarray::Array2;
68
69    #[test]
70    fn test_empty() {
71        let input: Vec<usize> = vec![];
72        assert!(one_hot_encoding_vec(input).is_err());
73    }
74
75    #[test]
76    fn test_single_element() {
77        let input: Vec<usize> = vec![0];
78        let expected: Array2<f64> = array![[1.]];
79        assert_eq!(one_hot_encoding_vec(input).unwrap(), expected);
80    }
81
82    #[test]
83    fn test_multiple_elements() {
84        let input: Vec<usize> = vec![0, 2, 1, 3];
85        let expected: Array2<f64> = array![
86            [1., 0., 0., 0.],
87            [0., 0., 1., 0.],
88            [0., 1., 0., 0.],
89            [0., 0., 0., 1.]
90        ];
91        assert_eq!(one_hot_encoding_vec(input).unwrap(), expected);
92    }
93}