use crate::error::{CoreError, CoreResult};
use crate::messages::{Message, Role};
use crate::template::ChatTemplate;
use crate::tools::ToolSchema;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PruneStrategy {
SlidingWindow,
Summarize,
}
#[derive(Debug, Clone)]
pub struct ContextConfig {
pub max_context_tokens: u32,
pub max_response_tokens: u32,
pub prune_strategy: PruneStrategy,
}
impl Default for ContextConfig {
fn default() -> Self {
Self {
max_context_tokens: 4096,
max_response_tokens: 2048,
prune_strategy: PruneStrategy::SlidingWindow,
}
}
}
#[derive(Debug, Clone)]
pub struct PreparedContext {
pub prompt: String,
pub token_count: u32,
pub messages_included: u32,
pub messages_pruned: u32,
}
#[derive(Debug, Clone)]
pub struct PrunePlan {
pub kept: Vec<std::ops::Range<usize>>,
pub dropped: Vec<std::ops::Range<usize>>,
}
fn group_into_turns(messages: &[Message]) -> Vec<std::ops::Range<usize>> {
let mut turns = Vec::new();
let mut turn_start: Option<usize> = None;
for (i, msg) in messages.iter().enumerate() {
if msg.role == Role::User {
if let Some(start) = turn_start {
turns.push(start..i);
}
turn_start = Some(i);
}
}
if let Some(start) = turn_start {
turns.push(start..messages.len());
}
turns
}
pub fn plan_prune(
template: &dyn ChatTemplate,
system_prompt: &str,
messages: &[Message],
tools: &[ToolSchema],
config: &ContextConfig,
token_counter: &dyn Fn(&str) -> u32,
) -> CoreResult<PrunePlan> {
let available = config
.max_context_tokens
.saturating_sub(config.max_response_tokens);
let system_block = template.format_system(system_prompt, tools);
let fixed_overhead = token_counter(&system_block) + token_counter(template.assistant_prefix());
if fixed_overhead >= available {
return Err(CoreError::Context(format!(
"System prompt and tools ({fixed_overhead} tokens) exceed \
available context budget ({available} tokens)"
)));
}
let mut budget = available - fixed_overhead;
let turns = group_into_turns(messages);
if turns.is_empty() {
return Ok(PrunePlan {
kept: vec![],
dropped: vec![],
});
}
let turn_costs: Vec<u32> = turns
.iter()
.map(|range| {
messages[range.clone()]
.iter()
.map(|msg| token_counter(&template.format_message(msg)))
.sum()
})
.collect();
let turn_pinned: Vec<bool> = turns
.iter()
.map(|range| messages[range.clone()].iter().any(|m| m.pinned))
.collect();
let last = turns.len() - 1;
let mut keep = vec![false; turns.len()];
if turn_costs[last] > budget {
return Err(CoreError::Context(format!(
"Latest message ({} tokens) plus system prompt \
({fixed_overhead} tokens) exceeds context budget ({available} tokens). \
Clear the conversation or increase context size.",
turn_costs[last]
)));
}
budget -= turn_costs[last];
keep[last] = true;
for i in 0..last {
if turn_pinned[i] {
if turn_costs[i] > budget {
let pinned_total: u32 = (0..turns.len())
.filter(|&j| turn_pinned[j])
.map(|j| turn_costs[j])
.sum();
return Err(CoreError::Context(format!(
"Pinned messages ({pinned_total} tokens) exceed the available \
context budget ({available} tokens). Unpin some messages or \
increase context size."
)));
}
budget -= turn_costs[i];
keep[i] = true;
}
}
for i in (0..last).rev() {
if keep[i] {
continue;
}
if turn_costs[i] <= budget {
budget -= turn_costs[i];
keep[i] = true;
} else {
break;
}
}
let mut kept = Vec::new();
let mut dropped = Vec::new();
for (i, range) in turns.iter().enumerate() {
if keep[i] {
kept.push(range.clone());
} else {
dropped.push(range.clone());
}
}
Ok(PrunePlan { kept, dropped })
}
pub fn prepare_context(
template: &dyn ChatTemplate,
system_prompt: &str,
messages: &[Message],
tools: &[ToolSchema],
config: &ContextConfig,
token_counter: &dyn Fn(&str) -> u32,
) -> CoreResult<PreparedContext> {
let plan = plan_prune(
template,
system_prompt,
messages,
tools,
config,
token_counter,
)?;
let kept: Vec<Message> = plan
.kept
.iter()
.flat_map(|range| messages[range.clone()].iter().cloned())
.collect();
let kept_count = kept.len() as u32;
let pruned = messages.len() as u32 - kept_count;
let prompt = template.format(system_prompt, &kept, tools);
let token_count = token_counter(&prompt);
Ok(PreparedContext {
prompt,
token_count,
messages_included: kept_count,
messages_pruned: pruned,
})
}