ghostflow_ml/
nlp.rs

1//! Natural Language Processing - Tokenizers, Embeddings, Text Processing
2//!
3//! This module provides NLP utilities for text preprocessing and representation.
4
5use std::collections::HashMap;
6
7/// Word-level tokenizer
8pub struct WordTokenizer {
9    pub lowercase: bool,
10    pub remove_punctuation: bool,
11    pub min_word_length: usize,
12    vocab_: Option<HashMap<String, usize>>,
13    index_to_word_: Option<Vec<String>>,
14}
15
16impl WordTokenizer {
17    pub fn new() -> Self {
18        WordTokenizer {
19            lowercase: true,
20            remove_punctuation: true,
21            min_word_length: 1,
22            vocab_: None,
23            index_to_word_: None,
24        }
25    }
26
27    pub fn lowercase(mut self, lowercase: bool) -> Self {
28        self.lowercase = lowercase;
29        self
30    }
31
32    pub fn remove_punctuation(mut self, remove: bool) -> Self {
33        self.remove_punctuation = remove;
34        self
35    }
36
37    pub fn min_word_length(mut self, length: usize) -> Self {
38        self.min_word_length = length;
39        self
40    }
41
42    fn preprocess(&self, text: &str) -> String {
43        let mut processed = text.to_string();
44        
45        if self.lowercase {
46            processed = processed.to_lowercase();
47        }
48
49        if self.remove_punctuation {
50            processed = processed.chars()
51                .filter(|c| c.is_alphanumeric() || c.is_whitespace())
52                .collect();
53        }
54
55        processed
56    }
57
58    pub fn tokenize(&self, text: &str) -> Vec<String> {
59        let processed = self.preprocess(text);
60        
61        processed.split_whitespace()
62            .filter(|word| word.len() >= self.min_word_length)
63            .map(|s| s.to_string())
64            .collect()
65    }
66
67    pub fn fit(&mut self, texts: &[String]) {
68        let mut vocab = HashMap::new();
69        let mut index = 0;
70
71        for text in texts {
72            let tokens = self.tokenize(text);
73            for token in tokens {
74                if !vocab.contains_key(&token) {
75                    vocab.insert(token.clone(), index);
76                    index += 1;
77                }
78            }
79        }
80
81        let mut index_to_word = vec![String::new(); vocab.len()];
82        for (word, &idx) in &vocab {
83            index_to_word[idx] = word.clone();
84        }
85
86        self.vocab_ = Some(vocab);
87        self.index_to_word_ = Some(index_to_word);
88    }
89
90    pub fn texts_to_sequences(&self, texts: &[String]) -> Vec<Vec<usize>> {
91        let vocab = self.vocab_.as_ref().expect("Tokenizer not fitted");
92        
93        texts.iter()
94            .map(|text| {
95                self.tokenize(text)
96                    .iter()
97                    .filter_map(|token| vocab.get(token).copied())
98                    .collect()
99            })
100            .collect()
101    }
102
103    pub fn sequences_to_texts(&self, sequences: &[Vec<usize>]) -> Vec<String> {
104        let index_to_word = self.index_to_word_.as_ref().expect("Tokenizer not fitted");
105        
106        sequences.iter()
107            .map(|seq| {
108                seq.iter()
109                    .filter_map(|&idx| {
110                        if idx < index_to_word.len() {
111                            Some(index_to_word[idx].clone())
112                        } else {
113                            None
114                        }
115                    })
116                    .collect::<Vec<_>>()
117                    .join(" ")
118            })
119            .collect()
120    }
121
122    pub fn vocab_size(&self) -> usize {
123        self.vocab_.as_ref().map(|v| v.len()).unwrap_or(0)
124    }
125
126    pub fn vocab(&self) -> Option<&HashMap<String, usize>> {
127        self.vocab_.as_ref()
128    }
129}
130
131impl Default for WordTokenizer {
132    fn default() -> Self { Self::new() }
133}
134
135/// Character-level tokenizer
136pub struct CharTokenizer {
137    pub lowercase: bool,
138    char_to_idx_: Option<HashMap<char, usize>>,
139    idx_to_char_: Option<Vec<char>>,
140}
141
142impl CharTokenizer {
143    pub fn new() -> Self {
144        CharTokenizer {
145            lowercase: true,
146            char_to_idx_: None,
147            idx_to_char_: None,
148        }
149    }
150
151    pub fn lowercase(mut self, lowercase: bool) -> Self {
152        self.lowercase = lowercase;
153        self
154    }
155
156    pub fn fit(&mut self, texts: &[String]) {
157        let mut chars = std::collections::HashSet::new();
158
159        for text in texts {
160            let processed = if self.lowercase {
161                text.to_lowercase()
162            } else {
163                text.clone()
164            };
165
166            for c in processed.chars() {
167                chars.insert(c);
168            }
169        }
170
171        let mut char_to_idx = HashMap::new();
172        let mut idx_to_char = Vec::new();
173
174        for (i, c) in chars.into_iter().enumerate() {
175            char_to_idx.insert(c, i);
176            idx_to_char.push(c);
177        }
178
179        self.char_to_idx_ = Some(char_to_idx);
180        self.idx_to_char_ = Some(idx_to_char);
181    }
182
183    pub fn texts_to_sequences(&self, texts: &[String]) -> Vec<Vec<usize>> {
184        let char_to_idx = self.char_to_idx_.as_ref().expect("Tokenizer not fitted");
185        
186        texts.iter()
187            .map(|text| {
188                let processed = if self.lowercase {
189                    text.to_lowercase()
190                } else {
191                    text.clone()
192                };
193
194                processed.chars()
195                    .filter_map(|c| char_to_idx.get(&c).copied())
196                    .collect()
197            })
198            .collect()
199    }
200
201    pub fn sequences_to_texts(&self, sequences: &[Vec<usize>]) -> Vec<String> {
202        let idx_to_char = self.idx_to_char_.as_ref().expect("Tokenizer not fitted");
203        
204        sequences.iter()
205            .map(|seq| {
206                seq.iter()
207                    .filter_map(|&idx| {
208                        if idx < idx_to_char.len() {
209                            Some(idx_to_char[idx])
210                        } else {
211                            None
212                        }
213                    })
214                    .collect()
215            })
216            .collect()
217    }
218
219    pub fn vocab_size(&self) -> usize {
220        self.char_to_idx_.as_ref().map(|v| v.len()).unwrap_or(0)
221    }
222}
223
224impl Default for CharTokenizer {
225    fn default() -> Self { Self::new() }
226}
227
228/// BPE (Byte Pair Encoding) Tokenizer
229pub struct BPETokenizer {
230    pub vocab_size: usize,
231    merges_: Vec<(String, String)>,
232    vocab_: HashMap<String, usize>,
233}
234
235impl BPETokenizer {
236    pub fn new(vocab_size: usize) -> Self {
237        BPETokenizer {
238            vocab_size,
239            merges_: Vec::new(),
240            vocab_: HashMap::new(),
241        }
242    }
243
244    fn get_pairs(word: &[String]) -> Vec<(String, String)> {
245        let mut pairs = Vec::new();
246        for i in 0..word.len().saturating_sub(1) {
247            pairs.push((word[i].clone(), word[i + 1].clone()));
248        }
249        pairs
250    }
251
252    pub fn fit(&mut self, texts: &[String]) {
253        // Initialize vocabulary with characters
254        let mut vocab: HashMap<Vec<String>, usize> = HashMap::new();
255        
256        for text in texts {
257            for word in text.split_whitespace() {
258                let chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
259                *vocab.entry(chars).or_insert(0) += 1;
260            }
261        }
262
263        // Learn merges
264        for _ in 0..self.vocab_size {
265            let mut pair_freqs: HashMap<(String, String), usize> = HashMap::new();
266
267            for (word, &freq) in &vocab {
268                let pairs = Self::get_pairs(word);
269                for pair in pairs {
270                    *pair_freqs.entry(pair).or_insert(0) += freq;
271                }
272            }
273
274            if pair_freqs.is_empty() {
275                break;
276            }
277
278            let best_pair = pair_freqs.iter()
279                .max_by_key(|(_, &freq)| freq)
280                .map(|(pair, _)| pair.clone());
281
282            if let Some((first, second)) = best_pair {
283                self.merges_.push((first.clone(), second.clone()));
284
285                // Update vocabulary
286                let mut new_vocab = HashMap::new();
287                for (word, freq) in vocab {
288                    let new_word = self.merge_pair(&word, &first, &second);
289                    new_vocab.insert(new_word, freq);
290                }
291                vocab = new_vocab;
292            } else {
293                break;
294            }
295        }
296
297        // Build final vocabulary
298        let mut idx = 0;
299        for (word, _) in vocab {
300            for token in word {
301                if !self.vocab_.contains_key(&token) {
302                    self.vocab_.insert(token, idx);
303                    idx += 1;
304                }
305            }
306        }
307    }
308
309    fn merge_pair(&self, word: &[String], first: &str, second: &str) -> Vec<String> {
310        let mut result = Vec::new();
311        let mut i = 0;
312
313        while i < word.len() {
314            if i < word.len() - 1 && word[i] == first && word[i + 1] == second {
315                result.push(format!("{}{}", first, second));
316                i += 2;
317            } else {
318                result.push(word[i].clone());
319                i += 1;
320            }
321        }
322
323        result
324    }
325
326    pub fn encode(&self, text: &str) -> Vec<usize> {
327        let mut tokens = Vec::new();
328
329        for word in text.split_whitespace() {
330            let mut chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
331
332            for (first, second) in &self.merges_ {
333                chars = self.merge_pair(&chars, first, second);
334            }
335
336            for token in chars {
337                if let Some(&idx) = self.vocab_.get(&token) {
338                    tokens.push(idx);
339                }
340            }
341        }
342
343        tokens
344    }
345
346    pub fn vocab_size_actual(&self) -> usize {
347        self.vocab_.len()
348    }
349}
350
351/// TF-IDF Vectorizer
352pub struct TfidfVectorizer {
353    pub max_features: Option<usize>,
354    pub min_df: usize,
355    pub max_df: f32,
356    tokenizer: WordTokenizer,
357    idf_: Option<Vec<f32>>,
358    vocab_: Option<HashMap<String, usize>>,
359}
360
361impl TfidfVectorizer {
362    pub fn new() -> Self {
363        TfidfVectorizer {
364            max_features: None,
365            min_df: 1,
366            max_df: 1.0,
367            tokenizer: WordTokenizer::new(),
368            idf_: None,
369            vocab_: None,
370        }
371    }
372
373    pub fn max_features(mut self, max: usize) -> Self {
374        self.max_features = Some(max);
375        self
376    }
377
378    pub fn min_df(mut self, min: usize) -> Self {
379        self.min_df = min;
380        self
381    }
382
383    pub fn fit(&mut self, texts: &[String]) {
384        // Build vocabulary
385        let mut doc_freq: HashMap<String, usize> = HashMap::new();
386        let n_docs = texts.len();
387
388        for text in texts {
389            let tokens = self.tokenizer.tokenize(text);
390            let unique_tokens: std::collections::HashSet<_> = tokens.into_iter().collect();
391            
392            for token in unique_tokens {
393                *doc_freq.entry(token).or_insert(0) += 1;
394            }
395        }
396
397        // Filter by document frequency
398        let max_df_count = (self.max_df * n_docs as f32) as usize;
399        let mut filtered_vocab: Vec<(String, usize)> = doc_freq.into_iter()
400            .filter(|(_, freq)| *freq >= self.min_df && *freq <= max_df_count)
401            .collect();
402
403        // Limit vocabulary size
404        if let Some(max_feat) = self.max_features {
405            filtered_vocab.sort_by(|a, b| b.1.cmp(&a.1));
406            filtered_vocab.truncate(max_feat);
407        }
408
409        // Build vocabulary and IDF
410        let mut vocab = HashMap::new();
411        let mut idf = Vec::new();
412
413        for (i, (term, df)) in filtered_vocab.iter().enumerate() {
414            vocab.insert(term.clone(), i);
415            // IDF = log((n_docs + 1) / (df + 1)) + 1
416            let idf_value = ((n_docs + 1) as f32 / (*df + 1) as f32).ln() + 1.0;
417            idf.push(idf_value);
418        }
419
420        self.vocab_ = Some(vocab);
421        self.idf_ = Some(idf);
422    }
423
424    pub fn transform(&self, texts: &[String]) -> Vec<Vec<f32>> {
425        let vocab = self.vocab_.as_ref().expect("Vectorizer not fitted");
426        let idf = self.idf_.as_ref().unwrap();
427        let vocab_size = vocab.len();
428
429        texts.iter()
430            .map(|text| {
431                let tokens = self.tokenizer.tokenize(text);
432                let mut tf = vec![0.0f32; vocab_size];
433
434                // Count term frequencies
435                for token in &tokens {
436                    if let Some(&idx) = vocab.get(token) {
437                        tf[idx] += 1.0;
438                    }
439                }
440
441                // Normalize TF
442                let total: f32 = tf.iter().sum();
443                if total > 0.0 {
444                    for t in &mut tf {
445                        *t /= total;
446                    }
447                }
448
449                // Apply IDF
450                for (i, t) in tf.iter_mut().enumerate() {
451                    *t *= idf[i];
452                }
453
454                // L2 normalization
455                let norm: f32 = tf.iter().map(|&x| x * x).sum::<f32>().sqrt();
456                if norm > 0.0 {
457                    for t in &mut tf {
458                        *t /= norm;
459                    }
460                }
461
462                tf
463            })
464            .collect()
465    }
466
467    pub fn fit_transform(&mut self, texts: &[String]) -> Vec<Vec<f32>> {
468        self.fit(texts);
469        self.transform(texts)
470    }
471
472    pub fn vocab_size(&self) -> usize {
473        self.vocab_.as_ref().map(|v| v.len()).unwrap_or(0)
474    }
475}
476
477impl Default for TfidfVectorizer {
478    fn default() -> Self { Self::new() }
479}
480
481/// Word2Vec Skip-gram model (simplified)
482pub struct Word2Vec {
483    pub embedding_dim: usize,
484    pub window_size: usize,
485    pub min_count: usize,
486    pub learning_rate: f32,
487    pub epochs: usize,
488    embeddings_: Option<Vec<Vec<f32>>>,
489    vocab_: Option<HashMap<String, usize>>,
490}
491
492impl Word2Vec {
493    pub fn new(embedding_dim: usize) -> Self {
494        Word2Vec {
495            embedding_dim,
496            window_size: 5,
497            min_count: 5,
498            learning_rate: 0.025,
499            epochs: 5,
500            embeddings_: None,
501            vocab_: None,
502        }
503    }
504
505    pub fn window_size(mut self, size: usize) -> Self {
506        self.window_size = size;
507        self
508    }
509
510    pub fn min_count(mut self, count: usize) -> Self {
511        self.min_count = count;
512        self
513    }
514
515    pub fn fit(&mut self, texts: &[String]) {
516        // Build vocabulary
517        let mut word_counts: HashMap<String, usize> = HashMap::new();
518        let tokenizer = WordTokenizer::new();
519
520        for text in texts {
521            let tokens = tokenizer.tokenize(text);
522            for token in tokens {
523                *word_counts.entry(token).or_insert(0) += 1;
524            }
525        }
526
527        // Filter by min_count
528        let mut vocab = HashMap::new();
529        let mut idx = 0;
530        for (word, count) in word_counts {
531            if count >= self.min_count {
532                vocab.insert(word, idx);
533                idx += 1;
534            }
535        }
536
537        let vocab_size = vocab.len();
538
539        // Initialize embeddings randomly
540        use rand::prelude::*;
541        let mut rng = thread_rng();
542        let mut embeddings = vec![vec![0.0f32; self.embedding_dim]; vocab_size];
543        
544        for emb in &mut embeddings {
545            for val in emb {
546                *val = (rng.gen::<f32>() - 0.5) / self.embedding_dim as f32;
547            }
548        }
549
550        // Training (simplified skip-gram)
551        for _epoch in 0..self.epochs {
552            for text in texts {
553                let tokens = tokenizer.tokenize(text);
554                let indices: Vec<usize> = tokens.iter()
555                    .filter_map(|t| vocab.get(t).copied())
556                    .collect();
557
558                for (i, &center_idx) in indices.iter().enumerate() {
559                    let start = i.saturating_sub(self.window_size);
560                    let end = (i + self.window_size + 1).min(indices.len());
561
562                    for j in start..end {
563                        if i == j { continue; }
564                        let context_idx = indices[j];
565
566                        // Simplified gradient update
567                        for d in 0..self.embedding_dim {
568                            let grad = embeddings[context_idx][d] - embeddings[center_idx][d];
569                            embeddings[center_idx][d] += self.learning_rate * grad * 0.01;
570                        }
571                    }
572                }
573            }
574        }
575
576        self.embeddings_ = Some(embeddings);
577        self.vocab_ = Some(vocab);
578    }
579
580    pub fn get_vector(&self, word: &str) -> Option<&[f32]> {
581        let vocab = self.vocab_.as_ref()?;
582        let embeddings = self.embeddings_.as_ref()?;
583        let idx = vocab.get(word)?;
584        Some(&embeddings[*idx])
585    }
586
587    pub fn similarity(&self, word1: &str, word2: &str) -> Option<f32> {
588        let vec1 = self.get_vector(word1)?;
589        let vec2 = self.get_vector(word2)?;
590
591        let dot: f32 = vec1.iter().zip(vec2.iter()).map(|(a, b)| a * b).sum();
592        let norm1: f32 = vec1.iter().map(|x| x * x).sum::<f32>().sqrt();
593        let norm2: f32 = vec2.iter().map(|x| x * x).sum::<f32>().sqrt();
594
595        Some(dot / (norm1 * norm2).max(1e-10))
596    }
597
598    pub fn vocab_size(&self) -> usize {
599        self.vocab_.as_ref().map(|v| v.len()).unwrap_or(0)
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_word_tokenizer() {
609        let mut tokenizer = WordTokenizer::new();
610        let texts = vec![
611            "Hello world".to_string(),
612            "Hello Rust".to_string(),
613        ];
614
615        tokenizer.fit(&texts);
616        assert_eq!(tokenizer.vocab_size(), 3);
617
618        let sequences = tokenizer.texts_to_sequences(&texts);
619        assert_eq!(sequences.len(), 2);
620        assert_eq!(sequences[0].len(), 2);
621    }
622
623    #[test]
624    fn test_char_tokenizer() {
625        let mut tokenizer = CharTokenizer::new();
626        let texts = vec!["abc".to_string(), "def".to_string()];
627
628        tokenizer.fit(&texts);
629        assert_eq!(tokenizer.vocab_size(), 6);
630
631        let sequences = tokenizer.texts_to_sequences(&texts);
632        assert_eq!(sequences[0].len(), 3);
633    }
634
635    #[test]
636    fn test_tfidf() {
637        let texts = vec![
638            "the cat sat on the mat".to_string(),
639            "the dog sat on the log".to_string(),
640        ];
641
642        let mut vectorizer = TfidfVectorizer::new();
643        let vectors = vectorizer.fit_transform(&texts);
644
645        assert_eq!(vectors.len(), 2);
646        assert!(vectors[0].len() > 0);
647    }
648
649    #[test]
650    fn test_word2vec() {
651        let texts = vec![
652            "the quick brown fox jumps".to_string(),
653            "the lazy dog sleeps".to_string(),
654        ];
655
656        let mut w2v = Word2Vec::new(10).min_count(1);
657        w2v.epochs = 2;
658        w2v.fit(&texts);
659
660        assert!(w2v.vocab_size() > 0);
661        assert!(w2v.get_vector("the").is_some());
662    }
663}
664
665