use serde::{Deserialize, Serialize};
use super::models::{ModelId, lookup};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct ContextBudget {
pub model: ModelId,
pub total_tokens: u32,
pub system_reserve: u32,
pub response_reserve: u32,
pub mem_share: f32,
}
impl ContextBudget {
pub fn for_model(model: ModelId) -> Self {
let w = lookup(model);
Self {
model,
total_tokens: w.total_tokens,
system_reserve: w.system_reserve,
response_reserve: w.response_reserve,
mem_share: 0.45,
}
}
pub fn with_mem_share(mut self, share: f32) -> Self {
self.mem_share = share.clamp(0.0, 1.0);
self
}
pub fn available(&self) -> u32 {
self.total_tokens
.saturating_sub(self.system_reserve)
.saturating_sub(self.response_reserve)
}
pub fn memory_budget(&self) -> u32 {
(self.available() as f32 * self.mem_share) as u32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FallbackStrategy {
TruncateOldest,
SummarizeOldestK(u32),
DropDuplicates,
None,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RecallPlan {
pub k: u32,
pub chunk_tokens: u32,
pub dedup_radius: f32,
pub fallback: FallbackStrategy,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Query {
pub text: String,
pub estimated_tokens: u32,
}
pub fn plan_recall(b: &ContextBudget, history_tokens: u32, query: &Query) -> RecallPlan {
let avail = b.available();
let mem_budget = b.memory_budget();
let history_share = avail.saturating_sub(mem_budget);
let fallback = if history_tokens > history_share {
FallbackStrategy::TruncateOldest
} else if mem_budget > 100_000 {
FallbackStrategy::DropDuplicates
} else {
FallbackStrategy::None
};
let chunk_tokens = if b.total_tokens >= 800_000 {
1024
} else if b.total_tokens >= 200_000 {
512
} else {
256
};
let usable = (mem_budget as f32 * 0.7) as u32;
let k = if chunk_tokens == 0 {
0
} else {
(usable / chunk_tokens).clamp(1, 256)
};
let dedup_radius = if b.total_tokens >= 800_000 {
0.92
} else {
0.88
};
let _ = query;
RecallPlan {
k,
chunk_tokens,
dedup_radius,
fallback,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn q(t: u32) -> Query {
Query {
text: "test".into(),
estimated_tokens: t,
}
}
#[test]
fn deepseek_v4_yields_high_k_and_fits_under_mem_share() {
let b = ContextBudget::for_model(ModelId::DeepSeekV4_1m);
let plan = plan_recall(&b, 200_000, &q(64));
assert!(
plan.k >= 64,
"expected k>=64 for 1M context, got {}",
plan.k
);
let injected = plan.k * plan.chunk_tokens;
assert!(
injected as f32 <= b.memory_budget() as f32 * 0.8,
"plan injects {injected} but mem_budget is {}",
b.memory_budget()
);
}
#[test]
fn small_window_drops_to_smaller_chunks() {
let b = ContextBudget::for_model(ModelId::DeepSeekV3_128k);
let plan = plan_recall(&b, 8_000, &q(64));
assert!(plan.chunk_tokens <= 512);
}
#[test]
fn budget_does_not_overflow_total() {
for (m, _) in super::super::models::MODEL_TABLE {
let b = ContextBudget::for_model(*m);
let plan = plan_recall(&b, 0, &q(0));
let injected = plan.k * plan.chunk_tokens;
let total = b.system_reserve + b.response_reserve + injected;
assert!(
total <= b.total_tokens,
"model {} overflows: total {} > {}",
m.as_str(),
total,
b.total_tokens
);
}
}
#[test]
fn truncate_oldest_kicks_in_when_history_overflows() {
let b = ContextBudget::for_model(ModelId::Gpt5_1_128k);
let plan = plan_recall(&b, b.available() + 10_000, &q(1));
assert_eq!(plan.fallback, FallbackStrategy::TruncateOldest);
}
#[test]
fn dedup_radius_is_tighter_on_large_windows() {
let small = plan_recall(&ContextBudget::for_model(ModelId::Gpt5_1_128k), 1000, &q(1));
let huge = plan_recall(
&ContextBudget::for_model(ModelId::Gemini2_5Pro2m),
1000,
&q(1),
);
assert!(huge.dedup_radius >= small.dedup_radius);
}
#[test]
fn mem_share_is_clampable() {
let b = ContextBudget::for_model(ModelId::Claude3_7Sonnet1m).with_mem_share(2.0);
assert!(b.mem_share <= 1.0);
}
}