1use crate::error::{Result, TextError};
15use crate::tokenize::{Tokenizer, WordTokenizer};
16use crate::vectorize::{TfidfVectorizer, Vectorizer};
17use scirs2_core::ndarray::{Array1, Array2, Axis};
18use std::collections::HashMap;
19
20fn tokenize_lower(text: &str) -> Vec<String> {
26 let tokenizer = WordTokenizer::default();
27 tokenizer
28 .tokenize(text)
29 .unwrap_or_default()
30 .into_iter()
31 .map(|t| t.to_lowercase())
32 .collect()
33}
34
35pub struct NaiveBayesClassifier {
55 vocabulary: HashMap<String, usize>,
57 class_log_priors: Vec<f64>,
59 class_word_log_probs: Vec<Vec<f64>>,
62 classes: Vec<String>,
64 alpha: f64,
66}
67
68impl NaiveBayesClassifier {
69 pub fn new(alpha: f64) -> Self {
77 Self {
78 vocabulary: HashMap::new(),
79 class_log_priors: Vec::new(),
80 class_word_log_probs: Vec::new(),
81 classes: Vec::new(),
82 alpha,
83 }
84 }
85
86 pub fn fit(&mut self, texts: &[&str], labels: &[&str]) -> Result<()> {
95 if texts.len() != labels.len() {
96 return Err(TextError::InvalidInput(format!(
97 "texts ({}) and labels ({}) must have the same length",
98 texts.len(),
99 labels.len()
100 )));
101 }
102 if texts.is_empty() {
103 return Err(TextError::InvalidInput("Empty training corpus".to_string()));
104 }
105 if self.alpha <= 0.0 {
106 return Err(TextError::InvalidInput(
107 "alpha must be positive".to_string(),
108 ));
109 }
110
111 let mut class_index: HashMap<String, usize> = HashMap::new();
113 for &label in labels {
114 let n = class_index.len();
115 class_index.entry(label.to_string()).or_insert(n);
116 }
117
118 let n_classes = class_index.len();
119 let mut class_names: Vec<String> = vec![String::new(); n_classes];
120 for (name, &idx) in &class_index {
121 class_names[idx] = name.clone();
122 }
123
124 let mut class_doc_counts = vec![0usize; n_classes];
125 let mut class_word_counts: Vec<HashMap<String, f64>> =
127 (0..n_classes).map(|_| HashMap::new()).collect();
128
129 for (&text, &label) in texts.iter().zip(labels.iter()) {
130 let class_idx = *class_index
131 .get(label)
132 .ok_or_else(|| TextError::InvalidInput(format!("Unknown label '{label}'")))?;
133 class_doc_counts[class_idx] += 1;
134
135 for word in tokenize_lower(text) {
136 let n = self.vocabulary.len();
138 self.vocabulary.entry(word.clone()).or_insert(n);
139 *class_word_counts[class_idx].entry(word).or_insert(0.0) += 1.0;
140 }
141 }
142
143 let n_docs = texts.len() as f64;
144 let vocab_size = self.vocabulary.len();
145
146 let class_log_priors: Vec<f64> = class_doc_counts
148 .iter()
149 .map(|&c| (c as f64 / n_docs).ln())
150 .collect();
151
152 let mut class_word_log_probs: Vec<Vec<f64>> =
154 (0..n_classes).map(|_| vec![0.0; vocab_size]).collect();
155
156 for c in 0..n_classes {
157 let total: f64 = class_word_counts[c].values().sum();
158 let denom = total + self.alpha * vocab_size as f64;
159
160 for (word, &col_idx) in &self.vocabulary {
161 let cnt = class_word_counts[c].get(word).copied().unwrap_or(0.0);
162 class_word_log_probs[c][col_idx] = ((cnt + self.alpha) / denom).ln();
163 }
164 }
165
166 self.classes = class_names;
167 self.class_log_priors = class_log_priors;
168 self.class_word_log_probs = class_word_log_probs;
169
170 Ok(())
171 }
172
173 pub fn predict(&self, text: &str) -> Result<String> {
179 let proba = self.predict_log_proba(text)?;
180 let best = proba
181 .iter()
182 .enumerate()
183 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
184 .map(|(i, _)| i)
185 .ok_or_else(|| TextError::ModelNotFitted("No classes available".to_string()))?;
186 Ok(self.classes[best].clone())
187 }
188
189 pub fn predict_proba(&self, text: &str) -> Result<Vec<(String, f64)>> {
197 let log_proba = self.predict_log_proba(text)?;
198
199 let max_val = log_proba.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
201 let exps: Vec<f64> = log_proba.iter().map(|&v| (v - max_val).exp()).collect();
202 let sum: f64 = exps.iter().sum();
203
204 let mut result: Vec<(String, f64)> = self
205 .classes
206 .iter()
207 .zip(exps.iter())
208 .map(|(name, &e)| (name.clone(), e / sum))
209 .collect();
210
211 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212
213 Ok(result)
214 }
215
216 pub fn score(&self, texts: &[&str], labels: &[&str]) -> Result<f64> {
222 if texts.len() != labels.len() {
223 return Err(TextError::InvalidInput(
224 "texts and labels must have the same length".to_string(),
225 ));
226 }
227 if texts.is_empty() {
228 return Ok(0.0);
229 }
230
231 let mut correct = 0usize;
232 for (&text, &label) in texts.iter().zip(labels.iter()) {
233 if let Ok(pred) = self.predict(text) {
234 if pred == label {
235 correct += 1;
236 }
237 }
238 }
239
240 Ok(correct as f64 / texts.len() as f64)
241 }
242
243 fn predict_log_proba(&self, text: &str) -> Result<Vec<f64>> {
248 if self.classes.is_empty() {
249 return Err(TextError::ModelNotFitted(
250 "NaiveBayesClassifier has not been fitted yet".to_string(),
251 ));
252 }
253
254 let tokens = tokenize_lower(text);
255 let n_classes = self.classes.len();
256 let mut log_scores: Vec<f64> = self.class_log_priors.clone();
257
258 for word in &tokens {
259 if let Some(&col) = self.vocabulary.get(word) {
260 for c in 0..n_classes {
261 log_scores[c] += self.class_word_log_probs[c][col];
262 }
263 }
264 }
265
266 Ok(log_scores)
267 }
268}
269
270pub struct TextCnnLite {
296 filter_sizes: Vec<usize>,
298 n_filters: usize,
301 vocab: HashMap<String, usize>,
303 weights: Vec<Vec<f64>>,
305 bias: Vec<f64>,
307 classes: Vec<String>,
309}
310
311impl TextCnnLite {
312 pub fn new(filter_sizes: Vec<usize>, n_filters: usize) -> Self {
319 Self {
320 filter_sizes,
321 n_filters,
322 vocab: HashMap::new(),
323 weights: Vec::new(),
324 bias: Vec::new(),
325 classes: Vec::new(),
326 }
327 }
328
329 pub fn fit(&mut self, texts: &[&str], labels: &[&str], epochs: usize) -> Result<()> {
341 if texts.len() != labels.len() {
342 return Err(TextError::InvalidInput(format!(
343 "texts ({}) and labels ({}) must have the same length",
344 texts.len(),
345 labels.len()
346 )));
347 }
348 if texts.is_empty() {
349 return Err(TextError::InvalidInput("Empty training corpus".to_string()));
350 }
351
352 let mut class_index: HashMap<String, usize> = HashMap::new();
354 for &label in labels {
355 let n = class_index.len();
356 class_index.entry(label.to_string()).or_insert(n);
357 }
358 let n_classes = class_index.len();
359 let mut class_names: Vec<String> = vec![String::new(); n_classes];
360 for (name, &idx) in &class_index {
361 class_names[idx] = name.clone();
362 }
363
364 let mut ngram_counts: HashMap<String, usize> = HashMap::new();
366 for &text in texts {
367 let words = tokenize_lower(text);
368 for &size in &self.filter_sizes {
369 for ngram in ngrams(&words, size) {
370 *ngram_counts.entry(ngram).or_insert(0) += 1;
371 }
372 }
373 }
374
375 let mut ngram_vec: Vec<(String, usize)> = ngram_counts.into_iter().collect();
377 ngram_vec.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
378 let max_feats = self.n_filters * self.filter_sizes.len();
379 self.vocab = ngram_vec
380 .into_iter()
381 .take(max_feats)
382 .enumerate()
383 .map(|(i, (ng, _))| (ng, i))
384 .collect();
385
386 let n_features = self.vocab.len();
387 if n_features == 0 {
388 return Err(TextError::InvalidInput(
389 "No n-gram features found in corpus".to_string(),
390 ));
391 }
392
393 let mut x_data: Vec<Vec<f64>> = Vec::with_capacity(texts.len());
395 let mut y_labels: Vec<usize> = Vec::with_capacity(texts.len());
396
397 for (&text, &label) in texts.iter().zip(labels.iter()) {
398 x_data.push(self.vectorize(text));
399 let class_idx = *class_index
400 .get(label)
401 .ok_or_else(|| TextError::InvalidInput(format!("Unknown label '{label}'")))?;
402 y_labels.push(class_idx);
403 }
404
405 self.weights = vec![vec![0.0f64; n_features]; n_classes];
407 self.bias = vec![0.0f64; n_classes];
408
409 let lr = 0.1_f64;
411 let n_samples = texts.len();
412
413 for _epoch in 0..epochs {
414 for i in 0..n_samples {
415 let x = &x_data[i];
416 let y = y_labels[i];
417
418 let logits: Vec<f64> = (0..n_classes)
420 .map(|c| {
421 self.bias[c]
422 + x.iter()
423 .zip(self.weights[c].iter())
424 .map(|(xi, wi)| xi * wi)
425 .sum::<f64>()
426 })
427 .collect();
428
429 let probs = softmax(&logits);
430
431 for c in 0..n_classes {
433 let delta = probs[c] - if c == y { 1.0 } else { 0.0 };
434 self.bias[c] -= lr * delta;
435 for (j, &xj) in x.iter().enumerate() {
436 self.weights[c][j] -= lr * delta * xj;
437 }
438 }
439 }
440 }
441
442 self.classes = class_names;
443 Ok(())
444 }
445
446 pub fn predict(&self, text: &str) -> Result<String> {
452 if self.classes.is_empty() {
453 return Err(TextError::ModelNotFitted(
454 "TextCnnLite has not been fitted yet".to_string(),
455 ));
456 }
457 let x = self.vectorize(text);
458 let n_classes = self.classes.len();
459 let logits: Vec<f64> = (0..n_classes)
460 .map(|c| {
461 self.bias[c]
462 + x.iter()
463 .zip(self.weights[c].iter())
464 .map(|(xi, wi)| xi * wi)
465 .sum::<f64>()
466 })
467 .collect();
468
469 let best = logits
470 .iter()
471 .enumerate()
472 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
473 .map(|(i, _)| i)
474 .ok_or_else(|| TextError::ModelNotFitted("No classes available".to_string()))?;
475
476 Ok(self.classes[best].clone())
477 }
478
479 fn vectorize(&self, text: &str) -> Vec<f64> {
485 let mut v = vec![0.0f64; self.vocab.len()];
486 let words = tokenize_lower(text);
487 for &size in &self.filter_sizes {
488 for ngram in ngrams(&words, size) {
489 if let Some(&idx) = self.vocab.get(&ngram) {
490 v[idx] += 1.0;
491 }
492 }
493 }
494 v
495 }
496}
497
498pub struct TfIdfLogisticClassifier {
520 vectorizer: TfidfVectorizer,
522 weights: Vec<Vec<f64>>,
524 bias: Vec<f64>,
526 classes: Vec<String>,
528 fitted: bool,
530}
531
532impl Default for TfIdfLogisticClassifier {
533 fn default() -> Self {
534 Self::new()
535 }
536}
537
538impl TfIdfLogisticClassifier {
539 pub fn new() -> Self {
541 Self {
542 vectorizer: TfidfVectorizer::new(false, true, Some("l2".to_string())),
543 weights: Vec::new(),
544 bias: Vec::new(),
545 classes: Vec::new(),
546 fitted: false,
547 }
548 }
549
550 pub fn fit(&mut self, texts: &[&str], labels: &[&str], max_iter: usize, lr: f64) -> Result<()> {
563 if texts.len() != labels.len() {
564 return Err(TextError::InvalidInput(format!(
565 "texts ({}) and labels ({}) must have the same length",
566 texts.len(),
567 labels.len()
568 )));
569 }
570 if texts.is_empty() {
571 return Err(TextError::InvalidInput("Empty training corpus".to_string()));
572 }
573 if lr <= 0.0 {
574 return Err(TextError::InvalidInput(
575 "Learning rate must be positive".to_string(),
576 ));
577 }
578
579 let mut class_index: HashMap<String, usize> = HashMap::new();
581 for &label in labels {
582 let n = class_index.len();
583 class_index.entry(label.to_string()).or_insert(n);
584 }
585 let n_classes = class_index.len();
586 let mut class_names: Vec<String> = vec![String::new(); n_classes];
587 for (name, &idx) in &class_index {
588 class_names[idx] = name.clone();
589 }
590
591 let x_matrix = self.vectorizer.fit_transform(texts)?;
593 let n_features = x_matrix.ncols();
594
595 let y_labels: Vec<usize> = labels
596 .iter()
597 .map(|&label| {
598 class_index
599 .get(label)
600 .copied()
601 .ok_or_else(|| TextError::InvalidInput(format!("Unknown label '{label}'")))
602 })
603 .collect::<Result<_>>()?;
604
605 self.weights = vec![vec![0.0f64; n_features]; n_classes];
607 self.bias = vec![0.0f64; n_classes];
608
609 let n_samples = x_matrix.nrows();
611
612 for _epoch in 0..max_iter {
613 for i in 0..n_samples {
614 let x_row = x_matrix.row(i);
615 let y = y_labels[i];
616
617 let logits: Vec<f64> = (0..n_classes)
619 .map(|c| {
620 self.bias[c]
621 + x_row
622 .iter()
623 .zip(self.weights[c].iter())
624 .map(|(xi, wi)| xi * wi)
625 .sum::<f64>()
626 })
627 .collect();
628
629 let probs = softmax(&logits);
630
631 for c in 0..n_classes {
633 let delta = probs[c] - if c == y { 1.0 } else { 0.0 };
634 self.bias[c] -= lr * delta;
635 for (j, &xj) in x_row.iter().enumerate() {
636 self.weights[c][j] -= lr * delta * xj;
637 }
638 }
639 }
640 }
641
642 self.classes = class_names;
643 self.fitted = true;
644 Ok(())
645 }
646
647 pub fn predict(&self, text: &str) -> Result<String> {
653 let probs = self.predict_proba(text)?;
654 probs
655 .iter()
656 .enumerate()
657 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
658 .map(|(i, _)| self.classes[i].clone())
659 .ok_or_else(|| TextError::ModelNotFitted("No classes available".to_string()))
660 }
661
662 pub fn predict_proba(&self, text: &str) -> Result<Vec<f64>> {
668 if !self.fitted {
669 return Err(TextError::ModelNotFitted(
670 "TfIdfLogisticClassifier has not been fitted yet".to_string(),
671 ));
672 }
673
674 let x_vec = self.vectorizer.transform(text)?;
675 let n_classes = self.classes.len();
676
677 let logits: Vec<f64> = (0..n_classes)
678 .map(|c| {
679 self.bias[c]
680 + x_vec
681 .iter()
682 .zip(self.weights[c].iter())
683 .map(|(xi, wi)| xi * wi)
684 .sum::<f64>()
685 })
686 .collect();
687
688 Ok(softmax(&logits))
689 }
690}
691
692fn ngrams(words: &[String], size: usize) -> Vec<String> {
698 if size == 0 || words.len() < size {
699 return Vec::new();
700 }
701 (0..=words.len() - size)
702 .map(|i| words[i..i + size].join(" "))
703 .collect()
704}
705
706fn softmax(logits: &[f64]) -> Vec<f64> {
708 if logits.is_empty() {
709 return Vec::new();
710 }
711 let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
712 let exps: Vec<f64> = logits.iter().map(|&v| (v - max_val).exp()).collect();
713 let sum: f64 = exps.iter().sum();
714 if sum == 0.0 {
715 return vec![1.0 / logits.len() as f64; logits.len()];
716 }
717 exps.iter().map(|&e| e / sum).collect()
718}
719
720#[cfg(test)]
725mod tests {
726 use super::*;
727
728 #[test]
731 fn test_nb_fit_predict_binary() {
732 let mut clf = NaiveBayesClassifier::new(1.0);
733 let texts = &[
734 "spam buy now cheap discount",
735 "spam offer free money",
736 "hello how are you doing",
737 "good morning friend",
738 ];
739 let labels = &["spam", "spam", "ham", "ham"];
740 clf.fit(texts, labels).expect("fit should succeed");
741
742 let pred = clf
743 .predict("buy cheap spam now")
744 .expect("predict should succeed");
745 assert_eq!(pred, "spam");
746
747 let pred2 = clf
748 .predict("good morning hello")
749 .expect("predict should succeed");
750 assert_eq!(pred2, "ham");
751 }
752
753 #[test]
754 fn test_nb_predict_proba() {
755 let mut clf = NaiveBayesClassifier::new(1.0);
756 let texts = &[
757 "machine learning data science",
758 "deep learning neural network",
759 "cooking recipe food delicious",
760 "restaurant menu chef dinner",
761 ];
762 let labels = &["tech", "tech", "food", "food"];
763 clf.fit(texts, labels).expect("fit should succeed");
764
765 let proba = clf
766 .predict_proba("neural network deep learning")
767 .expect("predict_proba should succeed");
768 assert_eq!(proba.len(), 2);
769
770 let total: f64 = proba.iter().map(|(_, p)| p).sum();
772 assert!((total - 1.0).abs() < 1e-9, "probabilities should sum to 1");
773
774 assert_eq!(proba[0].0, "tech");
776 assert!(proba[0].1 > 0.5, "tech probability should exceed 0.5");
777 }
778
779 #[test]
780 fn test_nb_score_above_chance() {
781 let mut clf = NaiveBayesClassifier::new(1.0);
782 let train_texts = &[
783 "positive great excellent wonderful",
784 "positive good happy nice",
785 "positive fantastic brilliant awesome",
786 "negative bad terrible awful",
787 "negative horrible disappointing poor",
788 "negative dreadful appalling dire",
789 ];
790 let train_labels = &[
791 "positive", "positive", "positive", "negative", "negative", "negative",
792 ];
793 clf.fit(train_texts, train_labels)
794 .expect("fit should succeed");
795
796 let test_texts = &["excellent wonderful", "terrible awful bad"];
797 let test_labels = &["positive", "negative"];
798 let acc = clf
799 .score(test_texts, test_labels)
800 .expect("score should succeed");
801
802 assert!(acc > 0.5, "accuracy should exceed chance: {}", acc);
803 }
804
805 #[test]
806 fn test_nb_error_on_empty_corpus() {
807 let mut clf = NaiveBayesClassifier::new(1.0);
808 let result = clf.fit(&[], &[]);
809 assert!(result.is_err());
810 }
811
812 #[test]
813 fn test_nb_error_on_length_mismatch() {
814 let mut clf = NaiveBayesClassifier::new(1.0);
815 let result = clf.fit(&["text"], &["label1", "label2"]);
816 assert!(result.is_err());
817 }
818
819 #[test]
820 fn test_nb_not_fitted_error() {
821 let clf = NaiveBayesClassifier::new(1.0);
822 let result = clf.predict("some text");
823 assert!(result.is_err());
824 }
825
826 #[test]
827 fn test_nb_multiclass() {
828 let mut clf = NaiveBayesClassifier::new(0.5);
829 let texts = &[
830 "soccer football goal kick",
831 "basketball dunk three pointer",
832 "baseball pitcher batting home run",
833 "soccer penalty kick goal",
834 "basketball court dribble shoot",
835 "baseball strike out innings",
836 ];
837 let labels = &[
838 "soccer",
839 "basketball",
840 "baseball",
841 "soccer",
842 "basketball",
843 "baseball",
844 ];
845 clf.fit(texts, labels).expect("fit should succeed");
846
847 let pred = clf
848 .predict("goal kick soccer field")
849 .expect("predict should succeed");
850 assert_eq!(pred, "soccer");
851 }
852
853 #[test]
856 fn test_cnn_fit_succeeds() {
857 let mut clf = TextCnnLite::new(vec![2, 3], 8);
858 let texts = &[
859 "good movie fun entertaining",
860 "bad film boring terrible",
861 "great show exciting wonderful",
862 "awful program dull waste",
863 ];
864 let labels = &["pos", "neg", "pos", "neg"];
865 let result = clf.fit(texts, labels, 20);
866 assert!(result.is_ok(), "fit should succeed: {:?}", result);
867 }
868
869 #[test]
870 fn test_cnn_predict_returns_valid_class() {
871 let mut clf = TextCnnLite::new(vec![2, 3], 8);
872 let texts = &[
873 "wonderful excellent amazing brilliant",
874 "terrible dreadful awful horrible",
875 "great fantastic superb outstanding",
876 "poor disappointing bad mediocre",
877 ];
878 let labels = &["pos", "neg", "pos", "neg"];
879 clf.fit(texts, labels, 30).expect("fit should succeed");
880
881 let pred = clf
882 .predict("excellent wonderful")
883 .expect("predict should succeed");
884 assert!(
885 pred == "pos" || pred == "neg",
886 "prediction should be a valid class"
887 );
888 }
889
890 #[test]
891 fn test_cnn_not_fitted_error() {
892 let clf = TextCnnLite::new(vec![2], 4);
893 assert!(clf.predict("text").is_err());
894 }
895
896 #[test]
899 fn test_tfidf_logistic_fit_and_predict() {
900 let mut clf = TfIdfLogisticClassifier::new();
901 let texts = &[
902 "machine learning artificial intelligence",
903 "deep neural network training",
904 "cooking recipe baking flour",
905 "restaurant food delicious chef",
906 "algorithm data science research",
907 "dinner menu ingredients spices",
908 ];
909 let labels = &["tech", "tech", "food", "food", "tech", "food"];
910 clf.fit(texts, labels, 50, 0.1).expect("fit should succeed");
911
912 let pred = clf
913 .predict("neural network algorithm")
914 .expect("predict should succeed");
915 assert_eq!(pred, "tech");
916 }
917
918 #[test]
919 fn test_tfidf_logistic_predict_proba_sums_to_one() {
920 let mut clf = TfIdfLogisticClassifier::new();
921 let texts = &[
922 "positive happy good",
923 "negative sad bad",
924 "positive great excellent",
925 "negative terrible awful",
926 ];
927 let labels = &["pos", "neg", "pos", "neg"];
928 clf.fit(texts, labels, 30, 0.05)
929 .expect("fit should succeed");
930
931 let probs = clf
932 .predict_proba("happy good great")
933 .expect("predict_proba should succeed");
934 let sum: f64 = probs.iter().sum();
935 assert!(
936 (sum - 1.0).abs() < 1e-9,
937 "probabilities must sum to 1, got {}",
938 sum
939 );
940 }
941
942 #[test]
943 fn test_tfidf_logistic_not_fitted_error() {
944 let clf = TfIdfLogisticClassifier::new();
945 assert!(clf.predict("text").is_err());
946 }
947
948 #[test]
949 fn test_tfidf_logistic_error_on_empty_corpus() {
950 let mut clf = TfIdfLogisticClassifier::new();
951 let result = clf.fit(&[], &[], 10, 0.1);
952 assert!(result.is_err());
953 }
954
955 #[test]
956 fn test_tfidf_logistic_error_on_length_mismatch() {
957 let mut clf = TfIdfLogisticClassifier::new();
958 let result = clf.fit(&["text"], &["a", "b"], 10, 0.1);
959 assert!(result.is_err());
960 }
961
962 #[test]
965 fn test_ngrams_bigrams() {
966 let words: Vec<String> = ["the", "quick", "brown", "fox"]
967 .iter()
968 .map(|s| s.to_string())
969 .collect();
970 let bigrams = ngrams(&words, 2);
971 assert_eq!(bigrams.len(), 3);
972 assert_eq!(bigrams[0], "the quick");
973 assert_eq!(bigrams[2], "brown fox");
974 }
975
976 #[test]
977 fn test_ngrams_empty_when_size_too_large() {
978 let words: Vec<String> = ["a", "b"].iter().map(|s| s.to_string()).collect();
979 let result = ngrams(&words, 5);
980 assert!(result.is_empty());
981 }
982
983 #[test]
984 fn test_softmax_properties() {
985 let logits = vec![1.0, 2.0, 3.0];
986 let probs = softmax(&logits);
987 let sum: f64 = probs.iter().sum();
988 assert!((sum - 1.0).abs() < 1e-9);
989 assert!(probs[2] > probs[1] && probs[1] > probs[0]);
990 }
991
992 #[test]
993 fn test_softmax_large_values_stable() {
994 let logits = vec![1000.0, 1001.0, 999.0];
995 let probs = softmax(&logits);
996 assert!(probs.iter().all(|&p| p.is_finite()));
997 let sum: f64 = probs.iter().sum();
998 assert!((sum - 1.0).abs() < 1e-9);
999 }
1000}