1use ferrolearn_core::error::FerroError;
18use ferrolearn_core::traits::{Fit, FitTransform, Transform};
19use ndarray::Array2;
20use num_traits::Float;
21
22#[must_use]
47#[derive(Debug, Clone)]
48pub struct BinaryEncoder<F> {
49 _marker: std::marker::PhantomData<F>,
50}
51
52impl<F: Float + Send + Sync + 'static> BinaryEncoder<F> {
53 pub fn new() -> Self {
55 Self {
56 _marker: std::marker::PhantomData,
57 }
58 }
59}
60
61impl<F: Float + Send + Sync + 'static> Default for BinaryEncoder<F> {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[derive(Debug, Clone)]
76pub struct FittedBinaryEncoder<F> {
77 n_categories: Vec<usize>,
79 n_digits: Vec<usize>,
81 _marker: std::marker::PhantomData<F>,
82}
83
84impl<F: Float + Send + Sync + 'static> FittedBinaryEncoder<F> {
85 #[must_use]
87 pub fn n_categories(&self) -> &[usize] {
88 &self.n_categories
89 }
90
91 #[must_use]
93 pub fn n_digits(&self) -> &[usize] {
94 &self.n_digits
95 }
96
97 #[must_use]
99 pub fn n_output_features(&self) -> usize {
100 self.n_digits.iter().sum()
101 }
102}
103
104fn n_binary_digits(k: usize) -> usize {
110 if k <= 1 {
111 return 1;
112 }
113 let mut bits = 0usize;
115 let mut val = k - 1; while val > 0 {
117 bits += 1;
118 val >>= 1;
119 }
120 bits
121}
122
123impl<F: Float + Send + Sync + 'static> Fit<Array2<usize>, ()> for BinaryEncoder<F> {
128 type Fitted = FittedBinaryEncoder<F>;
129 type Error = FerroError;
130
131 fn fit(&self, x: &Array2<usize>, _y: &()) -> Result<FittedBinaryEncoder<F>, FerroError> {
139 let n_samples = x.nrows();
140 if n_samples == 0 {
141 return Err(FerroError::InsufficientSamples {
142 required: 1,
143 actual: 0,
144 context: "BinaryEncoder::fit".into(),
145 });
146 }
147
148 let n_features = x.ncols();
149 let mut n_categories = Vec::with_capacity(n_features);
150 let mut n_digits_vec = Vec::with_capacity(n_features);
151
152 for j in 0..n_features {
153 let col = x.column(j);
154 let max_cat = col.iter().copied().max().unwrap_or(0);
155 let k = max_cat + 1;
156 n_categories.push(k);
157 n_digits_vec.push(n_binary_digits(k));
158 }
159
160 Ok(FittedBinaryEncoder {
161 n_categories,
162 n_digits: n_digits_vec,
163 _marker: std::marker::PhantomData,
164 })
165 }
166}
167
168impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for FittedBinaryEncoder<F> {
169 type Output = Array2<F>;
170 type Error = FerroError;
171
172 fn transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
182 let n_features = self.n_categories.len();
183 if x.ncols() != n_features {
184 return Err(FerroError::ShapeMismatch {
185 expected: vec![x.nrows(), n_features],
186 actual: vec![x.nrows(), x.ncols()],
187 context: "FittedBinaryEncoder::transform".into(),
188 });
189 }
190
191 let n_samples = x.nrows();
192 let n_out = self.n_output_features();
193 let mut out = Array2::zeros((n_samples, n_out));
194
195 let mut col_offset = 0;
196 for j in 0..n_features {
197 let n_cats = self.n_categories[j];
198 let digits = self.n_digits[j];
199
200 for i in 0..n_samples {
201 let cat = x[[i, j]];
202 if cat >= n_cats {
203 return Err(FerroError::InvalidParameter {
204 name: format!("x[{i},{j}]"),
205 reason: format!(
206 "category {cat} exceeds max seen during fitting ({})",
207 n_cats - 1
208 ),
209 });
210 }
211
212 for bit in 0..digits {
214 let bit_pos = digits - 1 - bit;
215 if (cat >> bit_pos) & 1 == 1 {
216 out[[i, col_offset + bit]] = F::one();
217 }
218 }
219 }
220
221 col_offset += digits;
222 }
223
224 Ok(out)
225 }
226}
227
228impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for BinaryEncoder<F> {
230 type Output = Array2<F>;
231 type Error = FerroError;
232
233 fn transform(&self, _x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
235 Err(FerroError::InvalidParameter {
236 name: "BinaryEncoder".into(),
237 reason: "encoder must be fitted before calling transform; use fit() first".into(),
238 })
239 }
240}
241
242impl<F: Float + Send + Sync + 'static> FitTransform<Array2<usize>> for BinaryEncoder<F> {
243 type FitError = FerroError;
244
245 fn fit_transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
251 let fitted = self.fit(x, &())?;
252 fitted.transform(x)
253 }
254}
255
256#[cfg(test)]
261mod tests {
262 use super::*;
263 use ndarray::array;
264
265 #[test]
266 fn test_binary_encoder_basic() {
267 let enc = BinaryEncoder::<f64>::new();
268 let x = array![[0usize], [1], [2], [3]];
269 let fitted = enc.fit(&x, &()).unwrap();
270 let out = fitted.transform(&x).unwrap();
271 assert_eq!(out.ncols(), 2);
273 assert_eq!(out.row(0).to_vec(), vec![0.0, 0.0]);
275 assert_eq!(out.row(1).to_vec(), vec![0.0, 1.0]);
277 assert_eq!(out.row(2).to_vec(), vec![1.0, 0.0]);
279 assert_eq!(out.row(3).to_vec(), vec![1.0, 1.0]);
281 }
282
283 #[test]
284 fn test_binary_encoder_five_categories() {
285 let enc = BinaryEncoder::<f64>::new();
286 let x = array![[0usize], [1], [2], [3], [4]];
287 let fitted = enc.fit(&x, &()).unwrap();
288 let out = fitted.transform(&x).unwrap();
289 assert_eq!(out.ncols(), 3);
291 assert_eq!(out.row(0).to_vec(), vec![0.0, 0.0, 0.0]);
293 assert_eq!(out.row(4).to_vec(), vec![1.0, 0.0, 0.0]);
295 }
296
297 #[test]
298 fn test_binary_encoder_single_category() {
299 let enc = BinaryEncoder::<f64>::new();
300 let x = array![[0usize], [0], [0]];
301 let fitted = enc.fit(&x, &()).unwrap();
302 let out = fitted.transform(&x).unwrap();
303 assert_eq!(out.ncols(), 1);
305 for i in 0..3 {
306 assert_eq!(out[[i, 0]], 0.0);
307 }
308 }
309
310 #[test]
311 fn test_binary_encoder_two_categories() {
312 let enc = BinaryEncoder::<f64>::new();
313 let x = array![[0usize], [1]];
314 let fitted = enc.fit(&x, &()).unwrap();
315 let out = fitted.transform(&x).unwrap();
316 assert_eq!(out.ncols(), 1);
318 assert_eq!(out[[0, 0]], 0.0);
319 assert_eq!(out[[1, 0]], 1.0);
320 }
321
322 #[test]
323 fn test_binary_encoder_multi_feature() {
324 let enc = BinaryEncoder::<f64>::new();
325 let x = array![[0usize, 0], [1, 1], [2, 0]];
328 let fitted = enc.fit(&x, &()).unwrap();
329 assert_eq!(fitted.n_output_features(), 3); let out = fitted.transform(&x).unwrap();
331 assert_eq!(out.ncols(), 3);
332 }
333
334 #[test]
335 fn test_binary_encoder_n_binary_digits() {
336 assert_eq!(n_binary_digits(1), 1);
337 assert_eq!(n_binary_digits(2), 1);
338 assert_eq!(n_binary_digits(3), 2);
339 assert_eq!(n_binary_digits(4), 2);
340 assert_eq!(n_binary_digits(5), 3);
341 assert_eq!(n_binary_digits(8), 3);
342 assert_eq!(n_binary_digits(9), 4);
343 }
344
345 #[test]
346 fn test_binary_encoder_fit_transform() {
347 let enc = BinaryEncoder::<f64>::new();
348 let x = array![[0usize], [1], [2], [3]];
349 let out: Array2<f64> = enc.fit_transform(&x).unwrap();
350 assert_eq!(out.ncols(), 2);
351 }
352
353 #[test]
354 fn test_binary_encoder_zero_rows_error() {
355 let enc = BinaryEncoder::<f64>::new();
356 let x: Array2<usize> = Array2::zeros((0, 2));
357 assert!(enc.fit(&x, &()).is_err());
358 }
359
360 #[test]
361 fn test_binary_encoder_out_of_range_error() {
362 let enc = BinaryEncoder::<f64>::new();
363 let x_train = array![[0usize], [1]]; let fitted = enc.fit(&x_train, &()).unwrap();
365 let x_bad = array![[2usize]]; assert!(fitted.transform(&x_bad).is_err());
367 }
368
369 #[test]
370 fn test_binary_encoder_shape_mismatch_error() {
371 let enc = BinaryEncoder::<f64>::new();
372 let x_train = array![[0usize, 1], [1, 0]];
373 let fitted = enc.fit(&x_train, &()).unwrap();
374 let x_bad = array![[0usize]]; assert!(fitted.transform(&x_bad).is_err());
376 }
377
378 #[test]
379 fn test_binary_encoder_unfitted_error() {
380 let enc = BinaryEncoder::<f64>::new();
381 let x = array![[0usize]];
382 assert!(enc.transform(&x).is_err());
383 }
384
385 #[test]
386 fn test_binary_encoder_accessors() {
387 let enc = BinaryEncoder::<f64>::new();
388 let x = array![[0usize], [1], [2], [3]];
389 let fitted = enc.fit(&x, &()).unwrap();
390 assert_eq!(fitted.n_categories(), &[4]);
391 assert_eq!(fitted.n_digits(), &[2]);
392 assert_eq!(fitted.n_output_features(), 2);
393 }
394
395 #[test]
396 fn test_binary_encoder_eight_categories() {
397 let enc = BinaryEncoder::<f64>::new();
398 let x = array![[0usize], [1], [2], [3], [4], [5], [6], [7]];
399 let fitted = enc.fit(&x, &()).unwrap();
400 let out = fitted.transform(&x).unwrap();
401 assert_eq!(out.ncols(), 3);
403 assert_eq!(out.row(7).to_vec(), vec![1.0, 1.0, 1.0]);
405 assert_eq!(out.row(5).to_vec(), vec![1.0, 0.0, 1.0]);
407 }
408}