Skip to main content

bamboo_compression/
limits.rs

1//! Model context window limits registry.
2//!
3//! There is intentionally **no** built-in per-model table. Real per-model
4//! values come from two sources, in priority order:
5//!   1. provider runtime metadata (e.g. Copilot reports real context/output),
6//!   2. user overrides persisted in `model_limits.json`.
7//!
8//! Anything without a match falls back to a single global default
9//! (`DEFAULT_MAX_CONTEXT_TOKENS` / `DEFAULT_MAX_OUTPUT_TOKENS`). This keeps the
10//! registry from going stale as models churn — see `token_budget.rs`.
11
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::path::PathBuf;
16
17/// Sentinel pattern used for the single global fallback limit.
18pub const DEFAULT_MODEL_PATTERN: &str = "default";
19
20/// Global default context window applied to any model without a provider
21/// metadata value or a user override. 200K is a mainstream range across
22/// current frontier models.
23pub const DEFAULT_MAX_CONTEXT_TOKENS: u32 = 200_000;
24
25/// Global default maximum output tokens.
26pub const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 64_000;
27
28/// Default safety margin for token counting errors (floor; scales with context
29/// window via [`ModelLimit::get_safety_margin`]).
30pub const DEFAULT_SAFETY_MARGIN: u32 = 1000;
31
32/// Build the single global default limit (`200K` context / `64K` output).
33pub fn default_model_limit() -> ModelLimit {
34    builtin_limit(
35        DEFAULT_MODEL_PATTERN,
36        DEFAULT_MAX_CONTEXT_TOKENS,
37        DEFAULT_MAX_OUTPUT_TOKENS,
38    )
39}
40
41/// Whether a user override is a no-op — identical to the global default, so it
42/// carries no information and need not be persisted (diff-only storage).
43///
44/// The `model_pattern` is irrelevant: any model pinned to exactly the default
45/// context/output with no explicit safety margin resolves to the same budget
46/// as having no override at all.
47pub fn is_default_limit(limit: &ModelLimit) -> bool {
48    limit.max_context_tokens == DEFAULT_MAX_CONTEXT_TOKENS
49        && limit.max_output_tokens == Some(DEFAULT_MAX_OUTPUT_TOKENS)
50        && limit.safety_margin.is_none()
51}
52
53/// Model limit configuration (user-overridable).
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelLimit {
56    /// Model identifier (partial match supported, e.g., "gpt-4" matches "gpt-4o")
57    pub model_pattern: String,
58    /// Maximum context window size in tokens
59    pub max_context_tokens: u32,
60    /// Maximum output tokens (defaults to min(4096, max_context / 4))
61    #[serde(default)]
62    pub max_output_tokens: Option<u32>,
63    /// Safety margin for token counting (defaults to 1000)
64    #[serde(default)]
65    pub safety_margin: Option<u32>,
66}
67
68impl ModelLimit {
69    /// Create a new model limit with defaults.
70    pub fn new(model_pattern: impl Into<String>, max_context_tokens: u32) -> Self {
71        Self {
72            model_pattern: model_pattern.into(),
73            max_context_tokens,
74            max_output_tokens: None,
75            safety_margin: None,
76        }
77    }
78
79    /// Get max output tokens with default calculation.
80    pub fn get_max_output_tokens(&self) -> u32 {
81        self.max_output_tokens
82            .unwrap_or_else(|| (self.max_context_tokens / 4).min(4096))
83    }
84
85    /// Get safety margin, scaling proportionally with context window.
86    pub fn get_safety_margin(&self) -> u32 {
87        self.safety_margin
88            .unwrap_or_else(|| (self.max_context_tokens / 100).max(DEFAULT_SAFETY_MARGIN))
89    }
90}
91
92fn builtin_limit(pattern: &str, max_context_tokens: u32, max_output_tokens: u32) -> ModelLimit {
93    let mut limit = ModelLimit::new(pattern.to_string(), max_context_tokens);
94    limit.max_output_tokens = Some(max_output_tokens);
95    limit
96}
97
98/// Registry for model limits with built-in defaults and user overrides.
99#[derive(Debug, Clone)]
100pub struct ModelLimitsRegistry {
101    /// User-provided overrides (higher priority than built-in)
102    user_limits: HashMap<String, ModelLimit>,
103    /// Default path for user configuration file
104    config_path: Option<PathBuf>,
105}
106
107impl ModelLimitsRegistry {
108    /// Create a new registry with built-in defaults only.
109    pub fn new() -> Self {
110        Self {
111            user_limits: HashMap::new(),
112            config_path: None,
113        }
114    }
115
116    /// Create a registry with a specific config file path.
117    pub fn with_config_path(path: impl Into<PathBuf>) -> Self {
118        Self {
119            user_limits: HashMap::new(),
120            config_path: Some(path.into()),
121        }
122    }
123
124    /// Load user overrides from the registry's configured path.
125    ///
126    /// No-op when the registry was created without a `config_path` (use
127    /// [`Self::with_config_path`] — e.g. with [`get_default_config_path`] — to
128    /// point it at `{bamboo_data_dir}/model_limits.json`).
129    pub async fn load_user_config(&mut self) -> std::io::Result<()> {
130        let Some(path) = self.config_path.clone() else {
131            return Ok(());
132        };
133
134        if !path.exists() {
135            return Ok(());
136        }
137
138        let content = tokio::fs::read_to_string(&path).await?;
139        let limits: Vec<ModelLimit> = serde_json::from_str(&content)
140            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
141
142        for limit in limits {
143            self.user_limits.insert(limit.model_pattern.clone(), limit);
144        }
145
146        tracing::info!(
147            "Loaded {} user model limits from {:?}",
148            self.user_limits.len(),
149            path
150        );
151        Ok(())
152    }
153
154    /// Add a user limit override.
155    pub fn add_limit(&mut self, limit: ModelLimit) {
156        self.user_limits.insert(limit.model_pattern.clone(), limit);
157    }
158
159    /// Get limit for a model, with user overrides taking priority.
160    ///
161    /// Returns `None` if no matching limit is found.
162    ///
163    /// # Matching Strategy
164    /// 1. Exact match (highest priority)
165    /// 2. Model contains pattern (e.g., "gpt-4o-mini" contains "gpt-4o")
166    /// 3. Pattern contains model (e.g., "gpt-4" contains "gpt")
167    ///
168    /// For partial matches, the longest (most specific) pattern wins.
169    pub fn get(&self, model: &str) -> Option<ModelLimit> {
170        // Exact user override match (highest priority).
171        if let Some(limit) = self.user_limits.get(model) {
172            return Some(limit.clone());
173        }
174
175        // Best partial match among user overrides. Longer (more specific)
176        // patterns win for deterministic selection. There is no built-in table;
177        // a miss returns None and the caller falls back to the global default.
178        self.user_limits
179            .iter()
180            .filter(|(pattern, _)| model.contains(pattern.as_str()) || pattern.contains(model))
181            .max_by_key(|(pattern, _)| pattern.len())
182            .map(|(_, limit)| limit.clone())
183    }
184
185    /// Get limit for a model with fallback to default.
186    pub fn get_or_default(&self, model: &str) -> ModelLimit {
187        self.get(model).unwrap_or_else(default_model_limit)
188    }
189
190    /// Save current user limits to the configured file.
191    ///
192    /// No-op when the registry has no `config_path`.
193    pub async fn save_user_config(&self) -> std::io::Result<()> {
194        let Some(path) = self.config_path.clone() else {
195            return Ok(());
196        };
197
198        // Ensure parent directory exists
199        if let Some(parent) = path.parent() {
200            tokio::fs::create_dir_all(parent).await?;
201        }
202
203        let limits: Vec<&ModelLimit> = self.user_limits.values().collect();
204        let content = serde_json::to_string_pretty(&limits)
205            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
206        tokio::fs::write(&path, content).await?;
207
208        Ok(())
209    }
210
211    /// List all user-defined limits.
212    pub fn list_user_limits(&self) -> Vec<&ModelLimit> {
213        self.user_limits.values().collect()
214    }
215}
216
217impl Default for ModelLimitsRegistry {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223/// Get the default configuration file path.
224///
225/// Returns `{bamboo_data_dir}/model_limits.json`, given the Bamboo data dir.
226///
227/// The caller supplies the base directory so this crate stays free of any
228/// infrastructure/filesystem-config dependency.
229pub fn get_default_config_path(bamboo_dir: &std::path::Path) -> PathBuf {
230    bamboo_dir.join("model_limits.json")
231}
232
233/// Load user model limits from the unified `config.json` `model_limits` value.
234///
235/// The caller extracts the raw `model_limits` JSON value (e.g. from
236/// `config.extra.get("model_limits")`) and passes it here, keeping this crate
237/// independent of the concrete `Config` type.
238///
239/// Returns:
240/// - `Ok(None)` when `model_limits` is absent.
241/// - `Ok(Some(vec))` when present and valid (including empty array).
242/// - `Err(...)` when present but not a valid `Vec<ModelLimit>`.
243pub fn load_model_limits_from_unified_config(
244    raw_limits: Option<&Value>,
245) -> Result<Option<Vec<ModelLimit>>, String> {
246    let Some(raw_limits) = raw_limits else {
247        return Ok(None);
248    };
249
250    if raw_limits.is_null() {
251        return Ok(Some(Vec::new()));
252    }
253
254    match raw_limits {
255        Value::Array(_) => serde_json::from_value::<Vec<ModelLimit>>(raw_limits.clone())
256            .map(Some)
257            .map_err(|error| format!("invalid config.model_limits format: {error}")),
258        _ => Err("invalid config.model_limits format: expected array".to_string()),
259    }
260}
261
262/// Create a token budget for a specific model.
263///
264/// This is a convenience function that creates a budget with appropriate defaults.
265pub fn create_budget_for_model(model: &str, strategy: crate::BudgetStrategy) -> crate::TokenBudget {
266    let registry = ModelLimitsRegistry::default();
267    let limit = registry.get_or_default(model);
268
269    crate::TokenBudget {
270        max_context_tokens: limit.max_context_tokens,
271        max_output_tokens: limit.get_max_output_tokens(),
272        strategy,
273        safety_margin: limit.get_safety_margin(),
274        compression_trigger_percent: 85, // legacy — only used when working_reserve_tokens == 0
275        compression_target_percent: 45,
276        working_reserve_tokens: 50_000,
277        fallback_trigger_percent: 75,
278        prompt_cache_min_tool_output_chars: 1_200,
279        prompt_cache_head_chars: 280,
280        prompt_cache_tail_chars: 180,
281        prompt_cache_recent_user_turns: 2,
282        prompt_cache_recent_tool_chains: 2,
283        max_tool_output_tokens: 0,
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn default_limit_is_200k_64k() {
293        let limit = default_model_limit();
294        assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
295        assert_eq!(limit.max_context_tokens, 200_000);
296        assert_eq!(limit.get_max_output_tokens(), 64_000);
297    }
298
299    #[test]
300    fn is_default_limit_detects_no_op_overrides() {
301        // A row pinned to exactly the default values (any pattern) is a no-op.
302        let mut noop = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
303        noop.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
304        assert!(is_default_limit(&noop));
305
306        // The synthesized global default is itself a no-op override.
307        assert!(is_default_limit(&default_model_limit()));
308
309        // A different context window is a real override.
310        let mut smaller = ModelLimit::new("gpt-4o", 128_000);
311        smaller.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
312        assert!(!is_default_limit(&smaller));
313
314        // An explicit safety margin is a real override even at default size.
315        let mut custom_margin = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
316        custom_margin.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
317        custom_margin.safety_margin = Some(500);
318        assert!(!is_default_limit(&custom_margin));
319    }
320
321    #[test]
322    fn registry_returns_none_for_unknown_without_overrides() {
323        // No built-in table: an unknown model with no user override has no match.
324        let registry = ModelLimitsRegistry::new();
325        assert!(registry.get("gpt-5.2-codex").is_none());
326        assert!(registry.get("some-brand-new-model").is_none());
327    }
328
329    #[test]
330    fn registry_returns_default_for_unknown() {
331        let registry = ModelLimitsRegistry::new();
332        let limit = registry.get_or_default("unknown-model-xyz");
333        assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
334        assert_eq!(limit.max_context_tokens, 200_000);
335        assert_eq!(limit.get_max_output_tokens(), 64_000);
336    }
337
338    #[test]
339    fn user_override_exact_match_wins() {
340        let mut registry = ModelLimitsRegistry::new();
341        registry.add_limit(ModelLimit::new("gpt-5.2-codex", 64_000)); // Override with smaller limit
342
343        let limit = registry
344            .get("gpt-5.2-codex")
345            .expect("Should find overridden limit");
346        assert_eq!(limit.max_context_tokens, 64_000);
347    }
348
349    #[test]
350    fn user_override_partial_match_longest_wins() {
351        let mut registry = ModelLimitsRegistry::new();
352        registry.add_limit(ModelLimit::new("gpt-5", 111_000));
353        registry.add_limit(ModelLimit::new("gpt-5.2-codex", 222_000));
354
355        // "gpt-5.2-codex-preview" contains both patterns; the longest wins.
356        let limit = registry
357            .get("gpt-5.2-codex-preview")
358            .expect("Should partial-match a user override");
359        assert_eq!(limit.max_context_tokens, 222_000);
360    }
361
362    #[test]
363    fn model_limit_calculates_default_output_tokens() {
364        let limit = ModelLimit::new("test", 100_000);
365        // Default is min(max_context / 4, 4096) = min(25000, 4096) = 4096
366        assert_eq!(limit.get_max_output_tokens(), 4096);
367    }
368
369    #[test]
370    fn model_limit_uses_custom_output_tokens() {
371        let mut limit = ModelLimit::new("test", 100_000);
372        limit.max_output_tokens = Some(8192);
373        assert_eq!(limit.get_max_output_tokens(), 8192);
374    }
375
376    #[test]
377    fn model_limit_calculates_small_context_output() {
378        let limit = ModelLimit::new("test", 8_192);
379        // Default is min(8192 / 4, 4096) = 2048
380        assert_eq!(limit.get_max_output_tokens(), 2048);
381    }
382
383    #[test]
384    fn unified_config_loader_returns_none_when_absent() {
385        let loaded = load_model_limits_from_unified_config(None).expect("should parse");
386        assert!(loaded.is_none());
387    }
388
389    #[test]
390    fn unified_config_loader_reads_valid_model_limits() {
391        let raw = serde_json::json!([
392            {
393                "model_pattern": "gpt-5.2-codex",
394                "max_context_tokens": 64000,
395                "max_output_tokens": 2048,
396                "safety_margin": 512
397            }
398        ]);
399
400        let loaded = load_model_limits_from_unified_config(Some(&raw))
401            .expect("should parse")
402            .expect("should exist");
403        assert_eq!(loaded.len(), 1);
404        assert_eq!(loaded[0].model_pattern, "gpt-5.2-codex");
405        assert_eq!(loaded[0].max_context_tokens, 64_000);
406        assert_eq!(loaded[0].max_output_tokens, Some(2048));
407        assert_eq!(loaded[0].safety_margin, Some(512));
408    }
409
410    #[test]
411    fn unified_config_loader_errors_on_invalid_shape() {
412        let raw = serde_json::json!({"unexpected": true});
413        let error = load_model_limits_from_unified_config(Some(&raw)).expect_err("should error");
414        assert!(error.contains("expected array"));
415    }
416
417    #[test]
418    fn safety_margin_scales_with_context_window() {
419        // Small context → floor at DEFAULT_SAFETY_MARGIN (1000)
420        let small = ModelLimit::new("test", 8_192);
421        assert_eq!(small.get_safety_margin(), 1000);
422
423        // Medium context → proportional
424        let medium = ModelLimit::new("test", 200_000);
425        assert_eq!(medium.get_safety_margin(), 2000);
426
427        // Large context → proportional
428        let large = ModelLimit::new("test", 1_050_000);
429        assert_eq!(large.get_safety_margin(), 10_500);
430
431        // Explicit override takes precedence
432        let mut custom = ModelLimit::new("test", 200_000);
433        custom.safety_margin = Some(500);
434        assert_eq!(custom.get_safety_margin(), 500);
435    }
436
437    #[tokio::test]
438    async fn persisted_overrides_drive_runtime_resolution() {
439        // Integration: a `model_limits.json` on disk → registry load → resolve.
440        let dir = tempfile::tempdir().expect("tempdir");
441        let path = dir.path().join("model_limits.json");
442        tokio::fs::write(
443            &path,
444            r#"[{"model_pattern":"gpt-4o","max_context_tokens":128000,"max_output_tokens":16384}]"#,
445        )
446        .await
447        .expect("seed overrides");
448
449        let mut registry = ModelLimitsRegistry::with_config_path(path);
450        registry.load_user_config().await.expect("load user config");
451
452        // Persisted override is applied at runtime.
453        let gpt4o = registry.get("gpt-4o").expect("override present");
454        assert_eq!(gpt4o.max_context_tokens, 128_000);
455        assert_eq!(gpt4o.get_max_output_tokens(), 16_384);
456
457        // Unknown model with no override falls back to the single global default.
458        let unknown = registry.get_or_default("brand-new-frontier-model");
459        assert_eq!(unknown.model_pattern, DEFAULT_MODEL_PATTERN);
460        assert_eq!(unknown.max_context_tokens, 200_000);
461        assert_eq!(unknown.get_max_output_tokens(), 64_000);
462    }
463
464    #[test]
465    fn create_budget_for_model_uses_global_default_for_any_model() {
466        let budget = create_budget_for_model("anything-at-all", crate::BudgetStrategy::default());
467        assert_eq!(budget.max_context_tokens, 200_000);
468        assert_eq!(budget.max_output_tokens, 64_000);
469    }
470}