use crate::error::{Result, TextError};
use crate::sparse::{CsrMatrix, SparseMatrixBuilder, SparseVector};
use crate::tokenize::{Tokenizer, WordTokenizer};
use crate::vocabulary::Vocabulary;
use scirs2_core::ndarray::Array1;
use std::collections::HashMap;
pub struct SparseCountVectorizer {
tokenizer: Box<dyn Tokenizer + Send + Sync>,
vocabulary: Vocabulary,
binary: bool,
}
impl Clone for SparseCountVectorizer {
fn clone(&self) -> Self {
Self {
tokenizer: self.tokenizer.clone_box(),
vocabulary: self.vocabulary.clone(),
binary: self.binary,
}
}
}
impl SparseCountVectorizer {
pub fn new(binary: bool) -> Self {
Self {
tokenizer: Box::new(WordTokenizer::default()),
vocabulary: Vocabulary::new(),
binary,
}
}
pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>, binary: bool) -> Self {
Self {
tokenizer,
vocabulary: Vocabulary::new(),
binary,
}
}
pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
if texts.is_empty() {
return Err(TextError::InvalidInput(
"No texts provided for fitting".into(),
));
}
self.vocabulary = Vocabulary::new();
for &text in texts {
let tokens = self.tokenizer.tokenize(text)?;
for token in tokens {
self.vocabulary.add_token(&token);
}
}
Ok(())
}
pub fn transform(&self, text: &str) -> Result<SparseVector> {
let tokens = self.tokenizer.tokenize(text)?;
let mut counts: HashMap<usize, f64> = HashMap::new();
for token in tokens {
if let Some(idx) = self.vocabulary.get_index(&token) {
*counts.entry(idx).or_insert(0.0) += 1.0;
}
}
let mut indices: Vec<usize> = counts.keys().copied().collect();
indices.sort_unstable();
let values: Vec<f64> = if self.binary {
indices.iter().map(|_| 1.0).collect()
} else {
indices.iter().map(|&idx| counts[&idx]).collect()
};
let sparse_vec = SparseVector::fromindices_values(indices, values, self.vocabulary.len());
Ok(sparse_vec)
}
pub fn transform_batch(&self, texts: &[&str]) -> Result<CsrMatrix> {
let n_cols = self.vocabulary.len();
let mut builder = SparseMatrixBuilder::new(n_cols);
for &text in texts {
let sparse_vec = self.transform(text)?;
builder.add_row(sparse_vec)?;
}
Ok(builder.build())
}
pub fn fit_transform(&mut self, texts: &[&str]) -> Result<CsrMatrix> {
self.fit(texts)?;
self.transform_batch(texts)
}
pub fn vocabulary_size(&self) -> usize {
self.vocabulary.len()
}
pub fn vocabulary(&self) -> &Vocabulary {
&self.vocabulary
}
}
#[derive(Clone)]
pub struct SparseTfidfVectorizer {
count_vectorizer: SparseCountVectorizer,
idf: Option<Array1<f64>>,
useidf: bool,
norm: Option<String>,
}
impl SparseTfidfVectorizer {
pub fn new() -> Self {
Self {
count_vectorizer: SparseCountVectorizer::new(false),
idf: None,
useidf: true,
norm: Some("l2".to_string()),
}
}
pub fn with_settings(useidf: bool, norm: Option<String>) -> Self {
Self {
count_vectorizer: SparseCountVectorizer::new(false),
idf: None,
useidf,
norm,
}
}
pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
Self {
count_vectorizer: SparseCountVectorizer::with_tokenizer(tokenizer, false),
idf: None,
useidf: true,
norm: Some("l2".to_string()),
}
}
pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
self.count_vectorizer.fit(texts)?;
if self.useidf {
let n_docs = texts.len() as f64;
let vocab_size = self.count_vectorizer.vocabulary_size();
let mut doc_freq = vec![0.0; vocab_size];
for &text in texts {
let sparse_vec = self.count_vectorizer.transform(text)?;
for &idx in sparse_vec.indices() {
doc_freq[idx] += 1.0;
}
}
let mut idf_values = Array1::zeros(vocab_size);
for (idx, &df) in doc_freq.iter().enumerate() {
if df > 0.0 {
idf_values[idx] = (n_docs / df).ln() + 1.0;
} else {
idf_values[idx] = 1.0;
}
}
self.idf = Some(idf_values);
}
Ok(())
}
pub fn transform(&self, text: &str) -> Result<SparseVector> {
let mut sparse_vec = self.count_vectorizer.transform(text)?;
if self.useidf {
if let Some(ref idf) = self.idf {
let indices_copy: Vec<usize> = sparse_vec.indices().to_vec();
let values = sparse_vec.values_mut();
for (i, &idx) in indices_copy.iter().enumerate() {
values[i] *= idf[idx];
}
}
}
if let Some(ref norm_type) = self.norm {
match norm_type.as_str() {
"l2" => {
let norm = sparse_vec.norm();
if norm > 0.0 {
sparse_vec.scale(1.0 / norm);
}
}
"l1" => {
let sum: f64 = sparse_vec.values().iter().map(|x| x.abs()).sum();
if sum > 0.0 {
sparse_vec.scale(1.0 / sum);
}
}
_ => {
return Err(TextError::InvalidInput(format!(
"Unknown normalization type: {norm_type}"
)));
}
}
}
Ok(sparse_vec)
}
pub fn transform_batch(&self, texts: &[&str]) -> Result<CsrMatrix> {
let n_cols = self.count_vectorizer.vocabulary_size();
let mut builder = SparseMatrixBuilder::new(n_cols);
for &text in texts {
let sparse_vec = self.transform(text)?;
builder.add_row(sparse_vec)?;
}
Ok(builder.build())
}
pub fn fit_transform(&mut self, texts: &[&str]) -> Result<CsrMatrix> {
self.fit(texts)?;
self.transform_batch(texts)
}
pub fn vocabulary_size(&self) -> usize {
self.count_vectorizer.vocabulary_size()
}
pub fn vocabulary(&self) -> &Vocabulary {
self.count_vectorizer.vocabulary()
}
pub fn idf_values(&self) -> Option<&Array1<f64>> {
self.idf.as_ref()
}
}
impl Default for SparseTfidfVectorizer {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
pub fn sparse_cosine_similarity(v1: &SparseVector, v2: &SparseVector) -> Result<f64> {
if v1.size() != v2.size() {
return Err(TextError::InvalidInput(format!(
"Vector dimensions don't match: {} vs {}",
v1.size(),
v2.size()
)));
}
let dot = v1.dotsparse(v2)?;
let norm1 = v1.norm();
let norm2 = v2.norm();
if norm1 == 0.0 || norm2 == 0.0 {
Ok(if norm1 == norm2 { 1.0 } else { 0.0 })
} else {
Ok(dot / (norm1 * norm2))
}
}
pub struct MemoryStats {
pub sparse_bytes: usize,
pub dense_bytes: usize,
pub compression_ratio: f64,
pub sparsity: f64,
}
impl MemoryStats {
pub fn from_sparse_matrix(sparse: &CsrMatrix) -> Self {
let (n_rows, n_cols) = sparse.shape();
let dense_bytes = n_rows * n_cols * std::mem::size_of::<f64>();
let sparse_bytes = sparse.memory_usage();
let total_elements = n_rows * n_cols;
let nnz = sparse.nnz();
Self {
sparse_bytes,
dense_bytes,
compression_ratio: dense_bytes as f64 / sparse_bytes as f64,
sparsity: 1.0 - (nnz as f64 / total_elements as f64),
}
}
pub fn print_stats(&self) {
println!("Memory Usage Statistics:");
println!(" Sparse representation: {} bytes", self.sparse_bytes);
println!(" Dense representation: {} bytes", self.dense_bytes);
println!(" Compression ratio: {:.2}x", self.compression_ratio);
println!(" Sparsity: {:.1}%", self.sparsity * 100.0);
println!(
" Memory saved: {:.1}%",
(1.0 - 1.0 / self.compression_ratio) * 100.0
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_count_vectorizer() {
let texts = vec![
"this is a test document with some unique words",
"this is another test document with different vocabulary",
"yet another example document with more text content",
"completely different text with various other terms",
"final document in the test set with distinct words",
];
let mut vectorizer = SparseCountVectorizer::new(false);
let sparse_matrix = vectorizer.fit_transform(&texts).expect("Operation failed");
assert_eq!(sparse_matrix.shape().0, 5); assert!(sparse_matrix.nnz() > 0);
let stats = MemoryStats::from_sparse_matrix(&sparse_matrix);
assert!(stats.compression_ratio > 0.0);
assert!(stats.sparsity >= 0.0);
}
#[test]
fn test_sparse_tfidf_vectorizer() {
let texts = vec!["the quick brown fox", "the lazy dog", "brown fox jumps"];
let mut vectorizer = SparseTfidfVectorizer::new();
let sparse_matrix = vectorizer.fit_transform(&texts).expect("Operation failed");
assert_eq!(sparse_matrix.shape().0, 3);
let first_doc = sparse_matrix.get_row(0).expect("Operation failed");
assert!(first_doc.norm() > 0.0);
assert!((first_doc.norm() - 1.0).abs() < 1e-6);
}
#[test]
fn test_sparse_cosine_similarity() {
let v1 = SparseVector::fromindices_values(vec![0, 2, 3], vec![1.0, 2.0, 3.0], 5);
let v2 = SparseVector::fromindices_values(vec![1, 2, 4], vec![1.0, 2.0, 1.0], 5);
let similarity = sparse_cosine_similarity(&v1, &v2).expect("Operation failed");
let expected = 4.0 / (14.0_f64.sqrt() * 6.0_f64.sqrt());
assert!((similarity - expected).abs() < 1e-10);
}
#[test]
fn test_memory_efficiency_large() {
let texts: Vec<String> = (0..100)
.map(|i| {
let word_idx = i % 10;
format!("document {i} contains word{word_idx}")
})
.collect();
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_ref()).collect();
let mut vectorizer = SparseCountVectorizer::new(false);
let sparse_matrix = vectorizer
.fit_transform(&text_refs)
.expect("Operation failed");
let stats = MemoryStats::from_sparse_matrix(&sparse_matrix);
stats.print_stats();
assert!(stats.compression_ratio > 5.0);
assert!(stats.sparsity > 0.8);
}
}