dendritic_preprocessing/
encoding.rs

1use dendritic_ndarray::ndarray::NDArray;
2
3
4pub struct OneHotEncoding {
5    input_column: NDArray<f64>,
6    encoded_values: NDArray<f64>,
7    max_value: f64, 
8    num_samples: f64,
9}
10
11
12impl OneHotEncoding {
13
14    /// Create new instance of one hot encoding
15    pub fn new(input_column: NDArray<f64>) -> Result<OneHotEncoding, String>  {
16
17        if input_column.shape().dim(1) != 1 {
18            return Err("Input col must be of size (N, 1)".to_string())
19        }
20
21        if input_column.rank() > 2 {
22            return Err("Input col must be less than rank 2".to_string())
23        }
24
25        let max_value = input_column.values().iter().max_by(
26            |a, b| a.total_cmp(b)
27        ).unwrap();
28        let max_index = *max_value + 1.0;
29
30        Ok(Self {
31            input_column: input_column.clone(),
32            encoded_values: NDArray::new(vec![
33                input_column.shape().dim(0),
34                max_index as usize
35            ]).unwrap(),
36            max_value: max_index.clone(), 
37            num_samples: input_column.shape().dim(0) as f64
38        })
39    }
40
41    /// Get the maximum bound of one hot encoder
42    pub fn max_value(&self) -> f64 {
43        self.max_value
44    }
45
46    /// Get number of samples in one hot encoding
47    pub fn num_samples(&self) -> f64 {
48        self.num_samples
49    }
50
51    /// Transform data to be one hot encoded
52    pub fn transform(&mut self) -> &NDArray<f64> {
53
54        let mut row = 0;
55        let col_stride = self.encoded_values.shape().dim(1);
56        for idx in self.input_column.values() {
57            let index = (idx + row as f64) as usize; 
58            let _ = self.encoded_values.set_idx(index, 1.0);
59            row += col_stride;
60        }
61
62        &self.encoded_values
63    }
64
65}