1use std::collections::HashMap;
14use std::hash::Hash;
15use std::cmp::Eq;
16use std::fmt::Debug;
17use std::iter::Iterator;
18
19#[derive(Debug, Clone)]
21pub struct Config<T> {
22 pub max_nclasses: Option<u64>,
24 pub mapping_function: Option<fn(T) -> u64>,
26}
27
28#[derive(Debug, Clone)]
29pub enum EncoderType {
30 Ordinal,
32 OneHot,
34 CustomMapping,
36}
37
38#[derive(Debug)]
39pub enum Encoder<T>
40where T: Hash + Eq + Debug
41{
42 Ordinal(HashMap<T, u64>),
43 OneHot(HashMap<T, OheRepr>),
44 Custom(HashMap<T, u64>)
45}
46
47type OheRepr = Vec<bool>;
48
49#[derive(Debug, Clone)]
52pub enum Transform {
53 Ordinal(Vec<u64>),
54 OneHot(Vec<OheRepr>),
55 CustomMapping(Vec<u64>)
56}
57
58impl Transform {
59 pub fn len(&self) -> usize {
60 match self {
61 Transform::Ordinal(data) => data.len(),
62 Transform::OneHot(data) => data.len(),
63 Transform::CustomMapping(data) => data.len()
64 }
65 }
66}
67
68impl <T> Encoder<T>
69where T: Hash + Eq + Clone + Debug
70{
71 pub fn new(enctype: Option<EncoderType>) -> Encoder<T> {
72 let enctype = enctype.unwrap_or(EncoderType::Ordinal);
73
74 match enctype {
75 EncoderType::Ordinal => Encoder::Ordinal(HashMap::new()),
76 EncoderType::OneHot => Encoder::OneHot(HashMap::new()),
77 EncoderType::CustomMapping => Encoder::Custom(HashMap::new())
78 }
79 }
80
81 pub fn fit(&mut self, data: &Vec<T>, config: &Config<T>) {
84 let max_nclasses = config.max_nclasses.unwrap_or(u64::MAX) - 1;
85
86 match self {
87 Encoder::Ordinal(map) => {
88 let mut current_idx = 0u64;
89 for el in data.iter() {
90 if !map.contains_key(el) {
91 map.insert(el.clone(), current_idx);
92 if current_idx < max_nclasses {
93 current_idx += 1;
94 }
95 }
96 }
97 },
98
99 Encoder::OneHot(map) => {
100 let mut mapping: HashMap<T, u64> = HashMap::new();
101 let mut current_idx = 0u64;
102 for el in data.iter() {
104 if !mapping.contains_key(el) {
105 mapping.insert(el.clone(), current_idx);
106 if current_idx < max_nclasses {
107 current_idx += 1;
108 }
109 }
110 }
111
112 let vecsize = mapping.len();
113 for (key, value) in mapping.into_iter() {
114 let mut converted: OheRepr = format!("{:b}", value)
115 .chars()
116 .rev()
117 .enumerate()
118 .filter_map(|(_i, n)| match n {
119 '1' => {
120 Some(true)
121 },
122
123 '0' => Some(false),
124 _ => panic!("Invalid conversion to binary"),
125 })
126 .collect();
127 for _ in 0..vecsize - converted.len() {
129 converted.push(false);
130 }
131 map.insert(key, converted);
133 }
134 },
135
136 Encoder::Custom(map) => {
137 let mapping_func = config.mapping_function.unwrap();
138 for el in data.iter() {
139 if !map.contains_key(el) {
140 let value = mapping_func(el.clone());
141 map.insert(el.clone(), value);
142 }
143 }
144 },
145 }
146 }
147
148 pub fn transform(&self, data: &Vec<T>) -> Transform {
151 match self {
152 Encoder::Ordinal(map) => {
153 let res: Vec<u64> = data.iter().filter_map(|el| map.get(el)).cloned().collect();
154 Transform::Ordinal(res)
155 }
156
157 Encoder::OneHot(map) => {
158 let res: Vec<OheRepr> = data.iter().filter_map(|el| map.get(el)).cloned().collect();
159 Transform::OneHot(res)
160 },
161
162 Encoder::Custom(map) => {
163 let res: Vec<u64> = data.iter().filter_map(|el| map.get(el)).cloned().collect();
164 Transform::CustomMapping(res)
165 },
166
167 }
168
169 }
170
171 pub fn inverse_transform(&self, data: &Transform) -> Vec<T> {
174 match self {
175 Encoder::Ordinal(mapping) => match data {
176 Transform::Ordinal(typed_data) => {
177 let result: Vec<T> = typed_data.iter()
178 .flat_map(|&el| {
179 mapping.into_iter()
180 .filter(move |&(_key, val)| val == &el)
181 .map(|(key, &_val)| key.clone())
182 })
183 .collect();
184 result
185 },
186 _ => panic!("Transformed data not compatible with this encoder"),
187 },
188
189 Encoder::OneHot(mapping) => match data {
191 Transform::OneHot(typed_data) => {
192 let result: Vec<T> = typed_data.iter()
193 .flat_map(|el| {
194
195 mapping.into_iter()
196 .filter(move |&(_key, val)| {
197 let mut equal_el: usize = 0;
198 for i in 0..val.len() {
199 if val[i] == el[i] {
200 equal_el += 1;
201 }
202 }
203 equal_el == val.len()
205 }
206 )
207 .map(move |(key, _val)| {
208 key.clone()
211 })
212 })
213 .collect();
214 result
215 },
216 _ => panic!("Transformed data not compatible with this encoder")
217 },
218
219 Encoder::Custom(mapping) => match data {
220 Transform::CustomMapping(typed_data) => {
221 let result = typed_data.into_iter().flat_map(|&el| {
222 mapping
223 .into_iter()
224 .filter(move |&(_k, v)| v == &el)
225 .map(|(k, &_v)| k.clone())
226 })
227 .collect();
228 result
229 },
230 _ => panic!("Transformed data not compatible with this encoder"),
231 }
232 }
233 }
234
235 pub fn nclasses(&self) -> usize {
238 match self {
239 Encoder::Ordinal(mapping) => {
241 let values: Vec<u64> = mapping.values().cloned().collect();
242 let len = values.iter().max();
243 match len {
244 Some(v) => *v as usize + 1,
245 _ => 0 as usize
246 }
247 },
248 Encoder::OneHot(map) => map.len(),
249 Encoder::Custom(map) => map.len(),
250 }
251 }
252}
253
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_one_hot_encoding() {
261 let x = 128u64;
262 let ohe: Vec<bool> = format!("{:b}", x)
263 .chars()
264 .filter_map(|n| match n {
265 '1' => Some(true),
266 '0' => Some(false),
267 _ => panic!("Conversion to binary failed"),
268 })
269 .collect();
270 dbg!(&ohe);
271
272 assert_eq!(ohe.len(), 8);
273
274 }
277
278 #[test]
279 fn test_fit_ordinal_encoder() {
280 let data: Vec<String> = vec!["hello".to_string(),
281 "world".to_string(),
282 "world".to_string(),
283 "world".to_string(),
284 "world".to_string(),
285 "again".to_string(),
286 "hello".to_string(),
287 "again".to_string(),
288 "goodbye".to_string(),
289 ];
290 let enctype = EncoderType::Ordinal;
291 let config = Config{
292 max_nclasses: None,
293 mapping_function: None
294 };
295 let mut enc: Encoder<String> = Encoder::new(Some(enctype));
296 dbg!("created encoder ", &enc);
297
298 enc.fit(&data, &config);
299 dbg!("fitted encoder:", &enc);
300
301 let trans_data = enc.transform(&data);
302 dbg!("trans data: ", &trans_data);
303
304 let recon_data = enc.inverse_transform(&trans_data);
305 dbg!("recon data:", &recon_data);
306
307 assert_eq!(enc.nclasses(), 4);
308 }
309
310 #[test]
311 fn test_fit_ordinal_encoder_limited_classes() {
312 let data: Vec<String> = vec!["hello".to_string(),
313 "world".to_string(),
314 "world".to_string(),
315 "world".to_string(),
316 "world".to_string(),
317 "again".to_string(),
318 "hello".to_string(),
319 "again".to_string(),
320 "goodbye".to_string(),
321 ];
322 let enctype = EncoderType::Ordinal;
323 let config = Config{
324 max_nclasses: Some(3),
325 mapping_function: None
326 };
327 let mut enc: Encoder<String> = Encoder::new(Some(enctype));
328 dbg!("created encoder ", &enc);
329
330 enc.fit(&data, &config);
331 dbg!("fitted encoder:", &enc);
332
333 assert_eq!(enc.nclasses(), 3);
334 }
335
336 #[test]
337 fn test_fit_one_hot_encoder() {
338 let data: Vec<String> = vec!["hello".to_string(),
339 "world".to_string(),
340 "world".to_string(),
341 "world".to_string(),
342 "world".to_string(),
343 "again".to_string(),
344 "hello".to_string(),
345 "again".to_string(),
346 "goodbye".to_string(),
347 ];
348
349 let config = Config {
350 max_nclasses: None,
351 mapping_function: None
352 };
353 let mut enc: Encoder<String> = Encoder::new(Some(EncoderType::OneHot));
354 enc.fit(&data, &config);
355 dbg!("fitted encoder: ", &enc);
356
357 let trans_data = enc.transform(&data);
358 assert_eq!(trans_data.len(), data.len());
360
361 let recon_data = enc.inverse_transform(&trans_data);
362 dbg!("recon data:", &recon_data);
363
364 }
365
366 #[test]
367 fn test_fit_custom_encoder() {
368 let data: Vec<String> = vec!["hello".to_string(),
369 "world".to_string(),
370 "world".to_string(),
371 "world".to_string(),
372 "world".to_string(),
373 "again".to_string(),
374 "hello".to_string(),
375 "again".to_string(),
376 "goodbye".to_string(),
377 ];
378 let config: Config<String> = Config {
379 max_nclasses: Some(10),
380 mapping_function: Some(|el| match el.as_str() {
381 "hello" => 42,
382 "goodbye" => 99,
383 _ => 0
384 }),
385 };
386
387 let mut enc: Encoder<String> = Encoder::new(Some(EncoderType::CustomMapping));
388 enc.fit(&data, &config);
389 dbg!("fitted encoder: ", &enc);
390
391 let trans_data = enc.transform(&data);
392 dbg!("trans data: ", &trans_data);
393
394 let recon_data = enc.inverse_transform(&trans_data);
395 dbg!("recon data:", &recon_data);
396 }
397}