1extern crate ndarray;
2
3use ndarray::Array2;
4use std::error::Error;
5
6pub struct Unfit;
8
9pub struct Fit;
11
12pub fn one_hot_encoding_vec<T: AsRef<[usize]>>(
40 input_array: T,
41) -> Result<Array2<f64>, Box<dyn Error>> {
42 let input_array = input_array.as_ref();
43 let max_val = match input_array.iter().max() {
44 Some(&max) => max + 1,
45 None => return Err("Empty input array".into()),
46 };
47
48 let mut encoding_array: Vec<Vec<f64>> = Vec::with_capacity(input_array.len());
49
50 for &input_value in input_array {
51 let mut row = vec![0.; max_val];
52 row[input_value] = 1.;
53 encoding_array.push(row);
54 }
55
56 let data: Vec<f64> = encoding_array.into_iter().flatten().collect();
57 let n_row = input_array.len();
58 let n_col = max_val;
59
60 Array2::from_shape_vec((n_row, n_col), data).map_err(|err| err.into())
61}
62
63#[cfg(test)]
64mod util_tests {
65 use super::*;
66 use ndarray::array;
67 use ndarray::Array2;
68
69 #[test]
70 fn test_empty() {
71 let input: Vec<usize> = vec![];
72 assert!(one_hot_encoding_vec(input).is_err());
73 }
74
75 #[test]
76 fn test_single_element() {
77 let input: Vec<usize> = vec![0];
78 let expected: Array2<f64> = array![[1.]];
79 assert_eq!(one_hot_encoding_vec(input).unwrap(), expected);
80 }
81
82 #[test]
83 fn test_multiple_elements() {
84 let input: Vec<usize> = vec![0, 2, 1, 3];
85 let expected: Array2<f64> = array![
86 [1., 0., 0., 0.],
87 [0., 0., 1., 0.],
88 [0., 1., 0., 0.],
89 [0., 0., 0., 1.]
90 ];
91 assert_eq!(one_hot_encoding_vec(input).unwrap(), expected);
92 }
93}