1use ferrolearn_core::error::FerroError;
14use ferrolearn_core::traits::{Fit, FitTransform, Transform};
15use ndarray::Array2;
16use num_traits::Float;
17
18#[derive(Debug, Clone)]
41pub struct OneHotEncoder<F> {
42 _marker: std::marker::PhantomData<F>,
43}
44
45impl<F: Float + Send + Sync + 'static> OneHotEncoder<F> {
46 #[must_use]
48 pub fn new() -> Self {
49 Self {
50 _marker: std::marker::PhantomData,
51 }
52 }
53}
54
55impl<F: Float + Send + Sync + 'static> Default for OneHotEncoder<F> {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61#[derive(Debug, Clone)]
69pub struct FittedOneHotEncoder<F> {
70 pub(crate) n_categories: Vec<usize>,
72 _marker: std::marker::PhantomData<F>,
73}
74
75impl<F: Float + Send + Sync + 'static> FittedOneHotEncoder<F> {
76 #[must_use]
78 pub fn n_categories(&self) -> &[usize] {
79 &self.n_categories
80 }
81
82 #[must_use]
84 pub fn n_output_features(&self) -> usize {
85 self.n_categories.iter().sum()
86 }
87}
88
89impl<F: Float + Send + Sync + 'static> Fit<Array2<usize>, ()> for OneHotEncoder<F> {
94 type Fitted = FittedOneHotEncoder<F>;
95 type Error = FerroError;
96
97 fn fit(&self, x: &Array2<usize>, _y: &()) -> Result<FittedOneHotEncoder<F>, FerroError> {
105 let n_samples = x.nrows();
106 if n_samples == 0 {
107 return Err(FerroError::InsufficientSamples {
108 required: 1,
109 actual: 0,
110 context: "OneHotEncoder::fit".into(),
111 });
112 }
113
114 let n_features = x.ncols();
115 let mut n_categories = Vec::with_capacity(n_features);
116
117 for j in 0..n_features {
118 let col = x.column(j);
119 let max_cat = col.iter().copied().max().unwrap_or(0);
120 n_categories.push(max_cat + 1);
121 }
122
123 Ok(FittedOneHotEncoder {
124 n_categories,
125 _marker: std::marker::PhantomData,
126 })
127 }
128}
129
130impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for FittedOneHotEncoder<F> {
131 type Output = Array2<F>;
132 type Error = FerroError;
133
134 fn transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
144 let n_features = self.n_categories.len();
145 if x.ncols() != n_features {
146 return Err(FerroError::ShapeMismatch {
147 expected: vec![x.nrows(), n_features],
148 actual: vec![x.nrows(), x.ncols()],
149 context: "FittedOneHotEncoder::transform".into(),
150 });
151 }
152
153 let n_out_cols = self.n_output_features();
154 let n_samples = x.nrows();
155 let mut out = Array2::zeros((n_samples, n_out_cols));
156
157 let mut col_offset = 0;
158 for j in 0..n_features {
159 let n_cats = self.n_categories[j];
160 for i in 0..n_samples {
161 let cat = x[[i, j]];
162 if cat >= n_cats {
163 return Err(FerroError::InvalidParameter {
164 name: format!("x[{i},{j}]"),
165 reason: format!(
166 "category {cat} exceeds max seen during fitting ({})",
167 n_cats - 1
168 ),
169 });
170 }
171 out[[i, col_offset + cat]] = F::one();
172 }
173 col_offset += n_cats;
174 }
175
176 Ok(out)
177 }
178}
179
180impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for OneHotEncoder<F> {
183 type Output = Array2<F>;
184 type Error = FerroError;
185
186 fn transform(&self, _x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
191 Err(FerroError::InvalidParameter {
192 name: "OneHotEncoder".into(),
193 reason: "encoder must be fitted before calling transform; use fit() first".into(),
194 })
195 }
196}
197
198impl<F: Float + Send + Sync + 'static> FitTransform<Array2<usize>> for OneHotEncoder<F> {
199 type FitError = FerroError;
200
201 fn fit_transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
207 let fitted = self.fit(x, &())?;
208 fitted.transform(x)
209 }
210}
211
212impl<F: Float + Send + Sync + 'static> FittedOneHotEncoder<F> {
217 pub fn transform_1d(&self, x: &[usize]) -> Result<Array2<F>, FerroError> {
223 if self.n_categories.len() != 1 {
224 return Err(FerroError::InvalidParameter {
225 name: "transform_1d".into(),
226 reason: "encoder was fitted on more than one column; use transform instead".into(),
227 });
228 }
229 let col = Array2::from_shape_vec((x.len(), 1), x.to_vec()).map_err(|e| {
230 FerroError::InvalidParameter {
231 name: "x".into(),
232 reason: e.to_string(),
233 }
234 })?;
235 self.transform(&col)
236 }
237}
238
239#[cfg(test)]
244mod tests {
245 use super::*;
246 use ndarray::array;
247
248 #[test]
249 fn test_one_hot_single_column() {
250 let enc = OneHotEncoder::<f64>::new();
251 let x = array![[0usize], [1], [2]];
252 let fitted = enc.fit(&x, &()).unwrap();
253 assert_eq!(fitted.n_categories(), &[3]);
254 assert_eq!(fitted.n_output_features(), 3);
255
256 let out = fitted.transform(&x).unwrap();
257 assert_eq!(out.shape(), &[3, 3]);
258 assert_eq!(out[[0, 0]], 1.0);
260 assert_eq!(out[[0, 1]], 0.0);
261 assert_eq!(out[[0, 2]], 0.0);
262 assert_eq!(out[[1, 0]], 0.0);
264 assert_eq!(out[[1, 1]], 1.0);
265 assert_eq!(out[[1, 2]], 0.0);
266 assert_eq!(out[[2, 0]], 0.0);
268 assert_eq!(out[[2, 1]], 0.0);
269 assert_eq!(out[[2, 2]], 1.0);
270 }
271
272 #[test]
273 fn test_one_hot_multi_column() {
274 let enc = OneHotEncoder::<f64>::new();
275 let x = array![[0usize, 0], [1, 1], [2, 0]];
277 let fitted = enc.fit(&x, &()).unwrap();
278 assert_eq!(fitted.n_categories(), &[3, 2]);
279 assert_eq!(fitted.n_output_features(), 5);
280
281 let out = fitted.transform(&x).unwrap();
282 assert_eq!(out.shape(), &[3, 5]);
283 assert_eq!(out.row(0).to_vec(), vec![1.0, 0.0, 0.0, 1.0, 0.0]);
285 assert_eq!(out.row(1).to_vec(), vec![0.0, 1.0, 0.0, 0.0, 1.0]);
287 assert_eq!(out.row(2).to_vec(), vec![0.0, 0.0, 1.0, 1.0, 0.0]);
289 }
290
291 #[test]
292 fn test_out_of_range_category_error() {
293 let enc = OneHotEncoder::<f64>::new();
294 let x_train = array![[0usize], [1]];
295 let fitted = enc.fit(&x_train, &()).unwrap();
296 let x_bad = array![[2usize]];
298 assert!(fitted.transform(&x_bad).is_err());
299 }
300
301 #[test]
302 fn test_fit_transform_equivalence() {
303 let enc = OneHotEncoder::<f64>::new();
304 let x = array![[0usize, 1], [1, 0], [2, 1]];
305 let via_fit_transform: Array2<f64> = enc.fit_transform(&x).unwrap();
306 let fitted = enc.fit(&x, &()).unwrap();
307 let via_separate = fitted.transform(&x).unwrap();
308 for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
309 assert!((a - b).abs() < 1e-15);
310 }
311 }
312
313 #[test]
314 fn test_shape_mismatch_error() {
315 let enc = OneHotEncoder::<f64>::new();
316 let x_train = array![[0usize, 1], [1, 0]];
317 let fitted = enc.fit(&x_train, &()).unwrap();
318 let x_bad = array![[0usize]];
319 assert!(fitted.transform(&x_bad).is_err());
320 }
321}