sklears_preprocessing/
text.rs

1//! Text preprocessing utilities for sklears
2//!
3//! This module provides text preprocessing capabilities including:
4//! - Text tokenization and normalization
5//! - TF-IDF vectorization
6//! - N-gram feature generation
7//! - Text similarity features
8//! - Sentence embeddings
9
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
11use sklears_core::{
12    error::{Result, SklearsError},
13    traits::{Fit, Transform},
14};
15use std::collections::{HashMap, HashSet};
16
17/// Text normalization strategy
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum NormalizationStrategy {
20    /// No normalization
21    None,
22    /// Convert to lowercase
23    Lowercase,
24    /// Convert to lowercase and remove punctuation
25    LowercaseNoPunct,
26    /// Convert to lowercase, remove punctuation, and strip whitespace
27    Full,
28}
29
30/// Text tokenization strategy
31#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum TokenizationStrategy {
33    /// Split by whitespace
34    Whitespace,
35    /// Split by whitespace and punctuation
36    WhitespacePunct,
37    /// Simple word tokenization (alphanumeric only)
38    Word,
39}
40
41/// N-gram type
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum NgramType {
44    /// Character n-grams
45    Char,
46    /// Word n-grams
47    Word,
48}
49
50/// Configuration for text tokenizer
51#[derive(Debug, Clone)]
52pub struct TextTokenizerConfig {
53    pub normalization: NormalizationStrategy,
54    pub tokenization: TokenizationStrategy,
55    pub min_token_length: usize,
56    pub max_token_length: usize,
57    pub stop_words: Option<HashSet<String>>,
58}
59
60impl Default for TextTokenizerConfig {
61    fn default() -> Self {
62        Self {
63            normalization: NormalizationStrategy::Lowercase,
64            tokenization: TokenizationStrategy::Word,
65            min_token_length: 1,
66            max_token_length: 50,
67            stop_words: None,
68        }
69    }
70}
71
72/// Text tokenizer for preprocessing text data
73#[derive(Debug, Clone)]
74pub struct TextTokenizer {
75    config: TextTokenizerConfig,
76}
77
78impl Default for TextTokenizer {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl TextTokenizer {
85    /// Create a new text tokenizer with default configuration
86    pub fn new() -> Self {
87        Self {
88            config: TextTokenizerConfig::default(),
89        }
90    }
91
92    /// Create a new text tokenizer with custom configuration
93    pub fn with_config(config: TextTokenizerConfig) -> Self {
94        Self { config }
95    }
96
97    /// Normalize text according to the configuration
98    pub fn normalize(&self, text: &str) -> String {
99        match self.config.normalization {
100            NormalizationStrategy::None => text.to_string(),
101            NormalizationStrategy::Lowercase => text.to_lowercase(),
102            NormalizationStrategy::LowercaseNoPunct => text
103                .to_lowercase()
104                .chars()
105                .map(|c| {
106                    if c.is_alphanumeric() || c.is_whitespace() {
107                        c
108                    } else {
109                        ' '
110                    }
111                })
112                .collect(),
113            NormalizationStrategy::Full => text
114                .to_lowercase()
115                .chars()
116                .map(|c| {
117                    if c.is_alphanumeric() || c.is_whitespace() {
118                        c
119                    } else {
120                        ' '
121                    }
122                })
123                .collect::<String>()
124                .split_whitespace()
125                .collect::<Vec<_>>()
126                .join(" "),
127        }
128    }
129
130    /// Tokenize text into tokens
131    pub fn tokenize(&self, text: &str) -> Vec<String> {
132        let normalized = self.normalize(text);
133
134        let tokens: Vec<String> = match self.config.tokenization {
135            TokenizationStrategy::Whitespace => normalized
136                .split_whitespace()
137                .map(|s| s.to_string())
138                .collect(),
139            TokenizationStrategy::WhitespacePunct => normalized
140                .split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
141                .filter(|s| !s.is_empty())
142                .map(|s| s.to_string())
143                .collect(),
144            TokenizationStrategy::Word => normalized
145                .chars()
146                .collect::<String>()
147                .split(|c: char| !c.is_alphanumeric())
148                .filter(|s| !s.is_empty())
149                .map(|s| s.to_string())
150                .collect(),
151        };
152
153        // Filter by length
154        let mut filtered_tokens: Vec<String> = tokens
155            .into_iter()
156            .filter(|token| {
157                token.len() >= self.config.min_token_length
158                    && token.len() <= self.config.max_token_length
159            })
160            .collect();
161
162        // Remove stop words if configured
163        if let Some(ref stop_words) = self.config.stop_words {
164            filtered_tokens.retain(|token| !stop_words.contains(token));
165        }
166
167        filtered_tokens
168    }
169}
170
171/// Configuration for TF-IDF vectorizer
172#[derive(Debug, Clone)]
173pub struct TfIdfVectorizerConfig {
174    pub tokenizer_config: TextTokenizerConfig,
175    pub min_df: f64,
176    pub max_df: f64,
177    pub max_features: Option<usize>,
178    pub use_idf: bool,
179    pub smooth_idf: bool,
180    pub sublinear_tf: bool,
181}
182
183impl Default for TfIdfVectorizerConfig {
184    fn default() -> Self {
185        Self {
186            tokenizer_config: TextTokenizerConfig::default(),
187            min_df: 1.0,
188            max_df: 1.0,
189            max_features: None,
190            use_idf: true,
191            smooth_idf: true,
192            sublinear_tf: false,
193        }
194    }
195}
196
197/// TF-IDF vectorizer for converting text to numerical features
198#[derive(Debug, Clone)]
199pub struct TfIdfVectorizer {
200    config: TfIdfVectorizerConfig,
201    tokenizer: TextTokenizer,
202    vocabulary: HashMap<String, usize>,
203    idf_values: Array1<f64>,
204    fitted: bool,
205}
206
207impl Default for TfIdfVectorizer {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl TfIdfVectorizer {
214    /// Create a new TF-IDF vectorizer with default configuration
215    pub fn new() -> Self {
216        let config = TfIdfVectorizerConfig::default();
217        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
218
219        Self {
220            config,
221            tokenizer,
222            vocabulary: HashMap::new(),
223            idf_values: Array1::zeros(0),
224            fitted: false,
225        }
226    }
227
228    /// Create a new TF-IDF vectorizer with custom configuration
229    pub fn with_config(config: TfIdfVectorizerConfig) -> Self {
230        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
231
232        Self {
233            config,
234            tokenizer,
235            vocabulary: HashMap::new(),
236            idf_values: Array1::zeros(0),
237            fitted: false,
238        }
239    }
240
241    /// Build vocabulary from documents
242    fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
243        let mut term_doc_counts: HashMap<String, usize> = HashMap::new();
244        let n_docs = documents.len() as f64;
245
246        // Count document frequencies
247        for document in documents {
248            let tokens = self.tokenizer.tokenize(document);
249            let unique_tokens: HashSet<String> = tokens.into_iter().collect();
250
251            for token in unique_tokens {
252                *term_doc_counts.entry(token).or_insert(0) += 1;
253            }
254        }
255
256        // Filter by document frequency
257        let min_df = if self.config.min_df < 1.0 {
258            (self.config.min_df * n_docs).ceil() as usize
259        } else {
260            self.config.min_df as usize
261        };
262
263        let max_df = if self.config.max_df < 1.0 {
264            (self.config.max_df * n_docs).floor() as usize
265        } else {
266            self.config.max_df as usize
267        };
268
269        let mut filtered_terms: Vec<(String, usize)> = term_doc_counts
270            .into_iter()
271            .filter(|(_, count)| *count >= min_df && *count <= max_df)
272            .collect();
273
274        // Sort by document frequency (descending) for consistent ordering
275        filtered_terms.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
276
277        // Limit vocabulary size if specified
278        if let Some(max_features) = self.config.max_features {
279            filtered_terms.truncate(max_features);
280        }
281
282        // Build vocabulary mapping
283        for (idx, (term, _doc_freq)) in filtered_terms.iter().enumerate() {
284            self.vocabulary.insert(term.clone(), idx);
285        }
286
287        // Compute IDF values
288        let vocab_size = self.vocabulary.len();
289        let mut idf_values = Array1::zeros(vocab_size);
290
291        if self.config.use_idf {
292            for &idx in self.vocabulary.values() {
293                let doc_freq = filtered_terms[idx].1 as f64;
294                let idf = if self.config.smooth_idf {
295                    ((n_docs + 1.0) / (doc_freq + 1.0)).ln() + 1.0
296                } else {
297                    (n_docs / doc_freq).ln() + 1.0
298                };
299                idf_values[idx] = idf;
300            }
301        } else {
302            idf_values.fill(1.0);
303        }
304
305        self.idf_values = idf_values;
306        Ok(())
307    }
308
309    /// Transform documents to TF-IDF matrix
310    fn transform_documents(&self, documents: &[String]) -> Result<Array2<f64>> {
311        if !self.fitted {
312            return Err(SklearsError::NotFitted {
313                operation: "TfIdfVectorizer not fitted".to_string(),
314            });
315        }
316
317        let n_docs = documents.len();
318        let vocab_size = self.vocabulary.len();
319        let mut tfidf_matrix = Array2::zeros((n_docs, vocab_size));
320
321        for (doc_idx, document) in documents.iter().enumerate() {
322            let tokens = self.tokenizer.tokenize(document);
323            let mut term_counts: HashMap<usize, f64> = HashMap::new();
324
325            // Count term frequencies
326            for token in &tokens {
327                if let Some(&vocab_idx) = self.vocabulary.get(token) {
328                    *term_counts.entry(vocab_idx).or_insert(0.0) += 1.0;
329                }
330            }
331
332            // Compute TF-IDF values
333            let total_terms = tokens.len() as f64;
334            for (vocab_idx, count) in term_counts {
335                let tf = if self.config.sublinear_tf {
336                    1.0 + count.ln()
337                } else {
338                    count / total_terms
339                };
340
341                let tfidf = tf * self.idf_values[vocab_idx];
342                tfidf_matrix[[doc_idx, vocab_idx]] = tfidf;
343            }
344        }
345
346        Ok(tfidf_matrix)
347    }
348
349    /// Get the vocabulary mapping
350    pub fn get_vocabulary(&self) -> &HashMap<String, usize> {
351        &self.vocabulary
352    }
353
354    /// Get the IDF values
355    pub fn get_idf_values(&self) -> ArrayView1<'_, f64> {
356        self.idf_values.view()
357    }
358}
359
360impl Fit<Vec<String>, ()> for TfIdfVectorizer {
361    type Fitted = TfIdfVectorizer;
362
363    fn fit(mut self, x: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
364        self.build_vocabulary(x)?;
365        self.fitted = true;
366        Ok(self)
367    }
368}
369
370impl Transform<Vec<String>, Array2<f64>> for TfIdfVectorizer {
371    fn transform(&self, x: &Vec<String>) -> Result<Array2<f64>> {
372        self.transform_documents(x)
373    }
374}
375
376/// Configuration for N-gram generator
377#[derive(Debug, Clone)]
378pub struct NgramGeneratorConfig {
379    pub tokenizer_config: TextTokenizerConfig,
380    pub ngram_type: NgramType,
381    pub n_min: usize,
382    pub n_max: usize,
383}
384
385impl Default for NgramGeneratorConfig {
386    fn default() -> Self {
387        Self {
388            tokenizer_config: TextTokenizerConfig::default(),
389            ngram_type: NgramType::Word,
390            n_min: 1,
391            n_max: 2,
392        }
393    }
394}
395
396/// N-gram generator for creating n-gram features from text
397#[derive(Debug, Clone)]
398pub struct NgramGenerator {
399    config: NgramGeneratorConfig,
400    tokenizer: TextTokenizer,
401}
402
403impl Default for NgramGenerator {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409impl NgramGenerator {
410    /// Create a new N-gram generator with default configuration
411    pub fn new() -> Self {
412        let config = NgramGeneratorConfig::default();
413        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
414
415        Self { config, tokenizer }
416    }
417
418    /// Create a new N-gram generator with custom configuration
419    pub fn with_config(config: NgramGeneratorConfig) -> Self {
420        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
421        Self { config, tokenizer }
422    }
423
424    /// Generate n-grams from text
425    pub fn generate_ngrams(&self, text: &str) -> Vec<String> {
426        match self.config.ngram_type {
427            NgramType::Word => self.generate_word_ngrams(text),
428            NgramType::Char => self.generate_char_ngrams(text),
429        }
430    }
431
432    /// Generate word n-grams
433    fn generate_word_ngrams(&self, text: &str) -> Vec<String> {
434        let tokens = self.tokenizer.tokenize(text);
435        let mut ngrams = Vec::new();
436
437        for n in self.config.n_min..=self.config.n_max {
438            if n > tokens.len() {
439                break;
440            }
441
442            for window in tokens.windows(n) {
443                let ngram = window.join(" ");
444                ngrams.push(ngram);
445            }
446        }
447
448        ngrams
449    }
450
451    /// Generate character n-grams
452    fn generate_char_ngrams(&self, text: &str) -> Vec<String> {
453        let normalized = self.tokenizer.normalize(text);
454        let chars: Vec<char> = normalized.chars().collect();
455        let mut ngrams = Vec::new();
456
457        for n in self.config.n_min..=self.config.n_max {
458            if n > chars.len() {
459                break;
460            }
461
462            for window in chars.windows(n) {
463                let ngram: String = window.iter().collect();
464                ngrams.push(ngram);
465            }
466        }
467
468        ngrams
469    }
470}
471
472/// Configuration for text similarity calculator
473#[derive(Debug, Clone)]
474pub struct TextSimilarityConfig {
475    pub tokenizer_config: TextTokenizerConfig,
476    pub similarity_metric: SimilarityMetric,
477}
478
479/// Similarity metrics for text comparison
480#[derive(Debug, Clone, Copy, PartialEq)]
481pub enum SimilarityMetric {
482    /// Cosine similarity
483    Cosine,
484    /// Jaccard similarity
485    Jaccard,
486    /// Dice coefficient
487    Dice,
488}
489
490impl Default for TextSimilarityConfig {
491    fn default() -> Self {
492        Self {
493            tokenizer_config: TextTokenizerConfig::default(),
494            similarity_metric: SimilarityMetric::Cosine,
495        }
496    }
497}
498
499/// Text similarity calculator
500#[derive(Debug, Clone)]
501pub struct TextSimilarity {
502    config: TextSimilarityConfig,
503    tokenizer: TextTokenizer,
504}
505
506impl Default for TextSimilarity {
507    fn default() -> Self {
508        Self::new()
509    }
510}
511
512impl TextSimilarity {
513    /// Create a new text similarity calculator
514    pub fn new() -> Self {
515        let config = TextSimilarityConfig::default();
516        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
517
518        Self { config, tokenizer }
519    }
520
521    /// Create a new text similarity calculator with custom configuration
522    pub fn with_config(config: TextSimilarityConfig) -> Self {
523        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
524        Self { config, tokenizer }
525    }
526
527    /// Calculate similarity between two texts
528    pub fn similarity(&self, text1: &str, text2: &str) -> f64 {
529        match self.config.similarity_metric {
530            SimilarityMetric::Cosine => self.cosine_similarity(text1, text2),
531            SimilarityMetric::Jaccard => self.jaccard_similarity(text1, text2),
532            SimilarityMetric::Dice => self.dice_coefficient(text1, text2),
533        }
534    }
535
536    /// Calculate cosine similarity between two texts
537    fn cosine_similarity(&self, text1: &str, text2: &str) -> f64 {
538        let tokens1 = self.tokenizer.tokenize(text1);
539        let tokens2 = self.tokenizer.tokenize(text2);
540
541        let mut term_freq1: HashMap<String, f64> = HashMap::new();
542        let mut term_freq2: HashMap<String, f64> = HashMap::new();
543
544        for token in tokens1 {
545            *term_freq1.entry(token).or_insert(0.0) += 1.0;
546        }
547
548        for token in tokens2 {
549            *term_freq2.entry(token).or_insert(0.0) += 1.0;
550        }
551
552        let mut dot_product = 0.0;
553        let mut norm1 = 0.0;
554        let mut norm2 = 0.0;
555
556        let all_terms: HashSet<String> = term_freq1
557            .keys()
558            .chain(term_freq2.keys())
559            .cloned()
560            .collect();
561
562        for term in all_terms {
563            let freq1 = term_freq1.get(&term).unwrap_or(&0.0);
564            let freq2 = term_freq2.get(&term).unwrap_or(&0.0);
565
566            dot_product += freq1 * freq2;
567            norm1 += freq1 * freq1;
568            norm2 += freq2 * freq2;
569        }
570
571        if norm1 == 0.0 || norm2 == 0.0 {
572            0.0
573        } else {
574            dot_product / (norm1.sqrt() * norm2.sqrt())
575        }
576    }
577
578    /// Calculate Jaccard similarity between two texts
579    fn jaccard_similarity(&self, text1: &str, text2: &str) -> f64 {
580        let tokens1: HashSet<String> = self.tokenizer.tokenize(text1).into_iter().collect();
581        let tokens2: HashSet<String> = self.tokenizer.tokenize(text2).into_iter().collect();
582
583        let intersection = tokens1.intersection(&tokens2).count();
584        let union = tokens1.union(&tokens2).count();
585
586        if union == 0 {
587            0.0
588        } else {
589            intersection as f64 / union as f64
590        }
591    }
592
593    /// Calculate Dice coefficient between two texts
594    fn dice_coefficient(&self, text1: &str, text2: &str) -> f64 {
595        let tokens1: HashSet<String> = self.tokenizer.tokenize(text1).into_iter().collect();
596        let tokens2: HashSet<String> = self.tokenizer.tokenize(text2).into_iter().collect();
597
598        let intersection = tokens1.intersection(&tokens2).count();
599        let total = tokens1.len() + tokens2.len();
600
601        if total == 0 {
602            0.0
603        } else {
604            2.0 * intersection as f64 / total as f64
605        }
606    }
607}
608
609/// Configuration for bag-of-words embeddings
610#[derive(Debug, Clone, Default)]
611pub struct BagOfWordsConfig {
612    pub tokenizer_config: TextTokenizerConfig,
613    pub max_features: Option<usize>,
614    pub binary: bool,
615}
616
617/// Simple bag-of-words sentence embeddings
618#[derive(Debug, Clone)]
619pub struct BagOfWordsEmbedding {
620    config: BagOfWordsConfig,
621    tokenizer: TextTokenizer,
622    vocabulary: HashMap<String, usize>,
623    fitted: bool,
624}
625
626impl Default for BagOfWordsEmbedding {
627    fn default() -> Self {
628        Self::new()
629    }
630}
631
632impl BagOfWordsEmbedding {
633    /// Create a new bag-of-words embedding
634    pub fn new() -> Self {
635        let config = BagOfWordsConfig::default();
636        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
637
638        Self {
639            config,
640            tokenizer,
641            vocabulary: HashMap::new(),
642            fitted: false,
643        }
644    }
645
646    /// Create a new bag-of-words embedding with custom configuration
647    pub fn with_config(config: BagOfWordsConfig) -> Self {
648        let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
649
650        Self {
651            config,
652            tokenizer,
653            vocabulary: HashMap::new(),
654            fitted: false,
655        }
656    }
657
658    /// Build vocabulary from documents
659    fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
660        let mut term_counts: HashMap<String, usize> = HashMap::new();
661
662        // Count term frequencies
663        for document in documents {
664            let tokens = self.tokenizer.tokenize(document);
665            for token in tokens {
666                *term_counts.entry(token).or_insert(0) += 1;
667            }
668        }
669
670        // Sort terms by frequency (descending) for consistent ordering
671        let mut sorted_terms: Vec<(String, usize)> = term_counts.into_iter().collect();
672        sorted_terms.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
673
674        // Limit vocabulary size if specified
675        if let Some(max_features) = self.config.max_features {
676            sorted_terms.truncate(max_features);
677        }
678
679        // Build vocabulary mapping
680        for (idx, (term, _)) in sorted_terms.iter().enumerate() {
681            self.vocabulary.insert(term.clone(), idx);
682        }
683
684        Ok(())
685    }
686
687    /// Transform documents to bag-of-words matrix
688    fn transform_documents(&self, documents: &[String]) -> Result<Array2<f64>> {
689        if !self.fitted {
690            return Err(SklearsError::NotFitted {
691                operation: "BagOfWordsEmbedding not fitted".to_string(),
692            });
693        }
694
695        let n_docs = documents.len();
696        let vocab_size = self.vocabulary.len();
697        let mut bow_matrix = Array2::zeros((n_docs, vocab_size));
698
699        for (doc_idx, document) in documents.iter().enumerate() {
700            let tokens = self.tokenizer.tokenize(document);
701            let mut term_counts: HashMap<usize, f64> = HashMap::new();
702
703            // Count term frequencies
704            for token in &tokens {
705                if let Some(&vocab_idx) = self.vocabulary.get(token) {
706                    *term_counts.entry(vocab_idx).or_insert(0.0) += 1.0;
707                }
708            }
709
710            // Set values in matrix
711            for (vocab_idx, count) in term_counts {
712                let value = if self.config.binary { 1.0 } else { count };
713                bow_matrix[[doc_idx, vocab_idx]] = value;
714            }
715        }
716
717        Ok(bow_matrix)
718    }
719
720    /// Get the vocabulary mapping  
721    pub fn get_vocabulary(&self) -> &HashMap<String, usize> {
722        &self.vocabulary
723    }
724}
725
726impl Fit<Vec<String>, ()> for BagOfWordsEmbedding {
727    type Fitted = BagOfWordsEmbedding;
728
729    fn fit(mut self, x: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
730        self.build_vocabulary(x)?;
731        self.fitted = true;
732        Ok(self)
733    }
734}
735
736impl Transform<Vec<String>, Array2<f64>> for BagOfWordsEmbedding {
737    fn transform(&self, x: &Vec<String>) -> Result<Array2<f64>> {
738        self.transform_documents(x)
739    }
740}
741
742#[allow(non_snake_case)]
743#[cfg(test)]
744mod tests {
745    use super::*;
746    use approx::assert_abs_diff_eq;
747
748    #[test]
749    fn test_text_tokenizer() {
750        let tokenizer = TextTokenizer::new();
751        let text = "Hello, World! This is a TEST.";
752        let tokens = tokenizer.tokenize(text);
753
754        assert_eq!(tokens, vec!["hello", "world", "this", "is", "a", "test"]);
755    }
756
757    #[test]
758    fn test_tfidf_vectorizer() {
759        let vectorizer = TfIdfVectorizer::new();
760        let documents = vec![
761            "the cat sat on the mat".to_string(),
762            "the dog ran in the park".to_string(),
763            "cats and dogs are pets".to_string(),
764        ];
765
766        let fitted_vectorizer = vectorizer.fit(&documents, &()).unwrap();
767        let tfidf_matrix = fitted_vectorizer.transform(&documents).unwrap();
768
769        assert_eq!(
770            tfidf_matrix.shape(),
771            &[3, fitted_vectorizer.vocabulary.len()]
772        );
773
774        // Check that all values are non-negative
775        for &value in tfidf_matrix.iter() {
776            assert!(value >= 0.0);
777        }
778    }
779
780    #[test]
781    fn test_ngram_generator() {
782        let generator = NgramGenerator::new();
783        let text = "the quick brown fox";
784        let ngrams = generator.generate_ngrams(text);
785
786        // Should contain both unigrams and bigrams
787        assert!(ngrams.contains(&"the".to_string()));
788        assert!(ngrams.contains(&"quick".to_string()));
789        assert!(ngrams.contains(&"the quick".to_string()));
790        assert!(ngrams.contains(&"quick brown".to_string()));
791    }
792
793    #[test]
794    fn test_text_similarity() {
795        let similarity = TextSimilarity::new();
796
797        // Test cosine similarity
798        let sim1 = similarity.similarity("the cat sat", "the cat sat");
799        assert_abs_diff_eq!(sim1, 1.0, epsilon = 1e-10);
800
801        let sim2 = similarity.similarity("the cat sat", "the dog ran");
802        assert!(sim2 > 0.0 && sim2 < 1.0);
803
804        let sim3 = similarity.similarity("hello world", "goodbye moon");
805        assert_eq!(sim3, 0.0);
806    }
807
808    #[test]
809    fn test_bag_of_words_embedding() {
810        let embedding = BagOfWordsEmbedding::new();
811        let documents = vec![
812            "the cat sat".to_string(),
813            "the dog ran".to_string(),
814            "cats and dogs".to_string(),
815        ];
816
817        let fitted_embedding = embedding.fit(&documents, &()).unwrap();
818        let bow_matrix = fitted_embedding.transform(&documents).unwrap();
819
820        assert_eq!(bow_matrix.shape(), &[3, fitted_embedding.vocabulary.len()]);
821
822        // Check that all values are non-negative
823        for &value in bow_matrix.iter() {
824            assert!(value >= 0.0);
825        }
826    }
827}