Skip to main content

scry_learn/text/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Text processing and feature extraction for NLP tasks.
3//!
4//! Provides tokenization, count-based vectorization, and TF-IDF weighting.
5//! All vectorizers produce sparse CSR matrices via [`crate::sparse::CsrMatrix`].
6//!
7//! # Example
8//!
9//! ```ignore
10//! use scry_learn::text::{CountVectorizer, TfidfVectorizer};
11//!
12//! let docs = ["the cat sat", "the dog sat", "the cat played"];
13//!
14//! // Count vectorizer
15//! let mut cv = CountVectorizer::new();
16//! let counts = cv.fit_transform(&docs);
17//!
18//! // TF-IDF vectorizer
19//! let mut tfidf = TfidfVectorizer::new();
20//! let matrix = tfidf.fit_transform(&docs);
21//! ```
22
23pub mod count;
24pub mod tfidf;
25pub mod tokenizer;
26
27pub use count::CountVectorizer;
28pub use tfidf::{TfidfNorm, TfidfVectorizer};
29
30use crate::dataset::Dataset;
31use crate::sparse::CsrMatrix;
32
33/// Convert a sparse CSR matrix (from a text vectorizer) into a [`Dataset`].
34///
35/// The CsrMatrix is row-major (documents × features). This function
36/// transposes into column-major format and attaches the provided target
37/// vector and feature names.
38///
39/// # Example
40///
41/// ```ignore
42/// use scry_learn::text::{CountVectorizer, sparse_to_dataset};
43///
44/// let docs = ["good movie", "bad movie", "good film"];
45/// let target = vec![1.0, 0.0, 1.0];
46///
47/// let mut cv = CountVectorizer::new();
48/// let matrix = cv.fit_transform(&docs);
49/// let dataset = sparse_to_dataset(&matrix, target, cv.get_feature_names(), "label");
50/// ```
51pub fn sparse_to_dataset(
52    matrix: &CsrMatrix,
53    target: Vec<f64>,
54    feature_names: Vec<String>,
55    target_name: &str,
56) -> Dataset {
57    let n_rows = matrix.n_rows();
58    let n_cols = matrix.n_cols();
59    assert_eq!(
60        target.len(),
61        n_rows,
62        "target length must match number of documents"
63    );
64
65    // Convert row-major sparse to column-major dense.
66    let dense_rows = matrix.to_dense(); // Vec<Vec<f64>>, [n_rows][n_cols]
67    let mut features_col_major = vec![vec![0.0; n_rows]; n_cols];
68    for (i, row) in dense_rows.iter().enumerate() {
69        for (j, &val) in row.iter().enumerate() {
70            features_col_major[j][i] = val;
71        }
72    }
73
74    Dataset::new(features_col_major, target, feature_names, target_name)
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn sparse_to_dataset_roundtrip() {
83        let docs = ["good movie", "bad movie", "good film", "bad film"];
84        let target = vec![1.0, 0.0, 1.0, 0.0];
85
86        let mut cv = CountVectorizer::new();
87        let matrix = cv.fit_transform(&docs);
88        let dataset =
89            sparse_to_dataset(&matrix, target.clone(), cv.get_feature_names(), "sentiment");
90
91        assert_eq!(dataset.n_samples(), 4);
92        assert_eq!(dataset.n_features(), cv.n_features());
93        assert_eq!(dataset.target, target);
94    }
95
96    #[test]
97    fn sparse_to_dataset_feeds_into_multinomial_nb() {
98        let docs = [
99            "good great awesome",
100            "good nice wonderful",
101            "bad terrible awful",
102            "bad horrible nasty",
103            "good fantastic",
104            "bad disgusting",
105        ];
106        let target = vec![1.0, 1.0, 0.0, 0.0, 1.0, 0.0];
107
108        let mut cv = CountVectorizer::new();
109        let matrix = cv.fit_transform(&docs);
110        let dataset = sparse_to_dataset(&matrix, target, cv.get_feature_names(), "sentiment");
111
112        let mut nb = crate::naive_bayes::MultinomialNB::new();
113        nb.fit(&dataset).unwrap();
114
115        // Predict on training data (just checking it doesn't crash).
116        let rows = dataset.feature_matrix();
117        let preds = nb.predict(&rows).unwrap();
118        assert_eq!(preds.len(), 6);
119    }
120
121    #[test]
122    fn tfidf_to_dataset_feeds_into_logistic() {
123        let docs = [
124            "good great awesome",
125            "good nice wonderful",
126            "bad terrible awful",
127            "bad horrible nasty",
128            "good fantastic nice",
129            "bad disgusting terrible",
130        ];
131        let target = vec![1.0, 1.0, 0.0, 0.0, 1.0, 0.0];
132
133        let mut tfidf = TfidfVectorizer::new();
134        let matrix = tfidf.fit_transform(&docs);
135        let dataset = sparse_to_dataset(&matrix, target, tfidf.get_feature_names(), "sentiment");
136
137        let mut lr = crate::linear::LogisticRegression::new();
138        lr.fit(&dataset).unwrap();
139
140        let rows = dataset.feature_matrix();
141        let preds = lr.predict(&rows).unwrap();
142        assert_eq!(preds.len(), 6);
143    }
144}