Skip to main content

reddb_server/storage/ml/classifier/
features.rs

1//! Light-weight feature extractors.
2//!
3//! Production-grade feature engineering lives outside this module
4//! (callers can plug arbitrary vectors into the classifier). The
5//! helpers here exist so the classifier surface tests can exercise
6//! TF-IDF-style vectors end-to-end without dragging tokeniser
7//! dependencies in.
8
9use std::collections::HashMap;
10
11/// Shared vocabulary learnt from a corpus. Incremental — new tokens
12/// allocate a fresh index on `add`.
13#[derive(Debug, Default, Clone)]
14pub struct Vocabulary {
15    index: HashMap<String, usize>,
16    document_frequency: Vec<u64>,
17    total_documents: u64,
18}
19
20impl Vocabulary {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    /// Register a document — increments document frequency for each
26    /// distinct token observed.
27    pub fn add_document(&mut self, tokens: &[&str]) {
28        let mut seen: HashMap<String, ()> = HashMap::new();
29        for t in tokens {
30            let key = t.to_ascii_lowercase();
31            if seen.contains_key(&key) {
32                continue;
33            }
34            let idx = match self.index.get(&key) {
35                Some(i) => *i,
36                None => {
37                    let i = self.document_frequency.len();
38                    self.index.insert(key.clone(), i);
39                    self.document_frequency.push(0);
40                    i
41                }
42            };
43            self.document_frequency[idx] += 1;
44            seen.insert(key, ());
45        }
46        self.total_documents += 1;
47    }
48
49    pub fn dimensions(&self) -> usize {
50        self.document_frequency.len()
51    }
52
53    pub fn index_of(&self, token: &str) -> Option<usize> {
54        self.index.get(&token.to_ascii_lowercase()).copied()
55    }
56}
57
58/// Build a TF-IDF vector of length `vocab.dimensions()` from a single
59/// document's tokens. Tokens not in the vocabulary are ignored
60/// (caller should have called `vocab.add_document` during training).
61pub fn tf_idf_vectorize(vocab: &Vocabulary, tokens: &[&str]) -> Vec<f32> {
62    if vocab.dimensions() == 0 {
63        return Vec::new();
64    }
65    let mut tf = vec![0f32; vocab.dimensions()];
66    let mut total = 0f32;
67    for t in tokens {
68        if let Some(idx) = vocab.index_of(t) {
69            tf[idx] += 1.0;
70            total += 1.0;
71        }
72    }
73    if total > 0.0 {
74        for v in tf.iter_mut() {
75            *v /= total;
76        }
77    }
78    let total_docs = (vocab.total_documents.max(1)) as f32;
79    for (i, value) in tf.iter_mut().enumerate().take(vocab.dimensions()) {
80        let df = vocab.document_frequency[i].max(1) as f32;
81        let idf = ((total_docs + 1.0) / (df + 1.0)).ln() + 1.0;
82        *value *= idf;
83    }
84    tf
85}
86
87/// Build a one-hot vector of length `num_classes` with 1.0 at
88/// `class` and 0.0 elsewhere. Returns an empty vector if `class`
89/// is out of range.
90pub fn one_hot(class: u32, num_classes: usize) -> Vec<f32> {
91    let c = class as usize;
92    if c >= num_classes {
93        return Vec::new();
94    }
95    let mut v = vec![0f32; num_classes];
96    v[c] = 1.0;
97    v
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn vocabulary_allocates_indices_incrementally() {
106        let mut v = Vocabulary::new();
107        v.add_document(&["cat", "dog"]);
108        v.add_document(&["dog", "bird"]);
109        assert_eq!(v.dimensions(), 3);
110        assert!(v.index_of("cat").is_some());
111        assert!(v.index_of("dog").is_some());
112        assert!(v.index_of("bird").is_some());
113        // DF: cat 1, dog 2, bird 1
114    }
115
116    #[test]
117    fn tf_idf_vectorises_in_vocabulary_tokens() {
118        let mut v = Vocabulary::new();
119        v.add_document(&["cat", "cat", "the"]);
120        v.add_document(&["dog", "the"]);
121        v.add_document(&["cat", "dog"]);
122        let vec = tf_idf_vectorize(&v, &["cat"]);
123        assert_eq!(vec.len(), v.dimensions());
124        assert!(vec[v.index_of("cat").unwrap()] > 0.0);
125        assert_eq!(vec[v.index_of("dog").unwrap()], 0.0);
126    }
127
128    #[test]
129    fn tf_idf_ignores_oov_tokens() {
130        let mut v = Vocabulary::new();
131        v.add_document(&["hello"]);
132        let vec = tf_idf_vectorize(&v, &["nope", "missing"]);
133        for x in vec {
134            assert_eq!(x, 0.0);
135        }
136    }
137
138    #[test]
139    fn one_hot_is_correct_length_and_position() {
140        let v = one_hot(2, 4);
141        assert_eq!(v, vec![0.0, 0.0, 1.0, 0.0]);
142    }
143
144    #[test]
145    fn one_hot_rejects_out_of_range_class() {
146        assert!(one_hot(5, 3).is_empty());
147    }
148}