use crate::estimator::TokenEstimator;
#[derive(Debug, Clone)]
pub struct RagEntry {
pub content: String,
pub relevance: f32,
pub embedding: Option<Vec<f32>>,
}
#[derive(Debug)]
pub struct DeduplicatedRag {
pub entries: Vec<RagEntry>,
pub duplicates_removed: usize,
pub budget_trimmed: usize,
}
#[must_use]
pub fn deduplicate_rag(entries: &[RagEntry], budget_tokens: u32) -> DeduplicatedRag {
if entries.is_empty() {
return DeduplicatedRag {
entries: Vec::new(),
duplicates_removed: 0,
budget_trimmed: 0,
};
}
let mut remaining: Vec<RagEntry> = entries.to_vec();
let initial_count = remaining.len();
remaining = jaccard_dedup(remaining, 0.85);
let _after_jaccard = remaining.len();
remaining = semantic_dedup(remaining, 0.9);
let after_semantic = remaining.len();
let duplicates_removed = initial_count - after_semantic;
remaining.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut budget_used: u32 = 0;
let mut kept: Vec<RagEntry> = Vec::new();
let mut budget_trimmed = 0;
for entry in remaining {
let entry_tokens = TokenEstimator::estimate_tokens(&entry.content);
if budget_used + entry_tokens <= budget_tokens {
budget_used += entry_tokens;
kept.push(entry);
} else {
budget_trimmed += 1;
}
}
DeduplicatedRag {
entries: kept,
duplicates_removed,
budget_trimmed,
}
}
fn jaccard_dedup(entries: Vec<RagEntry>, threshold: f64) -> Vec<RagEntry> {
let word_sets: Vec<std::collections::HashSet<&str>> = entries
.iter()
.map(|e| e.content.split_whitespace().collect())
.collect();
let mut keep = vec![true; entries.len()];
for i in 0..entries.len() {
if !keep[i] {
continue;
}
for j in (i + 1)..entries.len() {
if !keep[j] {
continue;
}
let similarity = jaccard_index(&word_sets[i], &word_sets[j]);
if similarity > threshold {
if entries[i].relevance >= entries[j].relevance {
keep[j] = false;
} else {
keep[i] = false;
break; }
}
}
}
entries
.into_iter()
.zip(keep)
.filter_map(|(e, k)| if k { Some(e) } else { None })
.collect()
}
fn jaccard_index(a: &std::collections::HashSet<&str>, b: &std::collections::HashSet<&str>) -> f64 {
if a.is_empty() && b.is_empty() {
return 1.0;
}
#[allow(clippy::cast_precision_loss)]
let intersection = a.intersection(b).count() as f64;
#[allow(clippy::cast_precision_loss)]
let union = a.union(b).count() as f64;
if union == 0.0 {
0.0
} else {
intersection / union
}
}
fn semantic_dedup(entries: Vec<RagEntry>, threshold: f32) -> Vec<RagEntry> {
if !entries.iter().any(|e| e.embedding.is_some()) {
return entries;
}
let mut keep = vec![true; entries.len()];
for i in 0..entries.len() {
if !keep[i] {
continue;
}
let Some(ref emb_i) = entries[i].embedding else {
continue;
};
for j in (i + 1)..entries.len() {
if !keep[j] {
continue;
}
let Some(ref emb_j) = entries[j].embedding else {
continue;
};
let similarity = cosine_similarity(emb_i, emb_j);
if similarity > threshold {
if entries[i].relevance >= entries[j].relevance {
keep[j] = false;
} else {
keep[i] = false;
break;
}
}
}
}
entries
.into_iter()
.zip(keep)
.filter_map(|(e, k)| if k { Some(e) } else { None })
.collect()
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn entry(content: &str, relevance: f32) -> RagEntry {
RagEntry {
content: content.to_string(),
relevance,
embedding: None,
}
}
#[test]
fn empty_entries_produce_empty_result() {
let result = deduplicate_rag(&[], 1000);
assert!(result.entries.is_empty());
assert_eq!(result.duplicates_removed, 0);
}
#[test]
fn identical_entries_deduplicated() {
let entries = vec![
entry("The weather today is sunny and warm", 0.9),
entry("The weather today is sunny and warm", 0.8),
];
let result = deduplicate_rag(&entries, 1000);
assert_eq!(result.entries.len(), 1);
assert_eq!(result.duplicates_removed, 1);
assert!((result.entries[0].relevance - 0.9).abs() < f32::EPSILON);
}
#[test]
fn different_entries_kept() {
let entries = vec![
entry("The weather is sunny", 0.9),
entry("The stock market crashed today", 0.8),
];
let result = deduplicate_rag(&entries, 1000);
assert_eq!(result.entries.len(), 2);
assert_eq!(result.duplicates_removed, 0);
}
#[test]
fn budget_trims_low_relevance() {
let entries = vec![
entry("Very important fact about the user", 0.95),
entry("Somewhat relevant background info", 0.7),
entry("Barely relevant trivia about something", 0.3),
];
let result = deduplicate_rag(&entries, 10);
assert!(result.entries.len() < 3);
assert!(result.budget_trimmed > 0);
assert!((result.entries[0].relevance - 0.95).abs() < f32::EPSILON);
}
#[test]
fn cosine_similarity_identical_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn semantic_dedup_drops_similar_embeddings() {
let entries = vec![
RagEntry {
content: "Weather is sunny".to_string(),
relevance: 0.9,
embedding: Some(vec![1.0, 0.0, 0.0]),
},
RagEntry {
content: "It is a sunny day".to_string(),
relevance: 0.8,
embedding: Some(vec![0.99, 0.01, 0.0]), },
];
let result = semantic_dedup(entries, 0.9);
assert_eq!(result.len(), 1);
}
}