1use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::path::PathBuf;
16
17pub const DEFAULT_MODEL_PATTERN: &str = "default";
19
20pub const DEFAULT_MAX_CONTEXT_TOKENS: u32 = 200_000;
24
25pub const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 64_000;
27
28pub const DEFAULT_SAFETY_MARGIN: u32 = 1000;
31
32pub fn default_model_limit() -> ModelLimit {
34 builtin_limit(
35 DEFAULT_MODEL_PATTERN,
36 DEFAULT_MAX_CONTEXT_TOKENS,
37 DEFAULT_MAX_OUTPUT_TOKENS,
38 )
39}
40
41pub fn is_default_limit(limit: &ModelLimit) -> bool {
48 limit.max_context_tokens == DEFAULT_MAX_CONTEXT_TOKENS
49 && limit.max_output_tokens == Some(DEFAULT_MAX_OUTPUT_TOKENS)
50 && limit.safety_margin.is_none()
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelLimit {
56 pub model_pattern: String,
58 pub max_context_tokens: u32,
60 #[serde(default)]
62 pub max_output_tokens: Option<u32>,
63 #[serde(default)]
65 pub safety_margin: Option<u32>,
66}
67
68impl ModelLimit {
69 pub fn new(model_pattern: impl Into<String>, max_context_tokens: u32) -> Self {
71 Self {
72 model_pattern: model_pattern.into(),
73 max_context_tokens,
74 max_output_tokens: None,
75 safety_margin: None,
76 }
77 }
78
79 pub fn get_max_output_tokens(&self) -> u32 {
81 self.max_output_tokens
82 .unwrap_or_else(|| (self.max_context_tokens / 4).min(4096))
83 }
84
85 pub fn get_safety_margin(&self) -> u32 {
87 self.safety_margin
88 .unwrap_or_else(|| (self.max_context_tokens / 100).max(DEFAULT_SAFETY_MARGIN))
89 }
90}
91
92fn builtin_limit(pattern: &str, max_context_tokens: u32, max_output_tokens: u32) -> ModelLimit {
93 let mut limit = ModelLimit::new(pattern.to_string(), max_context_tokens);
94 limit.max_output_tokens = Some(max_output_tokens);
95 limit
96}
97
98#[derive(Debug, Clone)]
100pub struct ModelLimitsRegistry {
101 user_limits: HashMap<String, ModelLimit>,
103 config_path: Option<PathBuf>,
105}
106
107impl ModelLimitsRegistry {
108 pub fn new() -> Self {
110 Self {
111 user_limits: HashMap::new(),
112 config_path: None,
113 }
114 }
115
116 pub fn with_config_path(path: impl Into<PathBuf>) -> Self {
118 Self {
119 user_limits: HashMap::new(),
120 config_path: Some(path.into()),
121 }
122 }
123
124 pub async fn load_user_config(&mut self) -> std::io::Result<()> {
128 let path = self
129 .config_path
130 .clone()
131 .unwrap_or_else(get_default_config_path);
132
133 if !path.exists() {
134 return Ok(());
135 }
136
137 let content = tokio::fs::read_to_string(&path).await?;
138 let limits: Vec<ModelLimit> = serde_json::from_str(&content)
139 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
140
141 for limit in limits {
142 self.user_limits.insert(limit.model_pattern.clone(), limit);
143 }
144
145 tracing::info!(
146 "Loaded {} user model limits from {:?}",
147 self.user_limits.len(),
148 path
149 );
150 Ok(())
151 }
152
153 pub fn add_limit(&mut self, limit: ModelLimit) {
155 self.user_limits.insert(limit.model_pattern.clone(), limit);
156 }
157
158 pub fn get(&self, model: &str) -> Option<ModelLimit> {
169 if let Some(limit) = self.user_limits.get(model) {
171 return Some(limit.clone());
172 }
173
174 self.user_limits
178 .iter()
179 .filter(|(pattern, _)| model.contains(pattern.as_str()) || pattern.contains(model))
180 .max_by_key(|(pattern, _)| pattern.len())
181 .map(|(_, limit)| limit.clone())
182 }
183
184 pub fn get_or_default(&self, model: &str) -> ModelLimit {
186 self.get(model).unwrap_or_else(default_model_limit)
187 }
188
189 pub async fn save_user_config(&self) -> std::io::Result<()> {
191 let path = self
192 .config_path
193 .clone()
194 .unwrap_or_else(get_default_config_path);
195
196 if let Some(parent) = path.parent() {
198 tokio::fs::create_dir_all(parent).await?;
199 }
200
201 let limits: Vec<&ModelLimit> = self.user_limits.values().collect();
202 let content = serde_json::to_string_pretty(&limits)
203 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
204 tokio::fs::write(&path, content).await?;
205
206 Ok(())
207 }
208
209 pub fn list_user_limits(&self) -> Vec<&ModelLimit> {
211 self.user_limits.values().collect()
212 }
213}
214
215impl Default for ModelLimitsRegistry {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221pub fn get_default_config_path() -> PathBuf {
225 bamboo_infrastructure::paths::bamboo_dir().join("model_limits.json")
226}
227
228pub fn load_model_limits_from_unified_config(
235 config: &bamboo_infrastructure::Config,
236) -> Result<Option<Vec<ModelLimit>>, String> {
237 let Some(raw_limits) = config.extra.get("model_limits") else {
238 return Ok(None);
239 };
240
241 if raw_limits.is_null() {
242 return Ok(Some(Vec::new()));
243 }
244
245 match raw_limits {
246 Value::Array(_) => serde_json::from_value::<Vec<ModelLimit>>(raw_limits.clone())
247 .map(Some)
248 .map_err(|error| format!("invalid config.model_limits format: {error}")),
249 _ => Err("invalid config.model_limits format: expected array".to_string()),
250 }
251}
252
253pub fn create_budget_for_model(model: &str, strategy: crate::BudgetStrategy) -> crate::TokenBudget {
257 let registry = ModelLimitsRegistry::default();
258 let limit = registry.get_or_default(model);
259
260 crate::TokenBudget {
261 max_context_tokens: limit.max_context_tokens,
262 max_output_tokens: limit.get_max_output_tokens(),
263 strategy,
264 safety_margin: limit.get_safety_margin(),
265 compression_trigger_percent: 85, compression_target_percent: 45,
267 working_reserve_tokens: 50_000,
268 fallback_trigger_percent: 75,
269 prompt_cache_min_tool_output_chars: 1_200,
270 prompt_cache_head_chars: 280,
271 prompt_cache_tail_chars: 180,
272 prompt_cache_recent_user_turns: 2,
273 prompt_cache_recent_tool_chains: 2,
274 max_tool_output_tokens: 0,
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn default_limit_is_200k_64k() {
284 let limit = default_model_limit();
285 assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
286 assert_eq!(limit.max_context_tokens, 200_000);
287 assert_eq!(limit.get_max_output_tokens(), 64_000);
288 }
289
290 #[test]
291 fn is_default_limit_detects_no_op_overrides() {
292 let mut noop = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
294 noop.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
295 assert!(is_default_limit(&noop));
296
297 assert!(is_default_limit(&default_model_limit()));
299
300 let mut smaller = ModelLimit::new("gpt-4o", 128_000);
302 smaller.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
303 assert!(!is_default_limit(&smaller));
304
305 let mut custom_margin = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
307 custom_margin.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
308 custom_margin.safety_margin = Some(500);
309 assert!(!is_default_limit(&custom_margin));
310 }
311
312 #[test]
313 fn registry_returns_none_for_unknown_without_overrides() {
314 let registry = ModelLimitsRegistry::new();
316 assert!(registry.get("gpt-5.2-codex").is_none());
317 assert!(registry.get("some-brand-new-model").is_none());
318 }
319
320 #[test]
321 fn registry_returns_default_for_unknown() {
322 let registry = ModelLimitsRegistry::new();
323 let limit = registry.get_or_default("unknown-model-xyz");
324 assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
325 assert_eq!(limit.max_context_tokens, 200_000);
326 assert_eq!(limit.get_max_output_tokens(), 64_000);
327 }
328
329 #[test]
330 fn user_override_exact_match_wins() {
331 let mut registry = ModelLimitsRegistry::new();
332 registry.add_limit(ModelLimit::new("gpt-5.2-codex", 64_000)); let limit = registry
335 .get("gpt-5.2-codex")
336 .expect("Should find overridden limit");
337 assert_eq!(limit.max_context_tokens, 64_000);
338 }
339
340 #[test]
341 fn user_override_partial_match_longest_wins() {
342 let mut registry = ModelLimitsRegistry::new();
343 registry.add_limit(ModelLimit::new("gpt-5", 111_000));
344 registry.add_limit(ModelLimit::new("gpt-5.2-codex", 222_000));
345
346 let limit = registry
348 .get("gpt-5.2-codex-preview")
349 .expect("Should partial-match a user override");
350 assert_eq!(limit.max_context_tokens, 222_000);
351 }
352
353 #[test]
354 fn model_limit_calculates_default_output_tokens() {
355 let limit = ModelLimit::new("test", 100_000);
356 assert_eq!(limit.get_max_output_tokens(), 4096);
358 }
359
360 #[test]
361 fn model_limit_uses_custom_output_tokens() {
362 let mut limit = ModelLimit::new("test", 100_000);
363 limit.max_output_tokens = Some(8192);
364 assert_eq!(limit.get_max_output_tokens(), 8192);
365 }
366
367 #[test]
368 fn model_limit_calculates_small_context_output() {
369 let limit = ModelLimit::new("test", 8_192);
370 assert_eq!(limit.get_max_output_tokens(), 2048);
372 }
373
374 #[test]
375 fn unified_config_loader_returns_none_when_absent() {
376 let temp_dir = tempfile::tempdir().expect("tempdir");
377 let config =
378 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
379 let loaded = load_model_limits_from_unified_config(&config).expect("should parse");
380 assert!(loaded.is_none());
381 }
382
383 #[test]
384 fn unified_config_loader_reads_valid_model_limits() {
385 let temp_dir = tempfile::tempdir().expect("tempdir");
386 let mut config =
387 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
388 config.extra.insert(
389 "model_limits".to_string(),
390 serde_json::json!([
391 {
392 "model_pattern": "gpt-5.2-codex",
393 "max_context_tokens": 64000,
394 "max_output_tokens": 2048,
395 "safety_margin": 512
396 }
397 ]),
398 );
399
400 let loaded = load_model_limits_from_unified_config(&config)
401 .expect("should parse")
402 .expect("should exist");
403 assert_eq!(loaded.len(), 1);
404 assert_eq!(loaded[0].model_pattern, "gpt-5.2-codex");
405 assert_eq!(loaded[0].max_context_tokens, 64_000);
406 assert_eq!(loaded[0].max_output_tokens, Some(2048));
407 assert_eq!(loaded[0].safety_margin, Some(512));
408 }
409
410 #[test]
411 fn unified_config_loader_errors_on_invalid_shape() {
412 let temp_dir = tempfile::tempdir().expect("tempdir");
413 let mut config =
414 bamboo_infrastructure::Config::from_data_dir(Some(temp_dir.path().to_path_buf()));
415 config.extra.insert(
416 "model_limits".to_string(),
417 serde_json::json!({"unexpected": true}),
418 );
419
420 let error = load_model_limits_from_unified_config(&config).expect_err("should error");
421 assert!(error.contains("expected array"));
422 }
423
424 #[test]
425 fn safety_margin_scales_with_context_window() {
426 let small = ModelLimit::new("test", 8_192);
428 assert_eq!(small.get_safety_margin(), 1000);
429
430 let medium = ModelLimit::new("test", 200_000);
432 assert_eq!(medium.get_safety_margin(), 2000);
433
434 let large = ModelLimit::new("test", 1_050_000);
436 assert_eq!(large.get_safety_margin(), 10_500);
437
438 let mut custom = ModelLimit::new("test", 200_000);
440 custom.safety_margin = Some(500);
441 assert_eq!(custom.get_safety_margin(), 500);
442 }
443
444 #[tokio::test]
445 async fn persisted_overrides_drive_runtime_resolution() {
446 let dir = tempfile::tempdir().expect("tempdir");
448 let path = dir.path().join("model_limits.json");
449 tokio::fs::write(
450 &path,
451 r#"[{"model_pattern":"gpt-4o","max_context_tokens":128000,"max_output_tokens":16384}]"#,
452 )
453 .await
454 .expect("seed overrides");
455
456 let mut registry = ModelLimitsRegistry::with_config_path(path);
457 registry.load_user_config().await.expect("load user config");
458
459 let gpt4o = registry.get("gpt-4o").expect("override present");
461 assert_eq!(gpt4o.max_context_tokens, 128_000);
462 assert_eq!(gpt4o.get_max_output_tokens(), 16_384);
463
464 let unknown = registry.get_or_default("brand-new-frontier-model");
466 assert_eq!(unknown.model_pattern, DEFAULT_MODEL_PATTERN);
467 assert_eq!(unknown.max_context_tokens, 200_000);
468 assert_eq!(unknown.get_max_output_tokens(), 64_000);
469 }
470
471 #[test]
472 fn create_budget_for_model_uses_global_default_for_any_model() {
473 let budget = create_budget_for_model("anything-at-all", crate::BudgetStrategy::default());
474 assert_eq!(budget.max_context_tokens, 200_000);
475 assert_eq!(budget.max_output_tokens, 64_000);
476 }
477}