a3s_code_core/config/
provider.rs1use crate::llm::LlmConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11#[serde(rename_all = "camelCase")]
12pub struct ModelCost {
13 #[serde(default)]
15 pub input: f64,
16 #[serde(default)]
18 pub output: f64,
19 #[serde(default)]
21 pub cache_read: f64,
22 #[serde(default)]
24 pub cache_write: f64,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct ModelLimit {
30 #[serde(default)]
32 pub context: u32,
33 #[serde(default)]
35 pub output: u32,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct ModelModalities {
41 #[serde(default)]
43 pub input: Vec<String>,
44 #[serde(default)]
46 pub output: Vec<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct ModelConfig {
53 pub id: String,
55 #[serde(default)]
57 pub name: String,
58 #[serde(default)]
60 pub family: String,
61 #[serde(default)]
63 pub api_key: Option<String>,
64 #[serde(default)]
66 pub base_url: Option<String>,
67 #[serde(default)]
69 pub headers: HashMap<String, String>,
70 #[serde(default)]
72 pub session_id_header: Option<String>,
73 #[serde(default)]
75 pub attachment: bool,
76 #[serde(default)]
78 pub reasoning: bool,
79 #[serde(default = "default_true")]
81 pub tool_call: bool,
82 #[serde(default = "default_true")]
84 pub temperature: bool,
85 #[serde(default)]
87 pub release_date: Option<String>,
88 #[serde(default)]
90 pub modalities: ModelModalities,
91 #[serde(default)]
93 pub cost: ModelCost,
94 #[serde(default)]
96 pub limit: ModelLimit,
97}
98
99pub(crate) fn default_true() -> bool {
100 true
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(rename_all = "camelCase")]
106pub struct ProviderConfig {
107 pub name: String,
109 #[serde(default)]
111 pub api_key: Option<String>,
112 #[serde(default)]
114 pub base_url: Option<String>,
115 #[serde(default)]
117 pub headers: HashMap<String, String>,
118 #[serde(default)]
120 pub session_id_header: Option<String>,
121 #[serde(default)]
123 pub models: Vec<ModelConfig>,
124}
125
126pub(crate) fn apply_model_caps(
132 mut config: LlmConfig,
133 model: &ModelConfig,
134 thinking_budget: Option<usize>,
135) -> LlmConfig {
136 if model.reasoning {
138 if let Some(budget) = thinking_budget {
139 config = config.with_thinking_budget(budget);
140 }
141 }
142
143 if model.limit.output > 0 {
145 config = config.with_max_tokens(model.limit.output as usize);
146 }
147
148 if !model.temperature {
151 config.disable_temperature = true;
152 }
153
154 config
155}
156
157impl ProviderConfig {
158 pub fn find_model(&self, model_id: &str) -> Option<&ModelConfig> {
160 self.models.iter().find(|m| m.id == model_id)
161 }
162
163 pub fn get_api_key<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
165 model.api_key.as_deref().or(self.api_key.as_deref())
166 }
167
168 pub fn get_base_url<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
170 model.base_url.as_deref().or(self.base_url.as_deref())
171 }
172
173 pub fn get_headers(&self, model: &ModelConfig) -> HashMap<String, String> {
175 let mut headers = self.headers.clone();
176 headers.extend(model.headers.clone());
177 headers
178 }
179
180 pub fn get_session_id_header<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
182 model
183 .session_id_header
184 .as_deref()
185 .or(self.session_id_header.as_deref())
186 }
187}