agcodex_core/
model_provider_info.rs

1//! Registry of model providers supported by Codex.
2//!
3//! Providers can be defined in two places:
4//!   1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
5//!   2. User-defined entries inside `~/.agcodex/config.toml` under the `model_providers`
6//!      key. These override or extend the defaults at runtime.
7
8use agcodex_login::AuthMode;
9use agcodex_login::CodexAuth;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::HashMap;
13use std::env::VarError;
14use std::time::Duration;
15
16use crate::error::EnvVarError;
17const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
18const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
19const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
20
21/// Wire protocol that the provider speaks. Most third-party services only
22/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
23/// itself (and a handful of others) additionally expose the more modern
24/// *Responses* API. The two protocols use different request/response shapes
25/// and *cannot* be auto-detected at runtime, therefore each provider entry
26/// must declare which one it expects.
27#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum WireApi {
30    /// The Responses API exposed by OpenAI at `/v1/responses`.
31    Responses,
32
33    /// Regular Chat Completions compatible with `/v1/chat/completions`.
34    #[default]
35    Chat,
36}
37
38/// Serializable representation of a provider definition.
39#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
40pub struct ModelProviderInfo {
41    /// Friendly display name.
42    pub name: String,
43    /// Base URL for the provider's OpenAI-compatible API.
44    pub base_url: Option<String>,
45    /// Environment variable that stores the user's API key for this provider.
46    pub env_key: Option<String>,
47
48    /// Optional instructions to help the user get a valid value for the
49    /// variable and set it.
50    pub env_key_instructions: Option<String>,
51
52    /// Which wire protocol this provider expects.
53    #[serde(default)]
54    pub wire_api: WireApi,
55
56    /// Optional query parameters to append to the base URL.
57    pub query_params: Option<HashMap<String, String>>,
58
59    /// Additional HTTP headers to include in requests to this provider where
60    /// the (key, value) pairs are the header name and value.
61    pub http_headers: Option<HashMap<String, String>>,
62
63    /// Optional HTTP headers to include in requests to this provider where the
64    /// (key, value) pairs are the header name and _environment variable_ whose
65    /// value should be used. If the environment variable is not set, or the
66    /// value is empty, the header will not be included in the request.
67    pub env_http_headers: Option<HashMap<String, String>>,
68
69    /// Maximum number of times to retry a failed HTTP request to this provider.
70    pub request_max_retries: Option<u64>,
71
72    /// Number of times to retry reconnecting a dropped streaming response before failing.
73    pub stream_max_retries: Option<u64>,
74
75    /// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
76    /// the connection as lost.
77    pub stream_idle_timeout_ms: Option<u64>,
78
79    /// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
80    #[serde(default)]
81    pub requires_openai_auth: bool,
82}
83
84impl ModelProviderInfo {
85    /// Construct a `POST` RequestBuilder for the given URL using the provided
86    /// reqwest Client applying:
87    ///   • provider-specific headers (static + env based)
88    ///   • Bearer auth header when an API key is available.
89    ///   • Auth token for OAuth.
90    ///
91    /// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
92    /// one produced by [`ModelProviderInfo::api_key`].
93    pub async fn create_request_builder<'a>(
94        &'a self,
95        client: &'a reqwest::Client,
96        auth: &Option<CodexAuth>,
97    ) -> crate::error::Result<reqwest::RequestBuilder> {
98        let effective_auth = match self.api_key() {
99            Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)),
100            Ok(None) => auth.clone(),
101            Err(err) => {
102                if auth.is_some() {
103                    auth.clone()
104                } else {
105                    return Err(err);
106                }
107            }
108        };
109
110        let url = self.get_full_url(&effective_auth);
111
112        let mut builder = client.post(url);
113
114        if let Some(auth) = effective_auth.as_ref() {
115            builder = builder.bearer_auth(auth.get_token().await?);
116        }
117
118        Ok(self.apply_http_headers(builder))
119    }
120
121    fn get_query_string(&self) -> String {
122        self.query_params
123            .as_ref()
124            .map_or_else(String::new, |params| {
125                let full_params = params
126                    .iter()
127                    .map(|(k, v)| format!("{k}={v}"))
128                    .collect::<Vec<_>>()
129                    .join("&");
130                format!("?{full_params}")
131            })
132    }
133
134    pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
135        let default_base_url = if matches!(
136            auth,
137            Some(CodexAuth {
138                mode: AuthMode::ChatGPT,
139                ..
140            })
141        ) {
142            "https://chatgpt.com/backend-api/codex"
143        } else {
144            "https://api.openai.com/v1"
145        };
146        let query_string = self.get_query_string();
147        let base_url = self
148            .base_url
149            .clone()
150            .unwrap_or(default_base_url.to_string());
151
152        match self.wire_api {
153            WireApi::Responses => format!("{base_url}/responses{query_string}"),
154            WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
155        }
156    }
157
158    /// Apply provider-specific HTTP headers (both static and environment-based)
159    /// onto an existing `reqwest::RequestBuilder` and return the updated
160    /// builder.
161    fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
162        if let Some(extra) = &self.http_headers {
163            for (k, v) in extra {
164                builder = builder.header(k, v);
165            }
166        }
167
168        if let Some(env_headers) = &self.env_http_headers {
169            for (header, env_var) in env_headers {
170                if let Ok(val) = std::env::var(env_var)
171                    && !val.trim().is_empty()
172                {
173                    builder = builder.header(header, val);
174                }
175            }
176        }
177        builder
178    }
179
180    /// If `env_key` is Some, returns the API key for this provider if present
181    /// (and non-empty) in the environment. If `env_key` is required but
182    /// cannot be found, returns an error.
183    pub fn api_key(&self) -> crate::error::Result<Option<String>> {
184        match &self.env_key {
185            Some(env_key) => {
186                let env_value = std::env::var(env_key);
187                env_value
188                    .and_then(|v| {
189                        if v.trim().is_empty() {
190                            Err(VarError::NotPresent)
191                        } else {
192                            Ok(Some(v))
193                        }
194                    })
195                    .map_err(|_| {
196                        crate::error::CodexErr::EnvVar(EnvVarError {
197                            var: env_key.clone(),
198                            instructions: self.env_key_instructions.clone(),
199                        })
200                    })
201            }
202            None => Ok(None),
203        }
204    }
205
206    /// Effective maximum number of request retries for this provider.
207    pub fn request_max_retries(&self) -> u64 {
208        self.request_max_retries
209            .unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
210    }
211
212    /// Effective maximum number of stream reconnection attempts for this provider.
213    pub fn stream_max_retries(&self) -> u64 {
214        self.stream_max_retries
215            .unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
216    }
217
218    /// Effective idle timeout for streaming responses.
219    pub fn stream_idle_timeout(&self) -> Duration {
220        self.stream_idle_timeout_ms
221            .map(Duration::from_millis)
222            .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
223    }
224}
225
226const DEFAULT_OLLAMA_PORT: u32 = 11434;
227
228pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
229
230/// Built-in default provider list.
231pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
232    use ModelProviderInfo as P;
233
234    // We do not want to be in the business of adjucating which third-party
235    // providers are bundled with Codex CLI, so we only include the OpenAI and
236    // open source ("oss") providers by default. Users are encouraged to add to
237    // `model_providers` in config.toml to add their own providers.
238    [
239        (
240            "openai",
241            P {
242                name: "OpenAI".into(),
243                // Allow users to override the default OpenAI endpoint by
244                // exporting `OPENAI_BASE_URL`. This is useful when pointing
245                // Codex at a proxy, mock server, or Azure-style deployment
246                // without requiring a full TOML override for the built-in
247                // OpenAI provider.
248                base_url: std::env::var("OPENAI_BASE_URL")
249                    .ok()
250                    .filter(|v| !v.trim().is_empty()),
251                env_key: None,
252                env_key_instructions: None,
253                wire_api: WireApi::Responses,
254                query_params: None,
255                http_headers: Some(
256                    [("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
257                        .into_iter()
258                        .collect(),
259                ),
260                env_http_headers: Some(
261                    [
262                        (
263                            "OpenAI-Organization".to_string(),
264                            "OPENAI_ORGANIZATION".to_string(),
265                        ),
266                        ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
267                    ]
268                    .into_iter()
269                    .collect(),
270                ),
271                // Use global defaults for retry/timeout unless overridden in config.toml.
272                request_max_retries: None,
273                stream_max_retries: None,
274                stream_idle_timeout_ms: None,
275                requires_openai_auth: true,
276            },
277        ),
278        (BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
279    ]
280    .into_iter()
281    .map(|(k, v)| (k.to_string(), v))
282    .collect()
283}
284
285pub fn create_oss_provider() -> ModelProviderInfo {
286    // These CODEX_OSS_ environment variables are experimental: we may
287    // switch to reading values from config.toml instead.
288    let agcodex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
289        .ok()
290        .filter(|v| !v.trim().is_empty())
291    {
292        Some(url) => url,
293        None => format!(
294            "http://localhost:{port}/v1",
295            port = std::env::var("CODEX_OSS_PORT")
296                .ok()
297                .filter(|v| !v.trim().is_empty())
298                .and_then(|v| v.parse::<u32>().ok())
299                .unwrap_or(DEFAULT_OLLAMA_PORT)
300        ),
301    };
302
303    create_oss_provider_with_base_url(&agcodex_oss_base_url)
304}
305
306pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
307    ModelProviderInfo {
308        name: "gpt-oss".into(),
309        base_url: Some(base_url.into()),
310        env_key: None,
311        env_key_instructions: None,
312        wire_api: WireApi::Chat,
313        query_params: None,
314        http_headers: None,
315        env_http_headers: None,
316        request_max_retries: None,
317        stream_max_retries: None,
318        stream_idle_timeout_ms: None,
319        requires_openai_auth: false,
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use pretty_assertions::assert_eq;
327
328    #[test]
329    fn test_deserialize_ollama_model_provider_toml() {
330        let azure_provider_toml = r#"
331name = "Ollama"
332base_url = "http://localhost:11434/v1"
333        "#;
334        let expected_provider = ModelProviderInfo {
335            name: "Ollama".into(),
336            base_url: Some("http://localhost:11434/v1".into()),
337            env_key: None,
338            env_key_instructions: None,
339            wire_api: WireApi::Chat,
340            query_params: None,
341            http_headers: None,
342            env_http_headers: None,
343            request_max_retries: None,
344            stream_max_retries: None,
345            stream_idle_timeout_ms: None,
346            requires_openai_auth: false,
347        };
348
349        let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
350        assert_eq!(expected_provider, provider);
351    }
352
353    #[test]
354    fn test_deserialize_azure_model_provider_toml() {
355        let azure_provider_toml = r#"
356name = "Azure"
357base_url = "https://xxxxx.openai.azure.com/openai"
358env_key = "AZURE_OPENAI_API_KEY"
359query_params = { api-version = "2025-04-01-preview" }
360        "#;
361        let expected_provider = ModelProviderInfo {
362            name: "Azure".into(),
363            base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
364            env_key: Some("AZURE_OPENAI_API_KEY".into()),
365            env_key_instructions: None,
366            wire_api: WireApi::Chat,
367            query_params: Some(maplit::hashmap! {
368                "api-version".to_string() => "2025-04-01-preview".to_string(),
369            }),
370            http_headers: None,
371            env_http_headers: None,
372            request_max_retries: None,
373            stream_max_retries: None,
374            stream_idle_timeout_ms: None,
375            requires_openai_auth: false,
376        };
377
378        let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
379        assert_eq!(expected_provider, provider);
380    }
381
382    #[test]
383    fn test_deserialize_example_model_provider_toml() {
384        let azure_provider_toml = r#"
385name = "Example"
386base_url = "https://example.com"
387env_key = "API_KEY"
388http_headers = { "X-Example-Header" = "example-value" }
389env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
390        "#;
391        let expected_provider = ModelProviderInfo {
392            name: "Example".into(),
393            base_url: Some("https://example.com".into()),
394            env_key: Some("API_KEY".into()),
395            env_key_instructions: None,
396            wire_api: WireApi::Chat,
397            query_params: None,
398            http_headers: Some(maplit::hashmap! {
399                "X-Example-Header".to_string() => "example-value".to_string(),
400            }),
401            env_http_headers: Some(maplit::hashmap! {
402                "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
403            }),
404            request_max_retries: None,
405            stream_max_retries: None,
406            stream_idle_timeout_ms: None,
407            requires_openai_auth: false,
408        };
409
410        let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
411        assert_eq!(expected_provider, provider);
412    }
413}