1use super::api_client::{ApiClient, AuthMethod};
2use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage};
3use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
4use super::errors::ProviderError;
5use super::formats::openai::{create_request, get_usage, response_to_message};
6use super::formats::openai_responses::{
7 create_responses_request, get_responses_usage, responses_api_to_message,
8 responses_api_to_streaming_message, ResponsesApiResponse,
9};
10use super::retry::ProviderRetry;
11use super::utils::{
12 get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
13 ImageFormat,
14};
15use crate::config::declarative_providers::DeclarativeProviderConfig;
16use crate::conversation::message::Message;
17use anyhow::Result;
18use async_stream::try_stream;
19use async_trait::async_trait;
20use futures::{StreamExt, TryStreamExt};
21use reqwest::StatusCode;
22use serde_json::Value;
23use std::collections::HashMap;
24use std::io;
25use tokio::pin;
26use tokio_util::codec::{FramedRead, LinesCodec};
27use tokio_util::io::StreamReader;
28
29use crate::model::ModelConfig;
30use crate::providers::base::MessageStream;
31use crate::providers::utils::RequestLog;
32use rmcp::model::Tool;
33
34pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
35pub const OPEN_AI_DEFAULT_FAST_MODEL: &str = "gpt-4o-mini";
36pub const OPEN_AI_KNOWN_MODELS: &[(&str, usize)] = &[
37 ("gpt-4o", 128_000),
38 ("gpt-4o-mini", 128_000),
39 ("gpt-4.1", 128_000),
40 ("gpt-4.1-mini", 128_000),
41 ("o1", 200_000),
42 ("o3", 200_000),
43 ("gpt-3.5-turbo", 16_385),
44 ("gpt-4-turbo", 128_000),
45 ("o4-mini", 128_000),
46 ("gpt-5.1-codex", 400_000),
47 ("gpt-5-codex", 400_000),
48];
49
50pub const OPEN_AI_DOC_URL: &str = "https://platform.openai.com/docs/models";
51
52#[derive(Debug, serde::Serialize)]
53pub struct OpenAiProvider {
54 #[serde(skip)]
55 api_client: ApiClient,
56 base_path: String,
57 organization: Option<String>,
58 project: Option<String>,
59 model: ModelConfig,
60 custom_headers: Option<HashMap<String, String>>,
61 supports_streaming: bool,
62 name: String,
63}
64
65impl OpenAiProvider {
66 pub async fn from_env(model: ModelConfig) -> Result<Self> {
67 let model = model.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string());
68
69 let config = crate::config::Config::global();
70 let secrets = config.get_secrets("OPENAI_API_KEY", &["OPENAI_CUSTOM_HEADERS"])?;
71 let api_key = secrets.get("OPENAI_API_KEY").unwrap().clone();
72 let host: String = config
73 .get_param("OPENAI_HOST")
74 .unwrap_or_else(|_| "https://api.openai.com".to_string());
75 let base_path: String = config
76 .get_param("OPENAI_BASE_PATH")
77 .unwrap_or_else(|_| "v1/chat/completions".to_string());
78 let organization: Option<String> = config.get_param("OPENAI_ORGANIZATION").ok();
79 let project: Option<String> = config.get_param("OPENAI_PROJECT").ok();
80 let custom_headers: Option<HashMap<String, String>> = secrets
81 .get("OPENAI_CUSTOM_HEADERS")
82 .cloned()
83 .map(parse_custom_headers);
84 let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600);
85
86 let auth = AuthMethod::BearerToken(api_key);
87 let mut api_client =
88 ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
89
90 if let Some(org) = &organization {
91 api_client = api_client.with_header("OpenAI-Organization", org)?;
92 }
93
94 if let Some(project) = &project {
95 api_client = api_client.with_header("OpenAI-Project", project)?;
96 }
97
98 if let Some(headers) = &custom_headers {
99 let mut header_map = reqwest::header::HeaderMap::new();
100 for (key, value) in headers {
101 let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?;
102 let header_value = reqwest::header::HeaderValue::from_str(value)?;
103 header_map.insert(header_name, header_value);
104 }
105 api_client = api_client.with_headers(header_map)?;
106 }
107
108 Ok(Self {
109 api_client,
110 base_path,
111 organization,
112 project,
113 model,
114 custom_headers,
115 supports_streaming: true,
116 name: Self::metadata().name,
117 })
118 }
119
120 #[doc(hidden)]
121 pub fn new(api_client: ApiClient, model: ModelConfig) -> Self {
122 Self {
123 api_client,
124 base_path: "v1/chat/completions".to_string(),
125 organization: None,
126 project: None,
127 model,
128 custom_headers: None,
129 supports_streaming: true,
130 name: Self::metadata().name,
131 }
132 }
133
134 pub fn from_custom_config(
135 model: ModelConfig,
136 config: DeclarativeProviderConfig,
137 ) -> Result<Self> {
138 let global_config = crate::config::Config::global();
139 let api_key: String = global_config
140 .get_secret(&config.api_key_env)
141 .map_err(|_e| anyhow::anyhow!("Missing API key: {}", config.api_key_env))?;
142
143 let url = url::Url::parse(&config.base_url)
144 .map_err(|e| anyhow::anyhow!("Invalid base URL '{}': {}", config.base_url, e))?;
145
146 let host = if let Some(port) = url.port() {
147 format!(
148 "{}://{}:{}",
149 url.scheme(),
150 url.host_str().unwrap_or(""),
151 port
152 )
153 } else {
154 format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""))
155 };
156 let base_path = url.path().trim_start_matches('/').to_string();
157 let base_path = if base_path.is_empty() {
158 "v1/chat/completions".to_string()
159 } else {
160 base_path
161 };
162
163 let timeout_secs = config.timeout_seconds.unwrap_or(600);
164 let auth = AuthMethod::BearerToken(api_key);
165 let mut api_client =
166 ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
167
168 if let Some(headers) = &config.headers {
170 let mut header_map = reqwest::header::HeaderMap::new();
171 for (key, value) in headers {
172 let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?;
173 let header_value = reqwest::header::HeaderValue::from_str(value)?;
174 header_map.insert(header_name, header_value);
175 }
176 api_client = api_client.with_headers(header_map)?;
177 }
178
179 Ok(Self {
180 api_client,
181 base_path,
182 organization: None,
183 project: None,
184 model,
185 custom_headers: config.headers,
186 supports_streaming: config.supports_streaming.unwrap_or(true),
187 name: config.name.clone(),
188 })
189 }
190
191 fn uses_responses_api(model_name: &str) -> bool {
192 model_name.starts_with("gpt-5-codex") || model_name.starts_with("gpt-5.1-codex")
193 }
194
195 async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
196 let response = self
197 .api_client
198 .response_post(&self.base_path, payload)
199 .await?;
200 handle_response_openai_compat(response).await
201 }
202
203 async fn post_responses(&self, payload: &Value) -> Result<Value, ProviderError> {
204 let response = self
205 .api_client
206 .response_post("v1/responses", payload)
207 .await?;
208 handle_response_openai_compat(response).await
209 }
210}
211
212#[async_trait]
213impl Provider for OpenAiProvider {
214 fn metadata() -> ProviderMetadata {
215 let models = OPEN_AI_KNOWN_MODELS
216 .iter()
217 .map(|(name, limit)| ModelInfo::new(*name, *limit))
218 .collect();
219 ProviderMetadata::with_models(
220 "openai",
221 "OpenAI",
222 "GPT-4 and other OpenAI models, including OpenAI compatible ones",
223 OPEN_AI_DEFAULT_MODEL,
224 models,
225 OPEN_AI_DOC_URL,
226 vec![
227 ConfigKey::new("OPENAI_API_KEY", true, true, None),
228 ConfigKey::new("OPENAI_HOST", true, false, Some("https://api.openai.com")),
229 ConfigKey::new("OPENAI_BASE_PATH", true, false, Some("v1/chat/completions")),
230 ConfigKey::new("OPENAI_ORGANIZATION", false, false, None),
231 ConfigKey::new("OPENAI_PROJECT", false, false, None),
232 ConfigKey::new("OPENAI_CUSTOM_HEADERS", false, true, None),
233 ConfigKey::new("OPENAI_TIMEOUT", false, false, Some("600")),
234 ],
235 )
236 }
237
238 fn get_name(&self) -> &str {
239 &self.name
240 }
241
242 fn get_model_config(&self) -> ModelConfig {
243 self.model.clone()
244 }
245
246 #[tracing::instrument(
247 skip(self, model_config, system, messages, tools),
248 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
249 )]
250 async fn complete_with_model(
251 &self,
252 model_config: &ModelConfig,
253 system: &str,
254 messages: &[Message],
255 tools: &[Tool],
256 ) -> Result<(Message, ProviderUsage), ProviderError> {
257 if Self::uses_responses_api(&model_config.model_name) {
258 let payload = create_responses_request(model_config, system, messages, tools)?;
259 let mut log = RequestLog::start(&self.model, &payload)?;
260
261 let json_response = self
262 .with_retry(|| async {
263 let payload_clone = payload.clone();
264 self.post_responses(&payload_clone).await
265 })
266 .await
267 .inspect_err(|e| {
268 let _ = log.error(e);
269 })?;
270
271 let responses_api_response: ResponsesApiResponse =
272 serde_json::from_value(json_response.clone()).map_err(|e| {
273 ProviderError::ExecutionError(format!(
274 "Failed to parse responses API response: {}",
275 e
276 ))
277 })?;
278
279 let message = responses_api_to_message(&responses_api_response)?;
280 let usage = get_responses_usage(&responses_api_response);
281 let model = responses_api_response.model.clone();
282
283 log.write(&json_response, Some(&usage))?;
284 Ok((message, ProviderUsage::new(model, usage)))
285 } else {
286 let payload = create_request(
287 model_config,
288 system,
289 messages,
290 tools,
291 &ImageFormat::OpenAi,
292 false,
293 )?;
294
295 let mut log = RequestLog::start(&self.model, &payload)?;
296 let json_response = self
297 .with_retry(|| async {
298 let payload_clone = payload.clone();
299 self.post(&payload_clone).await
300 })
301 .await
302 .inspect_err(|e| {
303 let _ = log.error(e);
304 })?;
305
306 let message = response_to_message(&json_response)?;
307 let usage = json_response
308 .get("usage")
309 .map(get_usage)
310 .unwrap_or_else(|| {
311 tracing::debug!("Failed to get usage data");
312 Usage::default()
313 });
314
315 let model = get_model(&json_response);
316 log.write(&json_response, Some(&usage))?;
317 Ok((message, ProviderUsage::new(model, usage)))
318 }
319 }
320
321 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
322 let models_path = self.base_path.replace("v1/chat/completions", "v1/models");
323 let response = self.api_client.response_get(&models_path).await?;
324 let json = handle_response_openai_compat(response).await?;
325 if let Some(err_obj) = json.get("error") {
326 let msg = err_obj
327 .get("message")
328 .and_then(|v| v.as_str())
329 .unwrap_or("unknown error");
330 return Err(ProviderError::Authentication(msg.to_string()));
331 }
332
333 let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
334 ProviderError::UsageError("Missing data field in JSON response".into())
335 })?;
336 let mut models: Vec<String> = data
337 .iter()
338 .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
339 .collect();
340 models.sort();
341 Ok(Some(models))
342 }
343
344 fn supports_embeddings(&self) -> bool {
345 true
346 }
347
348 async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
349 EmbeddingCapable::create_embeddings(self, texts)
350 .await
351 .map_err(|e| ProviderError::ExecutionError(e.to_string()))
352 }
353
354 fn supports_streaming(&self) -> bool {
355 self.supports_streaming
356 }
357
358 async fn stream(
359 &self,
360 system: &str,
361 messages: &[Message],
362 tools: &[Tool],
363 ) -> Result<MessageStream, ProviderError> {
364 if Self::uses_responses_api(&self.model.model_name) {
365 let mut payload = create_responses_request(&self.model, system, messages, tools)?;
366 payload["stream"] = serde_json::Value::Bool(true);
367
368 let mut log = RequestLog::start(&self.model, &payload)?;
369
370 let response = self
371 .with_retry(|| async {
372 let payload_clone = payload.clone();
373 let resp = self
374 .api_client
375 .response_post("v1/responses", &payload_clone)
376 .await?;
377 handle_status_openai_compat(resp).await
378 })
379 .await
380 .inspect_err(|e| {
381 let _ = log.error(e);
382 })?;
383
384 let stream = response.bytes_stream().map_err(io::Error::other);
385
386 Ok(Box::pin(try_stream! {
387 let stream_reader = StreamReader::new(stream);
388 let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
389
390 let message_stream = responses_api_to_streaming_message(framed);
391 pin!(message_stream);
392 while let Some(message) = message_stream.next().await {
393 let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
394 log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
395 yield (message, usage);
396 }
397 }))
398 } else {
399 let payload = create_request(
400 &self.model,
401 system,
402 messages,
403 tools,
404 &ImageFormat::OpenAi,
405 true,
406 )?;
407 let mut log = RequestLog::start(&self.model, &payload)?;
408
409 let response = self
410 .with_retry(|| async {
411 let resp = self
412 .api_client
413 .response_post(&self.base_path, &payload)
414 .await?;
415 handle_status_openai_compat(resp).await
416 })
417 .await
418 .inspect_err(|e| {
419 let _ = log.error(e);
420 })?;
421
422 stream_openai_compat(response, log)
423 }
424 }
425}
426
427fn parse_custom_headers(s: String) -> HashMap<String, String> {
428 s.split(',')
429 .filter_map(|header| {
430 let mut parts = header.splitn(2, '=');
431 let key = parts.next().map(|s| s.trim().to_string())?;
432 let value = parts.next().map(|s| s.trim().to_string())?;
433 Some((key, value))
434 })
435 .collect()
436}
437
438#[async_trait]
439impl EmbeddingCapable for OpenAiProvider {
440 async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
441 if texts.is_empty() {
442 return Ok(vec![]);
443 }
444
445 let embedding_model = std::env::var("ASTER_EMBEDDING_MODEL")
446 .unwrap_or_else(|_| "text-embedding-3-small".to_string());
447
448 let request = EmbeddingRequest {
449 input: texts,
450 model: embedding_model,
451 };
452
453 let response = self
454 .with_retry(|| async {
455 let request_clone = EmbeddingRequest {
456 input: request.input.clone(),
457 model: request.model.clone(),
458 };
459 let request_value = serde_json::to_value(request_clone)
460 .map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
461 self.api_client
462 .api_post("v1/embeddings", &request_value)
463 .await
464 .map_err(|e| ProviderError::ExecutionError(e.to_string()))
465 })
466 .await?;
467
468 if response.status != StatusCode::OK {
469 let error_text = response
470 .payload
471 .as_ref()
472 .and_then(|p| p.as_str())
473 .unwrap_or("Unknown error");
474 return Err(anyhow::anyhow!("Embedding API error: {}", error_text));
475 }
476
477 let embedding_response: EmbeddingResponse = serde_json::from_value(
478 response
479 .payload
480 .ok_or_else(|| anyhow::anyhow!("Empty response body"))?,
481 )?;
482
483 Ok(embedding_response
484 .data
485 .into_iter()
486 .map(|d| d.embedding)
487 .collect())
488 }
489}