1use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::path::PathBuf;
15
16pub const DEFAULT_MODEL_PATTERN: &str = "default";
18
19pub const DEFAULT_MAX_CONTEXT_TOKENS: u32 = 200_000;
23
24pub const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 64_000;
26
27pub const DEFAULT_SAFETY_MARGIN: u32 = 1000;
30
31pub fn default_model_limit() -> ModelLimit {
33 builtin_limit(
34 DEFAULT_MODEL_PATTERN,
35 DEFAULT_MAX_CONTEXT_TOKENS,
36 DEFAULT_MAX_OUTPUT_TOKENS,
37 )
38}
39
40pub fn is_default_limit(limit: &ModelLimit) -> bool {
47 limit.max_context_tokens == DEFAULT_MAX_CONTEXT_TOKENS
48 && limit.max_output_tokens == Some(DEFAULT_MAX_OUTPUT_TOKENS)
49 && limit.safety_margin.is_none()
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ModelLimit {
55 pub model_pattern: String,
57 pub max_context_tokens: u32,
59 #[serde(default)]
61 pub max_output_tokens: Option<u32>,
62 #[serde(default)]
64 pub safety_margin: Option<u32>,
65}
66
67impl ModelLimit {
68 pub fn new(model_pattern: impl Into<String>, max_context_tokens: u32) -> Self {
70 Self {
71 model_pattern: model_pattern.into(),
72 max_context_tokens,
73 max_output_tokens: None,
74 safety_margin: None,
75 }
76 }
77
78 pub fn get_max_output_tokens(&self) -> u32 {
80 self.max_output_tokens
81 .unwrap_or_else(|| (self.max_context_tokens / 4).min(4096))
82 }
83
84 pub fn get_safety_margin(&self) -> u32 {
86 self.safety_margin
87 .unwrap_or_else(|| (self.max_context_tokens / 100).max(DEFAULT_SAFETY_MARGIN))
88 }
89}
90
91fn builtin_limit(pattern: &str, max_context_tokens: u32, max_output_tokens: u32) -> ModelLimit {
92 let mut limit = ModelLimit::new(pattern.to_string(), max_context_tokens);
93 limit.max_output_tokens = Some(max_output_tokens);
94 limit
95}
96
97#[derive(Debug, Clone)]
99pub struct ModelLimitsRegistry {
100 user_limits: HashMap<String, ModelLimit>,
102 config_path: Option<PathBuf>,
104}
105
106impl ModelLimitsRegistry {
107 pub fn new() -> Self {
109 Self {
110 user_limits: HashMap::new(),
111 config_path: None,
112 }
113 }
114
115 pub fn with_config_path(path: impl Into<PathBuf>) -> Self {
117 Self {
118 user_limits: HashMap::new(),
119 config_path: Some(path.into()),
120 }
121 }
122
123 pub async fn load_user_config(&mut self) -> std::io::Result<()> {
127 let path = self
128 .config_path
129 .clone()
130 .unwrap_or_else(get_default_config_path);
131
132 if !path.exists() {
133 return Ok(());
134 }
135
136 let content = tokio::fs::read_to_string(&path).await?;
137 let limits: Vec<ModelLimit> = serde_json::from_str(&content)
138 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
139
140 for limit in limits {
141 self.user_limits.insert(limit.model_pattern.clone(), limit);
142 }
143
144 tracing::info!(
145 "Loaded {} user model limits from {:?}",
146 self.user_limits.len(),
147 path
148 );
149 Ok(())
150 }
151
152 pub fn add_limit(&mut self, limit: ModelLimit) {
154 self.user_limits.insert(limit.model_pattern.clone(), limit);
155 }
156
157 pub fn get(&self, model: &str) -> Option<ModelLimit> {
168 if let Some(limit) = self.user_limits.get(model) {
170 return Some(limit.clone());
171 }
172
173 self.user_limits
177 .iter()
178 .filter(|(pattern, _)| model.contains(pattern.as_str()) || pattern.contains(model))
179 .max_by_key(|(pattern, _)| pattern.len())
180 .map(|(_, limit)| limit.clone())
181 }
182
183 pub fn get_or_default(&self, model: &str) -> ModelLimit {
185 self.get(model).unwrap_or_else(default_model_limit)
186 }
187
188 pub async fn save_user_config(&self) -> std::io::Result<()> {
190 let path = self
191 .config_path
192 .clone()
193 .unwrap_or_else(get_default_config_path);
194
195 if let Some(parent) = path.parent() {
197 tokio::fs::create_dir_all(parent).await?;
198 }
199
200 let limits: Vec<&ModelLimit> = self.user_limits.values().collect();
201 let content = serde_json::to_string_pretty(&limits)
202 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
203 tokio::fs::write(&path, content).await?;
204
205 Ok(())
206 }
207
208 pub fn list_user_limits(&self) -> Vec<&ModelLimit> {
210 self.user_limits.values().collect()
211 }
212}
213
214impl Default for ModelLimitsRegistry {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220pub fn get_default_config_path() -> PathBuf {
224 bamboo_infrastructure::paths::bamboo_dir().join("model_limits.json")
225}
226
227pub fn load_model_limits_from_unified_config(
234 config: &bamboo_infrastructure::Config,
235) -> Result<Option<Vec<ModelLimit>>, String> {
236 let Some(raw_limits) = config.extra.get("model_limits") else {
237 return Ok(None);
238 };
239
240 if raw_limits.is_null() {
241 return Ok(Some(Vec::new()));
242 }
243
244 match raw_limits {
245 Value::Array(_) => serde_json::from_value::<Vec<ModelLimit>>(raw_limits.clone())
246 .map(Some)
247 .map_err(|error| format!("invalid config.model_limits format: {error}")),
248 _ => Err("invalid config.model_limits format: expected array".to_string()),
249 }
250}
251
252pub fn create_budget_for_model(model: &str, strategy: crate::BudgetStrategy) -> crate::TokenBudget {
256 let registry = ModelLimitsRegistry::default();
257 let limit = registry.get_or_default(model);
258
259 crate::TokenBudget {
260 max_context_tokens: limit.max_context_tokens,
261 max_output_tokens: limit.get_max_output_tokens(),
262 strategy,
263 safety_margin: limit.get_safety_margin(),
264 compression_trigger_percent: 85, compression_target_percent: 45,
266 working_reserve_tokens: 50_000,
267 fallback_trigger_percent: 75,
268 prompt_cache_min_tool_output_chars: 1_200,
269 prompt_cache_head_chars: 280,
270 prompt_cache_tail_chars: 180,
271 prompt_cache_recent_user_turns: 2,
272 prompt_cache_recent_tool_chains: 2,
273 max_tool_output_tokens: 0,
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn default_limit_is_200k_64k() {
283 let limit = default_model_limit();
284 assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
285 assert_eq!(limit.max_context_tokens, 200_000);
286 assert_eq!(limit.get_max_output_tokens(), 64_000);
287 }
288
289 #[test]
290 fn is_default_limit_detects_no_op_overrides() {
291 let mut noop = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
293 noop.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
294 assert!(is_default_limit(&noop));
295
296 assert!(is_default_limit(&default_model_limit()));
298
299 let mut smaller = ModelLimit::new("gpt-4o", 128_000);
301 smaller.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
302 assert!(!is_default_limit(&smaller));
303
304 let mut custom_margin = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
306 custom_margin.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
307 custom_margin.safety_margin = Some(500);
308 assert!(!is_default_limit(&custom_margin));
309 }
310
311 #[test]
312 fn registry_returns_none_for_unknown_without_overrides() {
313 let registry = ModelLimitsRegistry::new();
315 assert!(registry.get("gpt-5.2-codex").is_none());
316 assert!(registry.get("some-brand-new-model").is_none());
317 }
318
319 #[test]
320 fn registry_returns_default_for_unknown() {
321 let registry = ModelLimitsRegistry::new();
322 let limit = registry.get_or_default("unknown-model-xyz");
323 assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
324 assert_eq!(limit.max_context_tokens, 200_000);
325 assert_eq!(limit.get_max_output_tokens(), 64_000);
326 }
327
328 #[test]
329 fn user_override_exact_match_wins() {
330 let mut registry = ModelLimitsRegistry::new();
331 registry.add_limit(ModelLimit::new("gpt-5.2-codex", 64_000)); let limit = registry
334 .get("gpt-5.2-codex")
335 .expect("Should find overridden limit");
336 assert_eq!(limit.max_context_tokens, 64_000);
337 }
338
339 #[test]
340 fn user_override_partial_match_longest_wins() {
341 let mut registry = ModelLimitsRegistry::new();
342 registry.add_limit(ModelLimit::new("gpt-5", 111_000));
343 registry.add_limit(ModelLimit::new("gpt-5.2-codex", 222_000));
344
345 let limit = registry
347 .get("gpt-5.2-codex-preview")
348 .expect("Should partial-match a user override");
349 assert_eq!(limit.max_context_tokens, 222_000);
350 }
351
352 #[test]
353 fn model_limit_calculates_default_output_tokens() {
354 let limit = ModelLimit::new("test", 100_000);
355 assert_eq!(limit.get_max_output_tokens(), 4096);
357 }
358
359 #[test]
360 fn model_limit_uses_custom_output_tokens() {
361 let mut limit = ModelLimit::new("test", 100_000);
362 limit.max_output_tokens = Some(8192);
363 assert_eq!(limit.get_max_output_tokens(), 8192);
364 }
365
366 #[test]
367 fn model_limit_calculates_small_context_output() {
368 let limit = ModelLimit::new("test", 8_192);
369 assert_eq!(limit.get_max_output_tokens(), 2048);
371 }
372
373 #[test]
374 fn unified_config_loader_returns_none_when_absent() {
375 let temp_dir = tempfile::tempdir().expect("tempdir");
376 let config =
377 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
378 let loaded = load_model_limits_from_unified_config(&config).expect("should parse");
379 assert!(loaded.is_none());
380 }
381
382 #[test]
383 fn unified_config_loader_reads_valid_model_limits() {
384 let temp_dir = tempfile::tempdir().expect("tempdir");
385 let mut config =
386 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
387 config.extra.insert(
388 "model_limits".to_string(),
389 serde_json::json!([
390 {
391 "model_pattern": "gpt-5.2-codex",
392 "max_context_tokens": 64000,
393 "max_output_tokens": 2048,
394 "safety_margin": 512
395 }
396 ]),
397 );
398
399 let loaded = load_model_limits_from_unified_config(&config)
400 .expect("should parse")
401 .expect("should exist");
402 assert_eq!(loaded.len(), 1);
403 assert_eq!(loaded[0].model_pattern, "gpt-5.2-codex");
404 assert_eq!(loaded[0].max_context_tokens, 64_000);
405 assert_eq!(loaded[0].max_output_tokens, Some(2048));
406 assert_eq!(loaded[0].safety_margin, Some(512));
407 }
408
409 #[test]
410 fn unified_config_loader_errors_on_invalid_shape() {
411 let temp_dir = tempfile::tempdir().expect("tempdir");
412 let mut config =
413 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
414 config.extra.insert(
415 "model_limits".to_string(),
416 serde_json::json!({"unexpected": true}),
417 );
418
419 let error = load_model_limits_from_unified_config(&config).expect_err("should error");
420 assert!(error.contains("expected array"));
421 }
422
423 #[test]
424 fn safety_margin_scales_with_context_window() {
425 let small = ModelLimit::new("test", 8_192);
427 assert_eq!(small.get_safety_margin(), 1000);
428
429 let medium = ModelLimit::new("test", 200_000);
431 assert_eq!(medium.get_safety_margin(), 2000);
432
433 let large = ModelLimit::new("test", 1_050_000);
435 assert_eq!(large.get_safety_margin(), 10_500);
436
437 let mut custom = ModelLimit::new("test", 200_000);
439 custom.safety_margin = Some(500);
440 assert_eq!(custom.get_safety_margin(), 500);
441 }
442
443 #[tokio::test]
444 async fn persisted_overrides_drive_runtime_resolution() {
445 let dir = tempfile::tempdir().expect("tempdir");
447 let path = dir.path().join("model_limits.json");
448 tokio::fs::write(
449 &path,
450 r#"[{"model_pattern":"gpt-4o","max_context_tokens":128000,"max_output_tokens":16384}]"#,
451 )
452 .await
453 .expect("seed overrides");
454
455 let mut registry = ModelLimitsRegistry::with_config_path(path);
456 registry.load_user_config().await.expect("load user config");
457
458 let gpt4o = registry.get("gpt-4o").expect("override present");
460 assert_eq!(gpt4o.max_context_tokens, 128_000);
461 assert_eq!(gpt4o.get_max_output_tokens(), 16_384);
462
463 let unknown = registry.get_or_default("brand-new-frontier-model");
465 assert_eq!(unknown.model_pattern, DEFAULT_MODEL_PATTERN);
466 assert_eq!(unknown.max_context_tokens, 200_000);
467 assert_eq!(unknown.get_max_output_tokens(), 64_000);
468 }
469
470 #[test]
471 fn create_budget_for_model_uses_global_default_for_any_model() {
472 let budget = create_budget_for_model("anything-at-all", crate::BudgetStrategy::default());
473 assert_eq!(budget.max_context_tokens, 200_000);
474 assert_eq!(budget.max_output_tokens, 64_000);
475 }
476}