1use anofox_ml_core::{Result, RustMlError};
2use ndarray::Array2;
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct OneHotEncoder;
15
16impl OneHotEncoder {
17 pub fn new() -> Self {
19 Self
20 }
21
22 pub fn fit(&self, x: &Array2<usize>) -> Result<FittedOneHotEncoder> {
26 if x.is_empty() {
27 return Err(RustMlError::EmptyInput("input array is empty".into()));
28 }
29
30 let ncols = x.ncols();
31 let mut categories = Vec::with_capacity(ncols);
32
33 for j in 0..ncols {
34 let col = x.column(j);
35 let max_val = col.iter().copied().max().unwrap_or(0);
36 categories.push(max_val + 1);
37 }
38
39 Ok(FittedOneHotEncoder { categories })
40 }
41}
42
43impl Default for OneHotEncoder {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct FittedOneHotEncoder {
52 categories: Vec<usize>,
53}
54
55impl FittedOneHotEncoder {
56 pub fn transform(&self, x: &Array2<usize>) -> Result<Array2<f64>> {
60 if x.ncols() != self.categories.len() {
61 return Err(RustMlError::ShapeMismatch(format!(
62 "expected {} columns, got {}",
63 self.categories.len(),
64 x.ncols()
65 )));
66 }
67
68 let total_out_cols: usize = self.categories.iter().sum();
69 let nrows = x.nrows();
70 let mut result = Array2::<f64>::zeros((nrows, total_out_cols));
71
72 for i in 0..nrows {
73 let mut col_offset = 0;
74 for j in 0..x.ncols() {
75 let val = x[[i, j]];
76 if val >= self.categories[j] {
77 return Err(RustMlError::InvalidParameter(format!(
78 "value {} in column {} exceeds number of categories {}",
79 val, j, self.categories[j]
80 )));
81 }
82 result[[i, col_offset + val]] = 1.0;
83 col_offset += self.categories[j];
84 }
85 }
86
87 Ok(result)
88 }
89
90 pub fn categories(&self) -> &[usize] {
92 &self.categories
93 }
94
95 pub fn n_output_features(&self) -> usize {
97 self.categories.iter().sum()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use approx::assert_abs_diff_eq;
105 use ndarray::array;
106
107 #[test]
108 fn test_single_column() {
109 let x = array![[0usize], [1], [2]];
110 let encoder = OneHotEncoder::new();
111 let fitted = encoder.fit(&x).unwrap();
112 let encoded = fitted.transform(&x).unwrap();
113
114 assert_eq!(encoded.shape(), &[3, 3]);
115 assert_abs_diff_eq!(encoded[[0, 0]], 1.0);
117 assert_abs_diff_eq!(encoded[[0, 1]], 0.0);
118 assert_abs_diff_eq!(encoded[[0, 2]], 0.0);
119 assert_abs_diff_eq!(encoded[[1, 0]], 0.0);
121 assert_abs_diff_eq!(encoded[[1, 1]], 1.0);
122 assert_abs_diff_eq!(encoded[[1, 2]], 0.0);
123 assert_abs_diff_eq!(encoded[[2, 0]], 0.0);
125 assert_abs_diff_eq!(encoded[[2, 1]], 0.0);
126 assert_abs_diff_eq!(encoded[[2, 2]], 1.0);
127 }
128
129 #[test]
130 fn test_multiple_columns() {
131 let x = array![[0usize, 2], [1, 0], [0, 1]];
133 let encoder = OneHotEncoder::new();
134 let fitted = encoder.fit(&x).unwrap();
135 let encoded = fitted.transform(&x).unwrap();
136
137 assert_eq!(encoded.shape(), &[3, 5]); assert_eq!(fitted.n_output_features(), 5);
139
140 assert_abs_diff_eq!(encoded[[0, 0]], 1.0);
142 assert_abs_diff_eq!(encoded[[0, 1]], 0.0);
143 assert_abs_diff_eq!(encoded[[0, 2]], 0.0);
144 assert_abs_diff_eq!(encoded[[0, 3]], 0.0);
145 assert_abs_diff_eq!(encoded[[0, 4]], 1.0);
146
147 assert_abs_diff_eq!(encoded[[1, 0]], 0.0);
149 assert_abs_diff_eq!(encoded[[1, 1]], 1.0);
150 assert_abs_diff_eq!(encoded[[1, 2]], 1.0);
151 assert_abs_diff_eq!(encoded[[1, 3]], 0.0);
152 assert_abs_diff_eq!(encoded[[1, 4]], 0.0);
153 }
154
155 #[test]
156 fn test_binary_column() {
157 let x = array![[0usize], [1], [1], [0]];
158 let encoder = OneHotEncoder::new();
159 let fitted = encoder.fit(&x).unwrap();
160 let encoded = fitted.transform(&x).unwrap();
161
162 assert_eq!(encoded.shape(), &[4, 2]);
163 assert_eq!(fitted.categories(), &[2]);
164 }
165
166 #[test]
167 fn test_empty_input() {
168 let x: Array2<usize> = Array2::zeros((0, 0));
169 let encoder = OneHotEncoder::new();
170 assert!(encoder.fit(&x).is_err());
171 }
172
173 #[test]
174 fn test_shape_mismatch() {
175 let x_train = array![[0usize, 1], [1, 0]];
176 let encoder = OneHotEncoder::new();
177 let fitted = encoder.fit(&x_train).unwrap();
178
179 let x_wrong = array![[0usize, 1, 2]];
180 assert!(fitted.transform(&x_wrong).is_err());
181 }
182
183 #[test]
184 fn test_unknown_category_in_transform() {
185 let x_train = array![[0usize], [1]];
186 let encoder = OneHotEncoder::new();
187 let fitted = encoder.fit(&x_train).unwrap();
188
189 let x_test = array![[5usize]];
191 assert!(fitted.transform(&x_test).is_err());
192 }
193
194 #[test]
195 fn test_all_zeros() {
196 let x = array![[0usize, 0], [0, 0], [0, 0]];
197 let encoder = OneHotEncoder::new();
198 let fitted = encoder.fit(&x).unwrap();
199 let encoded = fitted.transform(&x).unwrap();
200
201 assert_eq!(encoded.shape(), &[3, 2]);
203 for i in 0..3 {
205 assert_abs_diff_eq!(encoded[[i, 0]], 1.0);
206 assert_abs_diff_eq!(encoded[[i, 1]], 1.0);
207 }
208 }
209
210 #[test]
211 fn test_row_sums() {
212 let x = array![[0usize, 2, 1], [2, 0, 0], [1, 1, 2]];
214 let encoder = OneHotEncoder::new();
215 let fitted = encoder.fit(&x).unwrap();
216 let encoded = fitted.transform(&x).unwrap();
217
218 assert_eq!(encoded.shape(), &[3, 9]);
220
221 for i in 0..3 {
223 let row_sum: f64 = encoded.row(i).sum();
224 assert_abs_diff_eq!(row_sum, 3.0, epsilon = 1e-10);
225 }
226 }
227
228 #[test]
229 fn test_default() {
230 let encoder = OneHotEncoder::default();
231 let x = array![[0usize], [1]];
232 let fitted = encoder.fit(&x).unwrap();
233 assert_eq!(fitted.categories(), &[2]);
234 }
235}