mnemo_core/budget/
planner.rs1use serde::{Deserialize, Serialize};
4
5use super::models::{ModelId, lookup};
6
7#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct ContextBudget {
12 pub model: ModelId,
13 pub total_tokens: u32,
14 pub system_reserve: u32,
15 pub response_reserve: u32,
16 pub mem_share: f32,
21}
22
23impl ContextBudget {
24 pub fn for_model(model: ModelId) -> Self {
25 let w = lookup(model);
26 Self {
27 model,
28 total_tokens: w.total_tokens,
29 system_reserve: w.system_reserve,
30 response_reserve: w.response_reserve,
31 mem_share: 0.45,
32 }
33 }
34
35 pub fn with_mem_share(mut self, share: f32) -> Self {
36 self.mem_share = share.clamp(0.0, 1.0);
37 self
38 }
39
40 pub fn available(&self) -> u32 {
42 self.total_tokens
43 .saturating_sub(self.system_reserve)
44 .saturating_sub(self.response_reserve)
45 }
46
47 pub fn memory_budget(&self) -> u32 {
48 (self.available() as f32 * self.mem_share) as u32
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum FallbackStrategy {
55 TruncateOldest,
57 SummarizeOldestK(u32),
59 DropDuplicates,
61 None,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct RecallPlan {
67 pub k: u32,
68 pub chunk_tokens: u32,
71 pub dedup_radius: f32,
74 pub fallback: FallbackStrategy,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78pub struct Query {
79 pub text: String,
80 pub estimated_tokens: u32,
81}
82
83pub fn plan_recall(b: &ContextBudget, history_tokens: u32, query: &Query) -> RecallPlan {
84 let avail = b.available();
85 let mem_budget = b.memory_budget();
86
87 let history_share = avail.saturating_sub(mem_budget);
90 let fallback = if history_tokens > history_share {
91 FallbackStrategy::TruncateOldest
92 } else if mem_budget > 100_000 {
93 FallbackStrategy::DropDuplicates
95 } else {
96 FallbackStrategy::None
97 };
98
99 let chunk_tokens = if b.total_tokens >= 800_000 {
103 1024
104 } else if b.total_tokens >= 200_000 {
105 512
106 } else {
107 256
108 };
109
110 let usable = (mem_budget as f32 * 0.7) as u32;
114 let k = usable
115 .checked_div(chunk_tokens)
116 .map_or(0, |q| q.clamp(1, 256));
117
118 let dedup_radius = if b.total_tokens >= 800_000 {
121 0.92
122 } else {
123 0.88
124 };
125
126 let _ = query; RecallPlan {
129 k,
130 chunk_tokens,
131 dedup_radius,
132 fallback,
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 fn q(t: u32) -> Query {
141 Query {
142 text: "test".into(),
143 estimated_tokens: t,
144 }
145 }
146
147 #[test]
148 fn deepseek_v4_yields_high_k_and_fits_under_mem_share() {
149 let b = ContextBudget::for_model(ModelId::DeepSeekV4_1m);
150 let plan = plan_recall(&b, 200_000, &q(64));
151 assert!(
152 plan.k >= 64,
153 "expected k>=64 for 1M context, got {}",
154 plan.k
155 );
156 let injected = plan.k * plan.chunk_tokens;
157 assert!(
158 injected as f32 <= b.memory_budget() as f32 * 0.8,
159 "plan injects {injected} but mem_budget is {}",
160 b.memory_budget()
161 );
162 }
163
164 #[test]
165 fn small_window_drops_to_smaller_chunks() {
166 let b = ContextBudget::for_model(ModelId::DeepSeekV3_128k);
167 let plan = plan_recall(&b, 8_000, &q(64));
168 assert!(plan.chunk_tokens <= 512);
169 }
170
171 #[test]
172 fn budget_does_not_overflow_total() {
173 for (m, _) in super::super::models::MODEL_TABLE {
177 let b = ContextBudget::for_model(*m);
178 let plan = plan_recall(&b, 0, &q(0));
179 let injected = plan.k * plan.chunk_tokens;
180 let total = b.system_reserve + b.response_reserve + injected;
181 assert!(
182 total <= b.total_tokens,
183 "model {} overflows: total {} > {}",
184 m.as_str(),
185 total,
186 b.total_tokens
187 );
188 }
189 }
190
191 #[test]
192 fn truncate_oldest_kicks_in_when_history_overflows() {
193 let b = ContextBudget::for_model(ModelId::Gpt5_1_128k);
194 let plan = plan_recall(&b, b.available() + 10_000, &q(1));
196 assert_eq!(plan.fallback, FallbackStrategy::TruncateOldest);
197 }
198
199 #[test]
200 fn dedup_radius_is_tighter_on_large_windows() {
201 let small = plan_recall(&ContextBudget::for_model(ModelId::Gpt5_1_128k), 1000, &q(1));
202 let huge = plan_recall(
203 &ContextBudget::for_model(ModelId::Gemini2_5Pro2m),
204 1000,
205 &q(1),
206 );
207 assert!(huge.dedup_radius >= small.dedup_radius);
208 }
209
210 #[test]
211 fn mem_share_is_clampable() {
212 let b = ContextBudget::for_model(ModelId::Claude3_7Sonnet1m).with_mem_share(2.0);
213 assert!(b.mem_share <= 1.0);
214 }
215}