1use crate::error::{Result, TextError};
15use crate::tokenize::{Tokenizer, WordTokenizer};
16use crate::vectorize::{TfidfVectorizer, Vectorizer};
17use scirs2_core::ndarray::{Array1, Array2, Axis};
18use scirs2_core::random::prelude::*;
19use scirs2_core::random::seq::SliceRandom;
20use scirs2_core::random::SeedableRng;
21use std::collections::{HashMap, HashSet};
22
23#[derive(Debug, Clone)]
29pub struct TextFeatureSelector {
30 min_df: f64,
32 max_df: f64,
34 use_counts: bool,
36 selected_features: Option<Vec<usize>>,
38}
39
40impl Default for TextFeatureSelector {
41 fn default() -> Self {
42 Self {
43 min_df: 0.0,
44 max_df: 1.0,
45 use_counts: false,
46 selected_features: None,
47 }
48 }
49}
50
51impl TextFeatureSelector {
52 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn set_min_df(mut self, mindf: f64) -> Result<Self> {
59 if mindf < 0.0 {
60 return Err(TextError::InvalidInput(
61 "min_df must be non-negative".to_string(),
62 ));
63 }
64 self.min_df = mindf;
65 Ok(self)
66 }
67
68 pub fn set_max_df(mut self, maxdf: f64) -> Result<Self> {
70 if !(0.0..=1.0).contains(&maxdf) {
71 return Err(TextError::InvalidInput(
72 "max_df must be between 0 and 1 for fractions".to_string(),
73 ));
74 }
75 self.max_df = maxdf;
76 Ok(self)
77 }
78
79 pub fn set_max_features(self, maxfeatures: f64) -> Result<Self> {
81 self.set_max_df(maxfeatures)
82 }
83
84 pub fn use_counts(mut self, usecounts: bool) -> Self {
86 self.use_counts = usecounts;
87 self
88 }
89
90 pub fn fit(&mut self, x: &Array2<f64>) -> Result<&mut Self> {
92 let n_samples = x.nrows();
93 let n_features = x.ncols();
94
95 let mut document_frequencies = vec![0; n_features];
96
97 for sample in x.axis_iter(Axis(0)) {
98 for (feature_idx, &value) in sample.iter().enumerate() {
99 if value > 0.0 {
100 document_frequencies[feature_idx] += 1;
101 }
102 }
103 }
104
105 let min_count = if self.use_counts {
106 self.min_df
107 } else {
108 self.min_df * n_samples as f64
109 };
110
111 let max_count = if self.use_counts {
112 self.max_df
113 } else {
114 self.max_df * n_samples as f64
115 };
116
117 let mut selected_features = Vec::new();
118 for (idx, &df) in document_frequencies.iter().enumerate() {
119 let df_f64 = df as f64;
120 if df_f64 >= min_count && df_f64 <= max_count {
121 selected_features.push(idx);
122 }
123 }
124
125 self.selected_features = Some(selected_features);
126 Ok(self)
127 }
128
129 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
131 let selected_features = self
132 .selected_features
133 .as_ref()
134 .ok_or_else(|| TextError::ModelNotFitted("Feature selector not fitted".to_string()))?;
135
136 if selected_features.is_empty() {
137 return Err(TextError::InvalidInput(
138 "No features selected. Try adjusting min_df and max_df".to_string(),
139 ));
140 }
141
142 let n_samples = x.nrows();
143 let n_selected = selected_features.len();
144
145 let mut result = Array2::zeros((n_samples, n_selected));
146
147 for (i, row) in x.axis_iter(Axis(0)).enumerate() {
148 for (j, &feature_idx) in selected_features.iter().enumerate() {
149 result[[i, j]] = row[feature_idx];
150 }
151 }
152
153 Ok(result)
154 }
155
156 pub fn fit_transform(&mut self, x: &Array2<f64>) -> Result<Array2<f64>> {
158 self.fit(x)?;
159 self.transform(x)
160 }
161
162 pub fn get_selected_features(&self) -> Option<&Vec<usize>> {
164 self.selected_features.as_ref()
165 }
166}
167
168#[derive(Debug, Clone)]
172pub struct TextClassificationMetrics;
173
174impl Default for TextClassificationMetrics {
175 fn default() -> Self {
176 Self
177 }
178}
179
180impl TextClassificationMetrics {
181 pub fn new() -> Self {
183 Self
184 }
185
186 pub fn precision<T>(
188 &self,
189 predictions: &[T],
190 true_labels: &[T],
191 class_idx: Option<T>,
192 ) -> Result<f64>
193 where
194 T: PartialEq + Copy + Default,
195 {
196 let positive_class = class_idx.unwrap_or_default();
197
198 if predictions.len() != true_labels.len() {
199 return Err(TextError::InvalidInput(
200 "Predictions and labels must have the same length".to_string(),
201 ));
202 }
203
204 let mut true_positives = 0;
205 let mut predicted_positives = 0;
206
207 for i in 0..predictions.len() {
208 if predictions[i] == positive_class {
209 predicted_positives += 1;
210 if true_labels[i] == positive_class {
211 true_positives += 1;
212 }
213 }
214 }
215
216 if predicted_positives == 0 {
217 return Ok(0.0);
218 }
219
220 Ok(true_positives as f64 / predicted_positives as f64)
221 }
222
223 pub fn recall<T>(
225 &self,
226 predictions: &[T],
227 true_labels: &[T],
228 class_idx: Option<T>,
229 ) -> Result<f64>
230 where
231 T: PartialEq + Copy + Default,
232 {
233 let positive_class = class_idx.unwrap_or_default();
234
235 if predictions.len() != true_labels.len() {
236 return Err(TextError::InvalidInput(
237 "Predictions and labels must have the same length".to_string(),
238 ));
239 }
240
241 let mut true_positives = 0;
242 let mut actual_positives = 0;
243
244 for i in 0..predictions.len() {
245 if true_labels[i] == positive_class {
246 actual_positives += 1;
247 if predictions[i] == positive_class {
248 true_positives += 1;
249 }
250 }
251 }
252
253 if actual_positives == 0 {
254 return Ok(0.0);
255 }
256
257 Ok(true_positives as f64 / actual_positives as f64)
258 }
259
260 pub fn f1_score<T>(
262 &self,
263 predictions: &[T],
264 true_labels: &[T],
265 class_idx: Option<T>,
266 ) -> Result<f64>
267 where
268 T: PartialEq + Copy + Default,
269 {
270 let precision = self.precision(predictions, true_labels, class_idx)?;
271 let recall = self.recall(predictions, true_labels, class_idx)?;
272
273 if precision + recall == 0.0 {
274 return Ok(0.0);
275 }
276
277 Ok(2.0 * precision * recall / (precision + recall))
278 }
279
280 pub fn accuracy<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<f64>
282 where
283 T: PartialEq,
284 {
285 if predictions.len() != truelabels.len() {
286 return Err(TextError::InvalidInput(
287 "Predictions and labels must have the same length".to_string(),
288 ));
289 }
290
291 if predictions.is_empty() {
292 return Err(TextError::InvalidInput(
293 "Cannot calculate accuracy for empty arrays".to_string(),
294 ));
295 }
296
297 let correct = predictions
298 .iter()
299 .zip(truelabels.iter())
300 .filter(|(pred, true_label)| pred == true_label)
301 .count();
302
303 Ok(correct as f64 / predictions.len() as f64)
304 }
305
306 pub fn binary_metrics<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<(f64, f64, f64)>
308 where
309 T: PartialEq + Copy + Default + PartialEq<usize>,
310 {
311 if predictions.len() != truelabels.len() {
312 return Err(TextError::InvalidInput(
313 "Predictions and labels must have the same length".to_string(),
314 ));
315 }
316
317 let mut tp = 0;
318 let mut fp = 0;
319 let mut fn_ = 0;
320
321 for (pred, true_label) in predictions.iter().zip(truelabels.iter()) {
322 if *pred == 1 && *true_label == 1 {
323 tp += 1;
324 } else if *pred == 1 && *true_label == 0 {
325 fp += 1;
326 } else if *pred == 0 && *true_label == 1 {
327 fn_ += 1;
328 }
329 }
330
331 let precision = if tp + fp > 0 {
332 tp as f64 / (tp + fp) as f64
333 } else {
334 0.0
335 };
336
337 let recall = if tp + fn_ > 0 {
338 tp as f64 / (tp + fn_) as f64
339 } else {
340 0.0
341 };
342
343 let f1 = if precision + recall > 0.0 {
344 2.0 * precision * recall / (precision + recall)
345 } else {
346 0.0
347 };
348
349 Ok((precision, recall, f1))
350 }
351}
352
353#[derive(Debug, Clone)]
357pub struct TextDataset {
358 pub texts: Vec<String>,
360 pub labels: Vec<String>,
362 label_index: Option<HashMap<String, usize>>,
364}
365
366impl TextDataset {
367 pub fn new(texts: Vec<String>, labels: Vec<String>) -> Result<Self> {
369 if texts.len() != labels.len() {
370 return Err(TextError::InvalidInput(
371 "Texts and labels must have the same length".to_string(),
372 ));
373 }
374
375 Ok(Self {
376 texts,
377 labels,
378 label_index: None,
379 })
380 }
381
382 pub fn len(&self) -> usize {
384 self.texts.len()
385 }
386
387 pub fn is_empty(&self) -> bool {
389 self.texts.is_empty()
390 }
391
392 pub fn unique_labels(&self) -> Vec<String> {
394 let mut unique = HashSet::new();
395 for label in &self.labels {
396 unique.insert(label.clone());
397 }
398 unique.into_iter().collect()
399 }
400
401 pub fn build_label_index(&mut self) -> Result<&mut Self> {
403 let mut index = HashMap::new();
404 let unique_labels = self.unique_labels();
405
406 for (i, label) in unique_labels.iter().enumerate() {
407 index.insert(label.clone(), i);
408 }
409
410 self.label_index = Some(index);
411 Ok(self)
412 }
413
414 pub fn get_label_indices(&self) -> Result<Vec<usize>> {
416 let index = self
417 .label_index
418 .as_ref()
419 .ok_or_else(|| TextError::ModelNotFitted("Label index not built".to_string()))?;
420
421 self.labels
422 .iter()
423 .map(|label| {
424 index
425 .get(label)
426 .copied()
427 .ok_or_else(|| TextError::InvalidInput(format!("Unknown label: {label}")))
428 })
429 .collect()
430 }
431
432 pub fn train_test_split(
434 &self,
435 test_size: f64,
436 random_seed: Option<u64>,
437 ) -> Result<(Self, Self)> {
438 if test_size <= 0.0 || test_size >= 1.0 {
439 return Err(TextError::InvalidInput(
440 "test_size must be between 0 and 1".to_string(),
441 ));
442 }
443
444 if self.is_empty() {
445 return Err(TextError::InvalidInput("Dataset is empty".to_string()));
446 }
447
448 let mut indices: Vec<usize> = (0..self.len()).collect();
449
450 if let Some(seed) = random_seed {
451 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
452 indices.shuffle(&mut rng);
453 } else {
454 let mut rng = scirs2_core::random::rng();
455 indices.shuffle(&mut rng);
456 }
457
458 let test_count = (self.len() as f64 * test_size).ceil() as usize;
459 let test_indices = indices[0..test_count].to_vec();
460 let train_indices = indices[test_count..].to_vec();
461
462 let train_texts = train_indices
463 .iter()
464 .map(|&i| self.texts[i].clone())
465 .collect();
466 let train_labels = train_indices
467 .iter()
468 .map(|&i| self.labels[i].clone())
469 .collect();
470 let test_texts = test_indices
471 .iter()
472 .map(|&i| self.texts[i].clone())
473 .collect();
474 let test_labels = test_indices
475 .iter()
476 .map(|&i| self.labels[i].clone())
477 .collect();
478
479 let mut train_dataset = Self::new(train_texts, train_labels)?;
480 let mut test_dataset = Self::new(test_texts, test_labels)?;
481
482 if self.label_index.is_some() {
483 train_dataset.build_label_index()?;
484 test_dataset.build_label_index()?;
485 }
486
487 Ok((train_dataset, test_dataset))
488 }
489}
490
491pub struct TextClassificationPipeline {
495 vectorizer: TfidfVectorizer,
497 feature_selector: Option<TextFeatureSelector>,
499}
500
501impl TextClassificationPipeline {
502 pub fn with_tfidf() -> Self {
504 Self::new(TfidfVectorizer::default())
505 }
506
507 pub fn new(vectorizer: TfidfVectorizer) -> Self {
509 Self {
510 vectorizer,
511 feature_selector: None,
512 }
513 }
514
515 pub fn with_feature_selector(mut self, selector: TextFeatureSelector) -> Self {
517 self.feature_selector = Some(selector);
518 self
519 }
520
521 pub fn fit(&mut self, dataset: &TextDataset) -> Result<&mut Self> {
523 let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
524 self.vectorizer.fit(&texts)?;
525 Ok(self)
526 }
527
528 pub fn transform(&self, dataset: &TextDataset) -> Result<Array2<f64>> {
530 let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
531 let mut features = self.vectorizer.transform_batch(&texts)?;
532
533 if let Some(selector) = &self.feature_selector {
534 features = selector.transform(&features)?;
535 }
536
537 Ok(features)
538 }
539
540 pub fn fit_transform(&mut self, dataset: &TextDataset) -> Result<Array2<f64>> {
542 self.fit(dataset)?;
543 self.transform(dataset)
544 }
545}
546
547#[derive(Debug, Clone)]
554pub struct MultinomialNaiveBayes {
555 feature_log_probs: HashMap<String, Vec<f64>>,
557 class_log_priors: HashMap<String, f64>,
559 n_features: usize,
561 alpha: f64,
563 classes: Vec<String>,
565}
566
567impl MultinomialNaiveBayes {
568 pub fn new(alpha: f64) -> Self {
570 Self {
571 feature_log_probs: HashMap::new(),
572 class_log_priors: HashMap::new(),
573 n_features: 0,
574 alpha,
575 classes: Vec::new(),
576 }
577 }
578
579 pub fn fit(&mut self, features: &Array2<f64>, labels: &[String]) -> Result<()> {
585 if features.nrows() != labels.len() {
586 return Err(TextError::InvalidInput(
587 "Features and labels must have the same number of rows".into(),
588 ));
589 }
590
591 let n_samples = features.nrows();
592 self.n_features = features.ncols();
593
594 let mut class_set = HashSet::new();
596 for label in labels {
597 class_set.insert(label.clone());
598 }
599 self.classes = class_set.into_iter().collect();
600 self.classes.sort();
601
602 for class in &self.classes {
604 let class_indices: Vec<usize> = labels
606 .iter()
607 .enumerate()
608 .filter(|(_, l)| *l == class)
609 .map(|(i, _)| i)
610 .collect();
611
612 let class_count = class_indices.len();
613
614 let log_prior = (class_count as f64 / n_samples as f64).ln();
616 self.class_log_priors.insert(class.clone(), log_prior);
617
618 let mut feature_sums = vec![0.0; self.n_features];
620 for &idx in &class_indices {
621 for j in 0..self.n_features {
622 feature_sums[j] += features[[idx, j]];
623 }
624 }
625
626 let total: f64 = feature_sums.iter().sum::<f64>() + self.alpha * self.n_features as f64;
628
629 let log_probs: Vec<f64> = feature_sums
631 .iter()
632 .map(|&count| ((count + self.alpha) / total).ln())
633 .collect();
634
635 self.feature_log_probs.insert(class.clone(), log_probs);
636 }
637
638 Ok(())
639 }
640
641 pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<String>> {
643 let mut predictions = Vec::with_capacity(features.nrows());
644
645 for row in features.axis_iter(Axis(0)) {
646 let (label, _) = self.predict_single(&row.to_owned())?;
647 predictions.push(label);
648 }
649
650 Ok(predictions)
651 }
652
653 fn predict_single(&self, features: &Array1<f64>) -> Result<(String, f64)> {
655 if self.classes.is_empty() {
656 return Err(TextError::ModelNotFitted("Classifier not trained".into()));
657 }
658
659 let mut best_class = String::new();
660 let mut best_score = f64::NEG_INFINITY;
661
662 for class in &self.classes {
663 let log_prior = self
664 .class_log_priors
665 .get(class)
666 .copied()
667 .unwrap_or(f64::NEG_INFINITY);
668
669 let log_probs = self
670 .feature_log_probs
671 .get(class)
672 .ok_or_else(|| TextError::RuntimeError("Missing feature probs".into()))?;
673
674 let log_likelihood: f64 = features
675 .iter()
676 .zip(log_probs.iter())
677 .map(|(&feat, &log_p)| feat * log_p)
678 .sum();
679
680 let score = log_prior + log_likelihood;
681 if score > best_score {
682 best_score = score;
683 best_class = class.clone();
684 }
685 }
686
687 Ok((best_class, best_score))
688 }
689}
690
691#[derive(Debug, Clone)]
698pub struct BernoulliNaiveBayes {
699 feature_log_probs: HashMap<String, Vec<f64>>,
701 feature_log_neg_probs: HashMap<String, Vec<f64>>,
703 class_log_priors: HashMap<String, f64>,
705 n_features: usize,
707 alpha: f64,
709 binarize_threshold: f64,
711 classes: Vec<String>,
713}
714
715impl BernoulliNaiveBayes {
716 pub fn new(alpha: f64) -> Self {
718 Self {
719 feature_log_probs: HashMap::new(),
720 feature_log_neg_probs: HashMap::new(),
721 class_log_priors: HashMap::new(),
722 n_features: 0,
723 alpha,
724 binarize_threshold: 0.0,
725 classes: Vec::new(),
726 }
727 }
728
729 pub fn with_binarize_threshold(mut self, threshold: f64) -> Self {
731 self.binarize_threshold = threshold;
732 self
733 }
734
735 pub fn fit(&mut self, features: &Array2<f64>, labels: &[String]) -> Result<()> {
737 if features.nrows() != labels.len() {
738 return Err(TextError::InvalidInput(
739 "Features and labels must have the same number of rows".into(),
740 ));
741 }
742
743 let n_samples = features.nrows();
744 self.n_features = features.ncols();
745
746 let mut class_set = HashSet::new();
747 for label in labels {
748 class_set.insert(label.clone());
749 }
750 self.classes = class_set.into_iter().collect();
751 self.classes.sort();
752
753 for class in &self.classes {
754 let class_indices: Vec<usize> = labels
755 .iter()
756 .enumerate()
757 .filter(|(_, l)| *l == class)
758 .map(|(i, _)| i)
759 .collect();
760
761 let class_count = class_indices.len() as f64;
762
763 let log_prior = (class_count / n_samples as f64).ln();
764 self.class_log_priors.insert(class.clone(), log_prior);
765
766 let mut feature_present = vec![0.0; self.n_features];
768 for &idx in &class_indices {
769 for j in 0..self.n_features {
770 if features[[idx, j]] > self.binarize_threshold {
771 feature_present[j] += 1.0;
772 }
773 }
774 }
775
776 let log_probs: Vec<f64> = feature_present
778 .iter()
779 .map(|&count| ((count + self.alpha) / (class_count + 2.0 * self.alpha)).ln())
780 .collect();
781
782 let log_neg_probs: Vec<f64> = feature_present
783 .iter()
784 .map(|&count| {
785 ((class_count - count + self.alpha) / (class_count + 2.0 * self.alpha)).ln()
786 })
787 .collect();
788
789 self.feature_log_probs.insert(class.clone(), log_probs);
790 self.feature_log_neg_probs
791 .insert(class.clone(), log_neg_probs);
792 }
793
794 Ok(())
795 }
796
797 pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<String>> {
799 let mut predictions = Vec::with_capacity(features.nrows());
800
801 for row in features.axis_iter(Axis(0)) {
802 let label = self.predict_single(&row.to_owned())?;
803 predictions.push(label);
804 }
805
806 Ok(predictions)
807 }
808
809 fn predict_single(&self, features: &Array1<f64>) -> Result<String> {
810 if self.classes.is_empty() {
811 return Err(TextError::ModelNotFitted("Classifier not trained".into()));
812 }
813
814 let mut best_class = String::new();
815 let mut best_score = f64::NEG_INFINITY;
816
817 for class in &self.classes {
818 let log_prior = self
819 .class_log_priors
820 .get(class)
821 .copied()
822 .unwrap_or(f64::NEG_INFINITY);
823
824 let log_probs = self
825 .feature_log_probs
826 .get(class)
827 .ok_or_else(|| TextError::RuntimeError("Missing probs".into()))?;
828 let log_neg_probs = self
829 .feature_log_neg_probs
830 .get(class)
831 .ok_or_else(|| TextError::RuntimeError("Missing neg probs".into()))?;
832
833 let mut log_likelihood = 0.0;
834 for j in 0..self.n_features {
835 if features[j] > self.binarize_threshold {
836 log_likelihood += log_probs[j];
837 } else {
838 log_likelihood += log_neg_probs[j];
839 }
840 }
841
842 let score = log_prior + log_likelihood;
843 if score > best_score {
844 best_score = score;
845 best_class = class.clone();
846 }
847 }
848
849 Ok(best_class)
850 }
851}
852
853pub struct TfidfCosineClassifier {
860 train_vectors: Option<Array2<f64>>,
862 train_labels: Vec<String>,
864 k: usize,
866}
867
868impl TfidfCosineClassifier {
869 pub fn new(k: usize) -> Self {
871 Self {
872 train_vectors: None,
873 train_labels: Vec::new(),
874 k,
875 }
876 }
877
878 pub fn fit(&mut self, features: &Array2<f64>, labels: &[String]) -> Result<()> {
880 if features.nrows() != labels.len() {
881 return Err(TextError::InvalidInput(
882 "Features and labels must have the same number of rows".into(),
883 ));
884 }
885
886 self.train_vectors = Some(features.clone());
887 self.train_labels = labels.to_vec();
888 Ok(())
889 }
890
891 pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<String>> {
893 let train_vectors = self
894 .train_vectors
895 .as_ref()
896 .ok_or_else(|| TextError::ModelNotFitted("Classifier not trained".into()))?;
897
898 let mut predictions = Vec::with_capacity(features.nrows());
899
900 for row in features.axis_iter(Axis(0)) {
901 let query = row.to_owned();
902
903 let mut similarities: Vec<(usize, f64)> = Vec::with_capacity(train_vectors.nrows());
905
906 for (idx, train_row) in train_vectors.axis_iter(Axis(0)).enumerate() {
907 let sim = cosine_similarity(&query, &train_row.to_owned());
908 similarities.push((idx, sim));
909 }
910
911 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
913
914 let mut class_votes: HashMap<&str, usize> = HashMap::new();
916 let k = self.k.min(similarities.len());
917
918 for &(idx, _) in similarities.iter().take(k) {
919 *class_votes.entry(&self.train_labels[idx]).or_insert(0) += 1;
920 }
921
922 let best_class = class_votes
923 .iter()
924 .max_by_key(|(_, &count)| count)
925 .map(|(label, _)| label.to_string())
926 .unwrap_or_default();
927
928 predictions.push(best_class);
929 }
930
931 Ok(predictions)
932 }
933}
934
935fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
937 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
938 let norm_a = a.iter().map(|x| x * x).sum::<f64>().sqrt();
939 let norm_b = b.iter().map(|x| x * x).sum::<f64>().sqrt();
940
941 if norm_a > 0.0 && norm_b > 0.0 {
942 dot / (norm_a * norm_b)
943 } else {
944 0.0
945 }
946}
947
948pub struct FeatureHasher {
956 n_features: usize,
958 tokenizer: Box<dyn Tokenizer + Send + Sync>,
960}
961
962impl std::fmt::Debug for FeatureHasher {
963 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
964 f.debug_struct("FeatureHasher")
965 .field("n_features", &self.n_features)
966 .finish()
967 }
968}
969
970impl FeatureHasher {
971 pub fn new(n_features: usize) -> Self {
973 Self {
974 n_features,
975 tokenizer: Box::new(WordTokenizer::default()),
976 }
977 }
978
979 fn hash_feature(&self, token: &str) -> usize {
981 let mut hash: u64 = 2166136261;
982 for byte in token.bytes() {
983 hash ^= u64::from(byte);
984 hash = hash.wrapping_mul(16777619);
985 }
986 (hash % (self.n_features as u64)) as usize
987 }
988
989 fn hash_sign(&self, token: &str) -> f64 {
991 let mut hash: u64 = 84696351;
992 for byte in token.bytes() {
993 hash ^= u64::from(byte);
994 hash = hash.wrapping_mul(16777619);
995 }
996 if hash.is_multiple_of(2) {
997 1.0
998 } else {
999 -1.0
1000 }
1001 }
1002
1003 pub fn transform_text(&self, text: &str) -> Result<Array1<f64>> {
1005 let tokens = self.tokenizer.tokenize(text)?;
1006 let mut features = Array1::zeros(self.n_features);
1007
1008 for token in &tokens {
1009 let idx = self.hash_feature(&token.to_lowercase());
1010 let sign = self.hash_sign(&token.to_lowercase());
1011 features[idx] += sign;
1012 }
1013
1014 Ok(features)
1015 }
1016
1017 pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
1019 let mut matrix = Array2::zeros((texts.len(), self.n_features));
1020
1021 for (i, &text) in texts.iter().enumerate() {
1022 let features = self.transform_text(text)?;
1023 for j in 0..self.n_features {
1024 matrix[[i, j]] = features[j];
1025 }
1026 }
1027
1028 Ok(matrix)
1029 }
1030
1031 pub fn num_features(&self) -> usize {
1033 self.n_features
1034 }
1035}
1036
1037#[derive(Debug, Clone)]
1041pub struct MultiLabelPrediction {
1042 pub labels: Vec<String>,
1044 pub scores: HashMap<String, f64>,
1046}
1047
1048#[derive(Debug, Clone)]
1053pub struct MultiLabelClassifier {
1054 classifiers: HashMap<String, MultinomialNaiveBayes>,
1056 threshold: f64,
1058 all_labels: Vec<String>,
1060}
1061
1062impl MultiLabelClassifier {
1063 pub fn new(threshold: f64) -> Self {
1065 Self {
1066 classifiers: HashMap::new(),
1067 threshold,
1068 all_labels: Vec::new(),
1069 }
1070 }
1071
1072 pub fn fit(&mut self, features: &Array2<f64>, label_sets: &[Vec<String>]) -> Result<()> {
1078 if features.nrows() != label_sets.len() {
1079 return Err(TextError::InvalidInput(
1080 "Features and label_sets must have the same length".into(),
1081 ));
1082 }
1083
1084 let mut all_labels_set = HashSet::new();
1086 for labels in label_sets {
1087 for label in labels {
1088 all_labels_set.insert(label.clone());
1089 }
1090 }
1091 self.all_labels = all_labels_set.into_iter().collect();
1092 self.all_labels.sort();
1093
1094 for label in &self.all_labels {
1096 let binary_labels: Vec<String> = label_sets
1097 .iter()
1098 .map(|ls| {
1099 if ls.contains(label) {
1100 "positive".to_string()
1101 } else {
1102 "negative".to_string()
1103 }
1104 })
1105 .collect();
1106
1107 let mut clf = MultinomialNaiveBayes::new(1.0);
1108 clf.fit(features, &binary_labels)?;
1109 self.classifiers.insert(label.clone(), clf);
1110 }
1111
1112 Ok(())
1113 }
1114
1115 pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<MultiLabelPrediction>> {
1117 let mut predictions = Vec::with_capacity(features.nrows());
1118
1119 for row in features.axis_iter(Axis(0)) {
1120 let row_arr = row.to_owned();
1121 let mut labels = Vec::new();
1122 let mut scores = HashMap::new();
1123
1124 let single_row = Array2::from_shape_fn((1, row_arr.len()), |(_, j)| row_arr[j]);
1126
1127 for label in &self.all_labels {
1128 if let Some(clf) = self.classifiers.get(label) {
1129 let pred = clf.predict(&single_row)?;
1130 if !pred.is_empty() && pred[0] == "positive" {
1131 labels.push(label.clone());
1132 scores.insert(label.clone(), 1.0);
1133 } else {
1134 scores.insert(label.clone(), 0.0);
1135 }
1136 }
1137 }
1138
1139 predictions.push(MultiLabelPrediction { labels, scores });
1140 }
1141
1142 Ok(predictions)
1143 }
1144}
1145
1146#[derive(Debug, Clone)]
1150pub struct FoldResult {
1151 pub fold: usize,
1153 pub accuracy: f64,
1155 pub predictions: Vec<String>,
1157 pub true_labels: Vec<String>,
1159}
1160
1161#[derive(Debug, Clone)]
1163pub struct CrossValidationResult {
1164 pub fold_results: Vec<FoldResult>,
1166 pub mean_accuracy: f64,
1168 pub std_accuracy: f64,
1170}
1171
1172pub fn cross_validate_nb(
1181 features: &Array2<f64>,
1182 labels: &[String],
1183 k: usize,
1184 alpha: f64,
1185 seed: Option<u64>,
1186) -> Result<CrossValidationResult> {
1187 if features.nrows() != labels.len() {
1188 return Err(TextError::InvalidInput(
1189 "Features and labels must have the same length".into(),
1190 ));
1191 }
1192
1193 let n = features.nrows();
1194 if k < 2 || k > n {
1195 return Err(TextError::InvalidInput(format!(
1196 "k must be between 2 and {} (number of samples)",
1197 n
1198 )));
1199 }
1200
1201 let mut indices: Vec<usize> = (0..n).collect();
1203 if let Some(s) = seed {
1204 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(s);
1205 indices.shuffle(&mut rng);
1206 } else {
1207 let mut rng = scirs2_core::random::rng();
1208 indices.shuffle(&mut rng);
1209 }
1210
1211 let fold_size = n / k;
1212 let mut fold_results = Vec::with_capacity(k);
1213
1214 for fold in 0..k {
1215 let test_start = fold * fold_size;
1216 let test_end = if fold == k - 1 {
1217 n
1218 } else {
1219 (fold + 1) * fold_size
1220 };
1221
1222 let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
1223 let train_indices: Vec<usize> = indices
1224 .iter()
1225 .enumerate()
1226 .filter(|(i, _)| *i < test_start || *i >= test_end)
1227 .map(|(_, &idx)| idx)
1228 .collect();
1229
1230 let n_train = train_indices.len();
1232 let n_test = test_indices.len();
1233 let n_features = features.ncols();
1234
1235 let mut train_features = Array2::zeros((n_train, n_features));
1236 let mut train_labels = Vec::with_capacity(n_train);
1237
1238 for (i, &idx) in train_indices.iter().enumerate() {
1239 for j in 0..n_features {
1240 train_features[[i, j]] = features[[idx, j]];
1241 }
1242 train_labels.push(labels[idx].clone());
1243 }
1244
1245 let mut test_features = Array2::zeros((n_test, n_features));
1246 let mut test_labels = Vec::with_capacity(n_test);
1247
1248 for (i, &idx) in test_indices.iter().enumerate() {
1249 for j in 0..n_features {
1250 test_features[[i, j]] = features[[idx, j]];
1251 }
1252 test_labels.push(labels[idx].clone());
1253 }
1254
1255 let mut clf = MultinomialNaiveBayes::new(alpha);
1257 clf.fit(&train_features, &train_labels)?;
1258 let predictions = clf.predict(&test_features)?;
1259
1260 let correct = predictions
1262 .iter()
1263 .zip(test_labels.iter())
1264 .filter(|(p, t)| p == t)
1265 .count();
1266 let accuracy = correct as f64 / n_test as f64;
1267
1268 fold_results.push(FoldResult {
1269 fold,
1270 accuracy,
1271 predictions,
1272 true_labels: test_labels,
1273 });
1274 }
1275
1276 let accuracies: Vec<f64> = fold_results.iter().map(|f| f.accuracy).collect();
1278 let mean_accuracy = accuracies.iter().sum::<f64>() / k as f64;
1279 let variance = accuracies
1280 .iter()
1281 .map(|&a| (a - mean_accuracy).powi(2))
1282 .sum::<f64>()
1283 / k as f64;
1284 let std_accuracy = variance.sqrt();
1285
1286 Ok(CrossValidationResult {
1287 fold_results,
1288 mean_accuracy,
1289 std_accuracy,
1290 })
1291}
1292
1293#[cfg(test)]
1294mod tests {
1295 use super::*;
1296
1297 #[test]
1298 fn test_text_dataset() {
1299 let texts = vec![
1300 "This is document 1".to_string(),
1301 "Another document".to_string(),
1302 "A third document".to_string(),
1303 ];
1304 let labels = vec!["A".to_string(), "B".to_string(), "A".to_string()];
1305
1306 let mut dataset = TextDataset::new(texts, labels).expect("Operation failed");
1307
1308 let mut label_index = HashMap::new();
1309 label_index.insert("A".to_string(), 0);
1310 label_index.insert("B".to_string(), 1);
1311 dataset.label_index = Some(label_index);
1312
1313 let label_indices = dataset.get_label_indices().expect("Operation failed");
1314 assert_eq!(label_indices[0], 0);
1315 assert_eq!(label_indices[1], 1);
1316 assert_eq!(label_indices[2], 0);
1317
1318 let unique_labels = dataset.unique_labels();
1319 assert_eq!(unique_labels.len(), 2);
1320 }
1321
1322 #[test]
1323 fn test_train_test_split() {
1324 let texts = (0..10).map(|i| format!("Text {i}")).collect();
1325 let labels = (0..10).map(|_| "A".to_string()).collect();
1326
1327 let dataset = TextDataset::new(texts, labels).expect("Operation failed");
1328 let (train, test) = dataset
1329 .train_test_split(0.3, Some(42))
1330 .expect("Operation failed");
1331
1332 assert_eq!(train.len(), 7);
1333 assert_eq!(test.len(), 3);
1334 }
1335
1336 #[test]
1337 fn test_feature_selector() {
1338 let mut features = Array2::zeros((5, 3));
1339 features[[0, 0]] = 1.0;
1340 features[[1, 0]] = 1.0;
1341 features[[2, 0]] = 1.0;
1342
1343 for i in 0..5 {
1344 features[[i, 1]] = 1.0;
1345 }
1346
1347 features[[0, 2]] = 1.0;
1348
1349 let mut selector = TextFeatureSelector::new()
1350 .set_min_df(0.25)
1351 .expect("Operation failed")
1352 .set_max_df(0.75)
1353 .expect("Operation failed");
1354
1355 let filtered = selector.fit_transform(&features).expect("Operation failed");
1356 assert_eq!(filtered.ncols(), 1);
1357 }
1358
1359 #[test]
1360 fn test_classification_metrics() {
1361 let predictions = vec![1_usize, 0, 1, 1, 0];
1362 let true_labels = vec![1_usize, 0, 1, 0, 0];
1363
1364 let metrics = TextClassificationMetrics::new();
1365 let accuracy = metrics
1366 .accuracy(&predictions, &true_labels)
1367 .expect("Operation failed");
1368 assert_eq!(accuracy, 0.8);
1369
1370 let (precision, recall, f1) = metrics
1371 .binary_metrics(&predictions, &true_labels)
1372 .expect("Operation failed");
1373 assert!((precision - 0.667).abs() < 0.001);
1374 assert_eq!(recall, 1.0);
1375 assert!((f1 - 0.8).abs() < 0.001);
1376 }
1377
1378 #[test]
1381 fn test_multinomial_nb_basic() {
1382 let features = Array2::from_shape_vec(
1384 (6, 3),
1385 vec![
1386 3.0, 1.0, 0.0, 2.0, 2.0, 0.0, 4.0, 0.0, 1.0, 0.0, 1.0, 3.0, 0.0, 2.0, 2.0, 1.0, 0.0, 4.0, ],
1393 )
1394 .expect("shape");
1395
1396 let labels = vec![
1397 "pos".to_string(),
1398 "pos".to_string(),
1399 "pos".to_string(),
1400 "neg".to_string(),
1401 "neg".to_string(),
1402 "neg".to_string(),
1403 ];
1404
1405 let mut clf = MultinomialNaiveBayes::new(1.0);
1406 clf.fit(&features, &labels).expect("fit");
1407
1408 let test = Array2::from_shape_vec((1, 3), vec![5.0, 0.0, 0.0]).expect("shape");
1410 let pred = clf.predict(&test).expect("predict");
1411 assert_eq!(pred[0], "pos");
1412
1413 let test = Array2::from_shape_vec((1, 3), vec![0.0, 0.0, 5.0]).expect("shape");
1415 let pred = clf.predict(&test).expect("predict");
1416 assert_eq!(pred[0], "neg");
1417 }
1418
1419 #[test]
1422 fn test_bernoulli_nb_basic() {
1423 let features = Array2::from_shape_vec(
1424 (6, 4),
1425 vec![
1426 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, ],
1433 )
1434 .expect("shape");
1435
1436 let labels = vec![
1437 "pos".to_string(),
1438 "pos".to_string(),
1439 "pos".to_string(),
1440 "neg".to_string(),
1441 "neg".to_string(),
1442 "neg".to_string(),
1443 ];
1444
1445 let mut clf = BernoulliNaiveBayes::new(1.0);
1446 clf.fit(&features, &labels).expect("fit");
1447
1448 let test = Array2::from_shape_vec((1, 4), vec![1.0, 1.0, 0.0, 0.0]).expect("shape");
1449 let pred = clf.predict(&test).expect("predict");
1450 assert_eq!(pred[0], "pos");
1451
1452 let test = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 1.0]).expect("shape");
1453 let pred = clf.predict(&test).expect("predict");
1454 assert_eq!(pred[0], "neg");
1455 }
1456
1457 #[test]
1460 fn test_tfidf_cosine_classifier() {
1461 let features = Array2::from_shape_vec(
1462 (4, 3),
1463 vec![
1464 1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 1.0, 0.1, 0.0, 0.9, ],
1469 )
1470 .expect("shape");
1471
1472 let labels = vec![
1473 "A".to_string(),
1474 "A".to_string(),
1475 "B".to_string(),
1476 "B".to_string(),
1477 ];
1478
1479 let mut clf = TfidfCosineClassifier::new(1);
1480 clf.fit(&features, &labels).expect("fit");
1481
1482 let test = Array2::from_shape_vec((1, 3), vec![0.8, 0.2, 0.0]).expect("shape");
1483 let pred = clf.predict(&test).expect("predict");
1484 assert_eq!(pred[0], "A");
1485 }
1486
1487 #[test]
1490 fn test_feature_hasher() {
1491 let hasher = FeatureHasher::new(100);
1492
1493 let features = hasher.transform_text("the quick brown fox").expect("hash");
1494 assert_eq!(features.len(), 100);
1495
1496 let nnz = features.iter().filter(|&&v| v != 0.0).count();
1498 assert!(nnz > 0);
1499 }
1500
1501 #[test]
1502 fn test_feature_hasher_batch() {
1503 let hasher = FeatureHasher::new(50);
1504
1505 let texts = vec!["hello world", "foo bar baz"];
1506 let matrix = hasher.transform_batch(&texts).expect("batch");
1507
1508 assert_eq!(matrix.nrows(), 2);
1509 assert_eq!(matrix.ncols(), 50);
1510 }
1511
1512 #[test]
1513 fn test_feature_hasher_deterministic() {
1514 let hasher = FeatureHasher::new(100);
1515
1516 let f1 = hasher.transform_text("hello world").expect("h1");
1517 let f2 = hasher.transform_text("hello world").expect("h2");
1518
1519 for i in 0..100 {
1520 assert_eq!(f1[i], f2[i]);
1521 }
1522 }
1523
1524 #[test]
1527 fn test_multi_label_classifier() {
1528 let features = Array2::from_shape_vec(
1529 (4, 3),
1530 vec![
1531 3.0, 1.0, 0.0, 2.0, 2.0, 0.0, 0.0, 1.0, 3.0, 0.0, 0.0, 4.0, ],
1536 )
1537 .expect("shape");
1538
1539 let label_sets = vec![
1540 vec!["sports".to_string(), "positive".to_string()],
1541 vec!["sports".to_string()],
1542 vec!["tech".to_string(), "negative".to_string()],
1543 vec!["tech".to_string()],
1544 ];
1545
1546 let mut clf = MultiLabelClassifier::new(0.5);
1547 clf.fit(&features, &label_sets).expect("fit");
1548
1549 let test = Array2::from_shape_vec((1, 3), vec![4.0, 0.0, 0.0]).expect("shape");
1550 let preds = clf.predict(&test).expect("predict");
1551 assert!(!preds.is_empty());
1552 }
1554
1555 #[test]
1558 fn test_cross_validation() {
1559 let n = 20;
1561 let features = Array2::from_shape_fn((n, 2), |(i, j)| {
1562 if i < n / 2 {
1563 if j == 0 {
1564 3.0
1565 } else {
1566 0.0
1567 }
1568 } else {
1569 if j == 0 {
1570 0.0
1571 } else {
1572 3.0
1573 }
1574 }
1575 });
1576
1577 let labels: Vec<String> = (0..n)
1578 .map(|i| {
1579 if i < n / 2 {
1580 "A".to_string()
1581 } else {
1582 "B".to_string()
1583 }
1584 })
1585 .collect();
1586
1587 let result = cross_validate_nb(&features, &labels, 5, 1.0, Some(42)).expect("cv");
1588
1589 assert_eq!(result.fold_results.len(), 5);
1590 assert!(
1592 result.mean_accuracy >= 0.5,
1593 "Mean accuracy: {}",
1594 result.mean_accuracy
1595 );
1596 }
1597
1598 #[test]
1599 fn test_cross_validation_invalid_k() {
1600 let features = Array2::zeros((5, 2));
1601 let labels = vec!["A".to_string(); 5];
1602
1603 let result = cross_validate_nb(&features, &labels, 1, 1.0, None);
1604 assert!(result.is_err());
1605
1606 let result = cross_validate_nb(&features, &labels, 10, 1.0, None);
1607 assert!(result.is_err());
1608 }
1609}