Skip to main content

a3s_code_core/config/
provider.rs

1use crate::llm::LlmConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5// ============================================================================
6// Provider Configuration
7// ============================================================================
8
9/// Model cost information (per million tokens)
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11#[serde(rename_all = "camelCase")]
12pub struct ModelCost {
13    /// Input token cost
14    #[serde(default)]
15    pub input: f64,
16    /// Output token cost
17    #[serde(default)]
18    pub output: f64,
19    /// Cache read cost
20    #[serde(default)]
21    pub cache_read: f64,
22    /// Cache write cost
23    #[serde(default)]
24    pub cache_write: f64,
25}
26
27/// Model limits
28#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct ModelLimit {
30    /// Maximum context tokens
31    #[serde(default)]
32    pub context: u32,
33    /// Maximum output tokens
34    #[serde(default)]
35    pub output: u32,
36}
37
38/// Model modalities (input/output types)
39#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct ModelModalities {
41    /// Supported input types
42    #[serde(default)]
43    pub input: Vec<String>,
44    /// Supported output types
45    #[serde(default)]
46    pub output: Vec<String>,
47}
48
49/// Model configuration
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct ModelConfig {
53    /// Model ID (e.g., "claude-sonnet-4-20250514")
54    pub id: String,
55    /// Display name
56    #[serde(default)]
57    pub name: String,
58    /// Model family (e.g., "claude-sonnet")
59    #[serde(default)]
60    pub family: String,
61    /// Per-model API key override
62    #[serde(default)]
63    pub api_key: Option<String>,
64    /// Per-model base URL override
65    #[serde(default)]
66    pub base_url: Option<String>,
67    /// Static HTTP headers for this model
68    #[serde(default)]
69    pub headers: HashMap<String, String>,
70    /// Header name to receive the runtime session ID
71    #[serde(default)]
72    pub session_id_header: Option<String>,
73    /// Supports file attachments
74    #[serde(default)]
75    pub attachment: bool,
76    /// Supports reasoning/thinking
77    #[serde(default)]
78    pub reasoning: bool,
79    /// Supports tool calling
80    #[serde(default = "default_true")]
81    pub tool_call: bool,
82    /// Supports temperature setting
83    #[serde(default = "default_true")]
84    pub temperature: bool,
85    /// Release date
86    #[serde(default)]
87    pub release_date: Option<String>,
88    /// Input/output modalities
89    #[serde(default)]
90    pub modalities: ModelModalities,
91    /// Cost information
92    #[serde(default)]
93    pub cost: ModelCost,
94    /// Token limits
95    #[serde(default)]
96    pub limit: ModelLimit,
97}
98
99pub(crate) fn default_true() -> bool {
100    true
101}
102
103/// Provider configuration
104#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(rename_all = "camelCase")]
106pub struct ProviderConfig {
107    /// Provider name (e.g., "anthropic", "openai")
108    pub name: String,
109    /// API key for this provider
110    #[serde(default)]
111    pub api_key: Option<String>,
112    /// Base URL for the API
113    #[serde(default)]
114    pub base_url: Option<String>,
115    /// Static HTTP headers for this provider
116    #[serde(default)]
117    pub headers: HashMap<String, String>,
118    /// Header name to receive the runtime session ID
119    #[serde(default)]
120    pub session_id_header: Option<String>,
121    /// Available models
122    #[serde(default)]
123    pub models: Vec<ModelConfig>,
124}
125
126/// Apply model capability flags to an LlmConfig.
127///
128/// - `temperature = false` → omit temperature (model ignores it, e.g. o1)
129/// - `reasoning = true` + `thinking_budget` set → pass budget to client
130/// - `limit.output > 0` → use as max_tokens
131pub(crate) fn apply_model_caps(
132    mut config: LlmConfig,
133    model: &ModelConfig,
134    thinking_budget: Option<usize>,
135) -> LlmConfig {
136    // reasoning=true + thinking_budget set → pass budget to client (Anthropic only)
137    if model.reasoning {
138        if let Some(budget) = thinking_budget {
139            config = config.with_thinking_budget(budget);
140        }
141    }
142
143    // limit.output > 0 → use as max_tokens cap
144    if model.limit.output > 0 {
145        config = config.with_max_tokens(model.limit.output as usize);
146    }
147
148    // temperature=false models (e.g. o1) must not receive a temperature param.
149    // Store the flag so the LLM client can gate it at call time.
150    if !model.temperature {
151        config.disable_temperature = true;
152    }
153
154    config
155}
156
157impl ProviderConfig {
158    /// Find a model by ID
159    pub fn find_model(&self, model_id: &str) -> Option<&ModelConfig> {
160        self.models.iter().find(|m| m.id == model_id)
161    }
162
163    /// Get the effective API key for a model (model override or provider default)
164    pub fn get_api_key<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
165        model.api_key.as_deref().or(self.api_key.as_deref())
166    }
167
168    /// Get the effective base URL for a model (model override or provider default)
169    pub fn get_base_url<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
170        model.base_url.as_deref().or(self.base_url.as_deref())
171    }
172
173    /// Get the effective static headers for a model (provider defaults with model overrides)
174    pub fn get_headers(&self, model: &ModelConfig) -> HashMap<String, String> {
175        let mut headers = self.headers.clone();
176        headers.extend(model.headers.clone());
177        headers
178    }
179
180    /// Get the header name that should carry the runtime session ID.
181    pub fn get_session_id_header<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
182        model
183            .session_id_header
184            .as_deref()
185            .or(self.session_id_header.as_deref())
186    }
187}