1use std::sync::Arc;
6
7use ailoop_core::{
8 AssistantBlock, ChatRequest, CompletionModel, Message, StreamChunk, SystemPrompt, UserBlock,
9};
10use async_trait::async_trait;
11use futures::StreamExt;
12
13use crate::errors::CompactionError;
14
15#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct CompactionOutput {
25 pub messages: Vec<Message>,
29 pub pinned: Vec<bool>,
33}
34
35impl CompactionOutput {
36 pub fn new(messages: Vec<Message>, pinned: Vec<bool>) -> Self {
42 Self { messages, pinned }
43 }
44}
45
46#[async_trait]
60pub trait CompactionStrategy: Send + Sync {
61 fn name(&self) -> &'static str;
67
68 async fn compact(
80 &self,
81 messages: &[Message],
82 pinned: &[bool],
83 preserve_n_last: usize,
84 ) -> Result<CompactionOutput, CompactionError>;
85}
86
87pub struct TruncateStrategy;
94
95#[async_trait]
96impl CompactionStrategy for TruncateStrategy {
97 fn name(&self) -> &'static str {
98 "truncate"
99 }
100
101 async fn compact(
102 &self,
103 messages: &[Message],
104 pinned: &[bool],
105 preserve_n_last: usize,
106 ) -> Result<CompactionOutput, CompactionError> {
107 if messages.len() <= preserve_n_last {
108 return Err(CompactionError::NotEnoughHistory);
109 }
110
111 let mut start = messages.len() - preserve_n_last;
112
113 while start > 0 && !is_safe_start(&messages[start]) {
118 start -= 1;
119 }
120
121 let mut out_messages = Vec::with_capacity(messages.len());
122 let mut out_pinned = Vec::with_capacity(messages.len());
123
124 for (i, msg) in messages.iter().enumerate().take(start) {
129 if pinned[i] {
130 out_messages.push(msg.clone());
131 out_pinned.push(true);
132 }
133 }
134
135 for (i, msg) in messages.iter().enumerate().skip(start) {
136 out_messages.push(msg.clone());
137 out_pinned.push(pinned[i]);
138 }
139
140 Ok(CompactionOutput {
141 messages: out_messages,
142 pinned: out_pinned,
143 })
144 }
145}
146
147fn is_safe_start(msg: &Message) -> bool {
148 match msg {
149 Message::User { blocks } => !blocks
150 .iter()
151 .any(|b| matches!(b, UserBlock::ToolResult { .. })),
152 Message::Assistant { .. } => false,
153 _ => false,
154 }
155}
156
157pub const DEFAULT_SUMMARIZER_PROMPT: &str = "You are summarizing a prior conversation between a user and an assistant. Produce a concise, faithful summary that captures the user's goals, decisions made, and important state (file paths, identifiers, numeric results, error messages) the next turn may need. Do not invent details. Output only the summary text — no preamble.";
161
162pub struct SummarizeStrategy<M> {
183 model: Arc<M>,
184 summarizer_prompt: String,
185 max_tokens: u32,
186}
187
188impl<M> SummarizeStrategy<M>
189where
190 M: CompletionModel + Send + Sync + 'static,
191{
192 pub fn new(model: Arc<M>) -> Self {
196 Self {
197 model,
198 summarizer_prompt: DEFAULT_SUMMARIZER_PROMPT.into(),
199 max_tokens: 1024,
200 }
201 }
202
203 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
207 self.summarizer_prompt = prompt.into();
208 self
209 }
210
211 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
215 self.max_tokens = max_tokens;
216 self
217 }
218
219 async fn summarize(&self, messages: Vec<Message>) -> Result<String, CompactionError> {
220 let mut req = ChatRequest::new(messages, self.max_tokens);
224 req.system_prompt = Some(SystemPrompt::Plain(self.summarizer_prompt.clone()));
225
226 let mut stream = self
227 .model
228 .chat_stream(req)
229 .await
230 .map_err(|e| CompactionError::SummarizationFailed(e.to_string()))?;
231
232 let mut buf = String::new();
233 while let Some(chunk) = stream.next().await {
234 let chunk = chunk.map_err(|e| CompactionError::SummarizationFailed(e.to_string()))?;
235 if let StreamChunk::TextDelta { delta } = chunk {
236 buf.push_str(&delta);
237 }
238 }
239
240 if buf.is_empty() {
241 return Err(CompactionError::SummarizationFailed(
242 "summarizer model returned no text".into(),
243 ));
244 }
245
246 Ok(buf)
247 }
248}
249
250#[async_trait]
251impl<M> CompactionStrategy for SummarizeStrategy<M>
252where
253 M: CompletionModel + Send + Sync + 'static,
254{
255 fn name(&self) -> &'static str {
256 "summarize"
257 }
258
259 async fn compact(
260 &self,
261 messages: &[Message],
262 pinned: &[bool],
263 preserve_n_last: usize,
264 ) -> Result<CompactionOutput, CompactionError> {
265 if messages.len() <= preserve_n_last {
266 return Err(CompactionError::NotEnoughHistory);
267 }
268
269 let mut start = messages.len() - preserve_n_last;
270 while start > 0 && !is_safe_start(&messages[start]) {
271 start -= 1;
272 }
273
274 let to_summarize: Vec<Message> = messages
279 .iter()
280 .enumerate()
281 .take(start)
282 .filter(|(i, _)| !pinned[*i])
283 .map(|(_, m)| flatten_for_summary(m))
284 .collect();
285
286 let mut out_messages = Vec::with_capacity(messages.len());
287 let mut out_pinned = Vec::with_capacity(messages.len());
288
289 for (i, msg) in messages.iter().enumerate().take(start) {
290 if pinned[i] {
291 out_messages.push(msg.clone());
292 out_pinned.push(true);
293 }
294 }
295
296 if !to_summarize.is_empty() {
297 let summary = self.summarize(to_summarize).await?;
298 out_messages.push(Message::user(format!(
299 "[Summary of prior conversation]\n{summary}"
300 )));
301 out_pinned.push(false);
302 }
303
304 for (i, msg) in messages.iter().enumerate().skip(start) {
305 out_messages.push(msg.clone());
306 out_pinned.push(pinned[i]);
307 }
308
309 Ok(CompactionOutput {
310 messages: out_messages,
311 pinned: out_pinned,
312 })
313 }
314}
315
316fn flatten_for_summary(msg: &Message) -> Message {
322 match msg {
323 Message::User { blocks } => Message::User {
324 blocks: blocks
325 .iter()
326 .map(|b| match b {
327 UserBlock::Text { text, .. } => UserBlock::text(text.clone()),
328 UserBlock::ToolResult {
329 call_id, content, ..
330 } => {
331 let parts: Vec<String> = content
335 .blocks
336 .iter()
337 .map(|b| match b {
338 ailoop_core::ToolResultBlock::Text { text } => text.clone(),
339 ailoop_core::ToolResultBlock::Image { .. } => "[image]".to_string(),
340 _ => "[unsupported tool result block]".to_string(),
341 })
342 .collect();
343 let body = parts.join(" ");
344 let body = if content.is_error {
345 format!("[error] {body}")
346 } else {
347 body
348 };
349 UserBlock::text(format!("[tool_result:{call_id}] {body}"))
350 }
351 UserBlock::Image { .. } => UserBlock::text("[image]"),
352 UserBlock::Document { .. } => UserBlock::text("[document]"),
353 _ => UserBlock::text("[unsupported user block]"),
358 })
359 .collect(),
360 },
361 Message::Assistant { blocks } => Message::Assistant {
362 blocks: blocks
363 .iter()
364 .map(|b| match b {
365 AssistantBlock::Text { text, .. } => AssistantBlock::text(text.clone()),
366 AssistantBlock::ToolCall { id, name, args, .. } => {
367 AssistantBlock::text(format!("[tool_call:{id} {name}] {args}"))
368 }
369 AssistantBlock::Reasoning { text, .. } => AssistantBlock::text(text.clone()),
370 AssistantBlock::RedactedReasoning { .. } => {
371 AssistantBlock::text("[redacted reasoning]".to_string())
372 }
373 _ => AssistantBlock::text("[unsupported assistant block]"),
374 })
375 .collect(),
376 },
377 _ => Message::user("[unsupported message]"),
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use ailoop_core::testing::{ScriptedError, ScriptedModel};
385 use ailoop_core::{AssistantBlock, FinishReason, ToolResultContent, Usage};
386 use serde_json::json;
387
388 fn tool_call(id: &str) -> Message {
389 Message::Assistant {
390 blocks: vec![AssistantBlock::tool_call(id, "t", json!({}))],
391 }
392 }
393
394 fn tool_result(call_id: &str) -> Message {
395 Message::User {
396 blocks: vec![UserBlock::tool_result(
397 call_id,
398 ToolResultContent::text("ok"),
399 )],
400 }
401 }
402
403 fn unpinned(n: usize) -> Vec<bool> {
404 vec![false; n]
405 }
406
407 #[tokio::test]
408 async fn keeps_normal_history_intact_when_no_pairs() {
409 let messages = vec![
410 Message::user("hi"),
411 Message::assistant_text("hello"),
412 Message::user("again"),
413 Message::assistant_text("yes"),
414 ];
415
416 let out = TruncateStrategy
417 .compact(&messages, &unpinned(messages.len()), 2)
418 .await
419 .unwrap();
420 assert_eq!(out.messages.len(), 2);
421 assert!(matches!(out.messages[0], Message::User { .. }));
422 assert_eq!(out.pinned, vec![false, false]);
423 }
424
425 #[tokio::test]
426 async fn walks_back_when_cut_lands_on_tool_result() {
427 let messages = vec![
428 Message::user("solve this"),
429 tool_call("c1"),
430 tool_result("c1"),
431 Message::assistant_text("done"),
432 ];
433
434 let out = TruncateStrategy
435 .compact(&messages, &unpinned(messages.len()), 2)
436 .await
437 .unwrap();
438 assert_eq!(out.messages.len(), 4);
439 }
440
441 #[tokio::test]
442 async fn walks_back_when_cut_lands_on_assistant() {
443 let messages = vec![
444 Message::user("hi"),
445 Message::assistant_text("hey"),
446 Message::user("more"),
447 Message::assistant_text("done"),
448 ];
449
450 let out = TruncateStrategy
451 .compact(&messages, &unpinned(messages.len()), 1)
452 .await
453 .unwrap();
454 assert_eq!(out.messages.len(), 2);
455 assert!(matches!(out.messages[0], Message::User { .. }));
456 }
457
458 #[tokio::test]
459 async fn pinned_prefix_message_survives_truncation() {
460 let messages = vec![
461 Message::user("system-ish pinned"),
462 Message::user("turn 1 q"),
463 Message::assistant_text("turn 1 a"),
464 Message::user("turn 2 q"),
465 Message::assistant_text("turn 2 a"),
466 ];
467 let mut pinned = unpinned(messages.len());
468 pinned[0] = true;
469
470 let out = TruncateStrategy
471 .compact(&messages, &pinned, 2)
472 .await
473 .unwrap();
474
475 assert_eq!(out.messages.len(), 3, "pinned prefix + tail of 2");
476 assert!(matches!(&out.messages[0], Message::User { blocks }
477 if matches!(&blocks[0], UserBlock::Text { text, .. } if text == "system-ish pinned")));
478 assert_eq!(out.pinned, vec![true, false, false]);
479 }
480
481 fn summary_turn(text: &str) -> Vec<StreamChunk> {
482 vec![
483 StreamChunk::TextDelta {
484 delta: text.to_string(),
485 },
486 StreamChunk::TurnFinished {
487 reason: FinishReason::EndTurn,
488 usage: Usage::default(),
489 service_tier: None,
490 },
491 ]
492 }
493
494 fn first_user_text(msg: &Message) -> Option<&str> {
495 match msg {
496 Message::User { blocks } => blocks.iter().find_map(|b| match b {
497 UserBlock::Text { text, .. } => Some(text.as_str()),
498 _ => None,
499 }),
500 _ => None,
501 }
502 }
503
504 #[tokio::test]
505 async fn summarize_strategy_replaces_prefix_with_summary() {
506 let model = Arc::new(ScriptedModel::new([summary_turn(
507 "User asked about turn N, assistant answered.",
508 )]));
509 let strategy = SummarizeStrategy::new(model);
510
511 let messages = vec![
512 Message::user("turn 1 q"),
513 Message::assistant_text("turn 1 a"),
514 Message::user("turn 2 q"),
515 Message::assistant_text("turn 2 a"),
516 Message::user("turn 3 q"),
517 Message::assistant_text("turn 3 a"),
518 ];
519 let pinned = unpinned(messages.len());
520
521 let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
522
523 assert_eq!(out.messages.len(), 3);
525 let summary_text =
526 first_user_text(&out.messages[0]).expect("summary must be a User text message");
527 assert!(
528 summary_text.contains("[Summary of prior conversation]")
529 && summary_text.contains("User asked about turn N"),
530 "summary block content unexpected: {summary_text}"
531 );
532 assert_eq!(out.pinned, vec![false, false, false]);
534 }
535
536 #[tokio::test]
537 async fn summarize_strategy_preserves_pinned_prefix() {
538 let model = Arc::new(ScriptedModel::new([summary_turn("compact summary body")]));
539 let strategy = SummarizeStrategy::new(model);
540
541 let messages = vec![
542 Message::user("PIN: persistent anchor"),
543 Message::user("turn 1 q"),
544 Message::assistant_text("turn 1 a"),
545 Message::user("turn 2 q"),
546 Message::assistant_text("turn 2 a"),
547 Message::user("turn 3 q"),
548 Message::assistant_text("turn 3 a"),
549 ];
550 let mut pinned = unpinned(messages.len());
551 pinned[0] = true;
552
553 let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
554
555 assert_eq!(out.messages.len(), 4);
557 assert_eq!(
558 first_user_text(&out.messages[0]),
559 Some("PIN: persistent anchor")
560 );
561 assert!(
562 first_user_text(&out.messages[1])
563 .unwrap()
564 .contains("compact summary body"),
565 "expected summary right after pinned anchor"
566 );
567 assert_eq!(out.pinned, vec![true, false, false, false]);
568 }
569
570 #[tokio::test]
571 async fn summarize_strategy_propagates_model_error() {
572 let model = Arc::new(ScriptedModel::with_turns([Err(ScriptedError(
573 "summary network outage".into(),
574 ))]));
575 let strategy = SummarizeStrategy::new(model);
576
577 let messages = vec![
578 Message::user("turn 1 q"),
579 Message::assistant_text("turn 1 a"),
580 Message::user("turn 2 q"),
581 Message::assistant_text("turn 2 a"),
582 Message::user("turn 3 q"),
583 ];
584 let pinned = unpinned(messages.len());
585
586 let err = strategy
587 .compact(&messages, &pinned, 2)
588 .await
589 .expect_err("model error must propagate");
590 match err {
591 CompactionError::SummarizationFailed(msg) => {
592 assert!(
593 msg.contains("summary network outage"),
594 "expected wrapped model error, got: {msg}"
595 );
596 }
597 other => panic!("expected SummarizationFailed, got {other:?}"),
598 }
599 }
600
601 #[tokio::test]
602 async fn summarize_strategy_skips_model_call_when_prefix_all_pinned() {
603 let model = Arc::new(ScriptedModel::new(Vec::<Vec<StreamChunk>>::new()));
608 let strategy = SummarizeStrategy::new(model);
609
610 let messages = vec![
611 Message::user("PIN A"),
612 Message::user("PIN B"),
613 Message::user("tail q"),
614 Message::assistant_text("tail a"),
615 ];
616 let mut pinned = unpinned(messages.len());
617 pinned[0] = true;
618 pinned[1] = true;
619
620 let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
621 assert_eq!(out.messages.len(), 4);
623 assert_eq!(first_user_text(&out.messages[0]), Some("PIN A"));
624 assert_eq!(first_user_text(&out.messages[1]), Some("PIN B"));
625 assert_eq!(first_user_text(&out.messages[2]), Some("tail q"));
626 assert_eq!(out.pinned, vec![true, true, false, false]);
627 }
628
629 #[tokio::test]
630 async fn summarize_strategy_flattens_tool_blocks_in_prefix() {
631 let model = Arc::new(ScriptedModel::new([summary_turn("flattened summary")]));
639 let strategy = SummarizeStrategy::new(model);
640
641 let messages = vec![
642 Message::user("solve task"),
643 tool_call("c1"),
644 tool_result("c1"),
645 Message::user("next q"),
646 Message::assistant_text("next a"),
647 ];
648 let pinned = unpinned(messages.len());
649
650 let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
651 assert_eq!(out.messages.len(), 3);
653 assert!(
654 first_user_text(&out.messages[0])
655 .unwrap()
656 .contains("flattened summary")
657 );
658 }
659
660 #[test]
661 fn flatten_for_summary_renders_tool_blocks_as_text() {
662 let call = Message::Assistant {
663 blocks: vec![AssistantBlock::tool_call("c1", "t", json!({"k": 1}))],
664 };
665 match flatten_for_summary(&call) {
666 Message::Assistant { blocks } => match &blocks[0] {
667 AssistantBlock::Text { text, .. } => {
668 assert!(text.starts_with("[tool_call:c1 t]"), "got: {text}");
669 assert!(text.contains("\"k\":1"), "args missing: {text}");
670 }
671 other => panic!("expected text block, got {other:?}"),
672 },
673 other => panic!("expected assistant message, got {other:?}"),
674 }
675
676 let result = Message::User {
677 blocks: vec![UserBlock::tool_result(
678 "c1",
679 ToolResultContent::text("done"),
680 )],
681 };
682 match flatten_for_summary(&result) {
683 Message::User { blocks } => match &blocks[0] {
684 UserBlock::Text { text, .. } => {
685 assert_eq!(text, "[tool_result:c1] done");
686 }
687 other => panic!("expected text block, got {other:?}"),
688 },
689 other => panic!("expected user message, got {other:?}"),
690 }
691 }
692}