rs_ml/transformer/
embedding.rs1use num_traits::{Float, ToPrimitive};
4use std::{
5 collections::{HashMap, HashSet},
6 hash::Hash,
7};
8
9use ndarray::{Array1, Array2};
10
11use crate::Estimator;
12
13use super::Transformer;
14
15#[derive(Copy, Clone, Debug, Default)]
17pub struct OneHotEmbeddingEstimator;
18
19#[derive(Debug, Clone)]
21pub struct OneHotEmbeddingTransformer<V> {
22 map: HashMap<V, usize>,
23}
24
25#[derive(Clone, Copy, Debug, Default)]
27pub struct OrderedEnumEmbeddingTransformer;
28
29impl<V: Eq + Hash + Clone, A> Estimator<A> for OneHotEmbeddingEstimator
30where
31 for<'a> &'a A: IntoIterator<Item = &'a V>,
32{
33 type Estimator = OneHotEmbeddingTransformer<V>;
34
35 fn fit(&self, input: &A) -> Option<Self::Estimator> {
36 let distinct: HashSet<V> = input.into_iter().cloned().collect();
37 let map: HashMap<V, usize> = distinct
38 .into_iter()
39 .enumerate()
40 .map(|(idx, v)| (v, idx))
41 .collect();
42
43 Some(OneHotEmbeddingTransformer { map })
44 }
45}
46
47impl<V: Hash + Eq, F: Float, It> Transformer<It, Array2<F>> for OneHotEmbeddingTransformer<V>
48where
49 for<'a> &'a It: IntoIterator<Item = &'a V>,
50{
51 fn transform(&self, input: &It) -> Option<Array2<F>> {
52 let a: Vec<usize> = input.into_iter().map(|v| self.map[v]).collect();
53
54 let mut ret = Array2::zeros((a.len(), self.map.len()));
55
56 for (idx, a) in a.into_iter().enumerate() {
57 ret[(idx, a)] = F::one();
58 }
59
60 Some(ret)
61 }
62}
63
64impl<V: ToPrimitive, It> Transformer<It, Array1<usize>> for OrderedEnumEmbeddingTransformer
65where
66 for<'a> &'a It: IntoIterator<Item = &'a V>,
67{
68 fn transform(&self, input: &It) -> Option<Array1<usize>> {
69 Some(Array1::from_iter(
70 input.into_iter().map(|v| ToPrimitive::to_usize(v).unwrap()),
71 ))
72 }
73}