1use agcodex_login::AuthMode;
9use agcodex_login::CodexAuth;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::HashMap;
13use std::env::VarError;
14use std::time::Duration;
15
16use crate::error::EnvVarError;
17const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
18const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
19const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
20
21#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum WireApi {
30 Responses,
32
33 #[default]
35 Chat,
36}
37
38#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
40pub struct ModelProviderInfo {
41 pub name: String,
43 pub base_url: Option<String>,
45 pub env_key: Option<String>,
47
48 pub env_key_instructions: Option<String>,
51
52 #[serde(default)]
54 pub wire_api: WireApi,
55
56 pub query_params: Option<HashMap<String, String>>,
58
59 pub http_headers: Option<HashMap<String, String>>,
62
63 pub env_http_headers: Option<HashMap<String, String>>,
68
69 pub request_max_retries: Option<u64>,
71
72 pub stream_max_retries: Option<u64>,
74
75 pub stream_idle_timeout_ms: Option<u64>,
78
79 #[serde(default)]
81 pub requires_openai_auth: bool,
82}
83
84impl ModelProviderInfo {
85 pub async fn create_request_builder<'a>(
94 &'a self,
95 client: &'a reqwest::Client,
96 auth: &Option<CodexAuth>,
97 ) -> crate::error::Result<reqwest::RequestBuilder> {
98 let effective_auth = match self.api_key() {
99 Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)),
100 Ok(None) => auth.clone(),
101 Err(err) => {
102 if auth.is_some() {
103 auth.clone()
104 } else {
105 return Err(err);
106 }
107 }
108 };
109
110 let url = self.get_full_url(&effective_auth);
111
112 let mut builder = client.post(url);
113
114 if let Some(auth) = effective_auth.as_ref() {
115 builder = builder.bearer_auth(auth.get_token().await?);
116 }
117
118 Ok(self.apply_http_headers(builder))
119 }
120
121 fn get_query_string(&self) -> String {
122 self.query_params
123 .as_ref()
124 .map_or_else(String::new, |params| {
125 let full_params = params
126 .iter()
127 .map(|(k, v)| format!("{k}={v}"))
128 .collect::<Vec<_>>()
129 .join("&");
130 format!("?{full_params}")
131 })
132 }
133
134 pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
135 let default_base_url = if matches!(
136 auth,
137 Some(CodexAuth {
138 mode: AuthMode::ChatGPT,
139 ..
140 })
141 ) {
142 "https://chatgpt.com/backend-api/codex"
143 } else {
144 "https://api.openai.com/v1"
145 };
146 let query_string = self.get_query_string();
147 let base_url = self
148 .base_url
149 .clone()
150 .unwrap_or(default_base_url.to_string());
151
152 match self.wire_api {
153 WireApi::Responses => format!("{base_url}/responses{query_string}"),
154 WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
155 }
156 }
157
158 fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
162 if let Some(extra) = &self.http_headers {
163 for (k, v) in extra {
164 builder = builder.header(k, v);
165 }
166 }
167
168 if let Some(env_headers) = &self.env_http_headers {
169 for (header, env_var) in env_headers {
170 if let Ok(val) = std::env::var(env_var)
171 && !val.trim().is_empty()
172 {
173 builder = builder.header(header, val);
174 }
175 }
176 }
177 builder
178 }
179
180 pub fn api_key(&self) -> crate::error::Result<Option<String>> {
184 match &self.env_key {
185 Some(env_key) => {
186 let env_value = std::env::var(env_key);
187 env_value
188 .and_then(|v| {
189 if v.trim().is_empty() {
190 Err(VarError::NotPresent)
191 } else {
192 Ok(Some(v))
193 }
194 })
195 .map_err(|_| {
196 crate::error::CodexErr::EnvVar(EnvVarError {
197 var: env_key.clone(),
198 instructions: self.env_key_instructions.clone(),
199 })
200 })
201 }
202 None => Ok(None),
203 }
204 }
205
206 pub fn request_max_retries(&self) -> u64 {
208 self.request_max_retries
209 .unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
210 }
211
212 pub fn stream_max_retries(&self) -> u64 {
214 self.stream_max_retries
215 .unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
216 }
217
218 pub fn stream_idle_timeout(&self) -> Duration {
220 self.stream_idle_timeout_ms
221 .map(Duration::from_millis)
222 .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
223 }
224}
225
226const DEFAULT_OLLAMA_PORT: u32 = 11434;
227
228pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
229
230pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
232 use ModelProviderInfo as P;
233
234 [
239 (
240 "openai",
241 P {
242 name: "OpenAI".into(),
243 base_url: std::env::var("OPENAI_BASE_URL")
249 .ok()
250 .filter(|v| !v.trim().is_empty()),
251 env_key: None,
252 env_key_instructions: None,
253 wire_api: WireApi::Responses,
254 query_params: None,
255 http_headers: Some(
256 [("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
257 .into_iter()
258 .collect(),
259 ),
260 env_http_headers: Some(
261 [
262 (
263 "OpenAI-Organization".to_string(),
264 "OPENAI_ORGANIZATION".to_string(),
265 ),
266 ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
267 ]
268 .into_iter()
269 .collect(),
270 ),
271 request_max_retries: None,
273 stream_max_retries: None,
274 stream_idle_timeout_ms: None,
275 requires_openai_auth: true,
276 },
277 ),
278 (BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
279 ]
280 .into_iter()
281 .map(|(k, v)| (k.to_string(), v))
282 .collect()
283}
284
285pub fn create_oss_provider() -> ModelProviderInfo {
286 let agcodex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
289 .ok()
290 .filter(|v| !v.trim().is_empty())
291 {
292 Some(url) => url,
293 None => format!(
294 "http://localhost:{port}/v1",
295 port = std::env::var("CODEX_OSS_PORT")
296 .ok()
297 .filter(|v| !v.trim().is_empty())
298 .and_then(|v| v.parse::<u32>().ok())
299 .unwrap_or(DEFAULT_OLLAMA_PORT)
300 ),
301 };
302
303 create_oss_provider_with_base_url(&agcodex_oss_base_url)
304}
305
306pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
307 ModelProviderInfo {
308 name: "gpt-oss".into(),
309 base_url: Some(base_url.into()),
310 env_key: None,
311 env_key_instructions: None,
312 wire_api: WireApi::Chat,
313 query_params: None,
314 http_headers: None,
315 env_http_headers: None,
316 request_max_retries: None,
317 stream_max_retries: None,
318 stream_idle_timeout_ms: None,
319 requires_openai_auth: false,
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use pretty_assertions::assert_eq;
327
328 #[test]
329 fn test_deserialize_ollama_model_provider_toml() {
330 let azure_provider_toml = r#"
331name = "Ollama"
332base_url = "http://localhost:11434/v1"
333 "#;
334 let expected_provider = ModelProviderInfo {
335 name: "Ollama".into(),
336 base_url: Some("http://localhost:11434/v1".into()),
337 env_key: None,
338 env_key_instructions: None,
339 wire_api: WireApi::Chat,
340 query_params: None,
341 http_headers: None,
342 env_http_headers: None,
343 request_max_retries: None,
344 stream_max_retries: None,
345 stream_idle_timeout_ms: None,
346 requires_openai_auth: false,
347 };
348
349 let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
350 assert_eq!(expected_provider, provider);
351 }
352
353 #[test]
354 fn test_deserialize_azure_model_provider_toml() {
355 let azure_provider_toml = r#"
356name = "Azure"
357base_url = "https://xxxxx.openai.azure.com/openai"
358env_key = "AZURE_OPENAI_API_KEY"
359query_params = { api-version = "2025-04-01-preview" }
360 "#;
361 let expected_provider = ModelProviderInfo {
362 name: "Azure".into(),
363 base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
364 env_key: Some("AZURE_OPENAI_API_KEY".into()),
365 env_key_instructions: None,
366 wire_api: WireApi::Chat,
367 query_params: Some(maplit::hashmap! {
368 "api-version".to_string() => "2025-04-01-preview".to_string(),
369 }),
370 http_headers: None,
371 env_http_headers: None,
372 request_max_retries: None,
373 stream_max_retries: None,
374 stream_idle_timeout_ms: None,
375 requires_openai_auth: false,
376 };
377
378 let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
379 assert_eq!(expected_provider, provider);
380 }
381
382 #[test]
383 fn test_deserialize_example_model_provider_toml() {
384 let azure_provider_toml = r#"
385name = "Example"
386base_url = "https://example.com"
387env_key = "API_KEY"
388http_headers = { "X-Example-Header" = "example-value" }
389env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
390 "#;
391 let expected_provider = ModelProviderInfo {
392 name: "Example".into(),
393 base_url: Some("https://example.com".into()),
394 env_key: Some("API_KEY".into()),
395 env_key_instructions: None,
396 wire_api: WireApi::Chat,
397 query_params: None,
398 http_headers: Some(maplit::hashmap! {
399 "X-Example-Header".to_string() => "example-value".to_string(),
400 }),
401 env_http_headers: Some(maplit::hashmap! {
402 "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
403 }),
404 request_max_retries: None,
405 stream_max_retries: None,
406 stream_idle_timeout_ms: None,
407 requires_openai_auth: false,
408 };
409
410 let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
411 assert_eq!(expected_provider, provider);
412 }
413}