reddb_server/storage/ml/classifier/
features.rs1use std::collections::HashMap;
10
11#[derive(Debug, Default, Clone)]
14pub struct Vocabulary {
15 index: HashMap<String, usize>,
16 document_frequency: Vec<u64>,
17 total_documents: u64,
18}
19
20impl Vocabulary {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn add_document(&mut self, tokens: &[&str]) {
28 let mut seen: HashMap<String, ()> = HashMap::new();
29 for t in tokens {
30 let key = t.to_ascii_lowercase();
31 if seen.contains_key(&key) {
32 continue;
33 }
34 let idx = match self.index.get(&key) {
35 Some(i) => *i,
36 None => {
37 let i = self.document_frequency.len();
38 self.index.insert(key.clone(), i);
39 self.document_frequency.push(0);
40 i
41 }
42 };
43 self.document_frequency[idx] += 1;
44 seen.insert(key, ());
45 }
46 self.total_documents += 1;
47 }
48
49 pub fn dimensions(&self) -> usize {
50 self.document_frequency.len()
51 }
52
53 pub fn index_of(&self, token: &str) -> Option<usize> {
54 self.index.get(&token.to_ascii_lowercase()).copied()
55 }
56}
57
58pub fn tf_idf_vectorize(vocab: &Vocabulary, tokens: &[&str]) -> Vec<f32> {
62 if vocab.dimensions() == 0 {
63 return Vec::new();
64 }
65 let mut tf = vec![0f32; vocab.dimensions()];
66 let mut total = 0f32;
67 for t in tokens {
68 if let Some(idx) = vocab.index_of(t) {
69 tf[idx] += 1.0;
70 total += 1.0;
71 }
72 }
73 if total > 0.0 {
74 for v in tf.iter_mut() {
75 *v /= total;
76 }
77 }
78 let total_docs = (vocab.total_documents.max(1)) as f32;
79 for (i, value) in tf.iter_mut().enumerate().take(vocab.dimensions()) {
80 let df = vocab.document_frequency[i].max(1) as f32;
81 let idf = ((total_docs + 1.0) / (df + 1.0)).ln() + 1.0;
82 *value *= idf;
83 }
84 tf
85}
86
87pub fn one_hot(class: u32, num_classes: usize) -> Vec<f32> {
91 let c = class as usize;
92 if c >= num_classes {
93 return Vec::new();
94 }
95 let mut v = vec![0f32; num_classes];
96 v[c] = 1.0;
97 v
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn vocabulary_allocates_indices_incrementally() {
106 let mut v = Vocabulary::new();
107 v.add_document(&["cat", "dog"]);
108 v.add_document(&["dog", "bird"]);
109 assert_eq!(v.dimensions(), 3);
110 assert!(v.index_of("cat").is_some());
111 assert!(v.index_of("dog").is_some());
112 assert!(v.index_of("bird").is_some());
113 }
115
116 #[test]
117 fn tf_idf_vectorises_in_vocabulary_tokens() {
118 let mut v = Vocabulary::new();
119 v.add_document(&["cat", "cat", "the"]);
120 v.add_document(&["dog", "the"]);
121 v.add_document(&["cat", "dog"]);
122 let vec = tf_idf_vectorize(&v, &["cat"]);
123 assert_eq!(vec.len(), v.dimensions());
124 assert!(vec[v.index_of("cat").unwrap()] > 0.0);
125 assert_eq!(vec[v.index_of("dog").unwrap()], 0.0);
126 }
127
128 #[test]
129 fn tf_idf_ignores_oov_tokens() {
130 let mut v = Vocabulary::new();
131 v.add_document(&["hello"]);
132 let vec = tf_idf_vectorize(&v, &["nope", "missing"]);
133 for x in vec {
134 assert_eq!(x, 0.0);
135 }
136 }
137
138 #[test]
139 fn one_hot_is_correct_length_and_position() {
140 let v = one_hot(2, 4);
141 assert_eq!(v, vec![0.0, 0.0, 1.0, 0.0]);
142 }
143
144 #[test]
145 fn one_hot_rejects_out_of_range_class() {
146 assert!(one_hot(5, 3).is_empty());
147 }
148}