use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RepetitionState {
Normal,
Warning(f32),
Degenerate,
}
#[derive(Debug)]
pub struct RepetitionDetector {
ngram_size: usize,
threshold: f32,
words: Vec<String>,
ngram_counts: HashMap<Vec<String>, u32>,
max_words: usize,
}
const DEFAULT_MAX_WORDS: usize = 500;
const MIN_WORDS_FOR_DETECTION: usize = 30;
impl RepetitionDetector {
#[must_use]
pub fn new(ngram_size: usize, threshold: f32) -> Self {
Self {
ngram_size: ngram_size.max(2),
threshold: threshold.clamp(0.0, 1.0),
words: Vec::new(),
ngram_counts: HashMap::new(),
max_words: DEFAULT_MAX_WORDS,
}
}
pub fn feed(&mut self, chunk: &str) -> RepetitionState {
if chunk.is_empty() {
return self.current_state();
}
let new_words: Vec<String> = chunk.split_whitespace().map(str::to_lowercase).collect();
self.words.extend(new_words);
if self.words.len() > self.max_words {
let excess = self.words.len() - self.max_words;
self.remove_old_ngrams(excess);
self.words.drain(..excess);
}
self.update_ngrams();
self.current_state()
}
#[must_use]
pub fn current_state(&self) -> RepetitionState {
if self.words.len() < MIN_WORDS_FOR_DETECTION {
return RepetitionState::Normal;
}
let ratio = self.repetition_ratio();
if ratio >= self.threshold {
RepetitionState::Degenerate
} else if ratio >= self.threshold * 0.7 {
RepetitionState::Warning(ratio)
} else {
RepetitionState::Normal
}
}
fn repetition_ratio(&self) -> f32 {
if self.ngram_counts.is_empty() {
return 0.0;
}
let total_ngrams: u32 = self.ngram_counts.values().sum();
let repeated_ngrams: u32 = self
.ngram_counts
.values()
.filter(|&&count| count > 1)
.map(|&count| count - 1) .sum();
if total_ngrams == 0 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
{
repeated_ngrams as f32 / total_ngrams as f32
}
}
}
fn update_ngrams(&mut self) {
self.ngram_counts.clear();
if self.words.len() < self.ngram_size {
return;
}
for window in self.words.windows(self.ngram_size) {
let ngram = window.to_vec();
*self.ngram_counts.entry(ngram).or_insert(0) += 1;
}
}
#[allow(clippy::unused_self)]
fn remove_old_ngrams(&self, _count: usize) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normal_text_not_flagged() {
let mut detector = RepetitionDetector::new(3, 0.3);
let text = "The quick brown fox jumps over the lazy dog and then \
runs through the forest while birds sing in the trees \
and flowers bloom in the meadow near the river bank";
let state = detector.feed(text);
assert_eq!(state, RepetitionState::Normal);
}
#[test]
fn repetitive_text_detected() {
let mut detector = RepetitionDetector::new(3, 0.3);
let text = "the cat sat the cat sat the cat sat the cat sat \
the cat sat the cat sat the cat sat the cat sat \
the cat sat the cat sat the cat sat the cat sat";
let state = detector.feed(text);
assert!(
matches!(state, RepetitionState::Degenerate),
"Expected Degenerate, got {state:?}"
);
}
#[test]
fn empty_chunk_returns_current_state() {
let mut detector = RepetitionDetector::new(3, 0.3);
let state = detector.feed("");
assert_eq!(state, RepetitionState::Normal);
}
#[test]
fn short_text_always_normal() {
let mut detector = RepetitionDetector::new(3, 0.3);
let state = detector.feed("hello world");
assert_eq!(state, RepetitionState::Normal);
}
#[test]
fn incremental_feeding_works() {
let mut detector = RepetitionDetector::new(3, 0.3);
for _ in 0..20 {
detector.feed("the cat sat on the mat ");
}
let state = detector.current_state();
assert!(
matches!(
state,
RepetitionState::Degenerate | RepetitionState::Warning(_)
),
"Expected repetition detection, got {state:?}"
);
}
#[test]
fn rolling_window_prevents_unbounded_growth() {
let mut detector = RepetitionDetector::new(3, 0.3);
for i in 0..1000 {
detector.feed(&format!("unique word number {i} in the stream "));
}
assert!(detector.words.len() <= DEFAULT_MAX_WORDS);
}
#[test]
fn threshold_clamped() {
let detector = RepetitionDetector::new(3, 1.5);
assert!((detector.threshold - 1.0).abs() < f32::EPSILON);
let detector2 = RepetitionDetector::new(3, -0.5);
assert!(detector2.threshold.abs() < f32::EPSILON);
}
}