1use std::collections::HashMap;
10
11use crate::error::{Result, TextError};
12
13#[derive(Debug, Clone, Default)]
22pub struct NaiveBayesClassifier {
23 class_log_priors: Vec<f64>,
25 log_likelihoods: Vec<Vec<f64>>,
27 classes: Vec<String>,
29 vocabulary: HashMap<String, usize>,
31 fitted: bool,
33}
34
35impl NaiveBayesClassifier {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 fn tokenize(text: &str) -> Vec<String> {
45 text.split(|c: char| !c.is_alphanumeric())
46 .filter(|s| !s.is_empty())
47 .map(|s| s.to_lowercase())
48 .collect()
49 }
50
51 fn text_to_counts(&self, text: &str) -> Vec<f64> {
53 let mut counts = vec![0.0f64; self.vocabulary.len()];
54 for word in Self::tokenize(text) {
55 if let Some(&idx) = self.vocabulary.get(&word) {
56 counts[idx] += 1.0;
57 }
58 }
59 counts
60 }
61
62 pub fn fit(&mut self, corpus: &[(String, String)], alpha: f64) -> Result<()> {
68 if corpus.is_empty() {
69 return Err(TextError::InvalidInput("corpus is empty".to_string()));
70 }
71 if alpha <= 0.0 {
72 return Err(TextError::InvalidInput(
73 "smoothing parameter alpha must be > 0".to_string(),
74 ));
75 }
76
77 let mut class_set: Vec<String> = corpus
79 .iter()
80 .map(|(_, label)| label.clone())
81 .collect::<std::collections::HashSet<_>>()
82 .into_iter()
83 .collect();
84 class_set.sort();
85 self.classes = class_set;
86 let n_classes = self.classes.len();
87 let class_to_id: HashMap<String, usize> = self
88 .classes
89 .iter()
90 .enumerate()
91 .map(|(i, c)| (c.clone(), i))
92 .collect();
93
94 let mut vocab_set: std::collections::HashSet<String> = std::collections::HashSet::new();
96 for (text, _) in corpus {
97 for word in Self::tokenize(text) {
98 vocab_set.insert(word);
99 }
100 }
101 let mut vocab_sorted: Vec<String> = vocab_set.into_iter().collect();
102 vocab_sorted.sort();
103 self.vocabulary = vocab_sorted
104 .iter()
105 .enumerate()
106 .map(|(i, w)| (w.clone(), i))
107 .collect();
108 let v = self.vocabulary.len();
109
110 let mut class_counts = vec![0usize; n_classes];
112 let mut word_counts_per_class: Vec<Vec<f64>> = vec![vec![0.0; v]; n_classes];
113
114 for (text, label) in corpus {
115 let ci = class_to_id[label];
116 class_counts[ci] += 1;
117 for word in Self::tokenize(text) {
118 if let Some(&wi) = self.vocabulary.get(&word) {
119 word_counts_per_class[ci][wi] += 1.0;
120 }
121 }
122 }
123
124 let total_docs = corpus.len() as f64;
125 self.class_log_priors = class_counts
126 .iter()
127 .map(|&c| (c as f64 / total_docs).ln())
128 .collect();
129
130 self.log_likelihoods = word_counts_per_class
132 .iter()
133 .map(|counts| {
134 let total: f64 = counts.iter().sum::<f64>() + alpha * v as f64;
135 counts.iter().map(|&c| ((c + alpha) / total).ln()).collect()
136 })
137 .collect();
138
139 self.fitted = true;
140 Ok(())
141 }
142
143 fn log_scores(&self, text: &str) -> Result<Vec<f64>> {
145 if !self.fitted {
146 return Err(TextError::ModelNotFitted(
147 "NaiveBayesClassifier is not fitted".to_string(),
148 ));
149 }
150 let counts = self.text_to_counts(text);
151 let scores: Vec<f64> = self
152 .class_log_priors
153 .iter()
154 .zip(self.log_likelihoods.iter())
155 .map(|(&prior, likelihoods)| {
156 let ll: f64 = counts
157 .iter()
158 .zip(likelihoods.iter())
159 .map(|(&c, &lp)| c * lp)
160 .sum();
161 prior + ll
162 })
163 .collect();
164 Ok(scores)
165 }
166
167 pub fn predict(&self, text: &str) -> Result<Option<String>> {
169 let scores = self.log_scores(text)?;
170 let best = scores
171 .iter()
172 .enumerate()
173 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
174 .map(|(i, _)| self.classes[i].clone());
175 Ok(best)
176 }
177
178 pub fn predict_proba(&self, text: &str) -> Result<Vec<(String, f64)>> {
180 let log_scores = self.log_scores(text)?;
181 let max_s = log_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
183 let exps: Vec<f64> = log_scores.iter().map(|&s| (s - max_s).exp()).collect();
184 let total: f64 = exps.iter().sum();
185 Ok(self
186 .classes
187 .iter()
188 .zip(exps.iter())
189 .map(|(cls, &e)| (cls.clone(), if total == 0.0 { 0.0 } else { e / total }))
190 .collect())
191 }
192
193 pub fn predict_batch(&self, texts: &[String]) -> Result<Vec<Option<String>>> {
195 texts.iter().map(|t| self.predict(t)).collect()
196 }
197
198 pub fn accuracy(&self, test_set: &[(String, String)]) -> Result<f64> {
200 if test_set.is_empty() {
201 return Ok(0.0);
202 }
203 let mut correct = 0usize;
204 for (text, gold) in test_set {
205 if let Ok(Some(pred)) = self.predict(text) {
206 if &pred == gold {
207 correct += 1;
208 }
209 }
210 }
211 Ok(correct as f64 / test_set.len() as f64)
212 }
213
214 pub fn class_names(&self) -> &[String] {
216 &self.classes
217 }
218}
219
220#[derive(Debug, Clone)]
230pub struct FastTextClassifier {
231 n_classes: usize,
232 classes: Vec<String>,
233 word_vectors: HashMap<String, Vec<f32>>,
234 weights: Vec<Vec<f32>>,
236 bias: Vec<f32>,
238 dim: usize,
239 fitted: bool,
240}
241
242impl FastTextClassifier {
243 pub fn new(n_classes: usize, dim: usize, classes: Vec<String>) -> Self {
249 assert_eq!(
250 classes.len(),
251 n_classes,
252 "classes.len() must equal n_classes"
253 );
254 FastTextClassifier {
255 n_classes,
256 classes,
257 word_vectors: HashMap::new(),
258 weights: vec![vec![0.0f32; n_classes]; dim],
259 bias: vec![0.0f32; n_classes],
260 dim,
261 fitted: false,
262 }
263 }
264
265 fn init_word_vec(word: &str, dim: usize) -> Vec<f32> {
270 let mut v = vec![0.0f32; dim];
271 for (i, c) in word.bytes().enumerate() {
272 let idx = i % dim;
273 v[idx] += (c as f32 - 64.0) / 128.0;
274 }
275 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
277 if norm > 0.0 {
278 v.iter_mut().for_each(|x| *x /= norm);
279 }
280 v
281 }
282
283 fn mean_embedding(&self, tokens: &[String]) -> Vec<f32> {
285 let mut sum = vec![0.0f32; self.dim];
286 let mut count = 0usize;
287 for tok in tokens {
288 if let Some(vec) = self.word_vectors.get(tok.as_str()) {
289 for (s, &v) in sum.iter_mut().zip(vec.iter()) {
290 *s += v;
291 }
292 count += 1;
293 }
294 }
295 if count > 0 {
296 sum.iter_mut().for_each(|s| *s /= count as f32);
297 }
298 sum
299 }
300
301 fn forward(&self, embedding: &[f32]) -> Vec<f32> {
303 let mut logits = self.bias.clone();
304 for (d, &e) in embedding.iter().enumerate() {
305 for k in 0..self.n_classes {
306 logits[k] += e * self.weights[d][k];
307 }
308 }
309 logits
310 }
311
312 fn softmax(logits: &mut [f32]) {
314 let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
315 logits.iter_mut().for_each(|x| *x = (*x - max_l).exp());
316 let sum: f32 = logits.iter().sum();
317 if sum > 0.0 {
318 logits.iter_mut().for_each(|x| *x /= sum);
319 }
320 }
321
322 pub fn fit(&mut self, corpus: &[(Vec<String>, usize)], n_epochs: usize, lr: f32) -> Result<()> {
328 if corpus.is_empty() {
329 return Err(TextError::InvalidInput("corpus is empty".to_string()));
330 }
331
332 for (tokens, _) in corpus {
334 for tok in tokens {
335 self.word_vectors
336 .entry(tok.clone())
337 .or_insert_with(|| Self::init_word_vec(tok, self.dim));
338 }
339 }
340
341 for _epoch in 0..n_epochs {
342 for (tokens, gold_class) in corpus {
343 let gold_class = *gold_class;
344 if gold_class >= self.n_classes {
345 continue;
346 }
347 let emb = self.mean_embedding(tokens);
348 let mut probs = self.forward(&emb);
349 Self::softmax(&mut probs);
350
351 let mut grad = probs.clone();
353 grad[gold_class] -= 1.0;
354
355 for d in 0..self.dim {
357 for k in 0..self.n_classes {
358 self.weights[d][k] -= lr * grad[k] * emb[d];
359 }
360 }
361 for k in 0..self.n_classes {
362 self.bias[k] -= lr * grad[k];
363 }
364 }
365 }
366
367 self.fitted = true;
368 Ok(())
369 }
370
371 pub fn predict(&self, tokens: &[String]) -> usize {
373 let probs = self.predict_proba(tokens);
374 probs
375 .iter()
376 .enumerate()
377 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
378 .map(|(i, _)| i)
379 .unwrap_or(0)
380 }
381
382 pub fn predict_proba(&self, tokens: &[String]) -> Vec<f32> {
384 let emb = self.mean_embedding(tokens);
385 let mut logits = self.forward(&emb);
386 Self::softmax(&mut logits);
387 logits
388 }
389
390 pub fn class_names(&self) -> &[String] {
392 &self.classes
393 }
394
395 pub fn is_fitted(&self) -> bool {
397 self.fitted
398 }
399}
400
401#[derive(Debug, Clone)]
407pub struct CountVectorizer {
408 vocabulary: HashMap<String, usize>,
409 max_features: Option<usize>,
410 min_df: usize,
411 max_df_ratio: f64,
412 ngram_range: (usize, usize),
413 fitted: bool,
414}
415
416impl Default for CountVectorizer {
417 fn default() -> Self {
418 CountVectorizer {
419 vocabulary: HashMap::new(),
420 max_features: None,
421 min_df: 1,
422 max_df_ratio: 1.0,
423 ngram_range: (1, 1),
424 fitted: false,
425 }
426 }
427}
428
429impl CountVectorizer {
430 pub fn new() -> Self {
432 Self::default()
433 }
434
435 pub fn with_max_features(mut self, n: usize) -> Self {
437 self.max_features = Some(n);
438 self
439 }
440
441 pub fn with_ngram_range(mut self, min: usize, max: usize) -> Self {
443 self.ngram_range = (min, max);
444 self
445 }
446
447 pub fn with_min_df(mut self, min_df: usize) -> Self {
449 self.min_df = min_df;
450 self
451 }
452
453 pub fn with_max_df_ratio(mut self, ratio: f64) -> Self {
455 self.max_df_ratio = ratio;
456 self
457 }
458
459 fn ngrams(&self, tokens: &[String]) -> Vec<String> {
463 let (min_n, max_n) = self.ngram_range;
464 let mut grams = Vec::new();
465 for n in min_n..=max_n {
466 for window in tokens.windows(n) {
467 grams.push(window.join(" "));
468 }
469 }
470 grams
471 }
472
473 fn tokenize(text: &str) -> Vec<String> {
475 text.split(|c: char| !c.is_alphanumeric())
476 .filter(|s| !s.is_empty())
477 .map(|s| s.to_lowercase())
478 .collect()
479 }
480
481 pub fn fit(&mut self, corpus: &[String]) -> Result<()> {
485 if corpus.is_empty() {
486 return Err(TextError::InvalidInput("corpus is empty".to_string()));
487 }
488 let n_docs = corpus.len();
489
490 let mut df: HashMap<String, usize> = HashMap::new();
492 let mut term_freq: HashMap<String, usize> = HashMap::new();
493
494 for doc in corpus {
495 let tokens = Self::tokenize(doc);
496 let grams = self.ngrams(&tokens);
497 let unique: std::collections::HashSet<String> = grams.iter().cloned().collect();
498 for gram in unique {
499 *df.entry(gram.clone()).or_insert(0) += 1;
500 *term_freq.entry(gram).or_insert(0) += 1;
501 }
502 }
503
504 let max_df_count = (self.max_df_ratio * n_docs as f64).ceil() as usize;
506 let mut candidates: Vec<(String, usize)> = df
507 .into_iter()
508 .filter(|(_, count)| *count >= self.min_df && *count <= max_df_count)
509 .collect();
510
511 candidates.sort_by(|a, b| {
513 let fa = term_freq.get(&a.0).copied().unwrap_or(0);
514 let fb = term_freq.get(&b.0).copied().unwrap_or(0);
515 fb.cmp(&fa).then_with(|| a.0.cmp(&b.0))
516 });
517
518 if let Some(max_f) = self.max_features {
520 candidates.truncate(max_f);
521 }
522
523 self.vocabulary = candidates
525 .into_iter()
526 .enumerate()
527 .map(|(i, (gram, _))| (gram, i))
528 .collect();
529
530 self.fitted = true;
531 Ok(())
532 }
533
534 pub fn transform(&self, texts: &[String]) -> Result<Vec<Vec<f64>>> {
536 if !self.fitted {
537 return Err(TextError::ModelNotFitted(
538 "CountVectorizer is not fitted".to_string(),
539 ));
540 }
541 let v = self.vocabulary.len();
542 texts
543 .iter()
544 .map(|text| {
545 let tokens = Self::tokenize(text);
546 let grams = self.ngrams(&tokens);
547 let mut counts = vec![0.0f64; v];
548 for gram in grams {
549 if let Some(&idx) = self.vocabulary.get(&gram) {
550 counts[idx] += 1.0;
551 }
552 }
553 Ok(counts)
554 })
555 .collect()
556 }
557
558 pub fn fit_transform(&mut self, corpus: &[String]) -> Result<Vec<Vec<f64>>> {
560 self.fit(corpus)?;
561 self.transform(corpus)
562 }
563
564 pub fn vocabulary_size(&self) -> usize {
566 self.vocabulary.len()
567 }
568
569 pub fn vocabulary(&self) -> &HashMap<String, usize> {
571 &self.vocabulary
572 }
573}
574
575#[derive(Debug, Clone)]
586pub struct TfidfTransformer {
587 pub idf: Vec<f64>,
589 pub smooth_idf: bool,
591 fitted: bool,
592}
593
594impl TfidfTransformer {
595 pub fn new(smooth_idf: bool) -> Self {
598 TfidfTransformer {
599 idf: Vec::new(),
600 smooth_idf,
601 fitted: false,
602 }
603 }
604
605 pub fn fit(&mut self, count_matrix: &[Vec<f64>]) -> Result<()> {
607 if count_matrix.is_empty() {
608 return Err(TextError::InvalidInput("count_matrix is empty".to_string()));
609 }
610 let n_docs = count_matrix.len() as f64;
611 let n_features = count_matrix[0].len();
612
613 let mut df = vec![0.0f64; n_features];
614 for row in count_matrix {
615 for (j, &c) in row.iter().enumerate() {
616 if c > 0.0 {
617 df[j] += 1.0;
618 }
619 }
620 }
621
622 self.idf = if self.smooth_idf {
623 df.iter()
624 .map(|&d| ((1.0 + n_docs) / (1.0 + d)).ln() + 1.0)
625 .collect()
626 } else {
627 df.iter()
628 .map(|&d| {
629 if d == 0.0 {
630 0.0
631 } else {
632 (n_docs / d).ln() + 1.0
633 }
634 })
635 .collect()
636 };
637
638 self.fitted = true;
639 Ok(())
640 }
641
642 pub fn transform(&self, count_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
644 if !self.fitted {
645 return Err(TextError::ModelNotFitted(
646 "TfidfTransformer is not fitted".to_string(),
647 ));
648 }
649 count_matrix
650 .iter()
651 .map(|row| {
652 let mut tfidf: Vec<f64> = row
653 .iter()
654 .zip(self.idf.iter())
655 .map(|(&c, &idf)| c * idf)
656 .collect();
657 let norm: f64 = tfidf.iter().map(|&x| x * x).sum::<f64>().sqrt();
659 if norm > 0.0 {
660 tfidf.iter_mut().for_each(|x| *x /= norm);
661 }
662 Ok(tfidf)
663 })
664 .collect()
665 }
666
667 pub fn fit_transform(&mut self, count_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
669 self.fit(count_matrix)?;
670 self.transform(count_matrix)
671 }
672}
673
674#[cfg(test)]
679mod tests {
680 use super::*;
681
682 fn news_corpus() -> Vec<(String, String)> {
683 vec![
684 ("football game soccer ball".into(), "sports".into()),
685 ("basketball players team score".into(), "sports".into()),
686 ("election president vote campaign".into(), "politics".into()),
687 ("senate congress legislation bill".into(), "politics".into()),
688 ("python rust programming language".into(), "tech".into()),
689 ("software compiler code debug".into(), "tech".into()),
690 ]
691 }
692
693 #[test]
696 fn test_nb_fit_predict() {
697 let mut nb = NaiveBayesClassifier::new();
698 let corpus = news_corpus();
699 nb.fit(&corpus, 1.0).expect("fit failed");
700 let pred = nb.predict("soccer football game").expect("predict failed");
702 assert!(pred.is_some());
703 assert_eq!(pred.unwrap(), "sports");
704 }
705
706 #[test]
707 fn test_nb_predict_proba_sums_to_one() {
708 let mut nb = NaiveBayesClassifier::new();
709 let corpus = news_corpus();
710 nb.fit(&corpus, 1.0).expect("fit failed");
711 let proba = nb.predict_proba("vote election").expect("proba failed");
712 let total: f64 = proba.iter().map(|(_, p)| p).sum();
713 assert!(
714 (total - 1.0).abs() < 1e-9,
715 "probabilities should sum to 1, got {}",
716 total
717 );
718 }
719
720 #[test]
721 fn test_nb_accuracy() {
722 let mut nb = NaiveBayesClassifier::new();
723 let corpus = news_corpus();
724 nb.fit(&corpus, 1.0).expect("fit failed");
725 let acc = nb.accuracy(&corpus).expect("accuracy failed");
726 assert!(acc >= 0.5, "Expected accuracy >= 0.5, got {}", acc);
727 }
728
729 #[test]
730 fn test_nb_class_names() {
731 let mut nb = NaiveBayesClassifier::new();
732 nb.fit(&news_corpus(), 1.0).expect("fit failed");
733 let classes = nb.class_names();
734 assert!(classes.contains(&"sports".to_string()));
735 assert!(classes.contains(&"tech".to_string()));
736 }
737
738 #[test]
739 fn test_nb_not_fitted_error() {
740 let nb = NaiveBayesClassifier::new();
741 let result = nb.predict("test");
742 assert!(result.is_err());
743 }
744
745 #[test]
746 fn test_nb_batch_predict() {
747 let mut nb = NaiveBayesClassifier::new();
748 nb.fit(&news_corpus(), 1.0).expect("fit failed");
749 let texts = vec!["soccer game".to_string(), "code compiler".to_string()];
750 let preds = nb.predict_batch(&texts).expect("batch predict failed");
751 assert_eq!(preds.len(), 2);
752 assert!(preds[0].is_some());
753 }
754
755 #[test]
758 fn test_fasttext_predict_without_training() {
759 let ft = FastTextClassifier::new(2, 16, vec!["sports".to_string(), "tech".to_string()]);
760 let tokens: Vec<String> = vec!["soccer".into(), "game".into()];
761 let pred = ft.predict(&tokens);
762 assert!(pred < 2);
763 }
764
765 #[test]
766 fn test_fasttext_fit_and_predict() {
767 let classes = vec!["pos".to_string(), "neg".to_string()];
768 let mut ft = FastTextClassifier::new(2, 8, classes);
769 let corpus = vec![
770 (vec!["good".to_string(), "great".to_string()], 0usize),
771 (vec!["excellent".to_string(), "wonderful".to_string()], 0),
772 (vec!["bad".to_string(), "terrible".to_string()], 1),
773 (vec!["awful".to_string(), "horrible".to_string()], 1),
774 ];
775 ft.fit(&corpus, 10, 0.1).expect("fit failed");
776 assert!(ft.is_fitted());
777 let probs = ft.predict_proba(&["good".to_string()]);
778 assert_eq!(probs.len(), 2);
779 let total: f32 = probs.iter().sum();
780 assert!((total - 1.0).abs() < 1e-5);
781 }
782
783 #[test]
786 fn test_count_vectorizer_basic() {
787 let mut cv = CountVectorizer::new();
788 let corpus: Vec<String> = vec![
789 "hello world".to_string(),
790 "hello rust".to_string(),
791 "world rust".to_string(),
792 ];
793 let matrix = cv.fit_transform(&corpus).expect("fit_transform failed");
794 assert_eq!(matrix.len(), 3);
795 assert!(cv.vocabulary_size() > 0);
796 }
797
798 #[test]
799 fn test_count_vectorizer_ngram() {
800 let mut cv = CountVectorizer::new().with_ngram_range(1, 2);
801 let corpus: Vec<String> = vec!["the quick fox".to_string(), "the lazy dog".to_string()];
802 cv.fit(&corpus).expect("fit failed");
803 assert!(cv.vocabulary_size() > 3);
805 }
806
807 #[test]
808 fn test_count_vectorizer_max_features() {
809 let mut cv = CountVectorizer::new().with_max_features(2);
810 let corpus: Vec<String> = vec!["a b c d e f".to_string(), "a b c d e f".to_string()];
811 cv.fit(&corpus).expect("fit failed");
812 assert_eq!(cv.vocabulary_size(), 2);
813 }
814
815 #[test]
816 fn test_count_vectorizer_not_fitted_error() {
817 let cv = CountVectorizer::new();
818 let result = cv.transform(&["hello".to_string()]);
819 assert!(result.is_err());
820 }
821
822 #[test]
825 fn test_tfidf_transformer_l2_norm() {
826 let mut tf = TfidfTransformer::new(true);
827 let counts = vec![vec![1.0, 0.0, 2.0], vec![0.0, 3.0, 1.0]];
828 let tfidf = tf.fit_transform(&counts).expect("fit_transform failed");
829 for row in &tfidf {
831 let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
832 assert!((norm - 1.0).abs() < 1e-9, "norm = {}", norm);
833 }
834 }
835
836 #[test]
837 fn test_tfidf_transformer_not_fitted_error() {
838 let tf = TfidfTransformer::new(true);
839 let result = tf.transform(&[vec![1.0, 2.0]]);
840 assert!(result.is_err());
841 }
842
843 #[test]
844 fn test_tfidf_smooth_vs_no_smooth() {
845 let mut tf_smooth = TfidfTransformer::new(true);
846 let mut tf_no = TfidfTransformer::new(false);
847 let counts = vec![vec![1.0, 2.0], vec![3.0, 0.0]];
848 tf_smooth.fit(&counts).expect("fit");
849 tf_no.fit(&counts).expect("fit");
850 assert_ne!(tf_smooth.idf, tf_no.idf);
852 }
853}