ai_tokenopt 0.5.7

Adaptive token optimization engine for LLM inference pipelines — compresses prompts, conversation history, tool schemas, and output streams to minimize token usage while preserving response quality.
Documentation
//! RAG context deduplication
//!
//! Two-pass deduplication: fast substring/Jaccard matching followed by
//! optional semantic similarity. Respects a token budget.

use crate::estimator::TokenEstimator;

/// A memory entry with relevance score and content.
///
/// This mirrors the shape of `SimilarMemory` from the application layer
/// without creating a hard dependency on the specific type.
#[derive(Debug, Clone)]
pub struct RagEntry {
    /// The memory content text
    pub content: String,
    /// Relevance score (0.0–1.0, higher = more relevant)
    pub relevance: f32,
    /// Optional embedding vector for semantic dedup
    pub embedding: Option<Vec<f32>>,
}

/// Deduplicated and budget-aware RAG context.
#[derive(Debug)]
pub struct DeduplicatedRag {
    /// Entries after deduplication, ordered by relevance
    pub entries: Vec<RagEntry>,
    /// Number of entries removed as duplicates
    pub duplicates_removed: usize,
    /// Number of entries removed for budget reasons
    pub budget_trimmed: usize,
}

/// Deduplicate RAG entries using a two-pass strategy.
///
/// Pass 1: Fast Jaccard word-set similarity (threshold > 0.85)
/// Pass 2: Cosine similarity on embeddings (threshold > 0.9) if available
///
/// Then trim to fit within the token budget, preferring higher-relevance entries.
#[must_use]
pub fn deduplicate_rag(entries: &[RagEntry], budget_tokens: u32) -> DeduplicatedRag {
    if entries.is_empty() {
        return DeduplicatedRag {
            entries: Vec::new(),
            duplicates_removed: 0,
            budget_trimmed: 0,
        };
    }

    let mut remaining: Vec<RagEntry> = entries.to_vec();
    let initial_count = remaining.len();

    // Pass 1: Jaccard word-set deduplication
    remaining = jaccard_dedup(remaining, 0.85);
    let _after_jaccard = remaining.len();

    // Pass 2: Semantic deduplication (if embeddings available)
    remaining = semantic_dedup(remaining, 0.9);
    let after_semantic = remaining.len();

    let duplicates_removed = initial_count - after_semantic;

    // Sort by relevance descending
    remaining.sort_by(|a, b| {
        b.relevance
            .partial_cmp(&a.relevance)
            .unwrap_or(std::cmp::Ordering::Equal)
    });

    // Budget trimming: greedily add entries until budget exhausted
    let mut budget_used: u32 = 0;
    let mut kept: Vec<RagEntry> = Vec::new();
    let mut budget_trimmed = 0;

    for entry in remaining {
        let entry_tokens = TokenEstimator::estimate_tokens(&entry.content);
        if budget_used + entry_tokens <= budget_tokens {
            budget_used += entry_tokens;
            kept.push(entry);
        } else {
            budget_trimmed += 1;
        }
    }

    DeduplicatedRag {
        entries: kept,
        duplicates_removed,
        budget_trimmed,
    }
}

/// Word-set Jaccard similarity deduplication.
///
/// For each pair of entries, compute the Jaccard index of their word sets.
/// If similarity exceeds the threshold, keep the entry with higher relevance.
fn jaccard_dedup(entries: Vec<RagEntry>, threshold: f64) -> Vec<RagEntry> {
    let word_sets: Vec<std::collections::HashSet<&str>> = entries
        .iter()
        .map(|e| e.content.split_whitespace().collect())
        .collect();

    let mut keep = vec![true; entries.len()];

    for i in 0..entries.len() {
        if !keep[i] {
            continue;
        }
        for j in (i + 1)..entries.len() {
            if !keep[j] {
                continue;
            }
            let similarity = jaccard_index(&word_sets[i], &word_sets[j]);
            if similarity > threshold {
                // Drop the one with lower relevance
                if entries[i].relevance >= entries[j].relevance {
                    keep[j] = false;
                } else {
                    keep[i] = false;
                    break; // i is now dropped, skip rest of j comparisons
                }
            }
        }
    }

    entries
        .into_iter()
        .zip(keep)
        .filter_map(|(e, k)| if k { Some(e) } else { None })
        .collect()
}

/// Compute the Jaccard index between two word sets.
fn jaccard_index(a: &std::collections::HashSet<&str>, b: &std::collections::HashSet<&str>) -> f64 {
    if a.is_empty() && b.is_empty() {
        return 1.0;
    }
    #[allow(clippy::cast_precision_loss)]
    let intersection = a.intersection(b).count() as f64;
    #[allow(clippy::cast_precision_loss)]
    let union = a.union(b).count() as f64;
    if union == 0.0 {
        0.0
    } else {
        intersection / union
    }
}

/// Cosine similarity deduplication on embedding vectors.
fn semantic_dedup(entries: Vec<RagEntry>, threshold: f32) -> Vec<RagEntry> {
    // Skip if no entries have embeddings
    if !entries.iter().any(|e| e.embedding.is_some()) {
        return entries;
    }

    let mut keep = vec![true; entries.len()];

    for i in 0..entries.len() {
        if !keep[i] {
            continue;
        }
        let Some(ref emb_i) = entries[i].embedding else {
            continue;
        };
        for j in (i + 1)..entries.len() {
            if !keep[j] {
                continue;
            }
            let Some(ref emb_j) = entries[j].embedding else {
                continue;
            };
            let similarity = cosine_similarity(emb_i, emb_j);
            if similarity > threshold {
                if entries[i].relevance >= entries[j].relevance {
                    keep[j] = false;
                } else {
                    keep[i] = false;
                    break;
                }
            }
        }
    }

    entries
        .into_iter()
        .zip(keep)
        .filter_map(|(e, k)| if k { Some(e) } else { None })
        .collect()
}

/// Compute cosine similarity between two vectors.
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    if a.len() != b.len() || a.is_empty() {
        return 0.0;
    }

    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    if norm_a == 0.0 || norm_b == 0.0 {
        0.0
    } else {
        dot / (norm_a * norm_b)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn entry(content: &str, relevance: f32) -> RagEntry {
        RagEntry {
            content: content.to_string(),
            relevance,
            embedding: None,
        }
    }

    #[test]
    fn empty_entries_produce_empty_result() {
        let result = deduplicate_rag(&[], 1000);
        assert!(result.entries.is_empty());
        assert_eq!(result.duplicates_removed, 0);
    }

    #[test]
    fn identical_entries_deduplicated() {
        let entries = vec![
            entry("The weather today is sunny and warm", 0.9),
            entry("The weather today is sunny and warm", 0.8),
        ];
        let result = deduplicate_rag(&entries, 1000);
        assert_eq!(result.entries.len(), 1);
        assert_eq!(result.duplicates_removed, 1);
        // Should keep the higher relevance entry
        assert!((result.entries[0].relevance - 0.9).abs() < f32::EPSILON);
    }

    #[test]
    fn different_entries_kept() {
        let entries = vec![
            entry("The weather is sunny", 0.9),
            entry("The stock market crashed today", 0.8),
        ];
        let result = deduplicate_rag(&entries, 1000);
        assert_eq!(result.entries.len(), 2);
        assert_eq!(result.duplicates_removed, 0);
    }

    #[test]
    fn budget_trims_low_relevance() {
        let entries = vec![
            entry("Very important fact about the user", 0.95),
            entry("Somewhat relevant background info", 0.7),
            entry("Barely relevant trivia about something", 0.3),
        ];
        // Tight budget: only room for ~1 entry
        let result = deduplicate_rag(&entries, 10);
        assert!(result.entries.len() < 3);
        assert!(result.budget_trimmed > 0);
        // Highest relevance should survive
        assert!((result.entries[0].relevance - 0.95).abs() < f32::EPSILON);
    }

    #[test]
    fn cosine_similarity_identical_vectors() {
        let a = vec![1.0, 0.0, 0.0];
        let b = vec![1.0, 0.0, 0.0];
        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
    }

    #[test]
    fn cosine_similarity_orthogonal_vectors() {
        let a = vec![1.0, 0.0];
        let b = vec![0.0, 1.0];
        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
    }

    #[test]
    fn semantic_dedup_drops_similar_embeddings() {
        let entries = vec![
            RagEntry {
                content: "Weather is sunny".to_string(),
                relevance: 0.9,
                embedding: Some(vec![1.0, 0.0, 0.0]),
            },
            RagEntry {
                content: "It is a sunny day".to_string(),
                relevance: 0.8,
                embedding: Some(vec![0.99, 0.01, 0.0]), // Very similar
            },
        ];
        let result = semantic_dedup(entries, 0.9);
        assert_eq!(result.len(), 1);
    }
}