use std::sync::Arc;
use ailoop_core::{
AssistantBlock, ChatRequest, CompletionModel, Message, StreamChunk, SystemPrompt, UserBlock,
};
use async_trait::async_trait;
use futures::StreamExt;
use crate::errors::CompactionError;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CompactionOutput {
pub messages: Vec<Message>,
pub pinned: Vec<bool>,
}
impl CompactionOutput {
pub fn new(messages: Vec<Message>, pinned: Vec<bool>) -> Self {
Self { messages, pinned }
}
}
#[async_trait]
pub trait CompactionStrategy: Send + Sync {
fn name(&self) -> &'static str;
async fn compact(
&self,
messages: &[Message],
pinned: &[bool],
preserve_n_last: usize,
) -> Result<CompactionOutput, CompactionError>;
}
pub struct TruncateStrategy;
#[async_trait]
impl CompactionStrategy for TruncateStrategy {
fn name(&self) -> &'static str {
"truncate"
}
async fn compact(
&self,
messages: &[Message],
pinned: &[bool],
preserve_n_last: usize,
) -> Result<CompactionOutput, CompactionError> {
if messages.len() <= preserve_n_last {
return Err(CompactionError::NotEnoughHistory);
}
let mut start = messages.len() - preserve_n_last;
while start > 0 && !is_safe_start(&messages[start]) {
start -= 1;
}
let mut out_messages = Vec::with_capacity(messages.len());
let mut out_pinned = Vec::with_capacity(messages.len());
for (i, msg) in messages.iter().enumerate().take(start) {
if pinned[i] {
out_messages.push(msg.clone());
out_pinned.push(true);
}
}
for (i, msg) in messages.iter().enumerate().skip(start) {
out_messages.push(msg.clone());
out_pinned.push(pinned[i]);
}
Ok(CompactionOutput {
messages: out_messages,
pinned: out_pinned,
})
}
}
fn is_safe_start(msg: &Message) -> bool {
match msg {
Message::User { blocks } => !blocks
.iter()
.any(|b| matches!(b, UserBlock::ToolResult { .. })),
Message::Assistant { .. } => false,
_ => false,
}
}
pub const DEFAULT_SUMMARIZER_PROMPT: &str = "You are summarizing a prior conversation between a user and an assistant. Produce a concise, faithful summary that captures the user's goals, decisions made, and important state (file paths, identifiers, numeric results, error messages) the next turn may need. Do not invent details. Output only the summary text — no preamble.";
pub struct SummarizeStrategy<M> {
model: Arc<M>,
summarizer_prompt: String,
max_tokens: u32,
}
impl<M> SummarizeStrategy<M>
where
M: CompletionModel + Send + Sync + 'static,
{
pub fn new(model: Arc<M>) -> Self {
Self {
model,
summarizer_prompt: DEFAULT_SUMMARIZER_PROMPT.into(),
max_tokens: 1024,
}
}
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.summarizer_prompt = prompt.into();
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
async fn summarize(&self, messages: Vec<Message>) -> Result<String, CompactionError> {
let mut req = ChatRequest::new(messages, self.max_tokens);
req.system_prompt = Some(SystemPrompt::Plain(self.summarizer_prompt.clone()));
let mut stream = self
.model
.chat_stream(req)
.await
.map_err(|e| CompactionError::SummarizationFailed(e.to_string()))?;
let mut buf = String::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| CompactionError::SummarizationFailed(e.to_string()))?;
if let StreamChunk::TextDelta { delta } = chunk {
buf.push_str(&delta);
}
}
if buf.is_empty() {
return Err(CompactionError::SummarizationFailed(
"summarizer model returned no text".into(),
));
}
Ok(buf)
}
}
#[async_trait]
impl<M> CompactionStrategy for SummarizeStrategy<M>
where
M: CompletionModel + Send + Sync + 'static,
{
fn name(&self) -> &'static str {
"summarize"
}
async fn compact(
&self,
messages: &[Message],
pinned: &[bool],
preserve_n_last: usize,
) -> Result<CompactionOutput, CompactionError> {
if messages.len() <= preserve_n_last {
return Err(CompactionError::NotEnoughHistory);
}
let mut start = messages.len() - preserve_n_last;
while start > 0 && !is_safe_start(&messages[start]) {
start -= 1;
}
let to_summarize: Vec<Message> = messages
.iter()
.enumerate()
.take(start)
.filter(|(i, _)| !pinned[*i])
.map(|(_, m)| flatten_for_summary(m))
.collect();
let mut out_messages = Vec::with_capacity(messages.len());
let mut out_pinned = Vec::with_capacity(messages.len());
for (i, msg) in messages.iter().enumerate().take(start) {
if pinned[i] {
out_messages.push(msg.clone());
out_pinned.push(true);
}
}
if !to_summarize.is_empty() {
let summary = self.summarize(to_summarize).await?;
out_messages.push(Message::user(format!(
"[Summary of prior conversation]\n{summary}"
)));
out_pinned.push(false);
}
for (i, msg) in messages.iter().enumerate().skip(start) {
out_messages.push(msg.clone());
out_pinned.push(pinned[i]);
}
Ok(CompactionOutput {
messages: out_messages,
pinned: out_pinned,
})
}
}
fn flatten_for_summary(msg: &Message) -> Message {
match msg {
Message::User { blocks } => Message::User {
blocks: blocks
.iter()
.map(|b| match b {
UserBlock::Text { text, .. } => UserBlock::text(text.clone()),
UserBlock::ToolResult {
call_id, content, ..
} => {
let parts: Vec<String> = content
.blocks
.iter()
.map(|b| match b {
ailoop_core::ToolResultBlock::Text { text } => text.clone(),
ailoop_core::ToolResultBlock::Image { .. } => "[image]".to_string(),
_ => "[unsupported tool result block]".to_string(),
})
.collect();
let body = parts.join(" ");
let body = if content.is_error {
format!("[error] {body}")
} else {
body
};
UserBlock::text(format!("[tool_result:{call_id}] {body}"))
}
UserBlock::Image { .. } => UserBlock::text("[image]"),
UserBlock::Document { .. } => UserBlock::text("[document]"),
_ => UserBlock::text("[unsupported user block]"),
})
.collect(),
},
Message::Assistant { blocks } => Message::Assistant {
blocks: blocks
.iter()
.map(|b| match b {
AssistantBlock::Text { text, .. } => AssistantBlock::text(text.clone()),
AssistantBlock::ToolCall { id, name, args, .. } => {
AssistantBlock::text(format!("[tool_call:{id} {name}] {args}"))
}
AssistantBlock::Reasoning { text, .. } => AssistantBlock::text(text.clone()),
AssistantBlock::RedactedReasoning { .. } => {
AssistantBlock::text("[redacted reasoning]".to_string())
}
_ => AssistantBlock::text("[unsupported assistant block]"),
})
.collect(),
},
_ => Message::user("[unsupported message]"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use ailoop_core::testing::{ScriptedError, ScriptedModel};
use ailoop_core::{AssistantBlock, FinishReason, ToolResultContent, Usage};
use serde_json::json;
fn tool_call(id: &str) -> Message {
Message::Assistant {
blocks: vec![AssistantBlock::tool_call(id, "t", json!({}))],
}
}
fn tool_result(call_id: &str) -> Message {
Message::User {
blocks: vec![UserBlock::tool_result(
call_id,
ToolResultContent::text("ok"),
)],
}
}
fn unpinned(n: usize) -> Vec<bool> {
vec![false; n]
}
#[tokio::test]
async fn keeps_normal_history_intact_when_no_pairs() {
let messages = vec![
Message::user("hi"),
Message::assistant_text("hello"),
Message::user("again"),
Message::assistant_text("yes"),
];
let out = TruncateStrategy
.compact(&messages, &unpinned(messages.len()), 2)
.await
.unwrap();
assert_eq!(out.messages.len(), 2);
assert!(matches!(out.messages[0], Message::User { .. }));
assert_eq!(out.pinned, vec![false, false]);
}
#[tokio::test]
async fn walks_back_when_cut_lands_on_tool_result() {
let messages = vec![
Message::user("solve this"),
tool_call("c1"),
tool_result("c1"),
Message::assistant_text("done"),
];
let out = TruncateStrategy
.compact(&messages, &unpinned(messages.len()), 2)
.await
.unwrap();
assert_eq!(out.messages.len(), 4);
}
#[tokio::test]
async fn walks_back_when_cut_lands_on_assistant() {
let messages = vec![
Message::user("hi"),
Message::assistant_text("hey"),
Message::user("more"),
Message::assistant_text("done"),
];
let out = TruncateStrategy
.compact(&messages, &unpinned(messages.len()), 1)
.await
.unwrap();
assert_eq!(out.messages.len(), 2);
assert!(matches!(out.messages[0], Message::User { .. }));
}
#[tokio::test]
async fn pinned_prefix_message_survives_truncation() {
let messages = vec![
Message::user("system-ish pinned"),
Message::user("turn 1 q"),
Message::assistant_text("turn 1 a"),
Message::user("turn 2 q"),
Message::assistant_text("turn 2 a"),
];
let mut pinned = unpinned(messages.len());
pinned[0] = true;
let out = TruncateStrategy
.compact(&messages, &pinned, 2)
.await
.unwrap();
assert_eq!(out.messages.len(), 3, "pinned prefix + tail of 2");
assert!(matches!(&out.messages[0], Message::User { blocks }
if matches!(&blocks[0], UserBlock::Text { text, .. } if text == "system-ish pinned")));
assert_eq!(out.pinned, vec![true, false, false]);
}
fn summary_turn(text: &str) -> Vec<StreamChunk> {
vec![
StreamChunk::TextDelta {
delta: text.to_string(),
},
StreamChunk::TurnFinished {
reason: FinishReason::EndTurn,
usage: Usage::default(),
service_tier: None,
},
]
}
fn first_user_text(msg: &Message) -> Option<&str> {
match msg {
Message::User { blocks } => blocks.iter().find_map(|b| match b {
UserBlock::Text { text, .. } => Some(text.as_str()),
_ => None,
}),
_ => None,
}
}
#[tokio::test]
async fn summarize_strategy_replaces_prefix_with_summary() {
let model = Arc::new(ScriptedModel::new([summary_turn(
"User asked about turn N, assistant answered.",
)]));
let strategy = SummarizeStrategy::new(model);
let messages = vec![
Message::user("turn 1 q"),
Message::assistant_text("turn 1 a"),
Message::user("turn 2 q"),
Message::assistant_text("turn 2 a"),
Message::user("turn 3 q"),
Message::assistant_text("turn 3 a"),
];
let pinned = unpinned(messages.len());
let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
assert_eq!(out.messages.len(), 3);
let summary_text =
first_user_text(&out.messages[0]).expect("summary must be a User text message");
assert!(
summary_text.contains("[Summary of prior conversation]")
&& summary_text.contains("User asked about turn N"),
"summary block content unexpected: {summary_text}"
);
assert_eq!(out.pinned, vec![false, false, false]);
}
#[tokio::test]
async fn summarize_strategy_preserves_pinned_prefix() {
let model = Arc::new(ScriptedModel::new([summary_turn("compact summary body")]));
let strategy = SummarizeStrategy::new(model);
let messages = vec![
Message::user("PIN: persistent anchor"),
Message::user("turn 1 q"),
Message::assistant_text("turn 1 a"),
Message::user("turn 2 q"),
Message::assistant_text("turn 2 a"),
Message::user("turn 3 q"),
Message::assistant_text("turn 3 a"),
];
let mut pinned = unpinned(messages.len());
pinned[0] = true;
let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
assert_eq!(out.messages.len(), 4);
assert_eq!(
first_user_text(&out.messages[0]),
Some("PIN: persistent anchor")
);
assert!(
first_user_text(&out.messages[1])
.unwrap()
.contains("compact summary body"),
"expected summary right after pinned anchor"
);
assert_eq!(out.pinned, vec![true, false, false, false]);
}
#[tokio::test]
async fn summarize_strategy_propagates_model_error() {
let model = Arc::new(ScriptedModel::with_turns([Err(ScriptedError(
"summary network outage".into(),
))]));
let strategy = SummarizeStrategy::new(model);
let messages = vec![
Message::user("turn 1 q"),
Message::assistant_text("turn 1 a"),
Message::user("turn 2 q"),
Message::assistant_text("turn 2 a"),
Message::user("turn 3 q"),
];
let pinned = unpinned(messages.len());
let err = strategy
.compact(&messages, &pinned, 2)
.await
.expect_err("model error must propagate");
match err {
CompactionError::SummarizationFailed(msg) => {
assert!(
msg.contains("summary network outage"),
"expected wrapped model error, got: {msg}"
);
}
other => panic!("expected SummarizationFailed, got {other:?}"),
}
}
#[tokio::test]
async fn summarize_strategy_skips_model_call_when_prefix_all_pinned() {
let model = Arc::new(ScriptedModel::new(Vec::<Vec<StreamChunk>>::new()));
let strategy = SummarizeStrategy::new(model);
let messages = vec![
Message::user("PIN A"),
Message::user("PIN B"),
Message::user("tail q"),
Message::assistant_text("tail a"),
];
let mut pinned = unpinned(messages.len());
pinned[0] = true;
pinned[1] = true;
let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
assert_eq!(out.messages.len(), 4);
assert_eq!(first_user_text(&out.messages[0]), Some("PIN A"));
assert_eq!(first_user_text(&out.messages[1]), Some("PIN B"));
assert_eq!(first_user_text(&out.messages[2]), Some("tail q"));
assert_eq!(out.pinned, vec![true, true, false, false]);
}
#[tokio::test]
async fn summarize_strategy_flattens_tool_blocks_in_prefix() {
let model = Arc::new(ScriptedModel::new([summary_turn("flattened summary")]));
let strategy = SummarizeStrategy::new(model);
let messages = vec![
Message::user("solve task"),
tool_call("c1"),
tool_result("c1"),
Message::user("next q"),
Message::assistant_text("next a"),
];
let pinned = unpinned(messages.len());
let out = strategy.compact(&messages, &pinned, 2).await.unwrap();
assert_eq!(out.messages.len(), 3);
assert!(
first_user_text(&out.messages[0])
.unwrap()
.contains("flattened summary")
);
}
#[test]
fn flatten_for_summary_renders_tool_blocks_as_text() {
let call = Message::Assistant {
blocks: vec![AssistantBlock::tool_call("c1", "t", json!({"k": 1}))],
};
match flatten_for_summary(&call) {
Message::Assistant { blocks } => match &blocks[0] {
AssistantBlock::Text { text, .. } => {
assert!(text.starts_with("[tool_call:c1 t]"), "got: {text}");
assert!(text.contains("\"k\":1"), "args missing: {text}");
}
other => panic!("expected text block, got {other:?}"),
},
other => panic!("expected assistant message, got {other:?}"),
}
let result = Message::User {
blocks: vec![UserBlock::tool_result(
"c1",
ToolResultContent::text("done"),
)],
};
match flatten_for_summary(&result) {
Message::User { blocks } => match &blocks[0] {
UserBlock::Text { text, .. } => {
assert_eq!(text, "[tool_result:c1] done");
}
other => panic!("expected text block, got {other:?}"),
},
other => panic!("expected user message, got {other:?}"),
}
}
}