labello/
lib.rs

1//! Labello: a fast label encoder in Rust
2//!
3//! With Labello it is possible to create different types of encoders: ordinal, one-hot, custom
4//!
5//! A custom encoder does not guarantee the reversibility of the mapping and inverse-mapping.
6//! An inverse-mapping operation is reversible (reconstruct the original data) depending on the
7//! mapping defined by the user.
8//! The other types of encoding do guarantee that an inverse-mapping operation reconstruct the
9//! original data losslessly
10//!
11//!
12
13use std::collections::HashMap;
14use std::hash::Hash;
15use std::cmp::Eq;
16use std::fmt::Debug;
17use std::iter::Iterator;
18
19/// configuration for encoder (metadata)
20#[derive(Debug, Clone)]
21pub struct Config<T> {
22    // maximum number of classes (repeat after max)
23    pub max_nclasses: Option<u64>,
24    // only for custom encoder (define closure and apply to the single element)
25    pub mapping_function: Option<fn(T) -> u64>,
26}
27
28#[derive(Debug, Clone)]
29pub enum EncoderType {
30    // encode categorical features with an ordinal encoding
31    Ordinal,
32    // encode categorical features as one-hot numeric array
33    OneHot,
34    // user-defined mapping function
35    CustomMapping,
36}
37
38#[derive(Debug)]
39pub enum Encoder<T>
40where T: Hash + Eq + Debug
41{
42    Ordinal(HashMap<T, u64>),
43    OneHot(HashMap<T, OheRepr>),
44    Custom(HashMap<T, u64>)
45}
46
47type OheRepr = Vec<bool>;
48
49/// transformed data type
50///
51#[derive(Debug, Clone)]
52pub enum Transform {
53    Ordinal(Vec<u64>),
54    OneHot(Vec<OheRepr>),
55    CustomMapping(Vec<u64>)
56}
57
58impl Transform {
59    pub fn len(&self) -> usize {
60        match self {
61            Transform::Ordinal(data) => data.len(),
62            Transform::OneHot(data) => data.len(),
63            Transform::CustomMapping(data) => data.len()
64        }
65    }
66}
67
68impl <T> Encoder<T>
69where T: Hash + Eq + Clone + Debug
70{
71    pub fn new(enctype: Option<EncoderType>) -> Encoder<T> {
72        let enctype = enctype.unwrap_or(EncoderType::Ordinal);
73
74        match enctype {
75            EncoderType::Ordinal => Encoder::Ordinal(HashMap::new()),
76            EncoderType::OneHot => Encoder::OneHot(HashMap::new()),
77            EncoderType::CustomMapping => Encoder::Custom(HashMap::new())
78        }
79    }
80
81    /// Fit label encoder given the type (ordinal, one-hot, custom)
82    ///
83    pub fn fit(&mut self, data: &Vec<T>, config: &Config<T>) {
84        let max_nclasses = config.max_nclasses.unwrap_or(u64::MAX) - 1;
85
86        match self {
87            Encoder::Ordinal(map) => {
88                let mut current_idx = 0u64;
89                for el in data.iter() {
90                    if !map.contains_key(el) {
91                        map.insert(el.clone(), current_idx);
92                        if current_idx < max_nclasses {
93                            current_idx += 1;
94                        }
95                    }
96                }
97            },
98
99            Encoder::OneHot(map) => {
100                let mut mapping: HashMap<T, u64> = HashMap::new();
101                let mut current_idx = 0u64;
102                // encode in a temporary hashmap (mapping)
103                for el in data.iter() {
104                    if !mapping.contains_key(el) {
105                        mapping.insert(el.clone(), current_idx);
106                        if current_idx < max_nclasses {
107                            current_idx += 1;
108                        }
109                    }
110                }
111
112                let vecsize = mapping.len();
113                for (key, value) in mapping.into_iter() {
114                    let mut converted: OheRepr = format!("{:b}", value)
115                                                .chars()
116                                                .rev()
117                                                .enumerate()
118                                                .filter_map(|(_i, n)| match n {
119                                                    '1' => {
120                                                        Some(true)
121                                                    },
122
123                                                    '0' => Some(false),
124                                                    _ => panic!("Invalid conversion to binary"),
125                                                })
126                                                .collect();
127                    // push remaining zeros (vecsize - current len)
128                    for _ in 0..vecsize - converted.len() {
129                        converted.push(false);
130                    }
131                    // insert into final hashmap
132                    map.insert(key, converted);
133                }
134            },
135
136            Encoder::Custom(map) => {
137                let mapping_func = config.mapping_function.unwrap();
138                for el in data.iter() {
139                    if !map.contains_key(el) {
140                        let value = mapping_func(el.clone());
141                        map.insert(el.clone(), value);
142                    }
143                }
144            },
145        }
146    }
147
148    /// Transform data to normalized encoding
149    ///
150    pub fn transform(&self, data: &Vec<T>) -> Transform  {
151        match self {
152            Encoder::Ordinal(map) => {
153                let res: Vec<u64> = data.iter().filter_map(|el| map.get(el)).cloned().collect();
154                Transform::Ordinal(res)
155            }
156
157            Encoder::OneHot(map) => {
158                let res: Vec<OheRepr> = data.iter().filter_map(|el| map.get(el)).cloned().collect();
159                Transform::OneHot(res)
160            },
161
162            Encoder::Custom(map) => {
163                let res: Vec<u64> = data.iter().filter_map(|el| map.get(el)).cloned().collect();
164                Transform::CustomMapping(res)
165            },
166
167        }
168
169    }
170
171    /// Transforms labels back to the original data (not necessarily true with custom encoder)
172    ///
173    pub fn inverse_transform(&self, data: &Transform) -> Vec<T> {
174        match self {
175            Encoder::Ordinal(mapping) => match data {
176                Transform::Ordinal(typed_data) => {
177                    let result: Vec<T> = typed_data.iter()
178                    .flat_map(|&el| {
179                        mapping.into_iter()
180                        .filter(move |&(_key, val)| val == &el)
181                        .map(|(key, &_val)| key.clone())
182                    })
183                    .collect();
184                    result
185                },
186                _ => panic!("Transformed data not compatible with this encoder"),
187            },
188
189            // TODO WIP inverse mapping is not reversible for one-hot (ERROR!!)
190            Encoder::OneHot(mapping) => match data {
191                Transform::OneHot(typed_data) => {
192                    let result: Vec<T> = typed_data.iter()
193                    .flat_map(|el| {
194
195                        mapping.into_iter()
196                        .filter(move |&(_key, val)| {
197                            let mut equal_el: usize = 0;
198                            for i in 0..val.len() {
199                                if val[i] == el[i] {
200                                    equal_el += 1;
201                                }
202                            }
203                            // println!("comparing {:?} with {:?} matched {:?}", el, val, equal_el == val.len());
204                            equal_el == val.len()
205                        }
206                    )
207                        .map(move |(key, _val)| {
208                            // dbg!("typed_data: ", el.clone());
209                            // dbg!("key: ", key.clone());
210                            key.clone()
211                        })
212                    })
213                    .collect();
214                    result
215                },
216                _ => panic!("Transformed data not compatible with this encoder")
217            },
218
219            Encoder::Custom(mapping) => match data {
220                Transform::CustomMapping(typed_data) => {
221                    let result = typed_data.into_iter().flat_map(|&el| {
222                        mapping
223                            .into_iter()
224                            .filter(move |&(_k, v)| v == &el)
225                        .map(|(k, &_v)| k.clone())
226                    })
227                    .collect();
228                    result
229                },
230                _ => panic!("Transformed data not compatible with this encoder"),
231            }
232        }
233    }
234
235    /// Return number of unique categories
236    ///
237    pub fn nclasses(&self) -> usize {
238        match self {
239            // TODO len is the same for every type
240            Encoder::Ordinal(mapping) => {
241                let values: Vec<u64> = mapping.values().cloned().collect();
242                let len = values.iter().max();
243                match len {
244                    Some(v) => *v as usize + 1,
245                    _ => 0 as usize
246                }
247            },
248            Encoder::OneHot(map) => map.len(),
249            Encoder::Custom(map) => map.len(),
250        }
251    }
252}
253
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_one_hot_encoding() {
261        let x = 128u64;
262        let ohe: Vec<bool> = format!("{:b}", x)
263            .chars()
264            .filter_map(|n| match n {
265                '1' => Some(true),
266                '0' => Some(false),
267                _ => panic!("Conversion to binary failed"),
268            })
269            .collect();
270        dbg!(&ohe);
271
272        assert_eq!(ohe.len(), 8);
273
274        // check number of bits is correct
275        // assert_eq!(log_2(128), 7);
276    }
277
278    #[test]
279    fn test_fit_ordinal_encoder() {
280        let data: Vec<String> = vec!["hello".to_string(),
281                                    "world".to_string(),
282                                    "world".to_string(),
283                                    "world".to_string(),
284                                    "world".to_string(),
285                                    "again".to_string(),
286                                    "hello".to_string(),
287                                    "again".to_string(),
288                                    "goodbye".to_string(),
289                                    ];
290        let enctype = EncoderType::Ordinal;
291        let config = Config{
292            max_nclasses: None,
293            mapping_function: None
294        };
295        let mut enc: Encoder<String> = Encoder::new(Some(enctype));
296        dbg!("created encoder ", &enc);
297
298        enc.fit(&data, &config);
299        dbg!("fitted encoder:", &enc);
300
301        let trans_data = enc.transform(&data);
302        dbg!("trans data: ", &trans_data);
303
304        let recon_data = enc.inverse_transform(&trans_data);
305        dbg!("recon data:", &recon_data);
306
307        assert_eq!(enc.nclasses(), 4);
308    }
309
310    #[test]
311    fn test_fit_ordinal_encoder_limited_classes() {
312        let data: Vec<String> = vec!["hello".to_string(),
313                                    "world".to_string(),
314                                    "world".to_string(),
315                                    "world".to_string(),
316                                    "world".to_string(),
317                                    "again".to_string(),
318                                    "hello".to_string(),
319                                    "again".to_string(),
320                                    "goodbye".to_string(),
321                                    ];
322        let enctype = EncoderType::Ordinal;
323        let config = Config{
324            max_nclasses: Some(3),
325            mapping_function: None
326        };
327        let mut enc: Encoder<String> = Encoder::new(Some(enctype));
328        dbg!("created encoder ", &enc);
329
330        enc.fit(&data, &config);
331        dbg!("fitted encoder:", &enc);
332
333        assert_eq!(enc.nclasses(), 3);
334    }
335
336    #[test]
337    fn test_fit_one_hot_encoder() {
338        let data: Vec<String> = vec!["hello".to_string(),
339                                    "world".to_string(),
340                                    "world".to_string(),
341                                    "world".to_string(),
342                                    "world".to_string(),
343                                    "again".to_string(),
344                                    "hello".to_string(),
345                                    "again".to_string(),
346                                    "goodbye".to_string(),
347                                    ];
348
349        let config = Config {
350            max_nclasses: None,
351            mapping_function: None
352        };
353        let mut enc: Encoder<String> = Encoder::new(Some(EncoderType::OneHot));
354        enc.fit(&data, &config);
355        dbg!("fitted encoder: ", &enc);
356
357        let trans_data = enc.transform(&data);
358        // dbg!("trans data: ", &trans_data);
359        assert_eq!(trans_data.len(), data.len());
360
361        let recon_data = enc.inverse_transform(&trans_data);
362        dbg!("recon data:", &recon_data);
363
364    }
365
366    #[test]
367    fn test_fit_custom_encoder() {
368        let data: Vec<String> = vec!["hello".to_string(),
369                                    "world".to_string(),
370                                    "world".to_string(),
371                                    "world".to_string(),
372                                    "world".to_string(),
373                                    "again".to_string(),
374                                    "hello".to_string(),
375                                    "again".to_string(),
376                                    "goodbye".to_string(),
377                                    ];
378        let config: Config<String> = Config {
379            max_nclasses: Some(10),
380            mapping_function: Some(|el| match el.as_str() {
381                "hello" => 42,
382                "goodbye" => 99,
383                _ => 0
384            }),
385        };
386
387        let mut enc: Encoder<String> = Encoder::new(Some(EncoderType::CustomMapping));
388        enc.fit(&data, &config);
389        dbg!("fitted encoder: ", &enc);
390
391        let trans_data = enc.transform(&data);
392        dbg!("trans data: ", &trans_data);
393
394        let recon_data = enc.inverse_transform(&trans_data);
395        dbg!("recon data:", &recon_data);
396    }
397}