use saorsa_ai::message::Message;
use saorsa_ai::tokens::{estimate_conversation_tokens, estimate_message_tokens};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompactionStrategy {
TruncateOldest,
SummarizeBlocks,
Hybrid,
}
#[derive(Debug, Clone)]
pub struct CompactionConfig {
pub max_tokens: u32,
pub preserve_recent_count: usize,
pub strategy: CompactionStrategy,
}
impl Default for CompactionConfig {
fn default() -> Self {
Self {
max_tokens: 100_000,
preserve_recent_count: 5,
strategy: CompactionStrategy::TruncateOldest,
}
}
}
#[derive(Debug, Clone)]
pub struct CompactionStats {
pub original_tokens: u32,
pub compacted_tokens: u32,
pub messages_removed: usize,
}
pub fn compact(
messages: &[Message],
system: Option<&str>,
config: &CompactionConfig,
) -> (Vec<Message>, CompactionStats) {
let original_tokens = estimate_conversation_tokens(messages, system);
if original_tokens <= config.max_tokens {
return (
messages.to_vec(),
CompactionStats {
original_tokens,
compacted_tokens: original_tokens,
messages_removed: 0,
},
);
}
match config.strategy {
CompactionStrategy::TruncateOldest => {
truncate_oldest(messages, system, config, original_tokens)
}
CompactionStrategy::SummarizeBlocks | CompactionStrategy::Hybrid => {
truncate_oldest(messages, system, config, original_tokens)
}
}
}
fn truncate_oldest(
messages: &[Message],
system: Option<&str>,
config: &CompactionConfig,
original_tokens: u32,
) -> (Vec<Message>, CompactionStats) {
let system_tokens = system.map_or(0, saorsa_ai::tokens::estimate_tokens);
let non_system = messages;
let recent_start = non_system
.len()
.saturating_sub(config.preserve_recent_count);
let old_messages = &non_system[..recent_start];
let recent_messages = &non_system[recent_start..];
let recent_tokens: u32 = recent_messages.iter().map(estimate_message_tokens).sum();
let available_for_old = config
.max_tokens
.saturating_sub(system_tokens)
.saturating_sub(recent_tokens);
let mut kept_old = Vec::new();
let mut current_tokens = 0u32;
for msg in old_messages.iter().rev() {
let msg_tokens = estimate_message_tokens(msg);
if current_tokens + msg_tokens <= available_for_old {
kept_old.push((*msg).clone());
current_tokens += msg_tokens;
} else {
break;
}
}
kept_old.reverse();
let mut result = Vec::new();
result.extend(kept_old);
result.extend(recent_messages.iter().map(|m| (*m).clone()));
let compacted_tokens = estimate_conversation_tokens(&result, system);
let messages_removed = messages.len() - result.len();
(
result,
CompactionStats {
original_tokens,
compacted_tokens,
messages_removed,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use saorsa_ai::message::{Message, Role};
fn make_message(role: &str, text: &str) -> Message {
match role {
"user" => Message::user(text),
"assistant" => Message::assistant(text),
_ => unreachable!("Invalid role"),
}
}
#[test]
fn test_no_compaction_when_under_limit() {
let messages = vec![
make_message("user", "Hello"),
make_message("assistant", "Hi"),
];
let config = CompactionConfig {
max_tokens: 100_000,
..Default::default()
};
let (compacted, stats) = compact(&messages, None, &config);
assert_eq!(compacted.len(), messages.len());
assert_eq!(stats.messages_removed, 0);
assert_eq!(stats.original_tokens, stats.compacted_tokens);
}
#[test]
fn test_truncate_oldest_removes_old_messages() {
let large_text = "x".repeat(1000);
let messages = vec![
make_message("user", &large_text),
make_message("assistant", &large_text),
make_message("user", &large_text),
make_message("assistant", &large_text),
make_message("user", "Recent message"),
make_message("assistant", "Recent response"),
];
let config = CompactionConfig {
max_tokens: 100, preserve_recent_count: 2,
strategy: CompactionStrategy::TruncateOldest,
};
let (compacted, stats) = compact(&messages, None, &config);
assert!(compacted.len() >= 2);
assert!(stats.messages_removed > 0);
assert!(stats.compacted_tokens <= config.max_tokens);
}
#[test]
fn test_recent_messages_always_preserved() {
let large_text = "a".repeat(1000);
let messages = vec![
make_message("user", &large_text), make_message("assistant", "Old response"),
make_message("user", "Recent 1"),
make_message("assistant", "Recent 2"),
];
let config = CompactionConfig {
max_tokens: 100,
preserve_recent_count: 2,
strategy: CompactionStrategy::TruncateOldest,
};
let (compacted, _stats) = compact(&messages, None, &config);
assert!(compacted.len() >= 2);
let last_two = &compacted[compacted.len() - 2..];
assert_eq!(last_two[0].role, Role::User);
assert_eq!(last_two[1].role, Role::Assistant);
}
#[test]
fn test_compaction_with_system_prompt() {
let large_text = "a".repeat(1000);
let messages = vec![
make_message("user", &large_text),
make_message("assistant", "Response"),
];
let system = Some("System prompt here");
let config = CompactionConfig {
max_tokens: 100,
preserve_recent_count: 1,
strategy: CompactionStrategy::TruncateOldest,
};
let (_compacted, stats) = compact(&messages, system, &config);
assert!(stats.compacted_tokens <= config.max_tokens);
}
#[test]
fn test_compaction_achieves_target() {
let a_text = "a".repeat(1000);
let b_text = "b".repeat(1000);
let c_text = "c".repeat(1000);
let d_text = "d".repeat(1000);
let messages = vec![
make_message("user", &a_text),
make_message("assistant", &b_text),
make_message("user", &c_text),
make_message("assistant", &d_text),
make_message("user", "Recent"),
];
let config = CompactionConfig {
max_tokens: 100,
preserve_recent_count: 1,
strategy: CompactionStrategy::TruncateOldest,
};
let (compacted, stats) = compact(&messages, None, &config);
assert!(stats.compacted_tokens <= config.max_tokens);
assert!(stats.messages_removed > 0);
assert!(compacted.len() < messages.len());
}
#[test]
fn test_statistics_tracked_correctly() {
let messages = vec![
make_message("user", "Message 1"),
make_message("assistant", "Response 1"),
make_message("user", "Message 2"),
];
let config = CompactionConfig {
max_tokens: 20,
preserve_recent_count: 1,
strategy: CompactionStrategy::TruncateOldest,
};
let (compacted, stats) = compact(&messages, None, &config);
assert_eq!(stats.messages_removed, messages.len() - compacted.len());
assert!(stats.original_tokens > 0);
assert!(stats.compacted_tokens > 0);
assert!(stats.compacted_tokens <= stats.original_tokens);
}
#[test]
fn test_default_config() {
let config = CompactionConfig::default();
assert_eq!(config.max_tokens, 100_000);
assert_eq!(config.preserve_recent_count, 5);
assert_eq!(config.strategy, CompactionStrategy::TruncateOldest);
}
}