use ailoop_core::{AssistantBlock, CharTokenizer, Message, Tokenizer, UserBlock};
use crate::{
compaction::{CompactionStrategy, TruncateStrategy},
errors::{CompactionError, FromMessagesError},
};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct CompactionReport {
pub before: usize,
pub after: usize,
pub strategy: &'static str,
}
pub struct ContextManager {
messages: Vec<Message>,
pinned: Vec<bool>,
max_tokens: usize,
preserve_n_last: usize,
strategy: Box<dyn CompactionStrategy>,
tokenizer: Box<dyn Tokenizer>,
}
impl ContextManager {
pub fn builder(max_tokens: usize) -> ContextManagerBuilder {
ContextManagerBuilder::new(max_tokens)
}
pub fn from_messages(
builder: ContextManagerBuilder,
messages: Vec<Message>,
pinned: Vec<bool>,
) -> Result<Self, FromMessagesError> {
if messages.len() != pinned.len() {
return Err(FromMessagesError::LengthMismatch {
messages: messages.len(),
pinned: pinned.len(),
});
}
let mut cm = builder.build();
cm.messages = messages;
cm.pinned = pinned;
Ok(cm)
}
}
impl ContextManager {
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
self.pinned.push(false);
}
pub fn estimated_tokens(&self) -> usize {
self.tokenizer.count_messages(&self.messages)
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn extend(&mut self, new_messages: Vec<Message>) {
let added = new_messages.len();
self.messages.extend(new_messages);
self.pinned.extend(std::iter::repeat_n(false, added));
}
pub fn pinned(&self) -> &[bool] {
&self.pinned
}
pub fn is_pinned(&self, idx: usize) -> bool {
self.pinned.get(idx).copied().unwrap_or(false)
}
pub fn pin_last(&mut self) {
if let Some(last) = self.pinned.last_mut() {
*last = true;
}
}
pub fn pin_at(&mut self, idx: usize) {
assert!(
idx < self.messages.len(),
"pin_at: index {idx} out of bounds"
);
self.pinned[idx] = true;
}
pub fn unpin_at(&mut self, idx: usize) {
assert!(
idx < self.messages.len(),
"unpin_at: index {idx} out of bounds"
);
self.pinned[idx] = false;
}
pub fn pin_with_tool_result(&mut self, idx: usize) {
assert!(
idx < self.messages.len(),
"pin_with_tool_result: index {idx} out of bounds"
);
self.pinned[idx] = true;
let target_ids: Vec<String> = match &self.messages[idx] {
Message::Assistant { blocks } => blocks
.iter()
.filter_map(|b| match b {
AssistantBlock::ToolCall { id, .. } => Some(id.clone()),
_ => None,
})
.collect(),
Message::User { blocks } => blocks
.iter()
.filter_map(|b| match b {
UserBlock::ToolResult { call_id, .. } => Some(call_id.clone()),
_ => None,
})
.collect(),
_ => Vec::new(),
};
if target_ids.is_empty() {
return;
}
let is_assistant_target = matches!(self.messages[idx], Message::Assistant { .. });
for (i, msg) in self.messages.iter().enumerate() {
if i == idx || self.pinned[i] {
continue;
}
let matches = match (is_assistant_target, msg) {
(true, Message::User { blocks }) => blocks.iter().any(|b| matches!(b,
UserBlock::ToolResult { call_id, .. } if target_ids.iter().any(|t| t == call_id))),
(false, Message::Assistant { blocks }) => blocks.iter().any(|b| matches!(b,
AssistantBlock::ToolCall { id, .. } if target_ids.iter().any(|t| t == id))),
_ => false,
};
if matches {
self.pinned[i] = true;
}
}
}
pub async fn compact_if_needed(&mut self) -> Result<Option<CompactionReport>, CompactionError> {
if self.estimated_tokens() < self.max_tokens {
return Ok(None);
}
let before = self.messages.len();
let output = self
.strategy
.compact(&self.messages, &self.pinned, self.preserve_n_last)
.await?;
debug_assert_eq!(
output.messages.len(),
output.pinned.len(),
"strategy must return a pinned mask matching the message vector",
);
let after = output.messages.len();
let strategy = self.strategy.name();
self.messages = output.messages;
self.pinned = output.pinned;
Ok(Some(CompactionReport {
before,
after,
strategy,
}))
}
}
pub struct ContextManagerBuilder {
max_tokens: usize,
preserve_n_last: usize,
tokenizer: Box<dyn Tokenizer>,
strategy: Box<dyn CompactionStrategy>,
}
impl ContextManagerBuilder {
fn new(max_tokens: usize) -> Self {
Self {
max_tokens,
preserve_n_last: 4,
tokenizer: Box::new(CharTokenizer),
strategy: Box::new(TruncateStrategy),
}
}
}
impl ContextManagerBuilder {
pub fn preserve_n_last(mut self, n: usize) -> Self {
self.preserve_n_last = n;
self
}
pub fn tokenizer(self, tokenizer: Box<dyn Tokenizer>) -> ContextManagerBuilder {
ContextManagerBuilder {
max_tokens: self.max_tokens,
preserve_n_last: self.preserve_n_last,
tokenizer,
strategy: self.strategy,
}
}
pub fn strategy(self, strategy: Box<dyn CompactionStrategy>) -> ContextManagerBuilder {
ContextManagerBuilder {
max_tokens: self.max_tokens,
preserve_n_last: self.preserve_n_last,
tokenizer: self.tokenizer,
strategy,
}
}
pub fn build(self) -> ContextManager {
ContextManager {
messages: Vec::new(),
pinned: Vec::new(),
max_tokens: self.max_tokens,
preserve_n_last: self.preserve_n_last,
strategy: self.strategy,
tokenizer: self.tokenizer,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ailoop_core::{AssistantBlock, ToolResultContent, UserBlock};
use serde_json::json;
fn tool_call_msg(id: &str) -> Message {
Message::Assistant {
blocks: vec![AssistantBlock::tool_call(id, "t", json!({}))],
}
}
fn tool_result_msg(call_id: &str) -> Message {
Message::User {
blocks: vec![UserBlock::tool_result(
call_id,
ToolResultContent::text("ok"),
)],
}
}
#[tokio::test]
async fn compact_if_needed_returns_none_when_under_budget() {
let mut mgr = ContextManager::builder(10_000).build();
mgr.add_message(Message::user("hi"));
mgr.add_message(Message::assistant_text("hello"));
let report = mgr
.compact_if_needed()
.await
.expect("compaction should succeed");
assert!(report.is_none(), "no compaction expected when under budget");
}
#[tokio::test]
async fn compact_if_needed_returns_report_when_over_budget() {
let mut mgr = ContextManager::builder(10).preserve_n_last(2).build();
mgr.add_message(Message::user("first turn"));
mgr.add_message(Message::assistant_text("first reply"));
mgr.add_message(Message::user("second turn"));
mgr.add_message(Message::assistant_text("second reply"));
mgr.add_message(Message::user("third turn"));
let report = mgr
.compact_if_needed()
.await
.expect("compaction should succeed")
.expect("expected compaction to run");
assert_eq!(report.strategy, "truncate");
assert!(
report.after < report.before,
"compaction must drop messages"
);
}
#[tokio::test]
async fn pin_last_survives_compaction() {
let mut mgr = ContextManager::builder(10).preserve_n_last(2).build();
mgr.add_message(Message::user("pinned anchor"));
mgr.pin_last();
for i in 0..5 {
mgr.add_message(Message::user(format!("turn {i} q")));
mgr.add_message(Message::assistant_text(format!("turn {i} a")));
}
let report = mgr
.compact_if_needed()
.await
.expect("compaction should succeed")
.expect("expected compaction to run");
assert!(report.after < report.before);
let first = mgr.messages().first().expect("history should be non-empty");
match first {
Message::User { blocks } => match &blocks[0] {
UserBlock::Text { text, .. } => assert_eq!(text, "pinned anchor"),
other => panic!("expected pinned text block, got {other:?}"),
},
other => panic!("expected pinned user message, got {other:?}"),
}
assert!(
mgr.is_pinned(0),
"pinned mask must be preserved across compaction"
);
}
#[tokio::test]
async fn pin_with_tool_result_keeps_pair_intact() {
let mut mgr = ContextManager::builder(10).preserve_n_last(2).build();
mgr.add_message(Message::user("task"));
mgr.add_message(tool_call_msg("c1"));
mgr.add_message(tool_result_msg("c1"));
mgr.add_message(Message::assistant_text("result"));
mgr.pin_with_tool_result(1);
assert!(mgr.is_pinned(1));
assert!(mgr.is_pinned(2), "partner tool_result must be pinned too");
for i in 0..6 {
mgr.add_message(Message::user(format!("filler {i}")));
mgr.add_message(Message::assistant_text(format!("ack {i}")));
}
mgr.compact_if_needed()
.await
.expect("compaction should succeed")
.expect("expected compaction to run");
let mut saw_call = false;
let mut saw_result = false;
for msg in mgr.messages() {
match msg {
Message::Assistant { blocks } => {
if blocks
.iter()
.any(|b| matches!(b, AssistantBlock::ToolCall { id, .. } if id == "c1"))
{
saw_call = true;
}
}
Message::User { blocks } => {
if blocks.iter().any(
|b| matches!(b, UserBlock::ToolResult { call_id, .. } if call_id == "c1"),
) {
assert!(saw_call, "tool_result must follow its tool_call");
saw_result = true;
}
}
_ => {}
}
}
assert!(
saw_call && saw_result,
"pinned pair must survive compaction"
);
}
#[tokio::test]
async fn pin_with_tool_result_on_result_pins_the_call() {
let mut mgr = ContextManager::builder(10).preserve_n_last(1).build();
mgr.add_message(tool_call_msg("c1"));
mgr.add_message(tool_result_msg("c1"));
mgr.add_message(Message::user("later"));
mgr.pin_with_tool_result(1);
assert!(mgr.is_pinned(0), "tool_call partner must be pinned");
assert!(mgr.is_pinned(1));
}
#[tokio::test]
async fn compact_uses_tokenizer_budget_not_character_count() {
struct PerMessageTokenizer;
impl Tokenizer for PerMessageTokenizer {
fn count_text(&self, _text: &str) -> usize {
10
}
}
let mut mgr = ContextManager::builder(35)
.tokenizer(Box::new(PerMessageTokenizer))
.preserve_n_last(2)
.build();
for i in 0..5 {
mgr.add_message(Message::user(format!("q{i}")));
}
assert_eq!(mgr.estimated_tokens(), 50);
let report = mgr
.compact_if_needed()
.await
.expect("compaction should succeed")
.expect("over-budget history must compact");
assert_eq!(report.before, 5);
assert!(report.after < report.before);
assert_eq!(report.after, 2);
}
#[test]
fn from_messages_restores_history_and_pin_mask() {
let messages = vec![
Message::user("first"),
Message::assistant_text("ack"),
Message::user("second"),
];
let pinned = vec![true, false, true];
let mgr = ContextManager::from_messages(ContextManager::builder(10_000), messages, pinned)
.expect("equal lengths");
assert_eq!(mgr.messages().len(), 3);
assert!(mgr.is_pinned(0));
assert!(!mgr.is_pinned(1));
assert!(mgr.is_pinned(2));
}
#[test]
fn from_messages_rejects_length_mismatch() {
let result = ContextManager::from_messages(
ContextManager::builder(10_000),
vec![Message::user("solo")],
vec![],
);
match result {
Err(FromMessagesError::LengthMismatch {
messages: 1,
pinned: 0,
}) => {}
Ok(_) => panic!("length mismatch must error, not panic"),
Err(other) => panic!("unexpected error: {other:?}"),
}
}
}