dendritic_preprocessing/
encoding.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use dendritic_ndarray::ndarray::NDArray;


pub struct OneHotEncoding {
    input_column: NDArray<f64>,
    encoded_values: NDArray<f64>,
    max_value: f64, 
    num_samples: f64,
}


impl OneHotEncoding {

    /// Create new instance of one hot encoding
    pub fn new(input_column: NDArray<f64>) -> Result<OneHotEncoding, String>  {

        if input_column.shape().dim(1) != 1 {
            return Err("Input col must be of size (N, 1)".to_string())
        }

        if input_column.rank() > 2 {
            return Err("Input col must be less than rank 2".to_string())
        }

        let max_value = input_column.values().iter().max_by(
            |a, b| a.total_cmp(b)
        ).unwrap();
        let max_index = *max_value + 1.0;

        Ok(Self {
            input_column: input_column.clone(),
            encoded_values: NDArray::new(vec![
                input_column.shape().dim(0),
                max_index as usize
            ]).unwrap(),
            max_value: max_index.clone(), 
            num_samples: input_column.shape().dim(0) as f64
        })
    }

    /// Get the maximum bound of one hot encoder
    pub fn max_value(&self) -> f64 {
        self.max_value
    }

    /// Get number of samples in one hot encoding
    pub fn num_samples(&self) -> f64 {
        self.num_samples
    }

    /// Transform data to be one hot encoded
    pub fn transform(&mut self) -> &NDArray<f64> {

        let mut row = 0;
        let col_stride = self.encoded_values.shape().dim(1);
        for idx in self.input_column.values() {
            let index = (idx + row as f64) as usize; 
            let _ = self.encoded_values.set_idx(index, 1.0);
            row += col_stride;
        }

        &self.encoded_values
    }

}