ai_lib/provider/
openai.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9#[cfg(feature = "unified_transport")]
10use std::time::Duration;
11use std::env;
12use std::sync::Arc;
13
14/// OpenAI adapter, supporting GPT series models
15///
16/// OpenAI adapter supporting GPT series models
17pub struct OpenAiAdapter {
18    transport: DynHttpTransportRef,
19    api_key: String,
20    base_url: String,
21    metrics: Arc<dyn Metrics>,
22}
23
24impl OpenAiAdapter {
25    #[allow(dead_code)]
26    fn build_default_timeout_secs() -> u64 {
27        std::env::var("AI_HTTP_TIMEOUT_SECS")
28            .ok()
29            .and_then(|s| s.parse::<u64>().ok())
30            .unwrap_or(30)
31    }
32
33    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
34        #[cfg(feature = "unified_transport")]
35        {
36            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
37            let client = crate::transport::client_factory::build_shared_client()
38                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
39            let t = HttpTransport::with_reqwest_client(client, timeout);
40            return Ok(t.boxed());
41        }
42        #[cfg(not(feature = "unified_transport"))]
43        {
44            let t = HttpTransport::new();
45            return Ok(t.boxed());
46        }
47    }
48
49    pub fn new() -> Result<Self, AiLibError> {
50        let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
51            AiLibError::AuthenticationError(
52                "OPENAI_API_KEY environment variable not set".to_string(),
53            )
54        })?;
55
56        Ok(Self {
57            transport: Self::build_default_transport()?,
58            api_key,
59            base_url: "https://api.openai.com/v1".to_string(),
60            metrics: Arc::new(NoopMetrics::new()),
61        })
62    }
63
64    /// Explicit API key override (takes precedence over env var) + optional base_url override.
65    pub fn new_with_overrides(
66        api_key: String,
67        base_url: Option<String>,
68    ) -> Result<Self, AiLibError> {
69        Ok(Self {
70            transport: Self::build_default_transport()?,
71            api_key,
72            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
73            metrics: Arc::new(NoopMetrics::new()),
74        })
75    }
76
77    /// Construct with an injected object-safe transport reference
78    pub fn with_transport_ref(
79        transport: DynHttpTransportRef,
80        api_key: String,
81        base_url: String,
82    ) -> Result<Self, AiLibError> {
83        Ok(Self {
84            transport,
85            api_key,
86            base_url,
87            metrics: Arc::new(NoopMetrics::new()),
88        })
89    }
90
91    pub fn with_transport_ref_and_metrics(
92        transport: DynHttpTransportRef,
93        api_key: String,
94        base_url: String,
95        metrics: Arc<dyn Metrics>,
96    ) -> Result<Self, AiLibError> {
97        Ok(Self {
98            transport,
99            api_key,
100            base_url,
101            metrics,
102        })
103    }
104
105    pub fn with_metrics(
106        api_key: String,
107        base_url: String,
108        metrics: Arc<dyn Metrics>,
109    ) -> Result<Self, AiLibError> {
110        Ok(Self {
111            transport: Self::build_default_transport()?,
112            api_key,
113            base_url,
114            metrics,
115        })
116    }
117
118    #[allow(dead_code)]
119    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
120        // Synchronous converter: do not perform provider uploads, just inline content
121        let mut openai_request = serde_json::json!({
122            "model": request.model,
123            "messages": serde_json::Value::Array(vec![])
124        });
125
126        let mut msgs: Vec<serde_json::Value> = Vec::new();
127        for msg in request.messages.iter() {
128            let role = match msg.role {
129                Role::System => "system",
130                Role::User => "user",
131                Role::Assistant => "assistant",
132            };
133            let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
134            msgs.push(serde_json::json!({"role": role, "content": content_val}));
135        }
136        openai_request["messages"] = serde_json::Value::Array(msgs);
137        openai_request
138    }
139
140    /// Async version that can upload local files to OpenAI before constructing the request
141    async fn convert_request_async(
142        &self,
143        request: &ChatCompletionRequest,
144    ) -> Result<serde_json::Value, AiLibError> {
145        // Build the OpenAI-compatible request JSON. For now we avoid provider-specific
146        // upload flows here and rely on the generic provider utils (which may inline files)
147        // to produce content JSON values.
148        let mut openai_request = serde_json::json!({
149            "model": request.model,
150            "messages": serde_json::Value::Array(vec![])
151        });
152
153        let mut msgs: Vec<serde_json::Value> = Vec::new();
154        for msg in request.messages.iter() {
155            let role = match msg.role {
156                Role::System => "system",
157                Role::User => "user",
158                Role::Assistant => "assistant",
159            };
160
161            // If it's an Image with no URL but has a local `name`, attempt async upload to OpenAI
162            let content_val = match &msg.content {
163                crate::types::common::Content::Image { url, mime: _, name } => {
164                    if url.is_some() {
165                        crate::provider::utils::content_to_provider_value(&msg.content)
166                    } else if let Some(n) = name {
167                        // Try provider upload; fall back to inline behavior on error
168                        let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
169                        match crate::provider::utils::upload_file_with_transport(
170                            Some(self.transport.clone()),
171                            &upload_url,
172                            n,
173                            "file",
174                        )
175                        .await
176                        {
177                            Ok(remote) => {
178                                // remote may be a full URL, a data: URL, or a provider file id.
179                                if remote.starts_with("http://")
180                                    || remote.starts_with("https://")
181                                    || remote.starts_with("data:")
182                                {
183                                    serde_json::json!({"image": {"url": remote}})
184                                } else {
185                                    // Treat as provider file id
186                                    serde_json::json!({"image": {"file_id": remote}})
187                                }
188                            }
189                            Err(_) => {
190                                crate::provider::utils::content_to_provider_value(&msg.content)
191                            }
192                        }
193                    } else {
194                        crate::provider::utils::content_to_provider_value(&msg.content)
195                    }
196                }
197                _ => crate::provider::utils::content_to_provider_value(&msg.content),
198            };
199            msgs.push(serde_json::json!({"role": role, "content": content_val}));
200        }
201
202        openai_request["messages"] = serde_json::Value::Array(msgs);
203
204        // Optional params
205        if let Some(temp) = request.temperature {
206            openai_request["temperature"] =
207                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
208        }
209        if let Some(max_tokens) = request.max_tokens {
210            openai_request["max_tokens"] =
211                serde_json::Value::Number(serde_json::Number::from(max_tokens));
212        }
213        if let Some(top_p) = request.top_p {
214            openai_request["top_p"] =
215                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
216        }
217        if let Some(freq_penalty) = request.frequency_penalty {
218            openai_request["frequency_penalty"] = serde_json::Value::Number(
219                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
220            );
221        }
222        if let Some(presence_penalty) = request.presence_penalty {
223            openai_request["presence_penalty"] = serde_json::Value::Number(
224                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
225            );
226        }
227
228        // Add function calling definitions if provided
229        if let Some(functions) = &request.functions {
230            openai_request["functions"] =
231                serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
232        }
233
234        // function_call policy may be set to control OpenAI behavior
235        if let Some(policy) = &request.function_call {
236            match policy {
237                crate::types::function_call::FunctionCallPolicy::None => {
238                    openai_request["function_call"] = serde_json::Value::String("none".to_string());
239                }
240                crate::types::function_call::FunctionCallPolicy::Auto(name) => {
241                    if name == "auto" {
242                        openai_request["function_call"] =
243                            serde_json::Value::String("auto".to_string());
244                    } else {
245                        openai_request["function_call"] = serde_json::Value::String(name.clone());
246                    }
247                }
248            }
249        }
250
251        Ok(openai_request)
252    }
253
254    // Note: provider-specific upload helpers were removed to avoid blocking the async
255    // runtime. Use `crate::provider::utils::upload_file_to_provider` (async) if provider
256    // upload behavior is desired; it will be integrated in a future change.
257
258    fn parse_response(
259        &self,
260        response: serde_json::Value,
261    ) -> Result<ChatCompletionResponse, AiLibError> {
262        let choices = response["choices"]
263            .as_array()
264            .ok_or_else(|| {
265                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
266            })?
267            .iter()
268            .enumerate()
269            .map(|(index, choice)| {
270                let message = choice["message"].as_object().ok_or_else(|| {
271                    AiLibError::ProviderError("Invalid choice format".to_string())
272                })?;
273
274                let role = match message["role"].as_str().unwrap_or("user") {
275                    "system" => Role::System,
276                    "assistant" => Role::Assistant,
277                    _ => Role::User,
278                };
279
280                let content = message["content"].as_str().unwrap_or("").to_string();
281
282                // Build the Message and try to populate a typed FunctionCall if provided by the provider
283                let mut msg_obj = Message {
284                    role,
285                    content: crate::types::common::Content::Text(content.clone()),
286                    function_call: None,
287                };
288
289                if let Some(fc_val) = message.get("function_call").cloned() {
290                    // Try direct deserialization into our typed FunctionCall first
291                    match serde_json::from_value::<crate::types::function_call::FunctionCall>(
292                        fc_val.clone(),
293                    ) {
294                        Ok(fc) => {
295                            msg_obj.function_call = Some(fc);
296                        }
297                        Err(_) => {
298                            // Fallback: some providers return `arguments` as a JSON-encoded string.
299                            let name = fc_val
300                                .get("name")
301                                .and_then(|v| v.as_str())
302                                .unwrap_or_default()
303                                .to_string();
304                            let args_val = match fc_val.get("arguments") {
305                                Some(a) if a.is_string() => {
306                                    // Parse stringified JSON
307                                    a.as_str()
308                                        .and_then(|s| {
309                                            serde_json::from_str::<serde_json::Value>(s).ok()
310                                        })
311                                        .unwrap_or(serde_json::Value::Null)
312                                }
313                                Some(a) => a.clone(),
314                                None => serde_json::Value::Null,
315                            };
316                            msg_obj.function_call =
317                                Some(crate::types::function_call::FunctionCall {
318                                    name,
319                                    arguments: Some(args_val),
320                                });
321                        }
322                    }
323                }
324                Ok(Choice {
325                    index: index as u32,
326                    message: msg_obj,
327                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
328                })
329            })
330            .collect::<Result<Vec<_>, AiLibError>>()?;
331
332        let usage = response["usage"].as_object().ok_or_else(|| {
333            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
334        })?;
335
336        let usage = Usage {
337            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
338            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
339            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
340        };
341
342        Ok(ChatCompletionResponse {
343            id: response["id"].as_str().unwrap_or("").to_string(),
344            object: response["object"].as_str().unwrap_or("").to_string(),
345            created: response["created"].as_u64().unwrap_or(0),
346            model: response["model"].as_str().unwrap_or("").to_string(),
347            choices,
348            usage,
349            usage_status: UsageStatus::Finalized, // OpenAI provides accurate usage data
350        })
351    }
352}
353
354#[async_trait::async_trait]
355impl ChatApi for OpenAiAdapter {
356    async fn chat_completion(
357        &self,
358        request: ChatCompletionRequest,
359    ) -> Result<ChatCompletionResponse, AiLibError> {
360        // Record a request counter and start a timer using standardized keys
361        self.metrics.incr_counter("openai.requests", 1).await;
362        let timer = self.metrics.start_timer("openai.request_duration_ms").await;
363        let url = format!("{}/chat/completions", self.base_url);
364
365        // Build request body via converter
366        let openai_request = self.convert_request_async(&request).await?;
367
368        // Use unified transport
369        let mut headers = HashMap::new();
370        headers.insert(
371            "Authorization".to_string(),
372            format!("Bearer {}", self.api_key),
373        );
374        headers.insert("Content-Type".to_string(), "application/json".to_string());
375
376        let response_json = self
377            .transport
378            .post_json(&url, Some(headers), openai_request)
379            .await?;
380
381        if let Some(t) = timer {
382            t.stop();
383        }
384
385        self.parse_response(response_json)
386    }
387
388    async fn chat_completion_stream(
389        &self,
390        _request: ChatCompletionRequest,
391    ) -> Result<
392        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
393        AiLibError,
394    > {
395        let stream = stream::empty();
396        Ok(Box::new(Box::pin(stream)))
397    }
398
399    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
400        let url = format!("{}/models", self.base_url);
401        let mut headers = HashMap::new();
402        headers.insert(
403            "Authorization".to_string(),
404            format!("Bearer {}", self.api_key),
405        );
406
407        let response = self.transport.get_json(&url, Some(headers)).await?;
408
409        Ok(response["data"]
410            .as_array()
411            .unwrap_or(&vec![])
412            .iter()
413            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
414            .collect())
415    }
416
417    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
418        Ok(ModelInfo {
419            id: model_id.to_string(),
420            object: "model".to_string(),
421            created: 0,
422            owned_by: "openai".to_string(),
423            permission: vec![ModelPermission {
424                id: "default".to_string(),
425                object: "model_permission".to_string(),
426                created: 0,
427                allow_create_engine: false,
428                allow_sampling: true,
429                allow_logprobs: false,
430                allow_search_indices: false,
431                allow_view: true,
432                allow_fine_tuning: false,
433                organization: "*".to_string(),
434                group: None,
435                is_blocking: false,
436            }],
437        })
438    }
439}