use sha2::{Digest, Sha256};
use std::collections::HashMap;
use super::ExperimentalStats;
use crate::provider::{ContentPart, Message};
pub const MIN_DEDUP_BYTES: usize = 256;
pub fn dedup_tool_outputs(messages: &mut [Message]) -> ExperimentalStats {
let mut seen: HashMap<[u8; 32], String> = HashMap::new();
let mut stats = ExperimentalStats::default();
for msg in messages.iter_mut() {
for part in msg.content.iter_mut() {
let ContentPart::ToolResult {
tool_call_id,
content,
} = part
else {
continue;
};
if content.len() < MIN_DEDUP_BYTES {
continue;
}
let hash = Sha256::digest(content.as_bytes()).into();
match seen.get(&hash) {
Some(first_id) => {
let marker = format!(
"[DEDUP] identical to tool_call_id={first_id} ({} bytes)",
content.len()
);
let saved = content.len().saturating_sub(marker.len());
*content = marker;
stats.dedup_hits += 1;
stats.total_bytes_saved += saved;
}
None => {
seen.insert(hash, tool_call_id.clone());
}
}
}
}
stats
}
#[cfg(test)]
mod tests {
use super::*;
fn tool_msg(id: &str, content: &str) -> Message {
Message {
role: crate::provider::Role::Tool,
content: vec![ContentPart::ToolResult {
tool_call_id: id.into(),
content: content.into(),
}],
}
}
#[test]
fn short_outputs_are_not_deduplicated() {
let mut msgs = vec![tool_msg("a", "ok"), tool_msg("b", "ok")];
let stats = dedup_tool_outputs(&mut msgs);
assert_eq!(stats.dedup_hits, 0);
}
#[test]
fn distinct_outputs_are_preserved() {
let mut msgs = vec![
tool_msg("a", &"x".repeat(1024)),
tool_msg("b", &"y".repeat(1024)),
];
let stats = dedup_tool_outputs(&mut msgs);
assert_eq!(stats.dedup_hits, 0);
assert_eq!(stats.total_bytes_saved, 0);
}
#[test]
fn three_way_dedup_references_first_sighting() {
let big = "z".repeat(2048);
let mut msgs = vec![
tool_msg("first", &big),
tool_msg("second", &big),
tool_msg("third", &big),
];
let stats = dedup_tool_outputs(&mut msgs);
assert_eq!(stats.dedup_hits, 2);
for idx in [1, 2] {
let ContentPart::ToolResult { content, .. } = &msgs[idx].content[0] else {
panic!("expected tool result");
};
assert!(content.contains("tool_call_id=first"));
}
}
}