Skip to main content

bamboo_compression/
limits.rs

1//! Model context window limits registry.
2//!
3//! There is intentionally **no** built-in per-model table. Real per-model
4//! values come from two sources, in priority order:
5//!   1. provider runtime metadata (e.g. Copilot reports real context/output),
6//!   2. user overrides persisted in `model_limits.json`.
7//!
8//! Anything without a match falls back to a single global default
9//! (`DEFAULT_MAX_CONTEXT_TOKENS` / `DEFAULT_MAX_OUTPUT_TOKENS`). This keeps the
10//! registry from going stale as models churn — see `token_budget.rs`.
11
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::path::PathBuf;
16
17/// Sentinel pattern used for the single global fallback limit.
18pub const DEFAULT_MODEL_PATTERN: &str = "default";
19
20/// Global default context window applied to any model without a provider
21/// metadata value or a user override. 1M reflects the current mainstream
22/// range across frontier models (Claude 3.5, GPT-4o, Gemini 1.5, etc.).
23pub const DEFAULT_MAX_CONTEXT_TOKENS: u32 = 1_000_000;
24
25/// Global default maximum output tokens.
26pub const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 128_000;
27
28/// Default safety margin for token counting errors (floor; scales with context
29/// window via [`ModelLimit::get_safety_margin`]).
30pub const DEFAULT_SAFETY_MARGIN: u32 = 1000;
31
32/// Build the single global default limit (`1M` context / `128K` output).
33pub fn default_model_limit() -> ModelLimit {
34    builtin_limit(
35        DEFAULT_MODEL_PATTERN,
36        DEFAULT_MAX_CONTEXT_TOKENS,
37        DEFAULT_MAX_OUTPUT_TOKENS,
38    )
39}
40
41/// Whether a user override is a no-op — identical to the global default, so it
42/// carries no information and need not be persisted (diff-only storage).
43///
44/// The `model_pattern` is irrelevant: any model pinned to exactly the default
45/// context/output with no explicit safety margin resolves to the same budget
46/// as having no override at all.
47pub fn is_default_limit(limit: &ModelLimit) -> bool {
48    limit.max_context_tokens == DEFAULT_MAX_CONTEXT_TOKENS
49        && limit.max_output_tokens == Some(DEFAULT_MAX_OUTPUT_TOKENS)
50        && limit.safety_margin.is_none()
51}
52
53/// Model limit configuration (user-overridable).
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelLimit {
56    /// Model identifier (partial match supported, e.g., "gpt-4" matches "gpt-4o")
57    pub model_pattern: String,
58    /// Maximum context window size in tokens
59    pub max_context_tokens: u32,
60    /// Maximum output tokens (defaults to min(max_context / 4,
61    /// DEFAULT_MAX_OUTPUT_TOKENS) when unset — see [`Self::get_max_output_tokens`])
62    #[serde(default)]
63    pub max_output_tokens: Option<u32>,
64    /// Safety margin for token counting (defaults to 1000)
65    #[serde(default)]
66    pub safety_margin: Option<u32>,
67}
68
69impl ModelLimit {
70    /// Create a new model limit with defaults.
71    pub fn new(model_pattern: impl Into<String>, max_context_tokens: u32) -> Self {
72        Self {
73            model_pattern: model_pattern.into(),
74            max_context_tokens,
75            max_output_tokens: None,
76            safety_margin: None,
77        }
78    }
79
80    /// Get max output tokens with default calculation.
81    ///
82    /// When unset, derive from the context window (`max_context_tokens / 4`)
83    /// capped at the global [`DEFAULT_MAX_OUTPUT_TOKENS`]. The cap tracks the
84    /// global default rather than a hard-coded `4096`, so a user override like
85    /// `ModelLimit::new("gpt-4o", 128_000)` (no explicit `max_output_tokens`)
86    /// resolves to `min(32_000, 128_000) = 32_000` instead of collapsing to
87    /// `4096` — see issue #20, bug 4.
88    pub fn get_max_output_tokens(&self) -> u32 {
89        self.max_output_tokens
90            .unwrap_or_else(|| (self.max_context_tokens / 4).min(DEFAULT_MAX_OUTPUT_TOKENS))
91    }
92
93    /// Get safety margin, scaling proportionally with context window.
94    pub fn get_safety_margin(&self) -> u32 {
95        self.safety_margin
96            .unwrap_or_else(|| (self.max_context_tokens / 100).max(DEFAULT_SAFETY_MARGIN))
97    }
98}
99
100fn builtin_limit(pattern: &str, max_context_tokens: u32, max_output_tokens: u32) -> ModelLimit {
101    let mut limit = ModelLimit::new(pattern.to_string(), max_context_tokens);
102    limit.max_output_tokens = Some(max_output_tokens);
103    limit
104}
105
106/// Registry for model limits with built-in defaults and user overrides.
107#[derive(Debug, Clone)]
108pub struct ModelLimitsRegistry {
109    /// User-provided overrides (higher priority than built-in)
110    user_limits: HashMap<String, ModelLimit>,
111    /// Default path for user configuration file
112    config_path: Option<PathBuf>,
113}
114
115impl ModelLimitsRegistry {
116    /// Create a new registry with built-in defaults only.
117    pub fn new() -> Self {
118        Self {
119            user_limits: HashMap::new(),
120            config_path: None,
121        }
122    }
123
124    /// Create a registry with a specific config file path.
125    pub fn with_config_path(path: impl Into<PathBuf>) -> Self {
126        Self {
127            user_limits: HashMap::new(),
128            config_path: Some(path.into()),
129        }
130    }
131
132    /// Load user overrides from the registry's configured path.
133    ///
134    /// No-op when the registry was created without a `config_path` (use
135    /// [`Self::with_config_path`] — e.g. with [`get_default_config_path`] — to
136    /// point it at `{bamboo_data_dir}/model_limits.json`).
137    pub async fn load_user_config(&mut self) -> std::io::Result<()> {
138        let Some(path) = self.config_path.clone() else {
139            return Ok(());
140        };
141
142        if !path.exists() {
143            return Ok(());
144        }
145
146        let content = tokio::fs::read_to_string(&path).await?;
147        let limits: Vec<ModelLimit> = serde_json::from_str(&content)
148            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
149
150        for limit in limits {
151            self.user_limits.insert(limit.model_pattern.clone(), limit);
152        }
153
154        tracing::info!(
155            "Loaded {} user model limits from {:?}",
156            self.user_limits.len(),
157            path
158        );
159        Ok(())
160    }
161
162    /// Add a user limit override.
163    pub fn add_limit(&mut self, limit: ModelLimit) {
164        self.user_limits.insert(limit.model_pattern.clone(), limit);
165    }
166
167    /// Get limit for a model, with user overrides taking priority.
168    ///
169    /// Returns `None` if no matching limit is found.
170    ///
171    /// # Matching Strategy
172    /// 1. Exact match (highest priority)
173    /// 2. Model contains pattern (e.g., "gpt-4o-mini" contains "gpt-4o")
174    ///
175    /// For partial matches, the longest (most specific) pattern wins.
176    ///
177    /// Only the `model.contains(pattern)` direction is correct: the configured
178    /// pattern must be a substring of the runtime model id. The reverse
179    /// (`pattern.contains(model)`) was a bug (#20, bug 3) — it let a short model
180    /// id like `"gpt-4o"` match a longer, unrelated pattern like `"gpt-4o-mini"`
181    /// and inherit the wrong limit.
182    pub fn get(&self, model: &str) -> Option<ModelLimit> {
183        // Exact user override match (highest priority).
184        if let Some(limit) = self.user_limits.get(model) {
185            return Some(limit.clone());
186        }
187
188        // Best partial match among user overrides: the pattern must be a
189        // substring of the model id. Longer (more specific) patterns win for
190        // deterministic selection. There is no built-in table; a miss returns
191        // None and the caller falls back to the global default.
192        self.user_limits
193            .iter()
194            .filter(|(pattern, _)| model.contains(pattern.as_str()))
195            .max_by_key(|(pattern, _)| pattern.len())
196            .map(|(_, limit)| limit.clone())
197    }
198
199    /// Get limit for a model with fallback to default.
200    pub fn get_or_default(&self, model: &str) -> ModelLimit {
201        self.get(model).unwrap_or_else(default_model_limit)
202    }
203
204    /// Save current user limits to the configured file.
205    ///
206    /// No-op when the registry has no `config_path`.
207    pub async fn save_user_config(&self) -> std::io::Result<()> {
208        let Some(path) = self.config_path.clone() else {
209            return Ok(());
210        };
211
212        // Ensure parent directory exists
213        if let Some(parent) = path.parent() {
214            tokio::fs::create_dir_all(parent).await?;
215        }
216
217        let limits: Vec<&ModelLimit> = self.user_limits.values().collect();
218        let content = serde_json::to_string_pretty(&limits)
219            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
220        tokio::fs::write(&path, content).await?;
221
222        Ok(())
223    }
224
225    /// List all user-defined limits.
226    pub fn list_user_limits(&self) -> Vec<&ModelLimit> {
227        self.user_limits.values().collect()
228    }
229}
230
231impl Default for ModelLimitsRegistry {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237/// Get the default configuration file path.
238///
239/// Returns `{bamboo_data_dir}/model_limits.json`, given the Bamboo data dir.
240///
241/// The caller supplies the base directory so this crate stays free of any
242/// infrastructure/filesystem-config dependency.
243pub fn get_default_config_path(bamboo_dir: &std::path::Path) -> PathBuf {
244    bamboo_dir.join("model_limits.json")
245}
246
247/// Load user model limits from the unified `config.json` `model_limits` value.
248///
249/// The caller extracts the raw `model_limits` JSON value (e.g. from
250/// `config.extra.get("model_limits")`) and passes it here, keeping this crate
251/// independent of the concrete `Config` type.
252///
253/// Returns:
254/// - `Ok(None)` when `model_limits` is absent.
255/// - `Ok(Some(vec))` when present and valid (including empty array).
256/// - `Err(...)` when present but not a valid `Vec<ModelLimit>`.
257pub fn load_model_limits_from_unified_config(
258    raw_limits: Option<&Value>,
259) -> Result<Option<Vec<ModelLimit>>, String> {
260    let Some(raw_limits) = raw_limits else {
261        return Ok(None);
262    };
263
264    if raw_limits.is_null() {
265        return Ok(Some(Vec::new()));
266    }
267
268    match raw_limits {
269        Value::Array(_) => serde_json::from_value::<Vec<ModelLimit>>(raw_limits.clone())
270            .map(Some)
271            .map_err(|error| format!("invalid config.model_limits format: {error}")),
272        _ => Err("invalid config.model_limits format: expected array".to_string()),
273    }
274}
275
276/// Create a token budget for a specific model, resolving its limit from the
277/// supplied `registry` (with user overrides loaded) and falling back to the
278/// global default when there is no match.
279///
280/// The registry is a required parameter on purpose: a previous version built a
281/// fresh empty `ModelLimitsRegistry::default()` internally, which silently
282/// discarded every user override from `model_limits.json` and always returned
283/// the global default (#20, bug 2). Callers must pass a registry they have
284/// loaded user overrides into (or [`ModelLimitsRegistry::new`] when they
285/// genuinely want the global default).
286pub fn create_budget_for_model(
287    model: &str,
288    strategy: crate::BudgetStrategy,
289    registry: &ModelLimitsRegistry,
290) -> crate::TokenBudget {
291    let limit = registry.get_or_default(model);
292
293    crate::TokenBudget {
294        max_context_tokens: limit.max_context_tokens,
295        max_output_tokens: limit.get_max_output_tokens(),
296        strategy,
297        safety_margin: limit.get_safety_margin(),
298        compression_trigger_percent: 85, // legacy — only used when working_reserve_tokens == 0
299        compression_target_percent: 45,
300        working_reserve_tokens: 50_000,
301        fallback_trigger_percent: 75,
302        prompt_cache_min_tool_output_chars: 1_200,
303        prompt_cache_head_chars: 280,
304        prompt_cache_tail_chars: 180,
305        prompt_cache_recent_user_turns: 2,
306        prompt_cache_recent_tool_chains: 2,
307        max_tool_output_tokens: 0,
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn default_limit_is_1m_128k() {
317        let limit = default_model_limit();
318        assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
319        assert_eq!(limit.max_context_tokens, 1_000_000);
320        assert_eq!(limit.get_max_output_tokens(), 128_000);
321    }
322
323    #[test]
324    fn is_default_limit_detects_no_op_overrides() {
325        // A row pinned to exactly the default values (any pattern) is a no-op.
326        let mut noop = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
327        noop.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
328        assert!(is_default_limit(&noop));
329
330        // The synthesized global default is itself a no-op override.
331        assert!(is_default_limit(&default_model_limit()));
332
333        // A different context window is a real override.
334        let mut smaller = ModelLimit::new("gpt-4o", 128_000);
335        smaller.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
336        assert!(!is_default_limit(&smaller));
337
338        // An explicit safety margin is a real override even at default size.
339        let mut custom_margin = ModelLimit::new("gpt-4o", DEFAULT_MAX_CONTEXT_TOKENS);
340        custom_margin.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
341        custom_margin.safety_margin = Some(500);
342        assert!(!is_default_limit(&custom_margin));
343    }
344
345    #[test]
346    fn registry_returns_none_for_unknown_without_overrides() {
347        // No built-in table: an unknown model with no user override has no match.
348        let registry = ModelLimitsRegistry::new();
349        assert!(registry.get("gpt-5.2-codex").is_none());
350        assert!(registry.get("some-brand-new-model").is_none());
351    }
352
353    #[test]
354    fn registry_returns_default_for_unknown() {
355        let registry = ModelLimitsRegistry::new();
356        let limit = registry.get_or_default("unknown-model-xyz");
357        assert_eq!(limit.model_pattern, DEFAULT_MODEL_PATTERN);
358        assert_eq!(limit.max_context_tokens, 1_000_000);
359        assert_eq!(limit.get_max_output_tokens(), 128_000);
360    }
361
362    #[test]
363    fn user_override_exact_match_wins() {
364        let mut registry = ModelLimitsRegistry::new();
365        registry.add_limit(ModelLimit::new("gpt-5.2-codex", 64_000)); // Override with smaller limit
366
367        let limit = registry
368            .get("gpt-5.2-codex")
369            .expect("Should find overridden limit");
370        assert_eq!(limit.max_context_tokens, 64_000);
371    }
372
373    #[test]
374    fn user_override_partial_match_longest_wins() {
375        let mut registry = ModelLimitsRegistry::new();
376        registry.add_limit(ModelLimit::new("gpt-5", 111_000));
377        registry.add_limit(ModelLimit::new("gpt-5.2-codex", 222_000));
378
379        // "gpt-5.2-codex-preview" contains both patterns; the longest wins.
380        let limit = registry
381            .get("gpt-5.2-codex-preview")
382            .expect("Should partial-match a user override");
383        assert_eq!(limit.max_context_tokens, 222_000);
384    }
385
386    #[test]
387    fn model_limit_calculates_default_output_tokens() {
388        let limit = ModelLimit::new("test", 100_000);
389        // Default is min(max_context / 4, DEFAULT_MAX_OUTPUT_TOKENS)
390        //        = min(25_000, 128_000) = 25_000 (no longer capped at 4096, #20 bug 4)
391        assert_eq!(limit.get_max_output_tokens(), 25_000);
392    }
393
394    #[test]
395    fn user_override_without_explicit_output_is_not_capped_at_4096() {
396        // Issue #20 bug 4: a user override created with `ModelLimit::new` leaves
397        // `max_output_tokens = None`. The derived default must scale with the
398        // context window (context / 4) rather than collapsing to 4096.
399        let gpt4o = ModelLimit::new("gpt-4o", 128_000);
400        assert!(gpt4o.max_output_tokens.is_none());
401        assert_eq!(gpt4o.get_max_output_tokens(), 32_000);
402
403        // Very large context windows are still capped at the global default so a
404        // single override can't request an unbounded output budget.
405        let huge = ModelLimit::new("huge", 2_000_000);
406        assert_eq!(huge.get_max_output_tokens(), DEFAULT_MAX_OUTPUT_TOKENS);
407    }
408
409    #[test]
410    fn matching_is_directional_model_contains_pattern_only() {
411        // Issue #20 bug 3: a short model id must NOT match a longer pattern.
412        let mut registry = ModelLimitsRegistry::new();
413        registry.add_limit(ModelLimit::new("gpt-4o-mini", 128_000));
414
415        // "gpt-4o" does not contain "gpt-4o-mini", so it must NOT inherit the
416        // mini override (the old `pattern.contains(model)` direction did).
417        assert!(registry.get("gpt-4o").is_none());
418
419        // The reverse still works: a model id that contains the pattern matches.
420        let mini = registry
421            .get("gpt-4o-mini-2024")
422            .expect("model id contains the pattern");
423        assert_eq!(mini.max_context_tokens, 128_000);
424    }
425
426    #[test]
427    fn model_limit_uses_custom_output_tokens() {
428        let mut limit = ModelLimit::new("test", 100_000);
429        limit.max_output_tokens = Some(8192);
430        assert_eq!(limit.get_max_output_tokens(), 8192);
431    }
432
433    #[test]
434    fn model_limit_calculates_small_context_output() {
435        let limit = ModelLimit::new("test", 8_192);
436        // Default is min(8192 / 4, 4096) = 2048
437        assert_eq!(limit.get_max_output_tokens(), 2048);
438    }
439
440    #[test]
441    fn unified_config_loader_returns_none_when_absent() {
442        let loaded = load_model_limits_from_unified_config(None).expect("should parse");
443        assert!(loaded.is_none());
444    }
445
446    #[test]
447    fn unified_config_loader_reads_valid_model_limits() {
448        let raw = serde_json::json!([
449            {
450                "model_pattern": "gpt-5.2-codex",
451                "max_context_tokens": 64000,
452                "max_output_tokens": 2048,
453                "safety_margin": 512
454            }
455        ]);
456
457        let loaded = load_model_limits_from_unified_config(Some(&raw))
458            .expect("should parse")
459            .expect("should exist");
460        assert_eq!(loaded.len(), 1);
461        assert_eq!(loaded[0].model_pattern, "gpt-5.2-codex");
462        assert_eq!(loaded[0].max_context_tokens, 64_000);
463        assert_eq!(loaded[0].max_output_tokens, Some(2048));
464        assert_eq!(loaded[0].safety_margin, Some(512));
465    }
466
467    #[test]
468    fn unified_config_loader_errors_on_invalid_shape() {
469        let raw = serde_json::json!({"unexpected": true});
470        let error = load_model_limits_from_unified_config(Some(&raw)).expect_err("should error");
471        assert!(error.contains("expected array"));
472    }
473
474    #[test]
475    fn safety_margin_scales_with_context_window() {
476        // Small context → floor at DEFAULT_SAFETY_MARGIN (1000)
477        let small = ModelLimit::new("test", 8_192);
478        assert_eq!(small.get_safety_margin(), 1000);
479
480        // Medium context → proportional
481        let medium = ModelLimit::new("test", 200_000);
482        assert_eq!(medium.get_safety_margin(), 2000);
483
484        // Large context → proportional
485        let large = ModelLimit::new("test", 1_050_000);
486        assert_eq!(large.get_safety_margin(), 10_500);
487
488        // Explicit override takes precedence
489        let mut custom = ModelLimit::new("test", 200_000);
490        custom.safety_margin = Some(500);
491        assert_eq!(custom.get_safety_margin(), 500);
492    }
493
494    #[tokio::test]
495    async fn persisted_overrides_drive_runtime_resolution() {
496        // Integration: a `model_limits.json` on disk → registry load → resolve.
497        let dir = tempfile::tempdir().expect("tempdir");
498        let path = dir.path().join("model_limits.json");
499        tokio::fs::write(
500            &path,
501            r#"[{"model_pattern":"gpt-4o","max_context_tokens":128000,"max_output_tokens":16384}]"#,
502        )
503        .await
504        .expect("seed overrides");
505
506        let mut registry = ModelLimitsRegistry::with_config_path(path);
507        registry.load_user_config().await.expect("load user config");
508
509        // Persisted override is applied at runtime.
510        let gpt4o = registry.get("gpt-4o").expect("override present");
511        assert_eq!(gpt4o.max_context_tokens, 128_000);
512        assert_eq!(gpt4o.get_max_output_tokens(), 16_384);
513
514        // Unknown model with no override falls back to the single global default.
515        let unknown = registry.get_or_default("brand-new-frontier-model");
516        assert_eq!(unknown.model_pattern, DEFAULT_MODEL_PATTERN);
517        assert_eq!(unknown.max_context_tokens, 1_000_000);
518        assert_eq!(unknown.get_max_output_tokens(), 128_000);
519    }
520
521    #[tokio::test]
522    async fn persisted_override_drives_runtime_token_budget() {
523        // Full chain (issue #20 acceptance): set a model limit on disk →
524        // load registry → build the runtime TokenBudget → the budget reflects
525        // the user-configured context window, not a stale/default value.
526        let dir = tempfile::tempdir().expect("tempdir");
527        let path = dir.path().join("model_limits.json");
528        // Note: NO explicit max_output_tokens, exercising the bug-4 default path.
529        tokio::fs::write(
530            &path,
531            r#"[{"model_pattern":"gpt-4o","max_context_tokens":128000}]"#,
532        )
533        .await
534        .expect("seed overrides");
535
536        let mut registry = ModelLimitsRegistry::with_config_path(path);
537        registry.load_user_config().await.expect("load user config");
538
539        // The runtime budget for the configured model matches the user limit...
540        let budget = create_budget_for_model("gpt-4o", crate::BudgetStrategy::default(), &registry);
541        assert_eq!(budget.max_context_tokens, 128_000);
542        // ...and the derived output budget is context/4 (32K), not the old 4096 cap.
543        assert_eq!(budget.max_output_tokens, 32_000);
544
545        // A model that does NOT contain the pattern is unaffected (bug-3 fix):
546        // it resolves to the global default, not the gpt-4o override.
547        let other =
548            create_budget_for_model("claude-sonnet", crate::BudgetStrategy::default(), &registry);
549        assert_eq!(other.max_context_tokens, 1_000_000);
550    }
551
552    #[test]
553    fn create_budget_for_model_uses_global_default_for_unmatched_model() {
554        // An empty registry yields the global default for any model.
555        let registry = ModelLimitsRegistry::new();
556        let budget = create_budget_for_model(
557            "anything-at-all",
558            crate::BudgetStrategy::default(),
559            &registry,
560        );
561        assert_eq!(budget.max_context_tokens, 1_000_000);
562        assert_eq!(budget.max_output_tokens, 128_000);
563    }
564
565    #[test]
566    fn create_budget_for_model_honors_registry_user_overrides() {
567        // Issue #20 bug 2: the budget must reflect the user override carried by
568        // the registry, not silently fall back to the global default.
569        let mut registry = ModelLimitsRegistry::new();
570        registry.add_limit(ModelLimit::new("gpt-4o", 128_000));
571
572        let budget = create_budget_for_model("gpt-4o", crate::BudgetStrategy::default(), &registry);
573        assert_eq!(budget.max_context_tokens, 128_000);
574        // And the derived output budget is the un-capped context/4 (bug 4), not 4096.
575        assert_eq!(budget.max_output_tokens, 32_000);
576    }
577}