Skip to main content

claude_agent/client/adapter/
config.rs

1//! Provider and model configuration.
2
3use std::collections::{HashMap, HashSet};
4use std::env;
5
6use crate::client::messages::{DEFAULT_MAX_TOKENS, MIN_THINKING_BUDGET};
7
8// Anthropic API models
9pub const DEFAULT_MODEL: &str = "claude-sonnet-4-5-20250929";
10pub const DEFAULT_SMALL_MODEL: &str = "claude-haiku-4-5-20251001";
11pub const DEFAULT_REASONING_MODEL: &str = "claude-opus-4-6";
12pub const FRONTIER_MODEL: &str = DEFAULT_REASONING_MODEL;
13
14// AWS Bedrock models (using global endpoint prefix for maximum availability)
15#[cfg(feature = "aws")]
16pub const BEDROCK_MODEL: &str = "global.anthropic.claude-sonnet-4-5-20250929-v1:0";
17#[cfg(feature = "aws")]
18pub const BEDROCK_SMALL_MODEL: &str = "global.anthropic.claude-haiku-4-5-20251001-v1:0";
19#[cfg(feature = "aws")]
20pub const BEDROCK_REASONING_MODEL: &str = "global.anthropic.claude-opus-4-6-v1:0";
21
22// GCP Vertex AI models
23#[cfg(feature = "gcp")]
24pub const VERTEX_MODEL: &str = "claude-sonnet-4-5@20250929";
25#[cfg(feature = "gcp")]
26pub const VERTEX_SMALL_MODEL: &str = "claude-haiku-4-5@20251001";
27#[cfg(feature = "gcp")]
28pub const VERTEX_REASONING_MODEL: &str = "claude-opus-4-6";
29
30// Azure Foundry models
31#[cfg(feature = "azure")]
32pub const FOUNDRY_MODEL: &str = "claude-sonnet-4-5";
33#[cfg(feature = "azure")]
34pub const FOUNDRY_SMALL_MODEL: &str = "claude-haiku-4-5";
35#[cfg(feature = "azure")]
36pub const FOUNDRY_REASONING_MODEL: &str = "claude-opus-4-6";
37
38#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
39#[serde(rename_all = "lowercase")]
40pub enum ModelType {
41    #[default]
42    Primary,
43    Small,
44    Reasoning,
45}
46
47#[derive(Clone, Debug)]
48pub struct ModelConfig {
49    pub primary: String,
50    pub small: String,
51    pub reasoning: Option<String>,
52}
53
54impl ModelConfig {
55    pub fn new(primary: impl Into<String>, small: impl Into<String>) -> Self {
56        Self {
57            primary: primary.into(),
58            small: small.into(),
59            reasoning: None,
60        }
61    }
62
63    pub fn anthropic() -> Self {
64        Self::from_env_with_defaults(DEFAULT_MODEL, DEFAULT_SMALL_MODEL, DEFAULT_REASONING_MODEL)
65    }
66
67    fn from_env_with_defaults(
68        default_primary: &str,
69        default_small: &str,
70        default_reasoning: &str,
71    ) -> Self {
72        Self {
73            primary: env::var("ANTHROPIC_MODEL").unwrap_or_else(|_| default_primary.into()),
74            small: env::var("ANTHROPIC_SMALL_FAST_MODEL").unwrap_or_else(|_| default_small.into()),
75            reasoning: Some(
76                env::var("ANTHROPIC_REASONING_MODEL").unwrap_or_else(|_| default_reasoning.into()),
77            ),
78        }
79    }
80
81    #[cfg(feature = "aws")]
82    pub fn bedrock() -> Self {
83        Self::from_env_with_defaults(BEDROCK_MODEL, BEDROCK_SMALL_MODEL, BEDROCK_REASONING_MODEL)
84    }
85
86    #[cfg(feature = "gcp")]
87    pub fn vertex() -> Self {
88        Self::from_env_with_defaults(VERTEX_MODEL, VERTEX_SMALL_MODEL, VERTEX_REASONING_MODEL)
89    }
90
91    #[cfg(feature = "azure")]
92    pub fn foundry() -> Self {
93        Self::from_env_with_defaults(FOUNDRY_MODEL, FOUNDRY_SMALL_MODEL, FOUNDRY_REASONING_MODEL)
94    }
95
96    pub fn primary(mut self, model: impl Into<String>) -> Self {
97        self.primary = model.into();
98        self
99    }
100
101    pub fn small(mut self, model: impl Into<String>) -> Self {
102        self.small = model.into();
103        self
104    }
105
106    pub fn reasoning(mut self, model: impl Into<String>) -> Self {
107        self.reasoning = Some(model.into());
108        self
109    }
110
111    pub fn get(&self, model_type: ModelType) -> &str {
112        match model_type {
113            ModelType::Primary => &self.primary,
114            ModelType::Small => &self.small,
115            ModelType::Reasoning => self.reasoning.as_deref().unwrap_or(&self.primary),
116        }
117    }
118
119    pub fn resolve_alias<'a>(&'a self, alias: &'a str) -> &'a str {
120        match alias {
121            "sonnet" => &self.primary,
122            "haiku" => &self.small,
123            "opus" => self.reasoning.as_deref().unwrap_or(&self.primary),
124            other => other,
125        }
126    }
127}
128
129impl Default for ModelConfig {
130    fn default() -> Self {
131        Self::anthropic()
132    }
133}
134
135#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
136pub enum BetaFeature {
137    InterleavedThinking,
138    ContextManagement,
139    StructuredOutputs,
140    PromptCaching,
141    MaxTokens128k,
142    CodeExecution,
143    Mcp,
144    WebSearch,
145    WebFetch,
146    OAuth,
147    FilesApi,
148    Effort,
149    /// 1M token context window (for Sonnet 4.5 on Bedrock/Vertex).
150    Context1M,
151    /// Tool search for progressive disclosure of MCP tools.
152    AdvancedToolUse,
153}
154
155impl BetaFeature {
156    const FEATURES: &'static [(BetaFeature, &'static str)] = &[
157        (Self::InterleavedThinking, "interleaved-thinking-2025-05-14"),
158        (Self::ContextManagement, "context-management-2025-06-27"),
159        (Self::StructuredOutputs, "structured-outputs-2025-11-13"),
160        (Self::PromptCaching, "prompt-caching-2024-07-31"),
161        (Self::MaxTokens128k, "max-tokens-3-5-sonnet-2024-07-15"),
162        (Self::CodeExecution, "code-execution-2025-01-24"),
163        (Self::Mcp, "mcp-2025-04-08"),
164        (Self::WebSearch, "web-search-2025-03-05"),
165        (Self::WebFetch, "web-fetch-2025-09-10"),
166        (Self::OAuth, "oauth-2025-04-20"),
167        (Self::FilesApi, "files-api-2025-04-14"),
168        (Self::Effort, "effort-2025-11-24"),
169        (Self::Context1M, "context-1m-2025-08-07"),
170        (Self::AdvancedToolUse, "advanced-tool-use-2025-11-20"),
171    ];
172
173    pub fn header_value(&self) -> &'static str {
174        Self::FEATURES
175            .iter()
176            .find(|(f, _)| f == self)
177            .map(|(_, v)| *v)
178            .expect("all variants covered in FEATURES")
179    }
180
181    fn from_header(value: &str) -> Option<Self> {
182        Self::FEATURES
183            .iter()
184            .find(|(_, v)| *v == value)
185            .map(|(f, _)| *f)
186    }
187
188    pub fn all() -> impl Iterator<Item = BetaFeature> {
189        Self::FEATURES.iter().map(|(f, _)| *f)
190    }
191}
192
193#[derive(Clone, Debug, Default)]
194pub struct BetaConfig {
195    features: HashSet<BetaFeature>,
196    custom: Vec<String>,
197}
198
199impl BetaConfig {
200    pub fn new() -> Self {
201        Self::default()
202    }
203
204    pub fn all() -> Self {
205        Self {
206            features: BetaFeature::all().collect(),
207            custom: Vec::new(),
208        }
209    }
210
211    pub fn feature(mut self, feature: BetaFeature) -> Self {
212        self.features.insert(feature);
213        self
214    }
215
216    pub fn custom(mut self, flag: impl Into<String>) -> Self {
217        self.custom.push(flag.into());
218        self
219    }
220
221    pub fn add(&mut self, feature: BetaFeature) {
222        self.features.insert(feature);
223    }
224
225    pub fn add_custom(&mut self, flag: impl Into<String>) {
226        self.custom.push(flag.into());
227    }
228
229    pub fn from_env() -> Self {
230        let mut config = Self::new();
231
232        if let Ok(flags) = env::var("ANTHROPIC_BETA_FLAGS") {
233            for flag in flags.split(',').map(str::trim).filter(|s| !s.is_empty()) {
234                if let Some(feature) = BetaFeature::from_header(flag) {
235                    config.features.insert(feature);
236                } else {
237                    config.custom.push(flag.to_string());
238                }
239            }
240        }
241
242        config
243    }
244
245    pub fn header_value(&self) -> Option<String> {
246        let mut flags: Vec<&str> = self.features.iter().map(|f| f.header_value()).collect();
247        flags.sort();
248
249        for custom in &self.custom {
250            if !flags.contains(&custom.as_str()) {
251                flags.push(custom);
252            }
253        }
254
255        if flags.is_empty() {
256            None
257        } else {
258            Some(flags.join(","))
259        }
260    }
261
262    pub fn is_empty(&self) -> bool {
263        self.features.is_empty() && self.custom.is_empty()
264    }
265
266    pub fn has(&self, feature: BetaFeature) -> bool {
267        self.features.contains(&feature)
268    }
269}
270
271#[derive(Clone, Debug)]
272pub struct ProviderConfig {
273    pub models: ModelConfig,
274    pub max_tokens: u32,
275    pub thinking_budget: Option<u32>,
276    pub enable_caching: bool,
277    pub api_version: String,
278    pub beta: BetaConfig,
279    pub extra_headers: HashMap<String, String>,
280}
281
282impl ProviderConfig {
283    pub fn new(models: ModelConfig) -> Self {
284        Self {
285            models,
286            max_tokens: DEFAULT_MAX_TOKENS,
287            thinking_budget: None,
288            enable_caching: !env::var("DISABLE_PROMPT_CACHING")
289                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
290                .unwrap_or(false),
291            api_version: "2023-06-01".into(),
292            beta: BetaConfig::from_env(),
293            extra_headers: HashMap::new(),
294        }
295    }
296
297    pub fn max_tokens(mut self, tokens: u32) -> Self {
298        self.max_tokens = tokens;
299        if tokens > DEFAULT_MAX_TOKENS {
300            self.beta.add(BetaFeature::MaxTokens128k);
301        }
302        self
303    }
304
305    pub fn thinking(mut self, budget: u32) -> Self {
306        self.thinking_budget = Some(budget.max(MIN_THINKING_BUDGET));
307        self.beta.add(BetaFeature::InterleavedThinking);
308        self
309    }
310
311    pub fn disable_caching(mut self) -> Self {
312        self.enable_caching = false;
313        self
314    }
315
316    pub fn api_version(mut self, version: impl Into<String>) -> Self {
317        self.api_version = version.into();
318        self
319    }
320
321    pub fn beta(mut self, feature: BetaFeature) -> Self {
322        self.beta.add(feature);
323        self
324    }
325
326    pub fn beta_config(mut self, config: BetaConfig) -> Self {
327        self.beta = config;
328        self
329    }
330
331    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
332        self.extra_headers.insert(key.into(), value.into());
333        self
334    }
335
336    pub fn requires_128k_beta(&self) -> bool {
337        self.max_tokens > DEFAULT_MAX_TOKENS
338    }
339}
340
341impl Default for ProviderConfig {
342    fn default() -> Self {
343        Self::new(ModelConfig::default())
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_model_config_get() {
353        let config = ModelConfig::anthropic();
354        assert!(config.get(ModelType::Primary).contains("sonnet"));
355        assert!(config.get(ModelType::Small).contains("haiku"));
356        assert!(config.get(ModelType::Reasoning).contains("opus"));
357    }
358
359    #[test]
360    fn test_provider_config_default_max_tokens() {
361        let config = ProviderConfig::default();
362        assert_eq!(config.max_tokens, DEFAULT_MAX_TOKENS);
363        assert!(!config.requires_128k_beta());
364    }
365
366    #[test]
367    fn test_provider_config_builder() {
368        let config = ProviderConfig::new(ModelConfig::anthropic())
369            .max_tokens(16384)
370            .thinking(10000)
371            .disable_caching();
372
373        assert_eq!(config.max_tokens, 16384);
374        assert_eq!(config.thinking_budget, Some(10000));
375        assert!(!config.enable_caching);
376        assert!(config.requires_128k_beta());
377        assert!(config.beta.has(BetaFeature::MaxTokens128k));
378        assert!(config.beta.has(BetaFeature::InterleavedThinking));
379    }
380
381    #[test]
382    fn test_provider_config_auto_128k_beta() {
383        let config = ProviderConfig::default().max_tokens(DEFAULT_MAX_TOKENS);
384        assert!(!config.beta.has(BetaFeature::MaxTokens128k));
385
386        let config = ProviderConfig::default().max_tokens(DEFAULT_MAX_TOKENS + 1);
387        assert!(config.beta.has(BetaFeature::MaxTokens128k));
388    }
389
390    #[test]
391    fn test_provider_config_thinking_auto_beta() {
392        let config = ProviderConfig::default().thinking(5000);
393        assert!(config.beta.has(BetaFeature::InterleavedThinking));
394        assert_eq!(config.thinking_budget, Some(5000));
395    }
396
397    #[test]
398    fn test_provider_config_thinking_min_budget() {
399        let config = ProviderConfig::default().thinking(500);
400        assert_eq!(config.thinking_budget, Some(MIN_THINKING_BUDGET));
401    }
402
403    #[test]
404    fn test_beta_feature_header() {
405        assert_eq!(
406            BetaFeature::InterleavedThinking.header_value(),
407            "interleaved-thinking-2025-05-14"
408        );
409        assert_eq!(
410            BetaFeature::MaxTokens128k.header_value(),
411            "max-tokens-3-5-sonnet-2024-07-15"
412        );
413    }
414
415    #[test]
416    fn test_beta_config_with_features() {
417        let config = BetaConfig::new()
418            .feature(BetaFeature::InterleavedThinking)
419            .feature(BetaFeature::ContextManagement);
420
421        assert!(config.has(BetaFeature::InterleavedThinking));
422        assert!(config.has(BetaFeature::ContextManagement));
423        assert!(!config.has(BetaFeature::MaxTokens128k));
424
425        let header = config.header_value().unwrap();
426        assert!(header.contains("interleaved-thinking"));
427        assert!(header.contains("context-management"));
428    }
429
430    #[test]
431    fn test_beta_config_custom() {
432        let config = BetaConfig::new()
433            .feature(BetaFeature::InterleavedThinking)
434            .custom("new-feature-2026-01-01");
435
436        let header = config.header_value().unwrap();
437        assert!(header.contains("interleaved-thinking"));
438        assert!(header.contains("new-feature-2026-01-01"));
439    }
440
441    #[test]
442    fn test_beta_config_all() {
443        let config = BetaConfig::all();
444        assert!(config.has(BetaFeature::InterleavedThinking));
445        assert!(config.has(BetaFeature::ContextManagement));
446        assert!(config.has(BetaFeature::MaxTokens128k));
447    }
448
449    #[test]
450    fn test_provider_config_beta() {
451        let config = ProviderConfig::default()
452            .beta(BetaFeature::InterleavedThinking)
453            .beta_config(
454                BetaConfig::new()
455                    .feature(BetaFeature::InterleavedThinking)
456                    .custom("experimental-feature"),
457            );
458
459        assert!(config.beta.has(BetaFeature::InterleavedThinking));
460        let header = config.beta.header_value().unwrap();
461        assert!(header.contains("experimental-feature"));
462    }
463
464    #[test]
465    fn test_beta_config_empty() {
466        let config = BetaConfig::new();
467        assert!(config.is_empty());
468        assert!(config.header_value().is_none());
469    }
470
471    #[test]
472    fn test_provider_config_extra_headers() {
473        let config = ProviderConfig::default()
474            .header("x-custom", "value")
475            .header("x-another", "test");
476
477        assert_eq!(config.extra_headers.get("x-custom"), Some(&"value".into()));
478        assert_eq!(config.extra_headers.get("x-another"), Some(&"test".into()));
479    }
480}