Skip to main content

scirs2_text/
text_classification.rs

1//! Advanced text classification module.
2//!
3//! This module provides three complementary classifiers:
4//!
5//! - [`NaiveBayesClassifier`] — multinomial Naïve Bayes with Laplace smoothing.
6//! - [`TextCnnLite`] — convolutional n-gram feature extraction followed by
7//!   multinomial logistic regression (gradient-descent trained).
8//! - [`TfIdfLogisticClassifier`] — TF-IDF features + multinomial logistic
9//!   regression trained with mini-batch gradient descent.
10//!
11//! All classifiers implement the same `fit / predict / predict_proba` pattern
12//! and never use `unwrap()`.
13
14use crate::error::{Result, TextError};
15use crate::tokenize::{Tokenizer, WordTokenizer};
16use crate::vectorize::{TfidfVectorizer, Vectorizer};
17use scirs2_core::ndarray::{Array1, Array2, Axis};
18use std::collections::HashMap;
19
20// ---------------------------------------------------------------------------
21// Tokenisation helper
22// ---------------------------------------------------------------------------
23
24/// Tokenise `text` into lowercased word tokens.
25fn tokenize_lower(text: &str) -> Vec<String> {
26    let tokenizer = WordTokenizer::default();
27    tokenizer
28        .tokenize(text)
29        .unwrap_or_default()
30        .into_iter()
31        .map(|t| t.to_lowercase())
32        .collect()
33}
34
35// ---------------------------------------------------------------------------
36// NaiveBayesClassifier
37// ---------------------------------------------------------------------------
38
39/// Multinomial Naïve Bayes text classifier with Laplace (additive) smoothing.
40///
41/// Internally works in log-space to prevent floating-point underflow.
42///
43/// # Example
44///
45/// ```rust
46/// use scirs2_text::text_classification::NaiveBayesClassifier;
47///
48/// let mut clf = NaiveBayesClassifier::new(1.0);
49/// let texts  = &["spam spam buy now", "hello friend good morning"];
50/// let labels = &["spam", "ham"];
51/// clf.fit(texts, labels).unwrap();
52/// assert_eq!(clf.predict("buy now cheap").unwrap(), "spam");
53/// ```
54pub struct NaiveBayesClassifier {
55    /// Vocabulary: word → column index
56    vocabulary: HashMap<String, usize>,
57    /// Log prior probability for each class: `log P(class)`
58    class_log_priors: Vec<f64>,
59    /// Log word probability for each class:
60    /// `class_word_log_probs[c][w] = log P(word w | class c)`
61    class_word_log_probs: Vec<Vec<f64>>,
62    /// Ordered class names
63    classes: Vec<String>,
64    /// Laplace smoothing parameter α (> 0)
65    alpha: f64,
66}
67
68impl NaiveBayesClassifier {
69    /// Create a new classifier with smoothing parameter `alpha`.
70    ///
71    /// `alpha = 1.0` is the standard Laplace (add-one) smoothing.
72    ///
73    /// # Errors
74    ///
75    /// Returns [`TextError::InvalidInput`] when `alpha <= 0`.
76    pub fn new(alpha: f64) -> Self {
77        Self {
78            vocabulary: HashMap::new(),
79            class_log_priors: Vec::new(),
80            class_word_log_probs: Vec::new(),
81            classes: Vec::new(),
82            alpha,
83        }
84    }
85
86    /// Train the classifier.
87    ///
88    /// # Errors
89    ///
90    /// Returns [`TextError::InvalidInput`] when:
91    /// - `texts` and `labels` have different lengths.
92    /// - The corpus is empty.
93    /// - `alpha <= 0`.
94    pub fn fit(&mut self, texts: &[&str], labels: &[&str]) -> Result<()> {
95        if texts.len() != labels.len() {
96            return Err(TextError::InvalidInput(format!(
97                "texts ({}) and labels ({}) must have the same length",
98                texts.len(),
99                labels.len()
100            )));
101        }
102        if texts.is_empty() {
103            return Err(TextError::InvalidInput("Empty training corpus".to_string()));
104        }
105        if self.alpha <= 0.0 {
106            return Err(TextError::InvalidInput(
107                "alpha must be positive".to_string(),
108            ));
109        }
110
111        // --- build vocabulary and class index ---
112        let mut class_index: HashMap<String, usize> = HashMap::new();
113        for &label in labels {
114            let n = class_index.len();
115            class_index.entry(label.to_string()).or_insert(n);
116        }
117
118        let n_classes = class_index.len();
119        let mut class_names: Vec<String> = vec![String::new(); n_classes];
120        for (name, &idx) in &class_index {
121            class_names[idx] = name.clone();
122        }
123
124        let mut class_doc_counts = vec![0usize; n_classes];
125        // word counts per class: class → (word → count)
126        let mut class_word_counts: Vec<HashMap<String, f64>> =
127            (0..n_classes).map(|_| HashMap::new()).collect();
128
129        for (&text, &label) in texts.iter().zip(labels.iter()) {
130            let class_idx = *class_index
131                .get(label)
132                .ok_or_else(|| TextError::InvalidInput(format!("Unknown label '{label}'")))?;
133            class_doc_counts[class_idx] += 1;
134
135            for word in tokenize_lower(text) {
136                // Add to vocabulary
137                let n = self.vocabulary.len();
138                self.vocabulary.entry(word.clone()).or_insert(n);
139                *class_word_counts[class_idx].entry(word).or_insert(0.0) += 1.0;
140            }
141        }
142
143        let n_docs = texts.len() as f64;
144        let vocab_size = self.vocabulary.len();
145
146        // Log priors
147        let class_log_priors: Vec<f64> = class_doc_counts
148            .iter()
149            .map(|&c| (c as f64 / n_docs).ln())
150            .collect();
151
152        // Log word probabilities: P(w|c) = (count(w,c) + alpha) / (total_words(c) + alpha * V)
153        let mut class_word_log_probs: Vec<Vec<f64>> =
154            (0..n_classes).map(|_| vec![0.0; vocab_size]).collect();
155
156        for c in 0..n_classes {
157            let total: f64 = class_word_counts[c].values().sum();
158            let denom = total + self.alpha * vocab_size as f64;
159
160            for (word, &col_idx) in &self.vocabulary {
161                let cnt = class_word_counts[c].get(word).copied().unwrap_or(0.0);
162                class_word_log_probs[c][col_idx] = ((cnt + self.alpha) / denom).ln();
163            }
164        }
165
166        self.classes = class_names;
167        self.class_log_priors = class_log_priors;
168        self.class_word_log_probs = class_word_log_probs;
169
170        Ok(())
171    }
172
173    /// Predict the most probable class for `text`.
174    ///
175    /// # Errors
176    ///
177    /// Returns [`TextError::ModelNotFitted`] if `fit` has not been called.
178    pub fn predict(&self, text: &str) -> Result<String> {
179        let proba = self.predict_log_proba(text)?;
180        let best = proba
181            .iter()
182            .enumerate()
183            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
184            .map(|(i, _)| i)
185            .ok_or_else(|| TextError::ModelNotFitted("No classes available".to_string()))?;
186        Ok(self.classes[best].clone())
187    }
188
189    /// Return `(class_name, probability)` pairs sorted by probability descending.
190    ///
191    /// Probabilities are derived from log-scores via softmax normalisation.
192    ///
193    /// # Errors
194    ///
195    /// Returns [`TextError::ModelNotFitted`] if `fit` has not been called.
196    pub fn predict_proba(&self, text: &str) -> Result<Vec<(String, f64)>> {
197        let log_proba = self.predict_log_proba(text)?;
198
199        // Softmax
200        let max_val = log_proba.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
201        let exps: Vec<f64> = log_proba.iter().map(|&v| (v - max_val).exp()).collect();
202        let sum: f64 = exps.iter().sum();
203
204        let mut result: Vec<(String, f64)> = self
205            .classes
206            .iter()
207            .zip(exps.iter())
208            .map(|(name, &e)| (name.clone(), e / sum))
209            .collect();
210
211        result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212
213        Ok(result)
214    }
215
216    /// Compute accuracy on a labelled test set.
217    ///
218    /// # Errors
219    ///
220    /// Returns [`TextError::InvalidInput`] when lengths differ.
221    pub fn score(&self, texts: &[&str], labels: &[&str]) -> Result<f64> {
222        if texts.len() != labels.len() {
223            return Err(TextError::InvalidInput(
224                "texts and labels must have the same length".to_string(),
225            ));
226        }
227        if texts.is_empty() {
228            return Ok(0.0);
229        }
230
231        let mut correct = 0usize;
232        for (&text, &label) in texts.iter().zip(labels.iter()) {
233            if let Ok(pred) = self.predict(text) {
234                if pred == label {
235                    correct += 1;
236                }
237            }
238        }
239
240        Ok(correct as f64 / texts.len() as f64)
241    }
242
243    // ------------------------------------------------------------------
244    // Internal helpers
245    // ------------------------------------------------------------------
246
247    fn predict_log_proba(&self, text: &str) -> Result<Vec<f64>> {
248        if self.classes.is_empty() {
249            return Err(TextError::ModelNotFitted(
250                "NaiveBayesClassifier has not been fitted yet".to_string(),
251            ));
252        }
253
254        let tokens = tokenize_lower(text);
255        let n_classes = self.classes.len();
256        let mut log_scores: Vec<f64> = self.class_log_priors.clone();
257
258        for word in &tokens {
259            if let Some(&col) = self.vocabulary.get(word) {
260                for c in 0..n_classes {
261                    log_scores[c] += self.class_word_log_probs[c][col];
262                }
263            }
264        }
265
266        Ok(log_scores)
267    }
268}
269
270// ---------------------------------------------------------------------------
271// TextCnnLite
272// ---------------------------------------------------------------------------
273
274/// Lightweight text-CNN classifier.
275///
276/// Extracts n-gram features (bag-of-n-grams) for each filter size and then
277/// trains a multinomial logistic regression head using mini-batch stochastic
278/// gradient descent.
279///
280/// This is intentionally a "lite" variant — it approximates a CNN's 1-D
281/// convolution by computing frequency counts of character/word n-grams and
282/// using the max-pooled (most-frequent) value as the feature.  A proper
283/// convolutional network would require a full neural-network framework.
284///
285/// # Example
286///
287/// ```rust
288/// use scirs2_text::text_classification::TextCnnLite;
289///
290/// let mut clf = TextCnnLite::new(vec![2, 3], 8);
291/// let texts  = &["good movie fun", "bad film boring", "great show entertaining", "terrible awful waste"];
292/// let labels = &["pos", "neg", "pos", "neg"];
293/// clf.fit(texts, labels, 20).unwrap();
294/// ```
295pub struct TextCnnLite {
296    /// Filter window sizes (in words).
297    filter_sizes: Vec<usize>,
298    /// Number of feature maps per filter size (for n-gram counting, this sets
299    /// the top-k cutoff for the feature vector).
300    n_filters: usize,
301    /// Vocabulary: n-gram → feature index
302    vocab: HashMap<String, usize>,
303    /// Logistic regression weights: `weights[c][f]`
304    weights: Vec<Vec<f64>>,
305    /// Per-class bias terms
306    bias: Vec<f64>,
307    /// Ordered class names
308    classes: Vec<String>,
309}
310
311impl TextCnnLite {
312    /// Create a new `TextCnnLite`.
313    ///
314    /// # Parameters
315    ///
316    /// - `filter_sizes`: n-gram window sizes, e.g. `[2, 3, 4]`.
317    /// - `n_filters`: number of top n-gram features to keep per filter size.
318    pub fn new(filter_sizes: Vec<usize>, n_filters: usize) -> Self {
319        Self {
320            filter_sizes,
321            n_filters,
322            vocab: HashMap::new(),
323            weights: Vec::new(),
324            bias: Vec::new(),
325            classes: Vec::new(),
326        }
327    }
328
329    /// Train the classifier.
330    ///
331    /// Internally:
332    /// 1. Build the n-gram vocabulary from all documents.
333    /// 2. Vectorise each document into a frequency vector.
334    /// 3. Train multinomial logistic regression with SGD.
335    ///
336    /// # Errors
337    ///
338    /// Returns [`TextError::InvalidInput`] when `texts` and `labels` lengths differ
339    /// or when the corpus is empty.
340    pub fn fit(&mut self, texts: &[&str], labels: &[&str], epochs: usize) -> Result<()> {
341        if texts.len() != labels.len() {
342            return Err(TextError::InvalidInput(format!(
343                "texts ({}) and labels ({}) must have the same length",
344                texts.len(),
345                labels.len()
346            )));
347        }
348        if texts.is_empty() {
349            return Err(TextError::InvalidInput("Empty training corpus".to_string()));
350        }
351
352        // --- class mapping ---
353        let mut class_index: HashMap<String, usize> = HashMap::new();
354        for &label in labels {
355            let n = class_index.len();
356            class_index.entry(label.to_string()).or_insert(n);
357        }
358        let n_classes = class_index.len();
359        let mut class_names: Vec<String> = vec![String::new(); n_classes];
360        for (name, &idx) in &class_index {
361            class_names[idx] = name.clone();
362        }
363
364        // --- build n-gram vocabulary ---
365        let mut ngram_counts: HashMap<String, usize> = HashMap::new();
366        for &text in texts {
367            let words = tokenize_lower(text);
368            for &size in &self.filter_sizes {
369                for ngram in ngrams(&words, size) {
370                    *ngram_counts.entry(ngram).or_insert(0) += 1;
371                }
372            }
373        }
374
375        // Keep top n_filters * filter_sizes.len() n-grams by frequency
376        let mut ngram_vec: Vec<(String, usize)> = ngram_counts.into_iter().collect();
377        ngram_vec.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
378        let max_feats = self.n_filters * self.filter_sizes.len();
379        self.vocab = ngram_vec
380            .into_iter()
381            .take(max_feats)
382            .enumerate()
383            .map(|(i, (ng, _))| (ng, i))
384            .collect();
385
386        let n_features = self.vocab.len();
387        if n_features == 0 {
388            return Err(TextError::InvalidInput(
389                "No n-gram features found in corpus".to_string(),
390            ));
391        }
392
393        // --- vectorise documents ---
394        let mut x_data: Vec<Vec<f64>> = Vec::with_capacity(texts.len());
395        let mut y_labels: Vec<usize> = Vec::with_capacity(texts.len());
396
397        for (&text, &label) in texts.iter().zip(labels.iter()) {
398            x_data.push(self.vectorize(text));
399            let class_idx = *class_index
400                .get(label)
401                .ok_or_else(|| TextError::InvalidInput(format!("Unknown label '{label}'")))?;
402            y_labels.push(class_idx);
403        }
404
405        // --- initialise weights ---
406        self.weights = vec![vec![0.0f64; n_features]; n_classes];
407        self.bias = vec![0.0f64; n_classes];
408
409        // --- mini-batch SGD ---
410        let lr = 0.1_f64;
411        let n_samples = texts.len();
412
413        for _epoch in 0..epochs {
414            for i in 0..n_samples {
415                let x = &x_data[i];
416                let y = y_labels[i];
417
418                // Forward: compute logits and softmax
419                let logits: Vec<f64> = (0..n_classes)
420                    .map(|c| {
421                        self.bias[c]
422                            + x.iter()
423                                .zip(self.weights[c].iter())
424                                .map(|(xi, wi)| xi * wi)
425                                .sum::<f64>()
426                    })
427                    .collect();
428
429                let probs = softmax(&logits);
430
431                // Backward: cross-entropy gradient
432                for c in 0..n_classes {
433                    let delta = probs[c] - if c == y { 1.0 } else { 0.0 };
434                    self.bias[c] -= lr * delta;
435                    for (j, &xj) in x.iter().enumerate() {
436                        self.weights[c][j] -= lr * delta * xj;
437                    }
438                }
439            }
440        }
441
442        self.classes = class_names;
443        Ok(())
444    }
445
446    /// Predict the most probable class for `text`.
447    ///
448    /// # Errors
449    ///
450    /// Returns [`TextError::ModelNotFitted`] when the model has not been trained.
451    pub fn predict(&self, text: &str) -> Result<String> {
452        if self.classes.is_empty() {
453            return Err(TextError::ModelNotFitted(
454                "TextCnnLite has not been fitted yet".to_string(),
455            ));
456        }
457        let x = self.vectorize(text);
458        let n_classes = self.classes.len();
459        let logits: Vec<f64> = (0..n_classes)
460            .map(|c| {
461                self.bias[c]
462                    + x.iter()
463                        .zip(self.weights[c].iter())
464                        .map(|(xi, wi)| xi * wi)
465                        .sum::<f64>()
466            })
467            .collect();
468
469        let best = logits
470            .iter()
471            .enumerate()
472            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
473            .map(|(i, _)| i)
474            .ok_or_else(|| TextError::ModelNotFitted("No classes available".to_string()))?;
475
476        Ok(self.classes[best].clone())
477    }
478
479    // ------------------------------------------------------------------
480    // Internal helpers
481    // ------------------------------------------------------------------
482
483    /// Compute the n-gram frequency feature vector for `text`.
484    fn vectorize(&self, text: &str) -> Vec<f64> {
485        let mut v = vec![0.0f64; self.vocab.len()];
486        let words = tokenize_lower(text);
487        for &size in &self.filter_sizes {
488            for ngram in ngrams(&words, size) {
489                if let Some(&idx) = self.vocab.get(&ngram) {
490                    v[idx] += 1.0;
491                }
492            }
493        }
494        v
495    }
496}
497
498// ---------------------------------------------------------------------------
499// TfIdfLogisticClassifier
500// ---------------------------------------------------------------------------
501
502/// Logistic regression classifier trained on TF-IDF features.
503///
504/// Uses the existing [`TfidfVectorizer`] to compute features and mini-batch
505/// stochastic gradient descent (SGD) to train the model.
506///
507/// # Example
508///
509/// ```rust
510/// use scirs2_text::text_classification::TfIdfLogisticClassifier;
511///
512/// let mut clf = TfIdfLogisticClassifier::new();
513/// let texts  = &["good great excellent", "bad terrible awful", "okay decent fine", "poor mediocre subpar"];
514/// let labels = &["pos", "neg", "pos", "neg"];
515/// clf.fit(texts, labels, 50, 0.1).unwrap();
516/// let pred = clf.predict("excellent wonderful").unwrap();
517/// assert!(!pred.is_empty());
518/// ```
519pub struct TfIdfLogisticClassifier {
520    /// Underlying TF-IDF vectorizer.
521    vectorizer: TfidfVectorizer,
522    /// Weight matrix: `weights[class_idx]` is a vector over vocabulary.
523    weights: Vec<Vec<f64>>,
524    /// Per-class bias terms.
525    bias: Vec<f64>,
526    /// Ordered class names.
527    classes: Vec<String>,
528    /// Whether the vectorizer has been fitted.
529    fitted: bool,
530}
531
532impl Default for TfIdfLogisticClassifier {
533    fn default() -> Self {
534        Self::new()
535    }
536}
537
538impl TfIdfLogisticClassifier {
539    /// Create a new classifier with default TF-IDF settings.
540    pub fn new() -> Self {
541        Self {
542            vectorizer: TfidfVectorizer::new(false, true, Some("l2".to_string())),
543            weights: Vec::new(),
544            bias: Vec::new(),
545            classes: Vec::new(),
546            fitted: false,
547        }
548    }
549
550    /// Fit the classifier.
551    ///
552    /// # Parameters
553    ///
554    /// - `texts`: Training documents.
555    /// - `labels`: Corresponding class labels.
556    /// - `max_iter`: Number of SGD epochs.
557    /// - `lr`: Learning rate for SGD.
558    ///
559    /// # Errors
560    ///
561    /// Returns [`TextError::InvalidInput`] when inputs are inconsistent.
562    pub fn fit(&mut self, texts: &[&str], labels: &[&str], max_iter: usize, lr: f64) -> Result<()> {
563        if texts.len() != labels.len() {
564            return Err(TextError::InvalidInput(format!(
565                "texts ({}) and labels ({}) must have the same length",
566                texts.len(),
567                labels.len()
568            )));
569        }
570        if texts.is_empty() {
571            return Err(TextError::InvalidInput("Empty training corpus".to_string()));
572        }
573        if lr <= 0.0 {
574            return Err(TextError::InvalidInput(
575                "Learning rate must be positive".to_string(),
576            ));
577        }
578
579        // --- build class mapping ---
580        let mut class_index: HashMap<String, usize> = HashMap::new();
581        for &label in labels {
582            let n = class_index.len();
583            class_index.entry(label.to_string()).or_insert(n);
584        }
585        let n_classes = class_index.len();
586        let mut class_names: Vec<String> = vec![String::new(); n_classes];
587        for (name, &idx) in &class_index {
588            class_names[idx] = name.clone();
589        }
590
591        // --- vectorise ---
592        let x_matrix = self.vectorizer.fit_transform(texts)?;
593        let n_features = x_matrix.ncols();
594
595        let y_labels: Vec<usize> = labels
596            .iter()
597            .map(|&label| {
598                class_index
599                    .get(label)
600                    .copied()
601                    .ok_or_else(|| TextError::InvalidInput(format!("Unknown label '{label}'")))
602            })
603            .collect::<Result<_>>()?;
604
605        // --- initialise weights ---
606        self.weights = vec![vec![0.0f64; n_features]; n_classes];
607        self.bias = vec![0.0f64; n_classes];
608
609        // --- SGD ---
610        let n_samples = x_matrix.nrows();
611
612        for _epoch in 0..max_iter {
613            for i in 0..n_samples {
614                let x_row = x_matrix.row(i);
615                let y = y_labels[i];
616
617                // Logits
618                let logits: Vec<f64> = (0..n_classes)
619                    .map(|c| {
620                        self.bias[c]
621                            + x_row
622                                .iter()
623                                .zip(self.weights[c].iter())
624                                .map(|(xi, wi)| xi * wi)
625                                .sum::<f64>()
626                    })
627                    .collect();
628
629                let probs = softmax(&logits);
630
631                // Gradient update
632                for c in 0..n_classes {
633                    let delta = probs[c] - if c == y { 1.0 } else { 0.0 };
634                    self.bias[c] -= lr * delta;
635                    for (j, &xj) in x_row.iter().enumerate() {
636                        self.weights[c][j] -= lr * delta * xj;
637                    }
638                }
639            }
640        }
641
642        self.classes = class_names;
643        self.fitted = true;
644        Ok(())
645    }
646
647    /// Predict the most probable class for `text`.
648    ///
649    /// # Errors
650    ///
651    /// Returns [`TextError::ModelNotFitted`] when the model has not been trained.
652    pub fn predict(&self, text: &str) -> Result<String> {
653        let probs = self.predict_proba(text)?;
654        probs
655            .iter()
656            .enumerate()
657            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
658            .map(|(i, _)| self.classes[i].clone())
659            .ok_or_else(|| TextError::ModelNotFitted("No classes available".to_string()))
660    }
661
662    /// Return the probability distribution over classes for `text`.
663    ///
664    /// # Errors
665    ///
666    /// Returns [`TextError::ModelNotFitted`] when the model has not been trained.
667    pub fn predict_proba(&self, text: &str) -> Result<Vec<f64>> {
668        if !self.fitted {
669            return Err(TextError::ModelNotFitted(
670                "TfIdfLogisticClassifier has not been fitted yet".to_string(),
671            ));
672        }
673
674        let x_vec = self.vectorizer.transform(text)?;
675        let n_classes = self.classes.len();
676
677        let logits: Vec<f64> = (0..n_classes)
678            .map(|c| {
679                self.bias[c]
680                    + x_vec
681                        .iter()
682                        .zip(self.weights[c].iter())
683                        .map(|(xi, wi)| xi * wi)
684                        .sum::<f64>()
685            })
686            .collect();
687
688        Ok(softmax(&logits))
689    }
690}
691
692// ---------------------------------------------------------------------------
693// Free helper functions
694// ---------------------------------------------------------------------------
695
696/// Compute word n-grams of `size` from `words`.
697fn ngrams(words: &[String], size: usize) -> Vec<String> {
698    if size == 0 || words.len() < size {
699        return Vec::new();
700    }
701    (0..=words.len() - size)
702        .map(|i| words[i..i + size].join(" "))
703        .collect()
704}
705
706/// Numerically stable softmax.
707fn softmax(logits: &[f64]) -> Vec<f64> {
708    if logits.is_empty() {
709        return Vec::new();
710    }
711    let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
712    let exps: Vec<f64> = logits.iter().map(|&v| (v - max_val).exp()).collect();
713    let sum: f64 = exps.iter().sum();
714    if sum == 0.0 {
715        return vec![1.0 / logits.len() as f64; logits.len()];
716    }
717    exps.iter().map(|&e| e / sum).collect()
718}
719
720// ---------------------------------------------------------------------------
721// Tests
722// ---------------------------------------------------------------------------
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727
728    // --- NaiveBayesClassifier tests ---
729
730    #[test]
731    fn test_nb_fit_predict_binary() {
732        let mut clf = NaiveBayesClassifier::new(1.0);
733        let texts = &[
734            "spam buy now cheap discount",
735            "spam offer free money",
736            "hello how are you doing",
737            "good morning friend",
738        ];
739        let labels = &["spam", "spam", "ham", "ham"];
740        clf.fit(texts, labels).expect("fit should succeed");
741
742        let pred = clf
743            .predict("buy cheap spam now")
744            .expect("predict should succeed");
745        assert_eq!(pred, "spam");
746
747        let pred2 = clf
748            .predict("good morning hello")
749            .expect("predict should succeed");
750        assert_eq!(pred2, "ham");
751    }
752
753    #[test]
754    fn test_nb_predict_proba() {
755        let mut clf = NaiveBayesClassifier::new(1.0);
756        let texts = &[
757            "machine learning data science",
758            "deep learning neural network",
759            "cooking recipe food delicious",
760            "restaurant menu chef dinner",
761        ];
762        let labels = &["tech", "tech", "food", "food"];
763        clf.fit(texts, labels).expect("fit should succeed");
764
765        let proba = clf
766            .predict_proba("neural network deep learning")
767            .expect("predict_proba should succeed");
768        assert_eq!(proba.len(), 2);
769
770        // Probabilities sum to ~1.0
771        let total: f64 = proba.iter().map(|(_, p)| p).sum();
772        assert!((total - 1.0).abs() < 1e-9, "probabilities should sum to 1");
773
774        // First entry should be "tech" with highest probability
775        assert_eq!(proba[0].0, "tech");
776        assert!(proba[0].1 > 0.5, "tech probability should exceed 0.5");
777    }
778
779    #[test]
780    fn test_nb_score_above_chance() {
781        let mut clf = NaiveBayesClassifier::new(1.0);
782        let train_texts = &[
783            "positive great excellent wonderful",
784            "positive good happy nice",
785            "positive fantastic brilliant awesome",
786            "negative bad terrible awful",
787            "negative horrible disappointing poor",
788            "negative dreadful appalling dire",
789        ];
790        let train_labels = &[
791            "positive", "positive", "positive", "negative", "negative", "negative",
792        ];
793        clf.fit(train_texts, train_labels)
794            .expect("fit should succeed");
795
796        let test_texts = &["excellent wonderful", "terrible awful bad"];
797        let test_labels = &["positive", "negative"];
798        let acc = clf
799            .score(test_texts, test_labels)
800            .expect("score should succeed");
801
802        assert!(acc > 0.5, "accuracy should exceed chance: {}", acc);
803    }
804
805    #[test]
806    fn test_nb_error_on_empty_corpus() {
807        let mut clf = NaiveBayesClassifier::new(1.0);
808        let result = clf.fit(&[], &[]);
809        assert!(result.is_err());
810    }
811
812    #[test]
813    fn test_nb_error_on_length_mismatch() {
814        let mut clf = NaiveBayesClassifier::new(1.0);
815        let result = clf.fit(&["text"], &["label1", "label2"]);
816        assert!(result.is_err());
817    }
818
819    #[test]
820    fn test_nb_not_fitted_error() {
821        let clf = NaiveBayesClassifier::new(1.0);
822        let result = clf.predict("some text");
823        assert!(result.is_err());
824    }
825
826    #[test]
827    fn test_nb_multiclass() {
828        let mut clf = NaiveBayesClassifier::new(0.5);
829        let texts = &[
830            "soccer football goal kick",
831            "basketball dunk three pointer",
832            "baseball pitcher batting home run",
833            "soccer penalty kick goal",
834            "basketball court dribble shoot",
835            "baseball strike out innings",
836        ];
837        let labels = &[
838            "soccer",
839            "basketball",
840            "baseball",
841            "soccer",
842            "basketball",
843            "baseball",
844        ];
845        clf.fit(texts, labels).expect("fit should succeed");
846
847        let pred = clf
848            .predict("goal kick soccer field")
849            .expect("predict should succeed");
850        assert_eq!(pred, "soccer");
851    }
852
853    // --- TextCnnLite tests ---
854
855    #[test]
856    fn test_cnn_fit_succeeds() {
857        let mut clf = TextCnnLite::new(vec![2, 3], 8);
858        let texts = &[
859            "good movie fun entertaining",
860            "bad film boring terrible",
861            "great show exciting wonderful",
862            "awful program dull waste",
863        ];
864        let labels = &["pos", "neg", "pos", "neg"];
865        let result = clf.fit(texts, labels, 20);
866        assert!(result.is_ok(), "fit should succeed: {:?}", result);
867    }
868
869    #[test]
870    fn test_cnn_predict_returns_valid_class() {
871        let mut clf = TextCnnLite::new(vec![2, 3], 8);
872        let texts = &[
873            "wonderful excellent amazing brilliant",
874            "terrible dreadful awful horrible",
875            "great fantastic superb outstanding",
876            "poor disappointing bad mediocre",
877        ];
878        let labels = &["pos", "neg", "pos", "neg"];
879        clf.fit(texts, labels, 30).expect("fit should succeed");
880
881        let pred = clf
882            .predict("excellent wonderful")
883            .expect("predict should succeed");
884        assert!(
885            pred == "pos" || pred == "neg",
886            "prediction should be a valid class"
887        );
888    }
889
890    #[test]
891    fn test_cnn_not_fitted_error() {
892        let clf = TextCnnLite::new(vec![2], 4);
893        assert!(clf.predict("text").is_err());
894    }
895
896    // --- TfIdfLogisticClassifier tests ---
897
898    #[test]
899    fn test_tfidf_logistic_fit_and_predict() {
900        let mut clf = TfIdfLogisticClassifier::new();
901        let texts = &[
902            "machine learning artificial intelligence",
903            "deep neural network training",
904            "cooking recipe baking flour",
905            "restaurant food delicious chef",
906            "algorithm data science research",
907            "dinner menu ingredients spices",
908        ];
909        let labels = &["tech", "tech", "food", "food", "tech", "food"];
910        clf.fit(texts, labels, 50, 0.1).expect("fit should succeed");
911
912        let pred = clf
913            .predict("neural network algorithm")
914            .expect("predict should succeed");
915        assert_eq!(pred, "tech");
916    }
917
918    #[test]
919    fn test_tfidf_logistic_predict_proba_sums_to_one() {
920        let mut clf = TfIdfLogisticClassifier::new();
921        let texts = &[
922            "positive happy good",
923            "negative sad bad",
924            "positive great excellent",
925            "negative terrible awful",
926        ];
927        let labels = &["pos", "neg", "pos", "neg"];
928        clf.fit(texts, labels, 30, 0.05)
929            .expect("fit should succeed");
930
931        let probs = clf
932            .predict_proba("happy good great")
933            .expect("predict_proba should succeed");
934        let sum: f64 = probs.iter().sum();
935        assert!(
936            (sum - 1.0).abs() < 1e-9,
937            "probabilities must sum to 1, got {}",
938            sum
939        );
940    }
941
942    #[test]
943    fn test_tfidf_logistic_not_fitted_error() {
944        let clf = TfIdfLogisticClassifier::new();
945        assert!(clf.predict("text").is_err());
946    }
947
948    #[test]
949    fn test_tfidf_logistic_error_on_empty_corpus() {
950        let mut clf = TfIdfLogisticClassifier::new();
951        let result = clf.fit(&[], &[], 10, 0.1);
952        assert!(result.is_err());
953    }
954
955    #[test]
956    fn test_tfidf_logistic_error_on_length_mismatch() {
957        let mut clf = TfIdfLogisticClassifier::new();
958        let result = clf.fit(&["text"], &["a", "b"], 10, 0.1);
959        assert!(result.is_err());
960    }
961
962    // --- Utility helpers ---
963
964    #[test]
965    fn test_ngrams_bigrams() {
966        let words: Vec<String> = ["the", "quick", "brown", "fox"]
967            .iter()
968            .map(|s| s.to_string())
969            .collect();
970        let bigrams = ngrams(&words, 2);
971        assert_eq!(bigrams.len(), 3);
972        assert_eq!(bigrams[0], "the quick");
973        assert_eq!(bigrams[2], "brown fox");
974    }
975
976    #[test]
977    fn test_ngrams_empty_when_size_too_large() {
978        let words: Vec<String> = ["a", "b"].iter().map(|s| s.to_string()).collect();
979        let result = ngrams(&words, 5);
980        assert!(result.is_empty());
981    }
982
983    #[test]
984    fn test_softmax_properties() {
985        let logits = vec![1.0, 2.0, 3.0];
986        let probs = softmax(&logits);
987        let sum: f64 = probs.iter().sum();
988        assert!((sum - 1.0).abs() < 1e-9);
989        assert!(probs[2] > probs[1] && probs[1] > probs[0]);
990    }
991
992    #[test]
993    fn test_softmax_large_values_stable() {
994        let logits = vec![1000.0, 1001.0, 999.0];
995        let probs = softmax(&logits);
996        assert!(probs.iter().all(|&p| p.is_finite()));
997        let sum: f64 = probs.iter().sum();
998        assert!((sum - 1.0).abs() < 1e-9);
999    }
1000}