Skip to main content

ferrolearn_preprocess/
binary_encoder.rs

1//! Binary encoder: encode categorical integers as binary digits.
2//!
3//! [`BinaryEncoder`] encodes each categorical integer feature into `ceil(log2(k))`
4//! binary columns, where `k` is the number of distinct categories. This is more
5//! compact than one-hot encoding for high-cardinality features.
6//!
7//! # Example
8//!
9//! ```text
10//! Input column with categories {0, 1, 2, 3}:
11//!   0 → [0, 0]
12//!   1 → [0, 1]
13//!   2 → [1, 0]
14//!   3 → [1, 1]
15//! ```
16
17use ferrolearn_core::error::FerroError;
18use ferrolearn_core::traits::{Fit, FitTransform, Transform};
19use ndarray::Array2;
20use num_traits::Float;
21
22// ---------------------------------------------------------------------------
23// BinaryEncoder (unfitted)
24// ---------------------------------------------------------------------------
25
26/// An unfitted binary encoder.
27///
28/// Takes a matrix of categorical integer features and encodes each category
29/// as a sequence of binary digits. For `k` categories, each feature produces
30/// `ceil(log2(k))` output columns.
31///
32/// # Examples
33///
34/// ```
35/// use ferrolearn_preprocess::binary_encoder::BinaryEncoder;
36/// use ferrolearn_core::traits::{Fit, Transform};
37/// use ndarray::array;
38///
39/// let enc = BinaryEncoder::<f64>::new();
40/// let x = array![[0usize], [1], [2], [3]];
41/// let fitted = enc.fit(&x, &()).unwrap();
42/// let out = fitted.transform(&x).unwrap();
43/// // 4 categories → ceil(log2(4)) = 2 binary columns
44/// assert_eq!(out.ncols(), 2);
45/// ```
46#[must_use]
47#[derive(Debug, Clone)]
48pub struct BinaryEncoder<F> {
49    _marker: std::marker::PhantomData<F>,
50}
51
52impl<F: Float + Send + Sync + 'static> BinaryEncoder<F> {
53    /// Create a new `BinaryEncoder`.
54    pub fn new() -> Self {
55        Self {
56            _marker: std::marker::PhantomData,
57        }
58    }
59}
60
61impl<F: Float + Send + Sync + 'static> Default for BinaryEncoder<F> {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67// ---------------------------------------------------------------------------
68// FittedBinaryEncoder
69// ---------------------------------------------------------------------------
70
71/// A fitted binary encoder holding the number of categories and binary digits
72/// per input feature.
73///
74/// Created by calling [`Fit::fit`] on a [`BinaryEncoder`].
75#[derive(Debug, Clone)]
76pub struct FittedBinaryEncoder<F> {
77    /// Number of categories for each input column.
78    n_categories: Vec<usize>,
79    /// Number of binary digits for each input column.
80    n_digits: Vec<usize>,
81    _marker: std::marker::PhantomData<F>,
82}
83
84impl<F: Float + Send + Sync + 'static> FittedBinaryEncoder<F> {
85    /// Return the number of categories per feature.
86    #[must_use]
87    pub fn n_categories(&self) -> &[usize] {
88        &self.n_categories
89    }
90
91    /// Return the number of binary digits per feature.
92    #[must_use]
93    pub fn n_digits(&self) -> &[usize] {
94        &self.n_digits
95    }
96
97    /// Return the total number of output columns.
98    #[must_use]
99    pub fn n_output_features(&self) -> usize {
100        self.n_digits.iter().sum()
101    }
102}
103
104// ---------------------------------------------------------------------------
105// Helpers
106// ---------------------------------------------------------------------------
107
108/// Compute ceil(log2(k)), with a minimum of 1.
109fn n_binary_digits(k: usize) -> usize {
110    if k <= 1 {
111        return 1;
112    }
113    // ceil(log2(k)) = number of bits needed to represent 0..k-1
114    let mut bits = 0usize;
115    let mut val = k - 1; // maximum value to represent
116    while val > 0 {
117        bits += 1;
118        val >>= 1;
119    }
120    bits
121}
122
123// ---------------------------------------------------------------------------
124// Trait implementations
125// ---------------------------------------------------------------------------
126
127impl<F: Float + Send + Sync + 'static> Fit<Array2<usize>, ()> for BinaryEncoder<F> {
128    type Fitted = FittedBinaryEncoder<F>;
129    type Error = FerroError;
130
131    /// Fit by determining the number of categories per column.
132    ///
133    /// The number of categories for column `j` is `max(x[:, j]) + 1`.
134    ///
135    /// # Errors
136    ///
137    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
138    fn fit(&self, x: &Array2<usize>, _y: &()) -> Result<FittedBinaryEncoder<F>, FerroError> {
139        let n_samples = x.nrows();
140        if n_samples == 0 {
141            return Err(FerroError::InsufficientSamples {
142                required: 1,
143                actual: 0,
144                context: "BinaryEncoder::fit".into(),
145            });
146        }
147
148        let n_features = x.ncols();
149        let mut n_categories = Vec::with_capacity(n_features);
150        let mut n_digits_vec = Vec::with_capacity(n_features);
151
152        for j in 0..n_features {
153            let col = x.column(j);
154            let max_cat = col.iter().copied().max().unwrap_or(0);
155            let k = max_cat + 1;
156            n_categories.push(k);
157            n_digits_vec.push(n_binary_digits(k));
158        }
159
160        Ok(FittedBinaryEncoder {
161            n_categories,
162            n_digits: n_digits_vec,
163            _marker: std::marker::PhantomData,
164        })
165    }
166}
167
168impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for FittedBinaryEncoder<F> {
169    type Output = Array2<F>;
170    type Error = FerroError;
171
172    /// Transform categorical data into binary encoded columns.
173    ///
174    /// # Errors
175    ///
176    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
177    /// from the number of features seen during fitting.
178    ///
179    /// Returns [`FerroError::InvalidParameter`] if any category value exceeds
180    /// the maximum seen during fitting.
181    fn transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
182        let n_features = self.n_categories.len();
183        if x.ncols() != n_features {
184            return Err(FerroError::ShapeMismatch {
185                expected: vec![x.nrows(), n_features],
186                actual: vec![x.nrows(), x.ncols()],
187                context: "FittedBinaryEncoder::transform".into(),
188            });
189        }
190
191        let n_samples = x.nrows();
192        let n_out = self.n_output_features();
193        let mut out = Array2::zeros((n_samples, n_out));
194
195        let mut col_offset = 0;
196        for j in 0..n_features {
197            let n_cats = self.n_categories[j];
198            let digits = self.n_digits[j];
199
200            for i in 0..n_samples {
201                let cat = x[[i, j]];
202                if cat >= n_cats {
203                    return Err(FerroError::InvalidParameter {
204                        name: format!("x[{i},{j}]"),
205                        reason: format!(
206                            "category {cat} exceeds max seen during fitting ({})",
207                            n_cats - 1
208                        ),
209                    });
210                }
211
212                // Encode category as binary digits (MSB first)
213                for bit in 0..digits {
214                    let bit_pos = digits - 1 - bit;
215                    if (cat >> bit_pos) & 1 == 1 {
216                        out[[i, col_offset + bit]] = F::one();
217                    }
218                }
219            }
220
221            col_offset += digits;
222        }
223
224        Ok(out)
225    }
226}
227
228/// Implement `Transform` on the unfitted encoder.
229impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for BinaryEncoder<F> {
230    type Output = Array2<F>;
231    type Error = FerroError;
232
233    /// Always returns an error — the encoder must be fitted first.
234    fn transform(&self, _x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
235        Err(FerroError::InvalidParameter {
236            name: "BinaryEncoder".into(),
237            reason: "encoder must be fitted before calling transform; use fit() first".into(),
238        })
239    }
240}
241
242impl<F: Float + Send + Sync + 'static> FitTransform<Array2<usize>> for BinaryEncoder<F> {
243    type FitError = FerroError;
244
245    /// Fit and transform in one step.
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if fitting or transformation fails.
250    fn fit_transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
251        let fitted = self.fit(x, &())?;
252        fitted.transform(x)
253    }
254}
255
256// ---------------------------------------------------------------------------
257// Tests
258// ---------------------------------------------------------------------------
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use ndarray::array;
264
265    #[test]
266    fn test_binary_encoder_basic() {
267        let enc = BinaryEncoder::<f64>::new();
268        let x = array![[0usize], [1], [2], [3]];
269        let fitted = enc.fit(&x, &()).unwrap();
270        let out = fitted.transform(&x).unwrap();
271        // 4 categories → ceil(log2(4)) = 2 columns
272        assert_eq!(out.ncols(), 2);
273        // 0 → [0, 0]
274        assert_eq!(out.row(0).to_vec(), vec![0.0, 0.0]);
275        // 1 → [0, 1]
276        assert_eq!(out.row(1).to_vec(), vec![0.0, 1.0]);
277        // 2 → [1, 0]
278        assert_eq!(out.row(2).to_vec(), vec![1.0, 0.0]);
279        // 3 → [1, 1]
280        assert_eq!(out.row(3).to_vec(), vec![1.0, 1.0]);
281    }
282
283    #[test]
284    fn test_binary_encoder_five_categories() {
285        let enc = BinaryEncoder::<f64>::new();
286        let x = array![[0usize], [1], [2], [3], [4]];
287        let fitted = enc.fit(&x, &()).unwrap();
288        let out = fitted.transform(&x).unwrap();
289        // 5 categories → ceil(log2(5)) = 3 columns
290        assert_eq!(out.ncols(), 3);
291        // 0 → [0, 0, 0]
292        assert_eq!(out.row(0).to_vec(), vec![0.0, 0.0, 0.0]);
293        // 4 → [1, 0, 0]
294        assert_eq!(out.row(4).to_vec(), vec![1.0, 0.0, 0.0]);
295    }
296
297    #[test]
298    fn test_binary_encoder_single_category() {
299        let enc = BinaryEncoder::<f64>::new();
300        let x = array![[0usize], [0], [0]];
301        let fitted = enc.fit(&x, &()).unwrap();
302        let out = fitted.transform(&x).unwrap();
303        // 1 category → 1 binary column (always 0)
304        assert_eq!(out.ncols(), 1);
305        for i in 0..3 {
306            assert_eq!(out[[i, 0]], 0.0);
307        }
308    }
309
310    #[test]
311    fn test_binary_encoder_two_categories() {
312        let enc = BinaryEncoder::<f64>::new();
313        let x = array![[0usize], [1]];
314        let fitted = enc.fit(&x, &()).unwrap();
315        let out = fitted.transform(&x).unwrap();
316        // 2 categories → ceil(log2(2)) = 1 column
317        assert_eq!(out.ncols(), 1);
318        assert_eq!(out[[0, 0]], 0.0);
319        assert_eq!(out[[1, 0]], 1.0);
320    }
321
322    #[test]
323    fn test_binary_encoder_multi_feature() {
324        let enc = BinaryEncoder::<f64>::new();
325        // Feature 0: 3 categories → 2 digits
326        // Feature 1: 2 categories → 1 digit
327        let x = array![[0usize, 0], [1, 1], [2, 0]];
328        let fitted = enc.fit(&x, &()).unwrap();
329        assert_eq!(fitted.n_output_features(), 3); // 2 + 1
330        let out = fitted.transform(&x).unwrap();
331        assert_eq!(out.ncols(), 3);
332    }
333
334    #[test]
335    fn test_binary_encoder_n_binary_digits() {
336        assert_eq!(n_binary_digits(1), 1);
337        assert_eq!(n_binary_digits(2), 1);
338        assert_eq!(n_binary_digits(3), 2);
339        assert_eq!(n_binary_digits(4), 2);
340        assert_eq!(n_binary_digits(5), 3);
341        assert_eq!(n_binary_digits(8), 3);
342        assert_eq!(n_binary_digits(9), 4);
343    }
344
345    #[test]
346    fn test_binary_encoder_fit_transform() {
347        let enc = BinaryEncoder::<f64>::new();
348        let x = array![[0usize], [1], [2], [3]];
349        let out: Array2<f64> = enc.fit_transform(&x).unwrap();
350        assert_eq!(out.ncols(), 2);
351    }
352
353    #[test]
354    fn test_binary_encoder_zero_rows_error() {
355        let enc = BinaryEncoder::<f64>::new();
356        let x: Array2<usize> = Array2::zeros((0, 2));
357        assert!(enc.fit(&x, &()).is_err());
358    }
359
360    #[test]
361    fn test_binary_encoder_out_of_range_error() {
362        let enc = BinaryEncoder::<f64>::new();
363        let x_train = array![[0usize], [1]]; // max category = 1
364        let fitted = enc.fit(&x_train, &()).unwrap();
365        let x_bad = array![[2usize]]; // category 2 not seen
366        assert!(fitted.transform(&x_bad).is_err());
367    }
368
369    #[test]
370    fn test_binary_encoder_shape_mismatch_error() {
371        let enc = BinaryEncoder::<f64>::new();
372        let x_train = array![[0usize, 1], [1, 0]];
373        let fitted = enc.fit(&x_train, &()).unwrap();
374        let x_bad = array![[0usize]]; // wrong number of columns
375        assert!(fitted.transform(&x_bad).is_err());
376    }
377
378    #[test]
379    fn test_binary_encoder_unfitted_error() {
380        let enc = BinaryEncoder::<f64>::new();
381        let x = array![[0usize]];
382        assert!(enc.transform(&x).is_err());
383    }
384
385    #[test]
386    fn test_binary_encoder_accessors() {
387        let enc = BinaryEncoder::<f64>::new();
388        let x = array![[0usize], [1], [2], [3]];
389        let fitted = enc.fit(&x, &()).unwrap();
390        assert_eq!(fitted.n_categories(), &[4]);
391        assert_eq!(fitted.n_digits(), &[2]);
392        assert_eq!(fitted.n_output_features(), 2);
393    }
394
395    #[test]
396    fn test_binary_encoder_eight_categories() {
397        let enc = BinaryEncoder::<f64>::new();
398        let x = array![[0usize], [1], [2], [3], [4], [5], [6], [7]];
399        let fitted = enc.fit(&x, &()).unwrap();
400        let out = fitted.transform(&x).unwrap();
401        // 8 categories → ceil(log2(8)) = 3 columns
402        assert_eq!(out.ncols(), 3);
403        // 7 → [1, 1, 1]
404        assert_eq!(out.row(7).to_vec(), vec![1.0, 1.0, 1.0]);
405        // 5 → [1, 0, 1]
406        assert_eq!(out.row(5).to_vec(), vec![1.0, 0.0, 1.0]);
407    }
408}