1use thiserror::Error;
7
8use crate::message::AgentMessage;
9use crate::session_event::CompactionReason;
10
11#[derive(Debug, Clone, PartialEq)]
13pub struct CompactionConfig {
14 pub enabled: bool,
15 pub threshold_tokens: u64,
16}
17
18impl Default for CompactionConfig {
19 fn default() -> Self {
20 Self {
21 enabled: true,
22 threshold_tokens: 100_000,
23 }
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct Entry {
30 pub id: String,
31 pub message: AgentMessage,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SummarySource {
37 Core,
38 Hook,
39}
40
41#[derive(Debug, Clone)]
43pub struct CompactionOutput {
44 pub reason: CompactionReason,
45 pub summary_text: String,
46 pub first_kept_entry_id: String,
47 pub tokens_before: u64,
48 pub tokens_after: u64,
49 pub kept_entries: Vec<Entry>,
50 pub summary_source: SummarySource,
51}
52
53#[derive(Debug, Error)]
55pub enum CompactionError {
56 #[error("nothing to compact")]
57 NothingToCompact,
58}
59
60pub trait CompactionHooks: Send + Sync {
62 fn generate_summary(&self, messages: &[AgentMessage]) -> Option<String>;
65}
66
67pub struct DefaultCompactionHooks;
69
70impl CompactionHooks for DefaultCompactionHooks {
71 fn generate_summary(&self, _messages: &[AgentMessage]) -> Option<String> {
72 None
73 }
74}
75
76pub struct CompactionEngine {
78 config: CompactionConfig,
79}
80
81impl CompactionEngine {
82 pub fn new(config: CompactionConfig) -> Self {
83 Self { config }
84 }
85
86 pub fn should_compact(&self, total_tokens: u64, reason: CompactionReason) -> bool {
88 match reason {
89 CompactionReason::Manual => true,
90 CompactionReason::Overflow => self.config.enabled,
91 CompactionReason::Threshold => {
92 self.config.enabled && total_tokens >= self.config.threshold_tokens
93 }
94 }
95 }
96
97 pub fn compact(
99 &self,
100 entries: &[Entry],
101 reason: CompactionReason,
102 hooks: &dyn CompactionHooks,
103 ) -> Result<CompactionOutput, CompactionError> {
104 if entries.len() < 2 {
105 return Err(CompactionError::NothingToCompact);
106 }
107
108 let tokens_before = estimate_total_tokens(entries);
109
110 let split_idx = find_split_point(entries);
113
114 let (compacted, kept) = entries.split_at(split_idx);
115 if kept.is_empty() {
116 return Err(CompactionError::NothingToCompact);
117 }
118
119 let first_kept_entry_id = kept[0].id.clone();
120
121 let compacted_messages: Vec<AgentMessage> =
123 compacted.iter().map(|e| e.message.clone()).collect();
124 let (summary_text, source) = match hooks.generate_summary(&compacted_messages) {
125 Some(s) => (s, SummarySource::Hook),
126 None => (
127 generate_core_summary(&compacted_messages),
128 SummarySource::Core,
129 ),
130 };
131
132 let kept_entries = kept.to_vec();
133 let tokens_after = estimate_total_tokens(&kept_entries);
134
135 Ok(CompactionOutput {
136 reason,
137 summary_text,
138 first_kept_entry_id,
139 tokens_before,
140 tokens_after,
141 kept_entries,
142 summary_source: source,
143 })
144 }
145}
146
147fn find_split_point(entries: &[Entry]) -> usize {
149 if entries.is_empty() {
150 return 0;
151 }
152
153 if entries.len() == 1 {
155 return 0;
156 }
157
158 let min_keep = 1;
160 let proportional = entries.len() / 4;
161 let keep_count = proportional.max(min_keep);
162
163 entries.len().saturating_sub(keep_count)
164}
165
166fn estimate_total_tokens(entries: &[Entry]) -> u64 {
168 entries.iter().map(estimate_entry_tokens).sum()
169}
170
171fn estimate_entry_tokens(entry: &Entry) -> u64 {
173 estimate_message_tokens(&entry.message)
174}
175
176fn estimate_message_tokens(msg: &AgentMessage) -> u64 {
178 let text = extract_text(msg);
179 text.len() as u64 / 4
180}
181
182fn extract_text(msg: &AgentMessage) -> String {
184 match msg {
185 AgentMessage::Llm(opi_ai::message::Message::User(u)) => u
186 .content
187 .iter()
188 .filter_map(|c| match c {
189 opi_ai::message::InputContent::Text { text } => Some(text.clone()),
190 opi_ai::message::InputContent::Image { media_type, .. } => {
191 Some(format!("[image: {}]", media_type.as_str()))
192 }
193 _ => None,
194 })
195 .collect::<Vec<_>>()
196 .join(" "),
197 AgentMessage::Llm(opi_ai::message::Message::Assistant(a)) => a
198 .content
199 .iter()
200 .filter_map(|c| match c {
201 opi_ai::message::AssistantContent::Text { text } => Some(text.clone()),
202 _ => None,
203 })
204 .collect::<Vec<_>>()
205 .join(" "),
206 AgentMessage::Llm(opi_ai::message::Message::ToolResult(tr)) => tr
207 .content
208 .iter()
209 .filter_map(|c| match c {
210 opi_ai::message::OutputContent::Text { text } => Some(text.clone()),
211 opi_ai::message::OutputContent::Image { media_type, .. } => {
212 Some(format!("[image: {}]", media_type.as_str()))
213 }
214 _ => None,
215 })
216 .collect::<Vec<_>>()
217 .join(" "),
218 AgentMessage::CompactionSummary(cs) => cs.summary.clone(),
219 AgentMessage::BranchSummary(bs) => bs.summary.clone(),
220 AgentMessage::Custom(c) => c.data.to_string(),
221 _ => String::new(),
222 }
223}
224
225fn generate_core_summary(messages: &[AgentMessage]) -> String {
227 let texts: Vec<String> = messages.iter().map(extract_text).collect();
228 let combined = texts.join(". ");
229 let byte_count = combined.len();
230
231 if byte_count <= 500 {
232 format!("Compacted {} messages: {}", messages.len(), combined)
233 } else {
234 let truncated = &combined[..combined
236 .char_indices()
237 .take_while(|(i, _)| *i < 497)
238 .last()
239 .map(|(i, _)| i)
240 .unwrap_or(497)];
241 format!("Compacted {} messages: {}...", messages.len(), truncated)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn estimate_tokens_basic() {
251 let msg = AgentMessage::Llm(opi_ai::message::Message::User(
252 opi_ai::message::UserMessage {
253 content: vec![opi_ai::message::InputContent::Text {
254 text: "Hello world test".into(), }],
256 timestamp_ms: 0,
257 },
258 ));
259 let tokens = estimate_message_tokens(&msg);
260 assert_eq!(tokens, 4, "17 chars / 4 = 4 tokens");
261 }
262
263 #[test]
264 fn split_point_keeps_tail() {
265 let entries: Vec<Entry> = (0..10)
266 .map(|i| Entry {
267 id: format!("e{}", i),
268 message: AgentMessage::Llm(opi_ai::message::Message::User(
269 opi_ai::message::UserMessage {
270 content: vec![opi_ai::message::InputContent::Text {
271 text: format!("msg {}", i),
272 }],
273 timestamp_ms: 0,
274 },
275 )),
276 })
277 .collect();
278
279 let split = find_split_point(&entries);
280 assert_eq!(split, 8, "should keep last 2 of 10 entries");
281 assert_eq!(entries[split].id, "e8");
282 }
283}