ai_tokenopt 0.5.7

Adaptive token optimization engine for LLM inference pipelines — compresses prompts, conversation history, tool schemas, and output streams to minimize token usage while preserving response quality.
Documentation
//! N-gram based repetition detection for streaming output
//!
//! Monitors the output stream for degenerate repetitive content and
//! signals when the stream should be terminated early to save tokens.

use std::collections::HashMap;

/// State of the repetition detector after processing a chunk.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RepetitionState {
    /// Output is normal, no repetition detected
    Normal,
    /// Warning: repetition ratio is elevated but below threshold
    Warning(f32),
    /// Degenerate: repetition ratio exceeded threshold, stream should stop
    Degenerate,
}

/// N-gram repetition detector for streaming output.
///
/// Maintains a rolling window of n-grams from the accumulated output
/// and calculates the ratio of repeated n-grams. When this ratio
/// exceeds the configured threshold, returns `Degenerate`.
#[derive(Debug)]
pub struct RepetitionDetector {
    ngram_size: usize,
    threshold: f32,
    /// Accumulated words from all chunks
    words: Vec<String>,
    /// N-gram frequency counts
    ngram_counts: HashMap<Vec<String>, u32>,
    /// Maximum words to track (rolling window)
    max_words: usize,
}

/// Default rolling window size (in words).
const DEFAULT_MAX_WORDS: usize = 500;

/// Minimum words needed before detection activates.
const MIN_WORDS_FOR_DETECTION: usize = 30;

impl RepetitionDetector {
    /// Create a new repetition detector.
    ///
    /// # Arguments
    ///
    /// * `ngram_size` — Size of n-grams to track (e.g., 3 for trigrams)
    /// * `threshold` — Ratio of repeated n-grams to trigger degenerate state (0.0–1.0)
    #[must_use]
    pub fn new(ngram_size: usize, threshold: f32) -> Self {
        Self {
            ngram_size: ngram_size.max(2),
            threshold: threshold.clamp(0.0, 1.0),
            words: Vec::new(),
            ngram_counts: HashMap::new(),
            max_words: DEFAULT_MAX_WORDS,
        }
    }

    /// Feed a new chunk of text and get the current repetition state.
    pub fn feed(&mut self, chunk: &str) -> RepetitionState {
        if chunk.is_empty() {
            return self.current_state();
        }

        // Extract words from the chunk
        let new_words: Vec<String> = chunk.split_whitespace().map(str::to_lowercase).collect();

        self.words.extend(new_words);

        // Enforce rolling window
        if self.words.len() > self.max_words {
            let excess = self.words.len() - self.max_words;
            // Remove old n-grams before trimming
            self.remove_old_ngrams(excess);
            self.words.drain(..excess);
        }

        // Add new n-grams
        self.update_ngrams();

        self.current_state()
    }

    /// Get the current repetition state without feeding new data.
    #[must_use]
    pub fn current_state(&self) -> RepetitionState {
        if self.words.len() < MIN_WORDS_FOR_DETECTION {
            return RepetitionState::Normal;
        }

        let ratio = self.repetition_ratio();

        if ratio >= self.threshold {
            RepetitionState::Degenerate
        } else if ratio >= self.threshold * 0.7 {
            RepetitionState::Warning(ratio)
        } else {
            RepetitionState::Normal
        }
    }

    /// Calculate the ratio of repeated n-grams (appearing more than once).
    fn repetition_ratio(&self) -> f32 {
        if self.ngram_counts.is_empty() {
            return 0.0;
        }

        let total_ngrams: u32 = self.ngram_counts.values().sum();
        let repeated_ngrams: u32 = self
            .ngram_counts
            .values()
            .filter(|&&count| count > 1)
            .map(|&count| count - 1) // Only count the repetitions
            .sum();

        if total_ngrams == 0 {
            0.0
        } else {
            #[allow(clippy::cast_precision_loss)]
            {
                repeated_ngrams as f32 / total_ngrams as f32
            }
        }
    }

    /// Rebuild n-gram counts from the current word window.
    fn update_ngrams(&mut self) {
        self.ngram_counts.clear();
        if self.words.len() < self.ngram_size {
            return;
        }
        for window in self.words.windows(self.ngram_size) {
            let ngram = window.to_vec();
            *self.ngram_counts.entry(ngram).or_insert(0) += 1;
        }
    }

    /// Remove n-grams that will be lost when trimming `count` words from the front.
    #[allow(clippy::unused_self)]
    fn remove_old_ngrams(&self, _count: usize) {
        // We rebuild in update_ngrams, so this is a no-op.
        // Kept for clarity of intent.
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn normal_text_not_flagged() {
        let mut detector = RepetitionDetector::new(3, 0.3);
        let text = "The quick brown fox jumps over the lazy dog and then \
                     runs through the forest while birds sing in the trees \
                     and flowers bloom in the meadow near the river bank";
        let state = detector.feed(text);
        assert_eq!(state, RepetitionState::Normal);
    }

    #[test]
    fn repetitive_text_detected() {
        let mut detector = RepetitionDetector::new(3, 0.3);
        // Highly repetitive text
        let text = "the cat sat the cat sat the cat sat the cat sat \
                     the cat sat the cat sat the cat sat the cat sat \
                     the cat sat the cat sat the cat sat the cat sat";
        let state = detector.feed(text);
        assert!(
            matches!(state, RepetitionState::Degenerate),
            "Expected Degenerate, got {state:?}"
        );
    }

    #[test]
    fn empty_chunk_returns_current_state() {
        let mut detector = RepetitionDetector::new(3, 0.3);
        let state = detector.feed("");
        assert_eq!(state, RepetitionState::Normal);
    }

    #[test]
    fn short_text_always_normal() {
        let mut detector = RepetitionDetector::new(3, 0.3);
        // Too short for detection to activate
        let state = detector.feed("hello world");
        assert_eq!(state, RepetitionState::Normal);
    }

    #[test]
    fn incremental_feeding_works() {
        let mut detector = RepetitionDetector::new(3, 0.3);
        // Feed repetitive content in small chunks
        for _ in 0..20 {
            detector.feed("the cat sat on the mat ");
        }
        let state = detector.current_state();
        assert!(
            matches!(
                state,
                RepetitionState::Degenerate | RepetitionState::Warning(_)
            ),
            "Expected repetition detection, got {state:?}"
        );
    }

    #[test]
    fn rolling_window_prevents_unbounded_growth() {
        let mut detector = RepetitionDetector::new(3, 0.3);
        for i in 0..1000 {
            detector.feed(&format!("unique word number {i} in the stream "));
        }
        assert!(detector.words.len() <= DEFAULT_MAX_WORDS);
    }

    #[test]
    fn threshold_clamped() {
        let detector = RepetitionDetector::new(3, 1.5);
        assert!((detector.threshold - 1.0).abs() < f32::EPSILON);

        let detector2 = RepetitionDetector::new(3, -0.5);
        assert!(detector2.threshold.abs() < f32::EPSILON);
    }
}