use crate::error::{Result, TextError};
use crate::tokenize::Tokenizer;
use crate::vocabulary::Vocabulary;
use regex::Regex;
use std::collections::{HashMap, HashSet};
pub trait TokenFilter {
fn apply(&self, tokens: &[String]) -> Vec<String>;
fn filtertext(&self, text: &str, tokenizer: &dyn Tokenizer) -> Result<String> {
let tokens = tokenizer.tokenize(text)?;
let filtered = self.apply(&tokens);
Ok(filtered.join(" "))
}
}
#[derive(Debug, Clone)]
pub struct LengthFilter {
pub min_length: usize,
pub max_length: usize,
}
impl Default for LengthFilter {
fn default() -> Self {
Self {
min_length: 1,
max_length: usize::MAX,
}
}
}
impl LengthFilter {
pub fn new(_min_length: usize, maxlength: usize) -> Self {
Self {
min_length: _min_length,
max_length: maxlength,
}
}
pub fn with_min_length(mut self, minlength: usize) -> Self {
self.min_length = minlength;
self
}
pub fn with_max_length(mut self, maxlength: usize) -> Self {
self.max_length = maxlength;
self
}
}
impl TokenFilter for LengthFilter {
fn apply(&self, tokens: &[String]) -> Vec<String> {
tokens
.iter()
.filter(|token| {
let len = token.chars().count(); len >= self.min_length && len <= self.max_length
})
.cloned()
.collect()
}
}
#[derive(Debug, Clone)]
pub struct FrequencyFilter {
pub min_count: usize,
pub max_count: Option<usize>,
pub max_freq: Option<f64>,
token_counts: HashMap<String, usize>,
total_count: usize,
}
impl FrequencyFilter {
pub fn from_tokens_with_vocabulary(
tokens: &[String],
vocabulary: &Vocabulary,
min_count: usize,
) -> Self {
let mut token_counts = HashMap::new();
for token in tokens {
if vocabulary.contains(token) {
*token_counts.entry(token.clone()).or_insert(0) += 1;
}
}
let total_count: usize = token_counts.values().sum();
Self {
min_count,
max_count: None,
max_freq: None,
token_counts,
total_count,
}
}
pub fn from_counts(_token_counts: HashMap<String, usize>, mincount: usize) -> Self {
let total_count = _token_counts.values().sum();
Self {
min_count: mincount,
max_count: None,
max_freq: None,
token_counts: _token_counts,
total_count,
}
}
pub fn learn_from_corpus(
texts: &[&str],
tokenizer: &dyn Tokenizer,
min_count: usize,
) -> Result<Self> {
let mut counts = HashMap::new();
let mut total = 0;
for &text in texts {
let tokens = tokenizer.tokenize(text)?;
for token in tokens {
*counts.entry(token).or_insert(0) += 1;
total += 1;
}
}
Ok(Self {
min_count,
max_count: None,
max_freq: None,
token_counts: counts,
total_count: total,
})
}
pub fn with_max_count(mut self, maxcount: usize) -> Self {
self.max_count = Some(maxcount);
self
}
pub fn with_max_freq(mut self, maxfreq: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&maxfreq) {
return Err(TextError::InvalidInput(
"max_freq must be between 0.0 and 1.0".to_string(),
));
}
self.max_freq = Some(maxfreq);
Ok(self)
}
}
impl TokenFilter for FrequencyFilter {
fn apply(&self, tokens: &[String]) -> Vec<String> {
tokens
.iter()
.filter(|token| {
let count = self.token_counts.get(*token).copied().unwrap_or(0);
if count < self.min_count {
return false;
}
if let Some(max_count) = self.max_count {
if count > max_count {
return false;
}
}
if let Some(max_freq) = self.max_freq {
if self.total_count > 0 {
let freq = count as f64 / self.total_count as f64;
if freq > max_freq {
return false;
}
}
}
true
})
.cloned()
.collect()
}
}
#[derive(Debug, Clone)]
pub struct RegexFilter {
pattern: Regex,
keep_matching: bool,
}
impl RegexFilter {
pub fn new(_pattern: &str, keepmatching: bool) -> Result<Self> {
match Regex::new(_pattern) {
Ok(regex) => Ok(Self {
pattern: regex,
keep_matching: keepmatching,
}),
Err(e) => Err(TextError::InvalidInput(format!(
"Invalid regex pattern: {e}"
))),
}
}
}
impl TokenFilter for RegexFilter {
fn apply(&self, tokens: &[String]) -> Vec<String> {
tokens
.iter()
.filter(|token| {
let matches = self.pattern.is_match(token);
matches == self.keep_matching
})
.cloned()
.collect()
}
}
#[derive(Debug, Clone)]
pub struct StopwordsFilter {
stopwords: HashSet<String>,
remove_stopwords: bool,
}
impl StopwordsFilter {
pub fn new(_stopwords: Vec<String>, removestopwords: bool) -> Self {
Self {
stopwords: _stopwords.into_iter().collect(),
remove_stopwords: removestopwords,
}
}
pub fn from_file(path: &str) -> Result<Self> {
use std::fs::File;
use std::io::{BufRead, BufReader};
let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
let reader = BufReader::new(file);
let mut stopwords = HashSet::new();
for line in reader.lines() {
let word = line.map_err(|e| TextError::IoError(e.to_string()))?;
if !word.trim().is_empty() && !word.starts_with('#') {
stopwords.insert(word.trim().to_lowercase());
}
}
Ok(Self {
stopwords,
remove_stopwords: true,
})
}
pub fn remove_stopwords(mut self, remove: bool) -> Self {
self.remove_stopwords = remove;
self
}
pub fn add_stopwords(&mut self, words: &[String]) {
for word in words {
self.stopwords.insert(word.clone());
}
}
pub fn get_stopwords(&self) -> Vec<String> {
self.stopwords.iter().cloned().collect()
}
}
impl TokenFilter for StopwordsFilter {
fn apply(&self, tokens: &[String]) -> Vec<String> {
tokens
.iter()
.filter(|token| {
let is_stopword = self.stopwords.contains(&token.to_lowercase());
if self.remove_stopwords {
!is_stopword
} else {
is_stopword
}
})
.cloned()
.collect()
}
}
pub struct CompositeFilter {
filters: Vec<Box<dyn TokenFilter + Send + Sync>>,
}
impl CompositeFilter {
pub fn new() -> Self {
Self {
filters: Vec::new(),
}
}
pub fn add_filter<F>(&mut self, filter: F)
where
F: TokenFilter + Send + Sync + 'static,
{
self.filters.push(Box::new(filter));
}
pub fn with_filter<F>(mut self, filter: F) -> Self
where
F: TokenFilter + Send + Sync + 'static,
{
self.add_filter(filter);
self
}
}
impl Default for CompositeFilter {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for CompositeFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeFilter")
.field("num_filters", &self.filters.len())
.finish()
}
}
impl Clone for CompositeFilter {
fn clone(&self) -> Self {
Self::new()
}
}
impl TokenFilter for CompositeFilter {
fn apply(&self, tokens: &[String]) -> Vec<String> {
let mut filtered = tokens.to_vec();
for filter in &self.filters {
filtered = filter.apply(&filtered);
}
filtered
}
}
pub struct CustomFilter<F>
where
F: Fn(&str) -> bool + Send + Sync,
{
predicate: F,
}
impl<F> CustomFilter<F>
where
F: Fn(&str) -> bool + Send + Sync,
{
pub fn new(predicate: F) -> Self {
Self { predicate }
}
}
impl<F> TokenFilter for CustomFilter<F>
where
F: Fn(&str) -> bool + Send + Sync,
{
fn apply(&self, tokens: &[String]) -> Vec<String> {
tokens
.iter()
.filter(|token| (self.predicate)(token))
.cloned()
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenize::WordTokenizer;
fn get_test_tokens() -> Vec<String> {
vec![
"the".to_string(),
"quick".to_string(),
"brown".to_string(),
"fox".to_string(),
"jumps".to_string(),
"over".to_string(),
"the".to_string(),
"lazy".to_string(),
"dog".to_string(),
]
}
#[test]
fn test_length_filter() {
let tokens = get_test_tokens();
let filter = LengthFilter::new(4, usize::MAX);
let filtered = filter.apply(&tokens);
let mut sorted_filtered = filtered.clone();
sorted_filtered.sort();
assert_eq!(
sorted_filtered,
vec!["brown", "jumps", "lazy", "over", "quick"]
);
let filter = LengthFilter::new(3, 3);
let filtered = filter.apply(&tokens);
let mut sorted_filtered = filtered.clone();
sorted_filtered.sort();
assert_eq!(sorted_filtered, vec!["dog", "fox", "the", "the"]);
}
#[test]
fn test_frequency_filter() {
let tokens = get_test_tokens();
let mut counts = HashMap::new();
for token in &tokens {
*counts.entry(token.clone()).or_insert(0) += 1;
}
let filter = FrequencyFilter::from_counts(counts, 2);
let filtered = filter.apply(&tokens);
assert_eq!(filtered, vec!["the", "the"]);
}
#[test]
fn test_regex_filter() {
let tokens = get_test_tokens();
let filter = RegexFilter::new(r"^b", true).expect("Operation failed");
let filtered = filter.apply(&tokens);
assert_eq!(filtered, vec!["brown"]);
let test_tokens = vec![
"the".to_string(),
"jumps".to_string(),
"the".to_string(),
"lazy".to_string(),
];
let filter = RegexFilter::new(r"o", false).expect("Operation failed");
let filtered = filter.apply(&test_tokens);
let mut sorted_filtered = filtered.clone();
sorted_filtered.sort();
let expected = vec!["jumps", "lazy", "the", "the"];
assert_eq!(sorted_filtered, expected);
}
#[test]
fn test_stopwords_filter() {
let tokens = get_test_tokens();
let stopwords = vec!["the".to_string(), "over".to_string()];
let filter = StopwordsFilter::new(stopwords, true);
let filtered = filter.apply(&tokens);
assert_eq!(
filtered,
vec!["quick", "brown", "fox", "jumps", "lazy", "dog"]
);
}
#[test]
fn test_composite_filter() {
let tokens = get_test_tokens();
let length_filter = LengthFilter::new(4, usize::MAX);
let regex_filter = RegexFilter::new(r"o", true).expect("Operation failed");
let composite = CompositeFilter::new()
.with_filter(length_filter)
.with_filter(regex_filter);
let filtered = composite.apply(&tokens);
assert_eq!(filtered, vec!["brown", "over"]);
}
#[test]
fn test_custom_filter() {
let tokens = get_test_tokens();
let filter = CustomFilter::new(|token: &str| token.contains('o'));
let filtered = filter.apply(&tokens);
let mut sorted_filtered = filtered.clone();
sorted_filtered.sort();
assert_eq!(sorted_filtered, vec!["brown", "dog", "fox", "over"]);
}
#[test]
fn test_filtertext() {
let text = "The quick brown fox jumps over the lazy dog";
let tokenizer = WordTokenizer::default();
let filter = LengthFilter::new(5, usize::MAX);
let filtered = filter
.filtertext(text, &tokenizer)
.expect("Operation failed");
assert_eq!(filtered, "quick brown jumps");
}
}