Skip to main content

aster/providers/
openai.rs

1use super::api_client::{ApiClient, AuthMethod};
2use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage};
3use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
4use super::errors::ProviderError;
5use super::formats::openai::{create_request, get_usage, response_to_message};
6use super::formats::openai_responses::{
7    create_responses_request, get_responses_usage, responses_api_to_message,
8    responses_api_to_streaming_message, ResponsesApiResponse,
9};
10use super::retry::ProviderRetry;
11use super::utils::{
12    get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
13    ImageFormat,
14};
15use crate::config::declarative_providers::DeclarativeProviderConfig;
16use crate::conversation::message::Message;
17use anyhow::Result;
18use async_stream::try_stream;
19use async_trait::async_trait;
20use futures::{StreamExt, TryStreamExt};
21use reqwest::StatusCode;
22use serde_json::Value;
23use std::collections::HashMap;
24use std::io;
25use tokio::pin;
26use tokio_util::codec::{FramedRead, LinesCodec};
27use tokio_util::io::StreamReader;
28
29use crate::model::ModelConfig;
30use crate::providers::base::MessageStream;
31use crate::providers::utils::RequestLog;
32use rmcp::model::Tool;
33
34pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
35pub const OPEN_AI_DEFAULT_FAST_MODEL: &str = "gpt-4o-mini";
36pub const OPEN_AI_KNOWN_MODELS: &[(&str, usize)] = &[
37    ("gpt-4o", 128_000),
38    ("gpt-4o-mini", 128_000),
39    ("gpt-4.1", 128_000),
40    ("gpt-4.1-mini", 128_000),
41    ("o1", 200_000),
42    ("o3", 200_000),
43    ("gpt-3.5-turbo", 16_385),
44    ("gpt-4-turbo", 128_000),
45    ("o4-mini", 128_000),
46    ("gpt-5.1-codex", 400_000),
47    ("gpt-5-codex", 400_000),
48];
49
50pub const OPEN_AI_DOC_URL: &str = "https://platform.openai.com/docs/models";
51
52#[derive(Debug, serde::Serialize)]
53pub struct OpenAiProvider {
54    #[serde(skip)]
55    api_client: ApiClient,
56    base_path: String,
57    organization: Option<String>,
58    project: Option<String>,
59    model: ModelConfig,
60    custom_headers: Option<HashMap<String, String>>,
61    supports_streaming: bool,
62    name: String,
63}
64
65impl OpenAiProvider {
66    pub async fn from_env(model: ModelConfig) -> Result<Self> {
67        let model = model.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string());
68
69        let config = crate::config::Config::global();
70        let secrets = config.get_secrets("OPENAI_API_KEY", &["OPENAI_CUSTOM_HEADERS"])?;
71        let api_key = secrets.get("OPENAI_API_KEY").unwrap().clone();
72        let host: String = config
73            .get_param("OPENAI_HOST")
74            .unwrap_or_else(|_| "https://api.openai.com".to_string());
75        let base_path: String = config
76            .get_param("OPENAI_BASE_PATH")
77            .unwrap_or_else(|_| "v1/chat/completions".to_string());
78        let organization: Option<String> = config.get_param("OPENAI_ORGANIZATION").ok();
79        let project: Option<String> = config.get_param("OPENAI_PROJECT").ok();
80        let custom_headers: Option<HashMap<String, String>> = secrets
81            .get("OPENAI_CUSTOM_HEADERS")
82            .cloned()
83            .map(parse_custom_headers);
84        let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600);
85
86        let auth = AuthMethod::BearerToken(api_key);
87        let mut api_client =
88            ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
89
90        if let Some(org) = &organization {
91            api_client = api_client.with_header("OpenAI-Organization", org)?;
92        }
93
94        if let Some(project) = &project {
95            api_client = api_client.with_header("OpenAI-Project", project)?;
96        }
97
98        if let Some(headers) = &custom_headers {
99            let mut header_map = reqwest::header::HeaderMap::new();
100            for (key, value) in headers {
101                let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?;
102                let header_value = reqwest::header::HeaderValue::from_str(value)?;
103                header_map.insert(header_name, header_value);
104            }
105            api_client = api_client.with_headers(header_map)?;
106        }
107
108        Ok(Self {
109            api_client,
110            base_path,
111            organization,
112            project,
113            model,
114            custom_headers,
115            supports_streaming: true,
116            name: Self::metadata().name,
117        })
118    }
119
120    #[doc(hidden)]
121    pub fn new(api_client: ApiClient, model: ModelConfig) -> Self {
122        Self {
123            api_client,
124            base_path: "v1/chat/completions".to_string(),
125            organization: None,
126            project: None,
127            model,
128            custom_headers: None,
129            supports_streaming: true,
130            name: Self::metadata().name,
131        }
132    }
133
134    pub fn from_custom_config(
135        model: ModelConfig,
136        config: DeclarativeProviderConfig,
137    ) -> Result<Self> {
138        let global_config = crate::config::Config::global();
139        let api_key: String = global_config
140            .get_secret(&config.api_key_env)
141            .map_err(|_e| anyhow::anyhow!("Missing API key: {}", config.api_key_env))?;
142
143        let url = url::Url::parse(&config.base_url)
144            .map_err(|e| anyhow::anyhow!("Invalid base URL '{}': {}", config.base_url, e))?;
145
146        let host = if let Some(port) = url.port() {
147            format!(
148                "{}://{}:{}",
149                url.scheme(),
150                url.host_str().unwrap_or(""),
151                port
152            )
153        } else {
154            format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""))
155        };
156        let base_path = url.path().trim_start_matches('/').to_string();
157        let base_path = if base_path.is_empty() {
158            "v1/chat/completions".to_string()
159        } else {
160            base_path
161        };
162
163        let timeout_secs = config.timeout_seconds.unwrap_or(600);
164        let auth = AuthMethod::BearerToken(api_key);
165        let mut api_client =
166            ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
167
168        // Add custom headers if present
169        if let Some(headers) = &config.headers {
170            let mut header_map = reqwest::header::HeaderMap::new();
171            for (key, value) in headers {
172                let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?;
173                let header_value = reqwest::header::HeaderValue::from_str(value)?;
174                header_map.insert(header_name, header_value);
175            }
176            api_client = api_client.with_headers(header_map)?;
177        }
178
179        Ok(Self {
180            api_client,
181            base_path,
182            organization: None,
183            project: None,
184            model,
185            custom_headers: config.headers,
186            supports_streaming: config.supports_streaming.unwrap_or(true),
187            name: config.name.clone(),
188        })
189    }
190
191    fn uses_responses_api(model_name: &str) -> bool {
192        model_name.starts_with("gpt-5-codex") || model_name.starts_with("gpt-5.1-codex")
193    }
194
195    async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
196        let response = self
197            .api_client
198            .response_post(&self.base_path, payload)
199            .await?;
200        handle_response_openai_compat(response).await
201    }
202
203    async fn post_responses(&self, payload: &Value) -> Result<Value, ProviderError> {
204        let response = self
205            .api_client
206            .response_post("v1/responses", payload)
207            .await?;
208        handle_response_openai_compat(response).await
209    }
210}
211
212#[async_trait]
213impl Provider for OpenAiProvider {
214    fn metadata() -> ProviderMetadata {
215        let models = OPEN_AI_KNOWN_MODELS
216            .iter()
217            .map(|(name, limit)| ModelInfo::new(*name, *limit))
218            .collect();
219        ProviderMetadata::with_models(
220            "openai",
221            "OpenAI",
222            "GPT-4 and other OpenAI models, including OpenAI compatible ones",
223            OPEN_AI_DEFAULT_MODEL,
224            models,
225            OPEN_AI_DOC_URL,
226            vec![
227                ConfigKey::new("OPENAI_API_KEY", true, true, None),
228                ConfigKey::new("OPENAI_HOST", true, false, Some("https://api.openai.com")),
229                ConfigKey::new("OPENAI_BASE_PATH", true, false, Some("v1/chat/completions")),
230                ConfigKey::new("OPENAI_ORGANIZATION", false, false, None),
231                ConfigKey::new("OPENAI_PROJECT", false, false, None),
232                ConfigKey::new("OPENAI_CUSTOM_HEADERS", false, true, None),
233                ConfigKey::new("OPENAI_TIMEOUT", false, false, Some("600")),
234            ],
235        )
236    }
237
238    fn get_name(&self) -> &str {
239        &self.name
240    }
241
242    fn get_model_config(&self) -> ModelConfig {
243        self.model.clone()
244    }
245
246    #[tracing::instrument(
247        skip(self, model_config, system, messages, tools),
248        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
249    )]
250    async fn complete_with_model(
251        &self,
252        model_config: &ModelConfig,
253        system: &str,
254        messages: &[Message],
255        tools: &[Tool],
256    ) -> Result<(Message, ProviderUsage), ProviderError> {
257        if Self::uses_responses_api(&model_config.model_name) {
258            let payload = create_responses_request(model_config, system, messages, tools)?;
259            let mut log = RequestLog::start(&self.model, &payload)?;
260
261            let json_response = self
262                .with_retry(|| async {
263                    let payload_clone = payload.clone();
264                    self.post_responses(&payload_clone).await
265                })
266                .await
267                .inspect_err(|e| {
268                    let _ = log.error(e);
269                })?;
270
271            let responses_api_response: ResponsesApiResponse =
272                serde_json::from_value(json_response.clone()).map_err(|e| {
273                    ProviderError::ExecutionError(format!(
274                        "Failed to parse responses API response: {}",
275                        e
276                    ))
277                })?;
278
279            let message = responses_api_to_message(&responses_api_response)?;
280            let usage = get_responses_usage(&responses_api_response);
281            let model = responses_api_response.model.clone();
282
283            log.write(&json_response, Some(&usage))?;
284            Ok((message, ProviderUsage::new(model, usage)))
285        } else {
286            let payload = create_request(
287                model_config,
288                system,
289                messages,
290                tools,
291                &ImageFormat::OpenAi,
292                false,
293            )?;
294
295            let mut log = RequestLog::start(&self.model, &payload)?;
296            let json_response = self
297                .with_retry(|| async {
298                    let payload_clone = payload.clone();
299                    self.post(&payload_clone).await
300                })
301                .await
302                .inspect_err(|e| {
303                    let _ = log.error(e);
304                })?;
305
306            let message = response_to_message(&json_response)?;
307            let usage = json_response
308                .get("usage")
309                .map(get_usage)
310                .unwrap_or_else(|| {
311                    tracing::debug!("Failed to get usage data");
312                    Usage::default()
313                });
314
315            let model = get_model(&json_response);
316            log.write(&json_response, Some(&usage))?;
317            Ok((message, ProviderUsage::new(model, usage)))
318        }
319    }
320
321    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
322        let models_path = self.base_path.replace("v1/chat/completions", "v1/models");
323        let response = self.api_client.response_get(&models_path).await?;
324        let json = handle_response_openai_compat(response).await?;
325        if let Some(err_obj) = json.get("error") {
326            let msg = err_obj
327                .get("message")
328                .and_then(|v| v.as_str())
329                .unwrap_or("unknown error");
330            return Err(ProviderError::Authentication(msg.to_string()));
331        }
332
333        let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
334            ProviderError::UsageError("Missing data field in JSON response".into())
335        })?;
336        let mut models: Vec<String> = data
337            .iter()
338            .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
339            .collect();
340        models.sort();
341        Ok(Some(models))
342    }
343
344    fn supports_embeddings(&self) -> bool {
345        true
346    }
347
348    async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
349        EmbeddingCapable::create_embeddings(self, texts)
350            .await
351            .map_err(|e| ProviderError::ExecutionError(e.to_string()))
352    }
353
354    fn supports_streaming(&self) -> bool {
355        self.supports_streaming
356    }
357
358    async fn stream(
359        &self,
360        system: &str,
361        messages: &[Message],
362        tools: &[Tool],
363    ) -> Result<MessageStream, ProviderError> {
364        if Self::uses_responses_api(&self.model.model_name) {
365            let mut payload = create_responses_request(&self.model, system, messages, tools)?;
366            payload["stream"] = serde_json::Value::Bool(true);
367
368            let mut log = RequestLog::start(&self.model, &payload)?;
369
370            let response = self
371                .with_retry(|| async {
372                    let payload_clone = payload.clone();
373                    let resp = self
374                        .api_client
375                        .response_post("v1/responses", &payload_clone)
376                        .await?;
377                    handle_status_openai_compat(resp).await
378                })
379                .await
380                .inspect_err(|e| {
381                    let _ = log.error(e);
382                })?;
383
384            let stream = response.bytes_stream().map_err(io::Error::other);
385
386            Ok(Box::pin(try_stream! {
387                let stream_reader = StreamReader::new(stream);
388                let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
389
390                let message_stream = responses_api_to_streaming_message(framed);
391                pin!(message_stream);
392                while let Some(message) = message_stream.next().await {
393                    let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
394                    log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
395                    yield (message, usage);
396                }
397            }))
398        } else {
399            let payload = create_request(
400                &self.model,
401                system,
402                messages,
403                tools,
404                &ImageFormat::OpenAi,
405                true,
406            )?;
407            let mut log = RequestLog::start(&self.model, &payload)?;
408
409            let response = self
410                .with_retry(|| async {
411                    let resp = self
412                        .api_client
413                        .response_post(&self.base_path, &payload)
414                        .await?;
415                    handle_status_openai_compat(resp).await
416                })
417                .await
418                .inspect_err(|e| {
419                    let _ = log.error(e);
420                })?;
421
422            stream_openai_compat(response, log)
423        }
424    }
425}
426
427fn parse_custom_headers(s: String) -> HashMap<String, String> {
428    s.split(',')
429        .filter_map(|header| {
430            let mut parts = header.splitn(2, '=');
431            let key = parts.next().map(|s| s.trim().to_string())?;
432            let value = parts.next().map(|s| s.trim().to_string())?;
433            Some((key, value))
434        })
435        .collect()
436}
437
438#[async_trait]
439impl EmbeddingCapable for OpenAiProvider {
440    async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
441        if texts.is_empty() {
442            return Ok(vec![]);
443        }
444
445        let embedding_model = std::env::var("ASTER_EMBEDDING_MODEL")
446            .unwrap_or_else(|_| "text-embedding-3-small".to_string());
447
448        let request = EmbeddingRequest {
449            input: texts,
450            model: embedding_model,
451        };
452
453        let response = self
454            .with_retry(|| async {
455                let request_clone = EmbeddingRequest {
456                    input: request.input.clone(),
457                    model: request.model.clone(),
458                };
459                let request_value = serde_json::to_value(request_clone)
460                    .map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
461                self.api_client
462                    .api_post("v1/embeddings", &request_value)
463                    .await
464                    .map_err(|e| ProviderError::ExecutionError(e.to_string()))
465            })
466            .await?;
467
468        if response.status != StatusCode::OK {
469            let error_text = response
470                .payload
471                .as_ref()
472                .and_then(|p| p.as_str())
473                .unwrap_or("Unknown error");
474            return Err(anyhow::anyhow!("Embedding API error: {}", error_text));
475        }
476
477        let embedding_response: EmbeddingResponse = serde_json::from_value(
478            response
479                .payload
480                .ok_or_else(|| anyhow::anyhow!("Empty response body"))?,
481        )?;
482
483        Ok(embedding_response
484            .data
485            .into_iter()
486            .map(|d| d.embedding)
487            .collect())
488    }
489}