use crate::error::{Result, TextError};
use crate::tokenize::Tokenizer;
use scirs2_core::parallel_ops;
use std::collections::HashMap;
#[allow(dead_code)]
pub fn count_tokens(text: &str, tokenizer: &dyn Tokenizer) -> Result<HashMap<String, usize>> {
let tokens = tokenizer.tokenize(text)?;
let mut counts = HashMap::new();
for token in tokens {
*counts.entry(token).or_insert(0) += 1;
}
Ok(counts)
}
#[allow(dead_code)]
pub fn count_tokens_batch(
texts: &[&str],
tokenizer: &dyn Tokenizer,
) -> Result<HashMap<String, usize>> {
let mut total_counts = HashMap::new();
for &text in texts {
let counts = count_tokens(text, tokenizer)?;
for (token, count) in counts {
*total_counts.entry(token).or_insert(0) += count;
}
}
Ok(total_counts)
}
#[allow(dead_code)]
pub fn count_tokens_batch_parallel<T>(
texts: &[&str],
tokenizer: &T,
) -> Result<HashMap<String, usize>>
where
T: Tokenizer + Send + Sync,
{
let texts_owned: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
let tokenizer_boxed = tokenizer.clone_box();
let token_counts = parallel_ops::parallel_map_result(&texts_owned, move |text| {
count_tokens(text, &*tokenizer_boxed).map_err(|e| {
scirs2_core::CoreError::ComputationError(scirs2_core::error::ErrorContext::new(
format!("Text processing error: {e}"),
))
})
})?;
let mut total_counts = HashMap::new();
for counts in token_counts {
for (token, count) in counts {
*total_counts.entry(token).or_insert(0) += count;
}
}
Ok(total_counts)
}
#[allow(dead_code)]
pub fn filter_tokens<F>(text: &str, tokenizer: &dyn Tokenizer, predicate: F) -> Result<String>
where
F: Fn(&str) -> bool,
{
let tokens = tokenizer.tokenize(text)?;
let filtered_tokens: Vec<String> = tokens
.iter()
.filter(|token| predicate(token))
.cloned()
.collect();
Ok(filtered_tokens.join(" "))
}
#[allow(dead_code)]
pub fn extract_ngrams(text: &str, tokenizer: &dyn Tokenizer, n: usize) -> Result<Vec<String>> {
if n == 0 {
return Err(TextError::InvalidInput(
"n-gram size must be greater than 0".to_string(),
));
}
let tokens = tokenizer.tokenize(text)?;
if tokens.is_empty() || tokens.len() < n {
return Ok(Vec::new());
}
let ngrams: Vec<String> = (0..=(tokens.len() - n))
.map(|i| tokens[i..(i + n)].to_vec().join(" "))
.collect();
Ok(ngrams)
}
#[allow(dead_code)]
pub fn extract_collocations(
text: &str,
tokenizer: &dyn Tokenizer,
window_size: usize,
min_count: usize,
) -> Result<HashMap<(String, String), usize>> {
let tokens = tokenizer.tokenize(text)?;
let mut collocations = HashMap::new();
if tokens.len() < 2 {
return Ok(collocations);
}
for i in 0..tokens.len() {
let end = std::cmp::min(i + window_size + 1, tokens.len());
for j in (i + 1)..end {
let pair = (tokens[i].clone(), tokens[j].clone());
*collocations.entry(pair).or_insert(0) += 1;
}
}
collocations.retain(|_, &mut _count| _count >= min_count);
Ok(collocations)
}
#[allow(dead_code)]
pub fn train_test_split(
texts: &[String],
test_size: f64,
random_seed: Option<u64>,
) -> Result<(Vec<String>, Vec<String>)> {
use scirs2_core::random::seq::SliceRandom;
use scirs2_core::random::SeedableRng;
if !(0.0..=1.0).contains(&test_size) {
return Err(TextError::InvalidInput(
"test_size must be between 0.0 and 1.0".to_string(),
));
}
if texts.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let mut rng = match random_seed {
Some(_seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(_seed),
None => {
let mut temp_rng = scirs2_core::random::rng();
scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
}
};
let mut texts_copy = texts.to_vec();
texts_copy.shuffle(&mut rng);
let test_count = (texts.len() as f64 * test_size).round() as usize;
let train_count = texts.len() - test_count;
let traintexts = texts_copy.iter().take(train_count).cloned().collect();
let testtexts = texts_copy.iter().skip(train_count).cloned().collect();
Ok((traintexts, testtexts))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenize::WordTokenizer;
#[test]
fn test_count_tokens() {
let tokenizer = WordTokenizer::default();
let text = "this is a test this is only a test";
let counts = count_tokens(text, &tokenizer).expect("Operation failed");
assert_eq!(counts.get("this").expect("Operation failed"), &2);
assert_eq!(counts.get("is").expect("Operation failed"), &2);
assert_eq!(counts.get("a").expect("Operation failed"), &2);
assert_eq!(counts.get("test").expect("Operation failed"), &2);
assert_eq!(counts.get("only").expect("Operation failed"), &1);
}
#[test]
fn test_filter_tokens() {
let tokenizer = WordTokenizer::default();
let text = "this is a test this is only a test";
let predicate = |token: &str| !["this", "is", "a"].contains(&token);
let filtered = filter_tokens(text, &tokenizer, predicate).expect("Operation failed");
assert_eq!(filtered, "test only test");
}
#[test]
fn test_extract_ngrams() {
let tokenizer = WordTokenizer::default();
let text = "this is a simple test";
let bigrams = extract_ngrams(text, &tokenizer, 2).expect("Operation failed");
assert_eq!(bigrams, vec!["this is", "is a", "a simple", "simple test"]);
let trigrams = extract_ngrams(text, &tokenizer, 3).expect("Operation failed");
assert_eq!(trigrams, vec!["this is a", "is a simple", "a simple test"]);
}
#[test]
fn test_extract_collocations() {
let tokenizer = WordTokenizer::default();
let text = "machine learning is a subset of artificial intelligence that provides systems with the ability to learn";
let collocations = extract_collocations(text, &tokenizer, 2, 1).expect("Operation failed");
assert!(collocations.contains_key(&("machine".to_string(), "learning".to_string())));
assert!(collocations.contains_key(&("artificial".to_string(), "intelligence".to_string())));
}
#[test]
fn test_train_test_split() {
let texts = vec![
"text 1".to_string(),
"text 2".to_string(),
"text 3".to_string(),
"text 4".to_string(),
"text 5".to_string(),
];
let (train, test) = train_test_split(&texts, 0.4, Some(42)).expect("Operation failed");
assert_eq!(train.len(), 3);
assert_eq!(test.len(), 2);
for text in &texts {
assert_eq!(
train.iter().filter(|&t| t == text).count()
+ test.iter().filter(|&t| t == text).count(),
1
);
}
}
}