ai_tokenopt 0.5.7

Adaptive token optimization engine for LLM inference pipelines — compresses prompts, conversation history, tool schemas, and output streams to minimize token usage while preserving response quality.
Documentation
//! Relevance-based message scoring for history retention.
//!
//! Uses BM25 (Okapi Best Match 25) to score conversation messages
//! against the current user query. Messages with higher relevance
//! scores are retained preferentially during compaction, instead of
//! the default FIFO (oldest-first) pruning strategy.
//!
//! Standalone-compatible: no embedding model or external service needed.

use std::collections::{HashMap, HashSet};

use crate::types::{ChatMessage, MessageRole};

/// BM25 tuning constants.
const BM25_K1: f64 = 1.2;
const BM25_B: f64 = 0.75;

/// Score each non-system message by relevance to a query using BM25.
///
/// Returns `(index, score)` pairs for every non-system message in the
/// input slice, sorted by **descending** score. The index refers to
/// the position in the original `messages` slice.
///
/// System messages are excluded from scoring because they are always
/// retained during compaction.
pub fn score_messages(query: &str, messages: &[ChatMessage]) -> Vec<(usize, f64)> {
    let query_terms = tokenize(query);
    if query_terms.is_empty() {
        // No meaningful query — return indices in original order with zero scores.
        return messages
            .iter()
            .enumerate()
            .filter(|(_, m)| m.role != MessageRole::System)
            .map(|(i, _)| (i, 0.0))
            .collect();
    }

    // Build per-document (message) term frequencies
    let docs: Vec<(usize, Vec<String>)> = messages
        .iter()
        .enumerate()
        .filter(|(_, m)| m.role != MessageRole::System)
        .map(|(i, m)| (i, tokenize(&m.content)))
        .collect();

    if docs.is_empty() {
        return Vec::new();
    }

    // IDF: inverse document frequency for each query term
    #[allow(clippy::cast_precision_loss)]
    let n = docs.len() as f64;
    #[allow(clippy::cast_precision_loss)]
    let avg_dl: f64 = docs.iter().map(|(_, t)| t.len() as f64).sum::<f64>() / n;

    // Count how many documents contain each query term
    let mut df: HashMap<&str, usize> = HashMap::new();
    for (_, terms) in &docs {
        let unique: HashSet<&str> = terms.iter().map(String::as_str).collect();
        for qt in &query_terms {
            if unique.contains(qt.as_str()) {
                *df.entry(qt.as_str()).or_default() += 1;
            }
        }
    }

    // Score each document
    let mut scores: Vec<(usize, f64)> = docs
        .iter()
        .map(|(idx, terms)| {
            #[allow(clippy::cast_precision_loss)]
            let dl = terms.len() as f64;
            let tf_map = term_freq(terms);
            let score: f64 = query_terms
                .iter()
                .map(|qt| {
                    #[allow(clippy::cast_precision_loss)]
                    let doc_freq = *df.get(qt.as_str()).unwrap_or(&0) as f64;
                    let idf = ((n - doc_freq + 0.5) / (doc_freq + 0.5)).ln_1p();
                    #[allow(clippy::cast_precision_loss)]
                    let tf = *tf_map.get(qt.as_str()).unwrap_or(&0) as f64;
                    idf * (tf * (BM25_K1 + 1.0))
                        / BM25_K1.mul_add(1.0 - BM25_B + BM25_B * dl / avg_dl, tf)
                })
                .sum();
            (*idx, score)
        })
        .collect();

    // Sort descending by score
    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    scores
}

/// Build a term-frequency map for a list of tokens.
fn term_freq(terms: &[String]) -> HashMap<&str, usize> {
    let mut freq = HashMap::new();
    for t in terms {
        *freq.entry(t.as_str()).or_insert(0usize) += 1;
    }
    freq
}

/// Simple whitespace + punctuation tokenizer producing lowercase stems.
fn tokenize(text: &str) -> Vec<String> {
    text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
        .filter(|w| w.len() >= 2)
        .map(str::to_lowercase)
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    fn msg(role: MessageRole, content: &str) -> ChatMessage {
        match role {
            MessageRole::User | MessageRole::Tool => ChatMessage::user(content),
            MessageRole::Assistant => ChatMessage::assistant(content),
            MessageRole::System => ChatMessage::system(content),
        }
    }

    #[test]
    fn relevant_message_scores_highest() {
        let messages = vec![
            msg(MessageRole::User, "What is the weather in Berlin?"),
            msg(
                MessageRole::Assistant,
                "The weather in Berlin is sunny, 22°C.",
            ),
            msg(MessageRole::User, "Tell me about Rust programming."),
            msg(
                MessageRole::Assistant,
                "Rust is a systems programming language focused on safety.",
            ),
        ];

        let scores = score_messages("Rust programming language", &messages);
        assert!(!scores.is_empty());

        // The Rust-related messages should score highest
        let top = scores[0];
        let top_msg = &messages[top.0];
        assert!(
            top_msg.content.contains("Rust"),
            "Expected Rust-related message to score highest, got: {}",
            top_msg.content
        );
    }

    #[test]
    fn system_messages_excluded() {
        let messages = vec![
            msg(MessageRole::System, "You are a helpful assistant."),
            msg(MessageRole::User, "Hello!"),
            msg(MessageRole::Assistant, "Hi there!"),
        ];

        let scores = score_messages("Hello", &messages);
        assert_eq!(scores.len(), 2, "System messages should be excluded");
        assert!(
            scores
                .iter()
                .all(|(idx, _)| messages[*idx].role != MessageRole::System)
        );
    }

    #[test]
    fn empty_query_returns_zero_scores() {
        let messages = vec![
            msg(MessageRole::User, "Hello"),
            msg(MessageRole::Assistant, "Hi"),
        ];

        let scores = score_messages("", &messages);
        assert_eq!(scores.len(), 2);
        assert!(scores.iter().all(|(_, s)| *s == 0.0));
    }

    #[test]
    fn empty_messages_returns_empty() {
        let scores = score_messages("query", &[]);
        assert!(scores.is_empty());
    }

    #[test]
    fn multiple_term_overlap_scores_higher() {
        let messages = vec![
            msg(MessageRole::User, "I like apples"),
            msg(
                MessageRole::User,
                "I like apples and oranges from the market",
            ),
            msg(MessageRole::User, "The cat sat on the mat"),
        ];

        let scores = score_messages("apples oranges market", &messages);
        // Message with more query term overlap should score higher
        let top = scores[0];
        assert_eq!(top.0, 1, "Message with most query terms should rank first");
    }

    #[test]
    fn preserves_original_indices() {
        let messages = vec![
            msg(MessageRole::System, "System prompt"),
            msg(MessageRole::User, "First user message"),
            msg(MessageRole::Assistant, "First response"),
            msg(MessageRole::User, "Second user message"),
        ];

        let scores = score_messages("first", &messages);
        // All returned indices should be valid and non-system
        for (idx, _) in &scores {
            assert!(*idx < messages.len());
            assert_ne!(messages[*idx].role, MessageRole::System);
        }
    }

    #[test]
    fn scores_are_non_negative() {
        let messages = vec![
            msg(MessageRole::User, "Hello world"),
            msg(MessageRole::Assistant, "Goodbye moon"),
        ];

        let scores = score_messages("Hello world", &messages);
        assert!(scores.iter().all(|(_, s)| *s >= 0.0));
    }
}