dendritic_preprocessing/
encoding.rs1use dendritic_ndarray::ndarray::NDArray;
2
3
4pub struct OneHotEncoding {
5 input_column: NDArray<f64>,
6 encoded_values: NDArray<f64>,
7 max_value: f64,
8 num_samples: f64,
9}
10
11
12impl OneHotEncoding {
13
14 pub fn new(input_column: NDArray<f64>) -> Result<OneHotEncoding, String> {
16
17 if input_column.shape().dim(1) != 1 {
18 return Err("Input col must be of size (N, 1)".to_string())
19 }
20
21 if input_column.rank() > 2 {
22 return Err("Input col must be less than rank 2".to_string())
23 }
24
25 let max_value = input_column.values().iter().max_by(
26 |a, b| a.total_cmp(b)
27 ).unwrap();
28 let max_index = *max_value + 1.0;
29
30 Ok(Self {
31 input_column: input_column.clone(),
32 encoded_values: NDArray::new(vec![
33 input_column.shape().dim(0),
34 max_index as usize
35 ]).unwrap(),
36 max_value: max_index.clone(),
37 num_samples: input_column.shape().dim(0) as f64
38 })
39 }
40
41 pub fn max_value(&self) -> f64 {
43 self.max_value
44 }
45
46 pub fn num_samples(&self) -> f64 {
48 self.num_samples
49 }
50
51 pub fn transform(&mut self) -> &NDArray<f64> {
53
54 let mut row = 0;
55 let col_stride = self.encoded_values.shape().dim(1);
56 for idx in self.input_column.values() {
57 let index = (idx + row as f64) as usize;
58 let _ = self.encoded_values.set_idx(index, 1.0);
59 row += col_stride;
60 }
61
62 &self.encoded_values
63 }
64
65}