use crate::llm::types::{
CompletionRequest, ContentBlock, Message, ReasoningEffort, Role, ToolDefinition, ToolResult,
};
use super::token_estimator::{estimate_message_tokens, estimate_tokens};
#[derive(Debug, Clone, PartialEq)]
pub enum ContextStrategy {
Unlimited,
SlidingWindow {
max_tokens: u32,
},
}
pub(crate) struct AgentContext {
system: String,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
max_turns: usize,
max_tokens: u32,
current_turn: usize,
context_strategy: ContextStrategy,
reasoning_effort: Option<ReasoningEffort>,
}
impl AgentContext {
pub(crate) fn new(
system: impl Into<String>,
task: impl Into<String>,
tools: Vec<ToolDefinition>,
) -> Self {
Self {
system: system.into(),
messages: vec![Message::user(task)],
tools,
max_turns: 10,
max_tokens: 4096,
current_turn: 0,
context_strategy: ContextStrategy::Unlimited,
reasoning_effort: None,
}
}
pub(crate) fn from_content(
system: impl Into<String>,
content: Vec<ContentBlock>,
tools: Vec<ToolDefinition>,
) -> Self {
Self {
system: system.into(),
messages: vec![Message {
role: Role::User,
content,
}],
tools,
max_turns: 10,
max_tokens: 4096,
current_turn: 0,
context_strategy: ContextStrategy::Unlimited,
reasoning_effort: None,
}
}
pub(crate) fn evict_media(&mut self) {
let last_user_idx = self.messages.iter().rposition(|m| m.role == Role::User);
for (i, msg) in self.messages.iter_mut().enumerate() {
if Some(i) == last_user_idx {
continue;
}
for block in &mut msg.content {
match block {
ContentBlock::Image { .. } => {
*block = ContentBlock::Text {
text: "[image previously sent]".into(),
};
}
ContentBlock::Audio { .. } => {
*block = ContentBlock::Text {
text: "[audio previously sent]".into(),
};
}
_ => {}
}
}
}
}
pub(crate) fn with_max_turns(mut self, max_turns: usize) -> Self {
self.max_turns = max_turns;
self
}
pub(crate) fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub(crate) fn with_context_strategy(mut self, strategy: ContextStrategy) -> Self {
self.context_strategy = strategy;
self
}
pub(crate) fn with_reasoning_effort(mut self, effort: Option<ReasoningEffort>) -> Self {
self.reasoning_effort = effort;
self
}
pub(crate) fn message_count(&self) -> usize {
self.messages.len()
}
pub(crate) fn current_turn(&self) -> usize {
self.current_turn
}
pub(crate) fn max_turns(&self) -> usize {
self.max_turns
}
pub(crate) fn increment_turn(&mut self) {
self.current_turn += 1;
}
pub(crate) fn add_assistant_message(&mut self, message: Message) {
self.messages.push(message);
}
pub(crate) fn add_user_message(&mut self, text: impl Into<String>) {
self.messages.push(Message::user(text));
}
pub(crate) fn add_tool_results(&mut self, results: Vec<ToolResult>) {
self.messages.push(Message::tool_results(results));
}
pub(crate) fn last_assistant_text(&self) -> Option<String> {
self.messages.iter().rev().find_map(|m| {
if m.role == Role::Assistant {
let text: String = m
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect();
Some(text)
} else {
None
}
})
}
pub(crate) fn total_tokens(&self) -> u32 {
self.messages
.iter()
.map(estimate_message_tokens)
.sum::<u32>()
+ estimate_tokens(&self.system)
}
pub(crate) fn needs_compaction(&self, max_tokens: u32) -> bool {
self.total_tokens() > max_tokens
}
pub(crate) fn inject_summary(&mut self, summary: String, keep_last_n: usize) {
let Some(first) = self.messages.first() else {
return;
};
let original_task: String = first
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect();
inject_summary_into_messages(&mut self.messages, &original_task, &summary, keep_last_n);
}
pub(crate) fn conversation_text(&self) -> String {
messages_to_text(&self.messages)
}
pub(crate) fn messages_to_be_compacted(&self, keep_last_n: usize) -> &[Message] {
if self.messages.len() <= 1 + keep_last_n {
return &[];
}
let tail_start = self.messages.len().saturating_sub(keep_last_n);
if tail_start <= 1 {
return &[];
}
&self.messages[1..tail_start]
}
pub(crate) fn to_request(&self) -> CompletionRequest {
let messages = match &self.context_strategy {
ContextStrategy::Unlimited => self.messages.clone(),
ContextStrategy::SlidingWindow { max_tokens } => {
apply_sliding_window(&self.messages, *max_tokens)
}
};
CompletionRequest {
system: self.system.clone(),
messages,
tools: self.tools.clone(),
max_tokens: self.max_tokens,
tool_choice: None,
reasoning_effort: self.reasoning_effort,
}
}
}
pub fn inject_summary_into_messages(
messages: &mut Vec<Message>,
original_task: &str,
summary: &str,
keep_last_n: usize,
) {
if messages.is_empty() {
return;
}
let total = messages.len();
if total <= 1 + keep_last_n {
return;
}
let combined = Message::user(format!(
"{original_task}\n\n[Previous conversation summary]\n{summary}"
));
let mut tail_start = total.saturating_sub(keep_last_n);
while tail_start < total && messages[tail_start].role == Role::User && tail_start > 1 {
tail_start -= 1;
}
let last_messages: Vec<Message> = messages[tail_start..].to_vec();
messages.clear();
messages.push(combined);
messages.extend(last_messages);
}
pub fn messages_to_text(messages: &[Message]) -> String {
let mut parts = Vec::with_capacity(messages.len());
for msg in messages {
let role = match msg.role {
Role::User => "User",
Role::Assistant => "Assistant",
};
let text: String = msg
.content
.iter()
.map(|b| match b {
ContentBlock::Text { text } => text.as_str().into(),
ContentBlock::ToolUse { name, input, .. } => {
format!("[Tool call: {name}({input})]")
}
ContentBlock::ToolResult { content, .. } => {
format!("[Tool result: {content}]")
}
ContentBlock::Image { media_type, .. } => {
format!("[Image: {media_type}]")
}
ContentBlock::Audio { format, .. } => {
format!("[Audio: {format}]")
}
})
.collect::<Vec<String>>()
.join(" ");
parts.push(format!("{role}: {text}"));
}
parts.join("\n")
}
pub fn apply_sliding_window(messages: &[Message], max_tokens: u32) -> Vec<Message> {
if messages.len() <= 1 {
return messages.to_vec();
}
let first = &messages[0];
let first_tokens = estimate_message_tokens(first);
if first_tokens >= max_tokens {
return vec![first.clone()];
}
let mut budget = max_tokens - first_tokens;
let tail = &messages[1..];
let mut included_from = tail.len();
let mut i = tail.len();
while i > 0 {
i -= 1;
let msg = &tail[i];
let msg_tokens = estimate_message_tokens(msg);
let is_tool_result = msg.role == Role::User
&& msg
.content
.iter()
.any(|b| matches!(b, ContentBlock::ToolResult { .. }));
if is_tool_result && i > 0 {
let prev = &tail[i - 1];
let prev_tokens = estimate_message_tokens(prev);
let pair_tokens = msg_tokens + prev_tokens;
if pair_tokens <= budget {
budget -= pair_tokens;
i -= 1;
included_from = i;
} else {
break;
}
} else if msg_tokens <= budget {
budget -= msg_tokens;
included_from = i;
} else {
break;
}
}
let mut result = vec![first.clone()];
result.extend_from_slice(&tail[included_from..]);
result
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn new_context_has_user_message() {
let ctx = AgentContext::new("system", "do something", vec![]);
let req = ctx.to_request();
assert_eq!(req.system, "system");
assert_eq!(req.messages.len(), 1);
assert_eq!(req.messages[0].role, Role::User);
}
#[test]
fn with_max_turns_overrides_default() {
let ctx = AgentContext::new("sys", "task", vec![]).with_max_turns(5);
assert_eq!(ctx.max_turns(), 5);
}
#[test]
fn with_max_tokens_overrides_default() {
let ctx = AgentContext::new("sys", "task", vec![]).with_max_tokens(8192);
let req = ctx.to_request();
assert_eq!(req.max_tokens, 8192);
}
#[test]
fn default_max_tokens_is_4096() {
let ctx = AgentContext::new("sys", "task", vec![]);
let req = ctx.to_request();
assert_eq!(req.max_tokens, 4096);
}
#[test]
fn turn_tracking() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
assert_eq!(ctx.current_turn(), 0);
ctx.increment_turn();
assert_eq!(ctx.current_turn(), 1);
}
#[test]
fn add_user_message_creates_user_message() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_user_message("follow up question");
let req = ctx.to_request();
assert_eq!(req.messages.len(), 2); assert_eq!(req.messages[1].role, Role::User);
}
#[test]
fn add_tool_results_creates_user_message() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_tool_results(vec![ToolResult::success("call-1", "result")]);
let req = ctx.to_request();
assert_eq!(req.messages.len(), 2);
assert_eq!(req.messages[1].role, Role::User);
}
#[test]
fn request_includes_tools() {
let tools = vec![ToolDefinition {
name: "search".into(),
description: "Search".into(),
input_schema: json!({"type": "object"}),
}];
let ctx = AgentContext::new("sys", "task", tools);
let req = ctx.to_request();
assert_eq!(req.tools.len(), 1);
assert_eq!(req.tools[0].name, "search");
}
#[test]
fn default_is_unlimited() {
let ctx = AgentContext::new("sys", "task", vec![]);
assert!(matches!(ctx.context_strategy, ContextStrategy::Unlimited));
}
#[test]
fn unlimited_passes_all() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("response 1"));
ctx.add_assistant_message(Message::assistant("response 2"));
ctx.add_assistant_message(Message::assistant("response 3"));
let req = ctx.to_request();
assert_eq!(req.messages.len(), 4); }
#[test]
fn sliding_window_preserves_first() {
let mut ctx = AgentContext::new("sys", "initial task", vec![])
.with_context_strategy(ContextStrategy::SlidingWindow { max_tokens: 20 });
ctx.add_assistant_message(Message::assistant("a".repeat(100)));
ctx.add_assistant_message(Message::assistant("recent"));
let req = ctx.to_request();
assert_eq!(req.messages[0].role, Role::User);
assert!(
req.messages[0]
.content
.iter()
.any(|b| matches!(b, ContentBlock::Text { text } if text == "initial task"))
);
}
#[test]
fn sliding_window_trims_old() {
let mut ctx = AgentContext::new("sys", "task", vec![])
.with_context_strategy(ContextStrategy::SlidingWindow { max_tokens: 50 });
for i in 0..10 {
ctx.add_assistant_message(Message::assistant(format!("response {i} with some text")));
}
let req = ctx.to_request();
assert!(req.messages.len() < 11);
assert_eq!(req.messages[0].role, Role::User);
}
#[test]
fn sliding_window_keeps_tool_pairs() {
let mut ctx = AgentContext::new("sys", "task", vec![])
.with_context_strategy(ContextStrategy::SlidingWindow { max_tokens: 200 });
ctx.add_assistant_message(Message {
role: Role::Assistant,
content: vec![ContentBlock::ToolUse {
id: "c1".into(),
name: "search".into(),
input: json!({"q": "test"}),
}],
});
ctx.add_tool_results(vec![ToolResult::success("c1", "found it")]);
ctx.add_assistant_message(Message::assistant("Based on the search results..."));
let req = ctx.to_request();
let has_tool_use = req.messages.iter().any(|m| {
m.content
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }))
});
let has_tool_result = req.messages.iter().any(|m| {
m.content
.iter()
.any(|b| matches!(b, ContentBlock::ToolResult { .. }))
});
assert_eq!(
has_tool_use, has_tool_result,
"tool_use and tool_result must be kept together"
);
}
#[test]
fn sliding_window_single_message() {
let ctx = AgentContext::new("sys", "task", vec![])
.with_context_strategy(ContextStrategy::SlidingWindow { max_tokens: 10 });
let req = ctx.to_request();
assert_eq!(req.messages.len(), 1);
}
#[test]
fn needs_compaction_below_threshold() {
let ctx = AgentContext::new("sys", "task", vec![]);
assert!(!ctx.needs_compaction(10000));
}
#[test]
fn needs_compaction_above_threshold() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
for _ in 0..50 {
ctx.add_assistant_message(Message::assistant("a".repeat(200)));
}
assert!(ctx.needs_compaction(100));
}
#[test]
fn inject_summary_replaces_middle() {
let mut ctx = AgentContext::new("sys", "initial task", vec![]);
ctx.add_assistant_message(Message::assistant("msg 1"));
ctx.add_assistant_message(Message::assistant("msg 2"));
ctx.add_assistant_message(Message::assistant("msg 3"));
ctx.add_assistant_message(Message::assistant("msg 4"));
ctx.add_assistant_message(Message::assistant("msg 5"));
ctx.inject_summary("summary of earlier conversation".into(), 2);
assert_eq!(ctx.messages.len(), 3);
let first_text: String = ctx.messages[0]
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
assert!(first_text.contains("initial task"));
assert!(first_text.contains("summary of earlier"));
}
#[test]
fn inject_summary_preserves_first_and_last() {
let mut ctx = AgentContext::new("sys", "first task", vec![]);
ctx.add_assistant_message(Message::assistant("old 1"));
ctx.add_assistant_message(Message::assistant("old 2"));
ctx.add_assistant_message(Message::assistant("recent 1"));
ctx.add_assistant_message(Message::assistant("recent 2"));
ctx.add_assistant_message(Message::assistant("recent 3"));
ctx.inject_summary("compressed".into(), 3);
assert_eq!(ctx.messages.len(), 4);
assert!(
ctx.messages[3]
.content
.iter()
.any(|b| matches!(b, ContentBlock::Text { text } if text == "recent 3"))
);
}
#[test]
fn inject_summary_noop_few_messages() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("only one"));
ctx.inject_summary("summary".into(), 4);
assert_eq!(ctx.messages.len(), 2);
}
#[test]
fn inject_summary_maintains_alternating_roles() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("a1"));
ctx.add_assistant_message(Message::assistant("a2"));
ctx.add_assistant_message(Message::assistant("a3"));
ctx.add_assistant_message(Message::assistant("a4"));
ctx.inject_summary("summary".into(), 2);
assert_eq!(ctx.messages[0].role, Role::User);
assert_eq!(ctx.messages[1].role, Role::Assistant);
}
#[test]
fn inject_summary_adjusts_tail_when_starting_with_user() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("a1"));
ctx.add_tool_results(vec![ToolResult::success("c1", "result1")]);
ctx.add_assistant_message(Message::assistant("a2"));
ctx.add_tool_results(vec![ToolResult::success("c2", "result2")]);
ctx.add_assistant_message(Message::assistant("a3"));
ctx.inject_summary("summary".into(), 2);
assert_eq!(ctx.messages[0].role, Role::User);
assert_eq!(ctx.messages[1].role, Role::Assistant);
for w in ctx.messages.windows(2) {
assert_ne!(w[0].role, w[1].role, "adjacent messages have same role");
}
}
#[test]
fn total_tokens_grows_with_messages() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
let initial = ctx.total_tokens();
ctx.add_assistant_message(Message::assistant("a".repeat(100)));
assert!(ctx.total_tokens() > initial);
}
#[test]
fn shared_inject_summary_preserves_alternation() {
let mut messages = vec![
Message::user("original task"),
Message::assistant("a1"),
Message::tool_results(vec![ToolResult::success("c1", "result1")]),
Message::assistant("a2"),
Message::tool_results(vec![ToolResult::success("c2", "result2")]),
Message::assistant("a3"),
];
inject_summary_into_messages(&mut messages, "original task", "summary of conversation", 2);
assert_eq!(messages[0].role, Role::User);
assert_eq!(messages[1].role, Role::Assistant);
for w in messages.windows(2) {
assert_ne!(w[0].role, w[1].role, "adjacent messages have same role");
}
let first_text: String = messages[0]
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
assert!(first_text.contains("original task"));
assert!(first_text.contains("summary of conversation"));
}
#[test]
fn inject_summary_tail_start_near_beginning() {
let mut messages = vec![
Message::user("original task"),
Message::assistant("first response"),
Message::assistant("second response"),
Message::assistant("third response"),
];
inject_summary_into_messages(&mut messages, "original task", "summary", 2);
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].role, Role::User);
assert_eq!(messages[1].role, Role::Assistant);
let first_text: String = messages[0]
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
assert!(first_text.contains("original task"));
assert!(first_text.contains("summary"));
}
#[test]
fn from_content_creates_multimodal_message() {
let content = vec![
ContentBlock::Text {
text: "describe this".into(),
},
ContentBlock::Image {
media_type: "image/jpeg".into(),
data: "base64data".into(),
},
];
let ctx = AgentContext::from_content("system", content, vec![]);
let req = ctx.to_request();
assert_eq!(req.messages.len(), 1);
assert_eq!(req.messages[0].role, Role::User);
assert_eq!(req.messages[0].content.len(), 2);
assert!(matches!(
&req.messages[0].content[1],
ContentBlock::Image { .. }
));
}
#[test]
fn evict_media_replaces_old_images_with_placeholder() {
let mut ctx = AgentContext::from_content(
"sys",
vec![
ContentBlock::Text {
text: "describe this".into(),
},
ContentBlock::Image {
media_type: "image/jpeg".into(),
data: "data1".into(),
},
],
vec![],
);
ctx.add_assistant_message(Message::assistant("It shows a cat."));
ctx.messages.push(Message {
role: Role::User,
content: vec![ContentBlock::Image {
media_type: "image/png".into(),
data: "data2".into(),
}],
});
ctx.evict_media();
assert_eq!(
ctx.messages[0].content[1],
ContentBlock::Text {
text: "[image previously sent]".into()
}
);
assert!(matches!(
&ctx.messages[2].content[0],
ContentBlock::Image { media_type, .. } if media_type == "image/png"
));
}
#[test]
fn evict_media_replaces_old_audio_with_placeholder() {
let mut ctx = AgentContext::from_content(
"sys",
vec![
ContentBlock::Text {
text: "listen to this".into(),
},
ContentBlock::Audio {
format: "ogg".into(),
data: "audiodata1".into(),
},
],
vec![],
);
ctx.add_assistant_message(Message::assistant("I heard it."));
ctx.messages.push(Message {
role: Role::User,
content: vec![ContentBlock::Audio {
format: "mp3".into(),
data: "audiodata2".into(),
}],
});
ctx.evict_media();
assert_eq!(
ctx.messages[0].content[1],
ContentBlock::Text {
text: "[audio previously sent]".into()
}
);
assert!(matches!(
&ctx.messages[2].content[0],
ContentBlock::Audio { format, .. } if format == "mp3"
));
}
#[test]
fn evict_media_noop_when_no_media() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("reply"));
let msg_count = ctx.message_count();
ctx.evict_media();
assert_eq!(ctx.message_count(), msg_count);
}
#[test]
fn inject_summary_empty_messages_is_noop() {
let mut messages = vec![];
inject_summary_into_messages(&mut messages, "task", "summary", 2);
assert!(messages.is_empty());
}
#[test]
fn inject_summary_while_loop_steps_back_to_assistant() {
let mut messages = vec![
Message::user("original task"),
Message::assistant("a1"),
Message::tool_results(vec![ToolResult::success("c1", "r1")]),
Message::assistant("a2"),
Message::tool_results(vec![ToolResult::success("c2", "r2")]),
Message::assistant("a3"),
Message::tool_results(vec![ToolResult::success("c3", "r3")]),
Message::assistant("a4"),
];
inject_summary_into_messages(&mut messages, "original task", "summary", 2);
assert_eq!(messages[0].role, Role::User);
assert_eq!(messages[1].role, Role::Assistant);
for w in messages.windows(2) {
assert_ne!(w[0].role, w[1].role, "adjacent messages have same role");
}
}
#[test]
fn messages_to_be_compacted_returns_middle() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("a1"));
ctx.add_assistant_message(Message::assistant("a2"));
ctx.add_assistant_message(Message::assistant("a3"));
ctx.add_assistant_message(Message::assistant("a4"));
let compacted = ctx.messages_to_be_compacted(2);
assert_eq!(compacted.len(), 2);
}
#[test]
fn messages_to_be_compacted_empty_when_few_messages() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("a1"));
let compacted = ctx.messages_to_be_compacted(2);
assert!(compacted.is_empty());
}
#[test]
fn messages_to_be_compacted_excludes_first_and_last() {
let mut ctx = AgentContext::new("sys", "task", vec![]);
ctx.add_assistant_message(Message::assistant("old1"));
ctx.add_assistant_message(Message::assistant("old2"));
ctx.add_assistant_message(Message::assistant("recent1"));
ctx.add_assistant_message(Message::assistant("recent2"));
let compacted = ctx.messages_to_be_compacted(2);
for msg in compacted {
let text: String = msg
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect();
assert!(
text.starts_with("old"),
"compacted messages should be old ones, got: {text}"
);
}
}
}