rs_ml/transformer/
embedding.rs

1//! Embed categorial features to float
2
3use num_traits::{Float, ToPrimitive};
4use std::{
5    collections::{HashMap, HashSet},
6    hash::Hash,
7};
8
9use ndarray::{Array1, Array2};
10
11use crate::Estimator;
12
13use super::Transformer;
14
15/// One hot embedding
16#[derive(Copy, Clone, Debug, Default)]
17pub struct OneHotEmbeddingEstimator;
18
19/// One hot embedding transfomer
20#[derive(Debug, Clone)]
21pub struct OneHotEmbeddingTransformer<V> {
22    map: HashMap<V, usize>,
23}
24
25/// OrderedEnumEmbeddingTransformer
26#[derive(Clone, Copy, Debug, Default)]
27pub struct OrderedEnumEmbeddingTransformer;
28
29impl<V: Eq + Hash + Clone, A> Estimator<A> for OneHotEmbeddingEstimator
30where
31    for<'a> &'a A: IntoIterator<Item = &'a V>,
32{
33    type Estimator = OneHotEmbeddingTransformer<V>;
34
35    fn fit(&self, input: &A) -> Option<Self::Estimator> {
36        let distinct: HashSet<V> = input.into_iter().cloned().collect();
37        let map: HashMap<V, usize> = distinct
38            .into_iter()
39            .enumerate()
40            .map(|(idx, v)| (v, idx))
41            .collect();
42
43        Some(OneHotEmbeddingTransformer { map })
44    }
45}
46
47impl<V: Hash + Eq, F: Float, It> Transformer<It, Array2<F>> for OneHotEmbeddingTransformer<V>
48where
49    for<'a> &'a It: IntoIterator<Item = &'a V>,
50{
51    fn transform(&self, input: &It) -> Option<Array2<F>> {
52        let a: Vec<usize> = input.into_iter().map(|v| self.map[v]).collect();
53
54        let mut ret = Array2::zeros((a.len(), self.map.len()));
55
56        for (idx, a) in a.into_iter().enumerate() {
57            ret[(idx, a)] = F::one();
58        }
59
60        Some(ret)
61    }
62}
63
64impl<V: ToPrimitive, It> Transformer<It, Array1<usize>> for OrderedEnumEmbeddingTransformer
65where
66    for<'a> &'a It: IntoIterator<Item = &'a V>,
67{
68    fn transform(&self, input: &It) -> Option<Array1<usize>> {
69        Some(Array1::from_iter(
70            input.into_iter().map(|v| ToPrimitive::to_usize(v).unwrap()),
71        ))
72    }
73}