ai_agent/utils/
context.rs1use crate::constants::env::ai;
3use once_cell::sync::Lazy;
4use regex::Regex;
5
6pub const MODEL_CONTEXT_WINDOW_DEFAULT: u64 = 200_000;
7pub const COMPACT_MAX_OUTPUT_TOKENS: u64 = 20_000;
8
9const MAX_OUTPUT_TOKENS_DEFAULT: u64 = 32_000;
10const MAX_OUTPUT_TOKENS_UPPER_LIMIT: u64 = 64_000;
11
12pub const CAPPED_DEFAULT_MAX_TOKENS: u64 = 8_000;
13pub const ESCALATED_MAX_TOKENS: u64 = 64_000;
14
15static DISABLE_1M_CONTEXT_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\[1m\]").unwrap());
16
17fn is_env_truthy(value: &str) -> bool {
18 value == "1" || value.to_lowercase() == "true" || value.to_lowercase() == "yes"
19}
20
21fn is_1m_context_disabled() -> bool {
22 std::env::var(ai::CODE_DISABLE_1M_CONTEXT)
23 .map(|v| is_env_truthy(&v))
24 .unwrap_or(false)
25}
26
27fn has_1m_context(model: &str) -> bool {
28 if is_1m_context_disabled() {
29 return false;
30 }
31 DISABLE_1M_CONTEXT_REGEX.is_match(model)
32}
33
34fn get_user_type() -> String {
35 std::env::var(ai::USER_TYPE).unwrap_or_default()
36}
37
38pub fn model_supports_1m(model: &str) -> bool {
39 if is_1m_context_disabled() {
40 return false;
41 }
42 let canonical = get_canonical_name(model);
43 canonical.contains("claude-sonnet-4") || canonical.contains("opus-4-6")
44}
45
46fn get_canonical_name(model: &str) -> String {
47 let m = model.to_lowercase();
48 if m.contains("sonnet-4-20250514") || m.contains("sonnet-4-6") {
49 return "claude-sonnet-4-6".to_string();
50 }
51 if m.contains("sonnet-4-20250507") || m.contains("sonnet-4-5") {
52 return "claude-sonnet-4-5".to_string();
53 }
54 if m.contains("sonnet-4") {
55 return "claude-sonnet-4".to_string();
56 }
57 if m.contains("opus-4-20250514") || m.contains("opus-4-6") {
58 return "claude-opus-4-6".to_string();
59 }
60 if m.contains("opus-4-20250501") || m.contains("opus-4-5") {
61 return "claude-opus-4-5".to_string();
62 }
63 if m.contains("opus-4-2") || m.contains("opus-4-1") {
64 return "claude-opus-4-1".to_string();
65 }
66 if m.contains("opus-4") {
67 return "claude-opus-4".to_string();
68 }
69 if m.contains("haiku-4") {
70 return "claude-haiku-4".to_string();
71 }
72 if m.contains("3-7-sonnet") {
73 return "claude-3-7-sonnet".to_string();
74 }
75 if m.contains("3-5-sonnet") || m.contains("sonnet-3-5") {
76 return "claude-3-5-sonnet".to_string();
77 }
78 if m.contains("3-5-haiku") || m.contains("haiku-3-5") {
79 return "claude-3-5-haiku".to_string();
80 }
81 if m.contains("3-opus") || m.contains("opus-3") {
82 return "claude-3-opus".to_string();
83 }
84 if m.contains("3-sonnet") || m.contains("sonnet-3") {
85 return "claude-3-sonnet".to_string();
86 }
87 if m.contains("3-haiku") || m.contains("haiku-3") {
88 return "claude-3-haiku".to_string();
89 }
90 m
91}
92
93fn get_model_capability(model: &str) -> Option<ModelCapability> {
94 None
95}
96
97#[derive(Debug, Clone)]
98pub struct ModelCapability {
99 pub max_input_tokens: Option<u64>,
100 pub max_tokens: Option<u64>,
101}
102
103const CONTEXT_1M_BETA_HEADER: &str = "context-1m-2025-08-07";
104
105pub fn get_context_window_for_model(model: &str, betas: Option<&[String]>) -> u64 {
106 if get_user_type() == "ant" {
107 if let Ok(override_val) = std::env::var(ai::CODE_MAX_CONTEXT_TOKENS) {
108 if let Ok(override_num) = override_val.parse::<u64>() {
109 if override_num > 0 {
110 return override_num;
111 }
112 }
113 }
114 }
115
116 if has_1m_context(model) {
117 return 1_000_000;
118 }
119
120 if let Some(cap) = get_model_capability(model) {
121 if let Some(max_input) = cap.max_input_tokens {
122 if max_input >= 100_000 {
123 if max_input > MODEL_CONTEXT_WINDOW_DEFAULT && is_1m_context_disabled() {
124 return MODEL_CONTEXT_WINDOW_DEFAULT;
125 }
126 return max_input;
127 }
128 }
129 }
130
131 if let Some(betas_arr) = betas {
132 if betas_arr.iter().any(|b| b == CONTEXT_1M_BETA_HEADER) && model_supports_1m(model) {
133 return 1_000_000;
134 }
135 }
136
137 if get_sonnet_1m_exp_treatment_enabled(model) {
138 return 1_000_000;
139 }
140
141 MODEL_CONTEXT_WINDOW_DEFAULT
142}
143
144fn get_global_config() -> GlobalConfig {
145 GlobalConfig::default()
146}
147
148#[derive(Debug, Default)]
149struct GlobalConfig {
150 client_data_cache: Option<std::collections::HashMap<String, String>>,
151}
152
153fn get_sonnet_1m_exp_treatment_enabled(model: &str) -> bool {
154 if is_1m_context_disabled() {
155 return false;
156 }
157 if has_1m_context(model) {
158 return false;
159 }
160 let canonical = get_canonical_name(model);
161 if !canonical.contains("sonnet-4-6") {
162 return false;
163 }
164 let config = get_global_config();
165 config
166 .client_data_cache
167 .as_ref()
168 .map(|c| {
169 c.get("coral_reef_sonnet")
170 .map(|v| v == "true")
171 .unwrap_or(false)
172 })
173 .unwrap_or(false)
174}
175
176pub fn calculate_context_percentages(
177 current_usage: Option<&ContextUsage>,
178 context_window_size: u64,
179) -> ContextPercentages {
180 let usage = match current_usage {
181 Some(u) => u,
182 None => {
183 return ContextPercentages {
184 used: None,
185 remaining: None,
186 }
187 }
188 };
189
190 let total_input_tokens =
191 usage.input_tokens + usage.cache_creation_input_tokens + usage.cache_read_input_tokens;
192
193 let used_percentage =
194 ((total_input_tokens as f64 / context_window_size as f64) * 100.0).round() as u64;
195 let clamped_used = used_percentage.min(100).max(0);
196
197 ContextPercentages {
198 used: Some(clamped_used),
199 remaining: Some(100 - clamped_used),
200 }
201}
202
203#[derive(Debug, Clone)]
204pub struct ContextUsage {
205 pub input_tokens: u64,
206 pub cache_creation_input_tokens: u64,
207 pub cache_read_input_tokens: u64,
208}
209
210#[derive(Debug, Clone)]
211pub struct ContextPercentages {
212 pub used: Option<u64>,
213 pub remaining: Option<u64>,
214}
215
216pub fn get_model_max_output_tokens(model: &str) -> MaxOutputTokens {
217 let mut default_tokens = MAX_OUTPUT_TOKENS_DEFAULT;
218 let mut upper_limit = MAX_OUTPUT_TOKENS_UPPER_LIMIT;
219
220 let m = get_canonical_name(model);
221
222 if m.contains("opus-4-6") {
223 default_tokens = 64_000;
224 upper_limit = 128_000;
225 } else if m.contains("sonnet-4-6") {
226 default_tokens = 32_000;
227 upper_limit = 128_000;
228 } else if m.contains("opus-4-5") || m.contains("sonnet-4") || m.contains("haiku-4") {
229 default_tokens = 32_000;
230 upper_limit = 64_000;
231 } else if m.contains("opus-4-1") || m.contains("opus-4") {
232 default_tokens = 32_000;
233 upper_limit = 32_000;
234 } else if m.contains("claude-3-opus") {
235 default_tokens = 4_096;
236 upper_limit = 4_096;
237 } else if m.contains("claude-3-sonnet") {
238 default_tokens = 8_192;
239 upper_limit = 8_192;
240 } else if m.contains("claude-3-haiku") {
241 default_tokens = 4_096;
242 upper_limit = 4_096;
243 } else if m.contains("3-5-sonnet") || m.contains("3-5-haiku") {
244 default_tokens = 8_192;
245 upper_limit = 8_192;
246 } else if m.contains("3-7-sonnet") {
247 default_tokens = 32_000;
248 upper_limit = 64_000;
249 }
250
251 if let Some(cap) = get_model_capability(model) {
252 if let Some(max_tokens) = cap.max_tokens {
253 if max_tokens >= 4_096 {
254 upper_limit = max_tokens;
255 default_tokens = default_tokens.min(upper_limit);
256 }
257 }
258 }
259
260 MaxOutputTokens {
261 default: default_tokens,
262 upper_limit,
263 }
264}
265
266#[derive(Debug, Clone)]
267pub struct MaxOutputTokens {
268 pub default: u64,
269 pub upper_limit: u64,
270}
271
272pub fn get_max_thinking_tokens_for_model(model: &str) -> u64 {
273 get_model_max_output_tokens(model).upper_limit - 1
274}
275
276fn is_max_tokens_cap_enabled() -> bool {
279 true
280}
281
282pub fn get_max_output_tokens_for_model(model: &str) -> u64 {
291 use crate::constants::env::ai_code::MAX_OUTPUT_TOKENS as ENV_MAX_OUTPUT_TOKENS;
292 use crate::utils::env_validation::validate_bounded_int_env_var;
293
294 let max_output = get_model_max_output_tokens(model);
295
296 let default_tokens = if is_max_tokens_cap_enabled() {
297 max_output.default.min(CAPPED_DEFAULT_MAX_TOKENS)
298 } else {
299 max_output.default
300 };
301
302 let env_value = std::env::var(ENV_MAX_OUTPUT_TOKENS).ok();
303 let result = validate_bounded_int_env_var(
304 ENV_MAX_OUTPUT_TOKENS,
305 env_value.as_deref(),
306 default_tokens as i64,
307 max_output.upper_limit as i64,
308 );
309 result.effective as u64
310}