mnemo_core/budget/
planner.rs1use serde::{Deserialize, Serialize};
4
5use super::models::{ModelId, lookup};
6
7#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct ContextBudget {
12 pub model: ModelId,
13 pub total_tokens: u32,
14 pub system_reserve: u32,
15 pub response_reserve: u32,
16 pub mem_share: f32,
21}
22
23impl ContextBudget {
24 pub fn for_model(model: ModelId) -> Self {
25 let w = lookup(model);
26 Self {
27 model,
28 total_tokens: w.total_tokens,
29 system_reserve: w.system_reserve,
30 response_reserve: w.response_reserve,
31 mem_share: 0.45,
32 }
33 }
34
35 pub fn with_mem_share(mut self, share: f32) -> Self {
36 self.mem_share = share.clamp(0.0, 1.0);
37 self
38 }
39
40 pub fn available(&self) -> u32 {
42 self.total_tokens
43 .saturating_sub(self.system_reserve)
44 .saturating_sub(self.response_reserve)
45 }
46
47 pub fn memory_budget(&self) -> u32 {
48 (self.available() as f32 * self.mem_share) as u32
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum FallbackStrategy {
55 TruncateOldest,
57 SummarizeOldestK(u32),
59 DropDuplicates,
61 None,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct RecallPlan {
67 pub k: u32,
68 pub chunk_tokens: u32,
71 pub dedup_radius: f32,
74 pub fallback: FallbackStrategy,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78pub struct Query {
79 pub text: String,
80 pub estimated_tokens: u32,
81}
82
83pub fn plan_recall(b: &ContextBudget, history_tokens: u32, query: &Query) -> RecallPlan {
84 let avail = b.available();
85 let mem_budget = b.memory_budget();
86
87 let history_share = avail.saturating_sub(mem_budget);
90 let fallback = if history_tokens > history_share {
91 FallbackStrategy::TruncateOldest
92 } else if mem_budget > 100_000 {
93 FallbackStrategy::DropDuplicates
95 } else {
96 FallbackStrategy::None
97 };
98
99 let chunk_tokens = if b.total_tokens >= 800_000 {
103 1024
104 } else if b.total_tokens >= 200_000 {
105 512
106 } else {
107 256
108 };
109
110 let usable = (mem_budget as f32 * 0.7) as u32;
114 let k = if chunk_tokens == 0 {
115 0
116 } else {
117 (usable / chunk_tokens).clamp(1, 256)
118 };
119
120 let dedup_radius = if b.total_tokens >= 800_000 {
123 0.92
124 } else {
125 0.88
126 };
127
128 let _ = query; RecallPlan {
131 k,
132 chunk_tokens,
133 dedup_radius,
134 fallback,
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 fn q(t: u32) -> Query {
143 Query {
144 text: "test".into(),
145 estimated_tokens: t,
146 }
147 }
148
149 #[test]
150 fn deepseek_v4_yields_high_k_and_fits_under_mem_share() {
151 let b = ContextBudget::for_model(ModelId::DeepSeekV4_1m);
152 let plan = plan_recall(&b, 200_000, &q(64));
153 assert!(
154 plan.k >= 64,
155 "expected k>=64 for 1M context, got {}",
156 plan.k
157 );
158 let injected = plan.k * plan.chunk_tokens;
159 assert!(
160 injected as f32 <= b.memory_budget() as f32 * 0.8,
161 "plan injects {injected} but mem_budget is {}",
162 b.memory_budget()
163 );
164 }
165
166 #[test]
167 fn small_window_drops_to_smaller_chunks() {
168 let b = ContextBudget::for_model(ModelId::DeepSeekV3_128k);
169 let plan = plan_recall(&b, 8_000, &q(64));
170 assert!(plan.chunk_tokens <= 512);
171 }
172
173 #[test]
174 fn budget_does_not_overflow_total() {
175 for (m, _) in super::super::models::MODEL_TABLE {
179 let b = ContextBudget::for_model(*m);
180 let plan = plan_recall(&b, 0, &q(0));
181 let injected = plan.k * plan.chunk_tokens;
182 let total = b.system_reserve + b.response_reserve + injected;
183 assert!(
184 total <= b.total_tokens,
185 "model {} overflows: total {} > {}",
186 m.as_str(),
187 total,
188 b.total_tokens
189 );
190 }
191 }
192
193 #[test]
194 fn truncate_oldest_kicks_in_when_history_overflows() {
195 let b = ContextBudget::for_model(ModelId::Gpt5_1_128k);
196 let plan = plan_recall(&b, b.available() + 10_000, &q(1));
198 assert_eq!(plan.fallback, FallbackStrategy::TruncateOldest);
199 }
200
201 #[test]
202 fn dedup_radius_is_tighter_on_large_windows() {
203 let small = plan_recall(&ContextBudget::for_model(ModelId::Gpt5_1_128k), 1000, &q(1));
204 let huge = plan_recall(
205 &ContextBudget::for_model(ModelId::Gemini2_5Pro2m),
206 1000,
207 &q(1),
208 );
209 assert!(huge.dedup_radius >= small.dedup_radius);
210 }
211
212 #[test]
213 fn mem_share_is_clampable() {
214 let b = ContextBudget::for_model(ModelId::Claude3_7Sonnet1m).with_mem_share(2.0);
215 assert!(b.mem_share <= 1.0);
216 }
217}