use thiserror::Error;
use crate::message::AgentMessage;
use crate::session_event::CompactionReason;
#[derive(Debug, Clone, PartialEq)]
pub struct CompactionConfig {
pub enabled: bool,
pub threshold_tokens: u64,
}
impl Default for CompactionConfig {
fn default() -> Self {
Self {
enabled: true,
threshold_tokens: 100_000,
}
}
}
#[derive(Debug, Clone)]
pub struct Entry {
pub id: String,
pub message: AgentMessage,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SummarySource {
Core,
Hook,
}
#[derive(Debug, Clone)]
pub struct CompactionOutput {
pub reason: CompactionReason,
pub summary_text: String,
pub first_kept_entry_id: String,
pub tokens_before: u64,
pub tokens_after: u64,
pub kept_entries: Vec<Entry>,
pub summary_source: SummarySource,
}
#[derive(Debug, Error)]
pub enum CompactionError {
#[error("nothing to compact")]
NothingToCompact,
}
pub trait CompactionHooks: Send + Sync {
fn generate_summary(&self, messages: &[AgentMessage]) -> Option<String>;
}
pub struct DefaultCompactionHooks;
impl CompactionHooks for DefaultCompactionHooks {
fn generate_summary(&self, _messages: &[AgentMessage]) -> Option<String> {
None
}
}
pub struct CompactionEngine {
config: CompactionConfig,
}
impl CompactionEngine {
pub fn new(config: CompactionConfig) -> Self {
Self { config }
}
pub fn should_compact(&self, total_tokens: u64, reason: CompactionReason) -> bool {
match reason {
CompactionReason::Manual => true,
CompactionReason::Overflow => self.config.enabled,
CompactionReason::Threshold => {
self.config.enabled && total_tokens >= self.config.threshold_tokens
}
}
}
pub fn compact(
&self,
entries: &[Entry],
reason: CompactionReason,
hooks: &dyn CompactionHooks,
) -> Result<CompactionOutput, CompactionError> {
if entries.len() < 2 {
return Err(CompactionError::NothingToCompact);
}
let tokens_before = estimate_total_tokens(entries);
let split_idx = find_split_point(entries);
let (compacted, kept) = entries.split_at(split_idx);
if kept.is_empty() {
return Err(CompactionError::NothingToCompact);
}
let first_kept_entry_id = kept[0].id.clone();
let compacted_messages: Vec<AgentMessage> =
compacted.iter().map(|e| e.message.clone()).collect();
let (summary_text, source) = match hooks.generate_summary(&compacted_messages) {
Some(s) => (s, SummarySource::Hook),
None => (
generate_core_summary(&compacted_messages),
SummarySource::Core,
),
};
let kept_entries = kept.to_vec();
let tokens_after = estimate_total_tokens(&kept_entries);
Ok(CompactionOutput {
reason,
summary_text,
first_kept_entry_id,
tokens_before,
tokens_after,
kept_entries,
summary_source: source,
})
}
}
fn find_split_point(entries: &[Entry]) -> usize {
if entries.is_empty() {
return 0;
}
if entries.len() == 1 {
return 0;
}
let min_keep = 1;
let proportional = entries.len() / 4;
let keep_count = proportional.max(min_keep);
entries.len().saturating_sub(keep_count)
}
fn estimate_total_tokens(entries: &[Entry]) -> u64 {
entries.iter().map(estimate_entry_tokens).sum()
}
fn estimate_entry_tokens(entry: &Entry) -> u64 {
estimate_message_tokens(&entry.message)
}
fn estimate_message_tokens(msg: &AgentMessage) -> u64 {
let text = extract_text(msg);
text.len() as u64 / 4
}
fn extract_text(msg: &AgentMessage) -> String {
match msg {
AgentMessage::Llm(opi_ai::message::Message::User(u)) => u
.content
.iter()
.filter_map(|c| match c {
opi_ai::message::InputContent::Text { text } => Some(text.clone()),
opi_ai::message::InputContent::Image { media_type, .. } => {
Some(format!("[image: {}]", media_type.as_str()))
}
_ => None,
})
.collect::<Vec<_>>()
.join(" "),
AgentMessage::Llm(opi_ai::message::Message::Assistant(a)) => a
.content
.iter()
.filter_map(|c| match c {
opi_ai::message::AssistantContent::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join(" "),
AgentMessage::Llm(opi_ai::message::Message::ToolResult(tr)) => tr
.content
.iter()
.filter_map(|c| match c {
opi_ai::message::OutputContent::Text { text } => Some(text.clone()),
opi_ai::message::OutputContent::Image { media_type, .. } => {
Some(format!("[image: {}]", media_type.as_str()))
}
_ => None,
})
.collect::<Vec<_>>()
.join(" "),
AgentMessage::CompactionSummary(cs) => cs.summary.clone(),
AgentMessage::BranchSummary(bs) => bs.summary.clone(),
AgentMessage::Custom(c) => c.data.to_string(),
_ => String::new(),
}
}
fn generate_core_summary(messages: &[AgentMessage]) -> String {
let texts: Vec<String> = messages.iter().map(extract_text).collect();
let combined = texts.join(". ");
let byte_count = combined.len();
if byte_count <= 500 {
format!("Compacted {} messages: {}", messages.len(), combined)
} else {
let truncated = &combined[..combined
.char_indices()
.take_while(|(i, _)| *i < 497)
.last()
.map(|(i, _)| i)
.unwrap_or(497)];
format!("Compacted {} messages: {}...", messages.len(), truncated)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn estimate_tokens_basic() {
let msg = AgentMessage::Llm(opi_ai::message::Message::User(
opi_ai::message::UserMessage {
content: vec![opi_ai::message::InputContent::Text {
text: "Hello world test".into(), }],
timestamp_ms: 0,
},
));
let tokens = estimate_message_tokens(&msg);
assert_eq!(tokens, 4, "17 chars / 4 = 4 tokens");
}
#[test]
fn split_point_keeps_tail() {
let entries: Vec<Entry> = (0..10)
.map(|i| Entry {
id: format!("e{}", i),
message: AgentMessage::Llm(opi_ai::message::Message::User(
opi_ai::message::UserMessage {
content: vec![opi_ai::message::InputContent::Text {
text: format!("msg {}", i),
}],
timestamp_ms: 0,
},
)),
})
.collect();
let split = find_split_point(&entries);
assert_eq!(split, 8, "should keep last 2 of 10 entries");
assert_eq!(entries[split].id, "e8");
}
}