use std::sync::Arc;
use crate::engine::token_estimate::TokenEstimator;
use crate::session::Session;
use crate::working_set::WorkingSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ContextLayer {
StaticPrefix = 0,
SemiStatic = 1,
Volatile = 2,
}
#[derive(Debug, Clone, Copy)]
pub enum BudgetPolicy {
Fixed(u32),
Fraction(f32),
Elastic { min: u32, max: u32 },
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SourceId(pub &'static str);
impl std::fmt::Display for SourceId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.0)
}
}
#[derive(Debug, Clone)]
pub struct RenderedBlock {
pub text: String,
pub token_count: u32,
pub layer_override: Option<ContextLayer>,
}
impl RenderedBlock {
#[must_use]
pub fn new(text: impl Into<String>) -> Self {
let text = text.into();
let token_count = crate::engine::token_estimate::estimate_text_tokens(&text) as u32;
Self {
text,
token_count,
layer_override: None,
}
}
#[must_use]
pub fn with_tokens(text: impl Into<String>, token_count: u32) -> Self {
Self {
text: text.into(),
token_count,
layer_override: None,
}
}
#[must_use]
pub fn placeholder(token_count: u32) -> Self {
Self {
text: String::new(),
token_count,
layer_override: None,
}
}
}
#[derive(Debug, Clone)]
pub struct SourceContribution {
pub source_id: SourceId,
pub token_count: u32,
pub was_truncated: bool,
}
pub struct ContextProjection<'a> {
pub session: &'a Session,
pub working_set: &'a WorkingSet,
pub step_idx: u32,
pub has_compaction_summary: bool,
pub cycle_briefing_count: usize,
}
impl<'a> ContextProjection<'a> {
#[must_use]
pub fn from_session(session: &'a Session, step_idx: u32) -> Self {
Self {
has_compaction_summary: session.compaction_summary_prompt.is_some(),
cycle_briefing_count: session.cycle_briefings.len(),
working_set: &session.working_set,
session,
step_idx,
}
}
}
pub type RenderFn = Arc<dyn Fn(&ContextProjection<'_>) -> Vec<RenderedBlock> + Send + Sync>;
pub struct ContextSource {
pub id: SourceId,
pub layer: ContextLayer,
pub priority: u8,
pub budget: BudgetPolicy,
pub render: RenderFn,
}
impl std::fmt::Debug for ContextSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContextSource")
.field("id", &self.id)
.field("layer", &self.layer)
.field("priority", &self.priority)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone, Default)]
pub struct CompiledContext {
pub blocks: Vec<RenderedBlock>,
pub contributions: Vec<SourceContribution>,
pub total_tokens: u32,
pub any_truncated: bool,
pub overflow_recovered: bool,
}
#[derive(Debug, Clone)]
pub struct BudgetOverride {
pub source_id: SourceId,
pub new_budget: BudgetPolicy,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompileError {
Overflow {
total_tokens: u32,
budget: u32,
},
EmptySources,
}
impl std::fmt::Display for CompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Overflow {
total_tokens,
budget,
} => write!(
f,
"context overflow: {total_tokens} tokens > {budget} token budget after eviction"
),
Self::EmptySources => write!(f, "no context sources registered"),
}
}
}
impl std::error::Error for CompileError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ContextCompilerMode {
#[default]
V2,
}
impl ContextCompilerMode {
#[must_use]
pub fn parse(_value: Option<&str>) -> Self {
Self::V2
}
#[must_use]
pub fn as_str(self) -> &'static str {
"v2"
}
}
#[derive(Debug, Default)]
pub struct ContextCompiler {
sources: Vec<ContextSource>,
pub token_estimator: TokenEstimator,
}
impl ContextCompiler {
#[must_use]
pub fn new() -> Self {
Self {
sources: Vec::new(),
token_estimator: TokenEstimator,
}
}
#[must_use]
pub fn register(mut self, source: ContextSource) -> Self {
self.sources.push(source);
self
}
#[must_use]
pub fn source_count(&self) -> usize {
self.sources.len()
}
#[must_use]
pub fn render_all<'p>(
&self,
projection: &ContextProjection<'p>,
) -> Vec<(&SourceId, Vec<RenderedBlock>)> {
let mut sorted: Vec<&ContextSource> = self.sources.iter().collect();
sorted.sort_unstable_by(|a, b| a.layer.cmp(&b.layer).then(b.priority.cmp(&a.priority)));
sorted
.iter()
.map(|src| (&src.id, (src.render)(projection)))
.collect()
}
#[must_use]
pub fn compile<'p>(&self, projection: &ContextProjection<'p>) -> CompiledContext {
let mut sorted: Vec<&ContextSource> = self.sources.iter().collect();
sorted.sort_unstable_by(|a, b| a.layer.cmp(&b.layer).then(b.priority.cmp(&a.priority)));
let blocks_per_source: Vec<Vec<RenderedBlock>> =
sorted.iter().map(|s| (s.render)(projection)).collect();
let source_tokens: Vec<u32> = blocks_per_source
.iter()
.map(|blocks| {
blocks
.iter()
.map(|b| {
if b.text.is_empty() {
b.token_count
} else {
self.token_estimator.estimate_text(&b.text) as u32
}
})
.sum()
})
.collect();
self.assemble_compiled(&sorted, &blocks_per_source, &source_tokens, &[], false)
}
pub fn compile_with_budget_override<'p>(
&self,
projection: &ContextProjection<'p>,
budget: u32,
overrides: &[BudgetOverride],
) -> Result<CompiledContext, CompileError> {
if self.sources.is_empty() {
return Err(CompileError::EmptySources);
}
let mut sorted: Vec<&ContextSource> = self.sources.iter().collect();
sorted.sort_unstable_by(|a, b| a.layer.cmp(&b.layer).then(b.priority.cmp(&a.priority)));
let blocks_per_source: Vec<Vec<RenderedBlock>> =
sorted.iter().map(|s| (s.render)(projection)).collect();
let source_tokens: Vec<u32> = blocks_per_source
.iter()
.map(|blocks| {
blocks
.iter()
.map(|b| {
if b.text.is_empty() {
b.token_count
} else {
self.token_estimator.estimate_text(&b.text) as u32
}
})
.sum()
})
.collect();
let total: u32 = source_tokens.iter().sum();
if total <= budget {
return Ok(self.assemble_compiled(
&sorted,
&blocks_per_source,
&source_tokens,
&[true; 0], false,
));
}
let mut enabled: Vec<bool> = vec![true; sorted.len()];
let mut effective_tokens: Vec<u32> = source_tokens.clone();
let mut remaining: u32 = total;
let mut overflow_recovered = false;
for i in (0..sorted.len()).rev() {
if remaining <= budget {
break;
}
if sorted[i].layer == ContextLayer::Volatile && enabled[i] {
remaining = remaining.saturating_sub(effective_tokens[i]);
enabled[i] = false;
overflow_recovered = true;
}
}
for i in (0..sorted.len()).rev() {
if remaining <= budget {
break;
}
if sorted[i].layer == ContextLayer::SemiStatic && enabled[i] {
let budget_policy = overrides
.iter()
.find(|o| o.source_id == sorted[i].id)
.map(|o| o.new_budget)
.unwrap_or(sorted[i].budget);
if let BudgetPolicy::Elastic { min, .. } = budget_policy
&& effective_tokens[i] > min
{
let freed = effective_tokens[i].saturating_sub(min);
remaining = remaining.saturating_sub(freed);
effective_tokens[i] = min;
overflow_recovered = true;
}
}
}
if remaining > budget {
return Err(CompileError::Overflow {
total_tokens: remaining,
budget,
});
}
Ok(self.assemble_compiled(
&sorted,
&blocks_per_source,
&effective_tokens,
&enabled,
overflow_recovered,
))
}
fn assemble_compiled(
&self,
sorted: &[&ContextSource],
blocks_per_source: &[Vec<RenderedBlock>],
source_tokens: &[u32],
enabled: &[bool],
overflow_recovered: bool,
) -> CompiledContext {
let mut out = CompiledContext {
overflow_recovered,
..Default::default()
};
let all_enabled = enabled.is_empty() || enabled.len() != sorted.len();
for (idx, src) in sorted.iter().enumerate() {
let is_enabled = all_enabled || enabled[idx];
if !is_enabled {
continue;
}
let tok = source_tokens[idx];
out.total_tokens = out.total_tokens.saturating_add(tok);
out.contributions.push(SourceContribution {
source_id: src.id.clone(),
token_count: tok,
was_truncated: false,
});
out.blocks.extend_from_slice(&blocks_per_source[idx]);
}
out
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::*;
fn dummy_source(id: &'static str, layer: ContextLayer, priority: u8) -> ContextSource {
ContextSource {
id: SourceId(id),
layer,
priority,
budget: BudgetPolicy::Elastic { min: 0, max: 4096 },
render: Arc::new(move |_| vec![RenderedBlock::new(format!("block:{id}"))]),
}
}
fn test_session() -> crate::session::Session {
crate::session::Session::new(
"test-model".into(),
PathBuf::from("/tmp"),
false,
false,
PathBuf::from("/tmp/notes.txt"),
PathBuf::from("/tmp/mcp.json"),
)
}
#[test]
fn compiler_render_order_is_layer_then_priority_desc() {
let compiler = ContextCompiler::new()
.register(dummy_source("volatile.low", ContextLayer::Volatile, 10))
.register(dummy_source("static.high", ContextLayer::StaticPrefix, 255))
.register(dummy_source("semi.mid", ContextLayer::SemiStatic, 128))
.register(dummy_source("static.low", ContextLayer::StaticPrefix, 10));
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let rendered = compiler.render_all(&proj);
let ids: Vec<&str> = rendered.iter().map(|(id, _)| id.0).collect();
assert_eq!(
ids,
["static.high", "static.low", "semi.mid", "volatile.low"]
);
}
#[test]
fn compiled_context_aggregates_token_counts() {
let compiler = ContextCompiler::new()
.register(dummy_source("a", ContextLayer::StaticPrefix, 100))
.register(dummy_source("b", ContextLayer::Volatile, 50));
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let ctx = compiler.compile(&proj);
assert_eq!(ctx.contributions.len(), 2);
assert_eq!(
ctx.total_tokens,
ctx.contributions.iter().map(|c| c.token_count).sum::<u32>()
);
}
#[test]
fn context_compiler_mode_parse_all_map_to_v2() {
for input in [
Some("v2"),
Some("legacy"),
Some("shadow"),
None,
Some("unknown"),
] {
assert_eq!(
ContextCompilerMode::parse(input),
ContextCompilerMode::V2,
"input={input:?}"
);
}
assert_eq!(ContextCompilerMode::V2.as_str(), "v2");
}
fn budget_source(
id: &'static str,
layer: ContextLayer,
priority: u8,
text: &'static str,
budget: BudgetPolicy,
) -> ContextSource {
ContextSource {
id: SourceId(id),
layer,
priority,
budget,
render: Arc::new(move |_| vec![RenderedBlock::new(text)]),
}
}
#[test]
fn budget_solve_no_eviction_when_under_budget() {
let compiler = ContextCompiler::new()
.register(budget_source(
"static",
ContextLayer::StaticPrefix,
255,
"system prompt text",
BudgetPolicy::Fixed(200),
))
.register(budget_source(
"volatile",
ContextLayer::Volatile,
100,
"turn meta",
BudgetPolicy::Elastic { min: 0, max: 500 },
));
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let result = compiler.compile_with_budget_override(&proj, 100_000, &[]);
assert!(result.is_ok(), "should succeed with huge budget");
let ctx = result.unwrap();
assert!(!ctx.overflow_recovered, "no eviction expected");
assert_eq!(ctx.contributions.len(), 2);
}
#[test]
fn budget_solve_evicts_volatile_before_semistatic() {
let compiler = ContextCompiler::new()
.register(budget_source(
"static",
ContextLayer::StaticPrefix,
255,
"system",
BudgetPolicy::Fixed(100),
))
.register(budget_source(
"semi",
ContextLayer::SemiStatic,
180,
"compaction summary",
BudgetPolicy::Elastic { min: 0, max: 500 },
))
.register(budget_source(
"volatile.hi",
ContextLayer::Volatile,
160,
"turn meta",
BudgetPolicy::Elastic { min: 0, max: 500 },
))
.register(budget_source(
"volatile.lo",
ContextLayer::Volatile,
100,
"steer text",
BudgetPolicy::Elastic { min: 0, max: 500 },
));
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let unconstrained = compiler.compile(&proj);
let total = unconstrained.total_tokens;
let lo_tokens = unconstrained
.contributions
.iter()
.find(|c| c.source_id.0 == "volatile.lo")
.map(|c| c.token_count)
.unwrap_or(0);
let budget = total.saturating_sub(lo_tokens).saturating_sub(1);
let result = compiler.compile_with_budget_override(&proj, budget, &[]);
assert!(result.is_ok(), "should succeed by evicting volatile.lo");
let ctx = result.unwrap();
assert!(
ctx.overflow_recovered,
"eviction should set overflow_recovered"
);
assert!(
ctx.contributions
.iter()
.all(|c| c.source_id.0 != "volatile.lo"),
"volatile.lo should be evicted"
);
assert!(ctx.contributions.iter().any(|c| c.source_id.0 == "static"));
assert!(ctx.contributions.iter().any(|c| c.source_id.0 == "semi"));
}
#[test]
fn budget_solve_returns_overflow_when_fixed_sources_exceed_budget() {
let compiler = ContextCompiler::new().register(budget_source(
"static",
ContextLayer::StaticPrefix,
255,
"this is a fixed system prompt that cannot be evicted",
BudgetPolicy::Fixed(10_000),
));
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let result = compiler.compile_with_budget_override(&proj, 1, &[]);
assert!(matches!(result, Err(CompileError::Overflow { .. })));
}
#[test]
fn budget_solve_empty_compiler_returns_error() {
let compiler = ContextCompiler::new();
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let result = compiler.compile_with_budget_override(&proj, 1000, &[]);
assert!(
matches!(result, Err(CompileError::EmptySources)),
"empty compiler should return EmptySources"
);
}
#[test]
fn budget_solve_latency_under_10ms() {
use std::time::Instant;
let mut compiler = ContextCompiler::new();
for i in 0..8u8 {
let layer = match i % 3 {
0 => ContextLayer::StaticPrefix,
1 => ContextLayer::SemiStatic,
_ => ContextLayer::Volatile,
};
compiler = compiler.register(budget_source(
Box::leak(format!("source-{i}").into_boxed_str()),
layer,
255u8.saturating_sub(i * 10),
"representative content block with typical size payload",
BudgetPolicy::Elastic { min: 10, max: 4096 },
));
}
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let budget = 1; let start = Instant::now();
let _ = compiler.compile_with_budget_override(&proj, budget, &[]);
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 10,
"compile_with_budget_override took {}ms (must be < 10ms)",
elapsed.as_millis()
);
}
#[test]
fn placeholder_block_token_count_used_by_budget_solver() {
let compiler = ContextCompiler::new()
.register(ContextSource {
id: SourceId("tools.catalog"),
layer: ContextLayer::StaticPrefix,
priority: 254,
budget: BudgetPolicy::Fixed(500),
render: Arc::new(|_| vec![RenderedBlock::placeholder(500)]),
})
.register(ContextSource {
id: SourceId("volatile.low"),
layer: ContextLayer::Volatile,
priority: 10,
budget: BudgetPolicy::Elastic { min: 0, max: 200 },
render: Arc::new(|_| {
vec![RenderedBlock::new("x".repeat(200 * 3))] }),
});
let session = test_session();
let proj = ContextProjection::from_session(&session, 0);
let ctx = compiler.compile(&proj);
let catalog = ctx
.contributions
.iter()
.find(|c| c.source_id.0 == "tools.catalog");
assert_eq!(
catalog.map(|c| c.token_count).unwrap_or(0),
500,
"placeholder token_count must be 500, not 0"
);
let result = compiler.compile_with_budget_override(&proj, 600, &[]);
assert!(
result.is_ok(),
"budget solve should succeed: evict volatile to fit under 600"
);
let ctx2 = result.unwrap();
let has_volatile = ctx2
.contributions
.iter()
.any(|c| c.source_id.0 == "volatile.low");
assert!(
!has_volatile,
"volatile.low should be evicted when total > 600"
);
let has_catalog = ctx2
.contributions
.iter()
.any(|c| c.source_id.0 == "tools.catalog");
assert!(has_catalog, "tools.catalog (Fixed) must survive eviction");
}
}