use crate::estimator::TokenEstimator;
use crate::ports::SummarizationPort;
use crate::types::{ChatMessage, MessageRole};
const PRESSURE_THRESHOLD: f64 = 0.8;
const MIN_AGE_TURNS: usize = 5;
const MIN_MESSAGE_TOKENS: u32 = 40;
pub async fn paraphrase_old_messages(
messages: &mut [ChatMessage],
pressure: f64,
current_turn: usize,
summarizer: &dyn SummarizationPort,
) -> (usize, u32) {
if pressure < PRESSURE_THRESHOLD {
return (0, 0);
}
let mut paraphrased = 0usize;
let mut tokens_saved = 0u32;
let candidates: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(i, m)| {
m.role == MessageRole::Assistant
&& current_turn.saturating_sub(*i / 2) >= MIN_AGE_TURNS
&& TokenEstimator::estimate_tokens(&m.content) >= MIN_MESSAGE_TOKENS
})
.map(|(i, _)| i)
.collect();
for idx in candidates {
let original_tokens = TokenEstimator::estimate_tokens(&messages[idx].content);
let prompt = format!(
"Compress this assistant response to its key information in 1-2 sentences:\n\n{}",
messages[idx].content
);
let result = summarizer
.summarize(
"You are a precise text compressor. Output only the compressed text.",
&prompt,
)
.await;
if let Ok(compressed) = result {
let new_tokens = TokenEstimator::estimate_tokens(&compressed);
if new_tokens < original_tokens {
messages[idx].content = compressed;
tokens_saved += original_tokens.saturating_sub(new_tokens);
paraphrased += 1;
}
}
}
(paraphrased, tokens_saved)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::TokenOptError;
struct MockSummarizer;
#[async_trait::async_trait]
impl SummarizationPort for MockSummarizer {
async fn summarize(
&self,
_system_prompt: &str,
_text: &str,
) -> Result<String, TokenOptError> {
Ok("Key info retained.".to_string())
}
}
struct FailingSummarizer;
#[async_trait::async_trait]
impl SummarizationPort for FailingSummarizer {
async fn summarize(
&self,
_system_prompt: &str,
_text: &str,
) -> Result<String, TokenOptError> {
Err(TokenOptError::SummarizationFailed(
"mock failure".to_string(),
))
}
}
fn make_messages(count: usize) -> Vec<ChatMessage> {
let mut msgs = Vec::with_capacity(count * 2);
for i in 0..count {
msgs.push(ChatMessage::user(format!("Question {i}")));
msgs.push(ChatMessage::assistant(format!(
"This is a very long and detailed answer to question {i} that contains \
a lot of information, explanations, examples, and context that could \
be compressed significantly without losing the key points."
)));
}
msgs
}
#[tokio::test]
async fn no_paraphrase_below_pressure_threshold() {
let mut msgs = make_messages(10);
let summarizer = MockSummarizer;
let (count, saved) = paraphrase_old_messages(&mut msgs, 0.5, 9, &summarizer).await;
assert_eq!(count, 0);
assert_eq!(saved, 0);
}
#[tokio::test]
async fn paraphrases_old_assistant_messages() {
let mut msgs = make_messages(10);
let summarizer = MockSummarizer;
let (count, saved) = paraphrase_old_messages(&mut msgs, 0.9, 9, &summarizer).await;
assert!(count > 0, "Should have paraphrased some messages");
assert!(saved > 0, "Should have saved some tokens");
}
#[tokio::test]
async fn does_not_paraphrase_recent_messages() {
let mut msgs = make_messages(10);
let original_last = msgs.last().unwrap().content.clone();
let summarizer = MockSummarizer;
let _ = paraphrase_old_messages(&mut msgs, 0.9, 9, &summarizer).await;
assert_eq!(msgs.last().unwrap().content, original_last);
}
#[tokio::test]
async fn does_not_paraphrase_user_messages() {
let mut msgs = make_messages(10);
let user_contents: Vec<String> = msgs
.iter()
.filter(|m| m.role == MessageRole::User)
.map(|m| m.content.clone())
.collect();
let summarizer = MockSummarizer;
let _ = paraphrase_old_messages(&mut msgs, 0.9, 9, &summarizer).await;
let after_user: Vec<String> = msgs
.iter()
.filter(|m| m.role == MessageRole::User)
.map(|m| m.content.clone())
.collect();
assert_eq!(user_contents, after_user);
}
#[tokio::test]
async fn handles_summarizer_failure_gracefully() {
let mut msgs = make_messages(10);
let original: Vec<String> = msgs.iter().map(|m| m.content.clone()).collect();
let summarizer = FailingSummarizer;
let (count, saved) = paraphrase_old_messages(&mut msgs, 0.9, 9, &summarizer).await;
assert_eq!(count, 0);
assert_eq!(saved, 0);
let after: Vec<String> = msgs.iter().map(|m| m.content.clone()).collect();
assert_eq!(original, after);
}
#[tokio::test]
async fn does_not_replace_if_compressed_is_longer() {
struct LongerSummarizer;
#[async_trait::async_trait]
impl SummarizationPort for LongerSummarizer {
async fn summarize(
&self,
_system_prompt: &str,
_text: &str,
) -> Result<String, TokenOptError> {
Ok(
"This is actually a much longer response that is even bigger \
than the original because the model went off on a tangent and \
produced way more text than was asked for. It includes extra \
details about various topics, side tangents about the weather, \
descriptions of unrelated events, philosophical musings about \
the nature of compression itself, some poetry thrown in for \
good measure, a brief aside on the history of text summarization, \
a few paragraphs of filler text, and more extraneous content \
that nobody would ever want to read in a compressed summary. \
It just keeps going and going with no end in sight, making \
it thoroughly unsuitable as a compressed version of anything."
.to_string(),
)
}
}
let mut msgs = make_messages(10);
let original: Vec<String> = msgs.iter().map(|m| m.content.clone()).collect();
let summarizer = LongerSummarizer;
let (count, _) = paraphrase_old_messages(&mut msgs, 0.9, 9, &summarizer).await;
assert_eq!(count, 0);
let after: Vec<String> = msgs.iter().map(|m| m.content.clone()).collect();
assert_eq!(original, after);
}
}