ai_lib/provider/
openai.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12/// OpenAI adapter, supporting GPT series models
13///
14/// OpenAI adapter supporting GPT series models
15pub struct OpenAiAdapter {
16    transport: DynHttpTransportRef,
17    api_key: String,
18    base_url: String,
19    metrics: Arc<dyn Metrics>,
20}
21
22impl OpenAiAdapter {
23    pub fn new() -> Result<Self, AiLibError> {
24        let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
25            AiLibError::AuthenticationError(
26                "OPENAI_API_KEY environment variable not set".to_string(),
27            )
28        })?;
29
30        Ok(Self {
31            transport: HttpTransport::new().boxed(),
32            api_key,
33            base_url: "https://api.openai.com/v1".to_string(),
34            metrics: Arc::new(NoopMetrics::new()),
35        })
36    }
37
38    /// Explicit API key override (takes precedence over env var) + optional base_url override.
39    pub fn new_with_overrides(
40        api_key: String,
41        base_url: Option<String>,
42    ) -> Result<Self, AiLibError> {
43        Ok(Self {
44            transport: HttpTransport::new().boxed(),
45            api_key,
46            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
47            metrics: Arc::new(NoopMetrics::new()),
48        })
49    }
50
51    /// Construct with an injected object-safe transport reference
52    pub fn with_transport_ref(
53        transport: DynHttpTransportRef,
54        api_key: String,
55        base_url: String,
56    ) -> Result<Self, AiLibError> {
57        Ok(Self {
58            transport,
59            api_key,
60            base_url,
61            metrics: Arc::new(NoopMetrics::new()),
62        })
63    }
64
65    pub fn with_transport_ref_and_metrics(
66        transport: DynHttpTransportRef,
67        api_key: String,
68        base_url: String,
69        metrics: Arc<dyn Metrics>,
70    ) -> Result<Self, AiLibError> {
71        Ok(Self {
72            transport,
73            api_key,
74            base_url,
75            metrics,
76        })
77    }
78
79    pub fn with_metrics(
80        api_key: String,
81        base_url: String,
82        metrics: Arc<dyn Metrics>,
83    ) -> Result<Self, AiLibError> {
84        Ok(Self {
85            transport: HttpTransport::new().boxed(),
86            api_key,
87            base_url,
88            metrics,
89        })
90    }
91
92    #[allow(dead_code)]
93    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
94        // Synchronous converter: do not perform provider uploads, just inline content
95        let mut openai_request = serde_json::json!({
96            "model": request.model,
97            "messages": serde_json::Value::Array(vec![])
98        });
99
100        let mut msgs: Vec<serde_json::Value> = Vec::new();
101        for msg in request.messages.iter() {
102            let role = match msg.role {
103                Role::System => "system",
104                Role::User => "user",
105                Role::Assistant => "assistant",
106            };
107            let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
108            msgs.push(serde_json::json!({"role": role, "content": content_val}));
109        }
110        openai_request["messages"] = serde_json::Value::Array(msgs);
111        openai_request
112    }
113
114    /// Async version that can upload local files to OpenAI before constructing the request
115    async fn convert_request_async(
116        &self,
117        request: &ChatCompletionRequest,
118    ) -> Result<serde_json::Value, AiLibError> {
119        // Build the OpenAI-compatible request JSON. For now we avoid provider-specific
120        // upload flows here and rely on the generic provider utils (which may inline files)
121        // to produce content JSON values.
122        let mut openai_request = serde_json::json!({
123            "model": request.model,
124            "messages": serde_json::Value::Array(vec![])
125        });
126
127        let mut msgs: Vec<serde_json::Value> = Vec::new();
128        for msg in request.messages.iter() {
129            let role = match msg.role {
130                Role::System => "system",
131                Role::User => "user",
132                Role::Assistant => "assistant",
133            };
134
135            // If it's an Image with no URL but has a local `name`, attempt async upload to OpenAI
136            let content_val = match &msg.content {
137                crate::types::common::Content::Image { url, mime: _, name } => {
138                    if url.is_some() {
139                        crate::provider::utils::content_to_provider_value(&msg.content)
140                    } else if let Some(n) = name {
141                        // Try provider upload; fall back to inline behavior on error
142                        let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
143                        match crate::provider::utils::upload_file_with_transport(
144                            Some(self.transport.clone()),
145                            &upload_url,
146                            n,
147                            "file",
148                        )
149                        .await
150                        {
151                            Ok(remote) => {
152                                // remote may be a full URL, a data: URL, or a provider file id.
153                                if remote.starts_with("http://")
154                                    || remote.starts_with("https://")
155                                    || remote.starts_with("data:")
156                                {
157                                    serde_json::json!({"image": {"url": remote}})
158                                } else {
159                                    // Treat as provider file id
160                                    serde_json::json!({"image": {"file_id": remote}})
161                                }
162                            }
163                            Err(_) => {
164                                crate::provider::utils::content_to_provider_value(&msg.content)
165                            }
166                        }
167                    } else {
168                        crate::provider::utils::content_to_provider_value(&msg.content)
169                    }
170                }
171                _ => crate::provider::utils::content_to_provider_value(&msg.content),
172            };
173            msgs.push(serde_json::json!({"role": role, "content": content_val}));
174        }
175
176        openai_request["messages"] = serde_json::Value::Array(msgs);
177
178        // Optional params
179        if let Some(temp) = request.temperature {
180            openai_request["temperature"] =
181                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
182        }
183        if let Some(max_tokens) = request.max_tokens {
184            openai_request["max_tokens"] =
185                serde_json::Value::Number(serde_json::Number::from(max_tokens));
186        }
187        if let Some(top_p) = request.top_p {
188            openai_request["top_p"] =
189                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
190        }
191        if let Some(freq_penalty) = request.frequency_penalty {
192            openai_request["frequency_penalty"] = serde_json::Value::Number(
193                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
194            );
195        }
196        if let Some(presence_penalty) = request.presence_penalty {
197            openai_request["presence_penalty"] = serde_json::Value::Number(
198                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
199            );
200        }
201
202        // Add function calling definitions if provided
203        if let Some(functions) = &request.functions {
204            openai_request["functions"] =
205                serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
206        }
207
208        // function_call policy may be set to control OpenAI behavior
209        if let Some(policy) = &request.function_call {
210            match policy {
211                crate::types::function_call::FunctionCallPolicy::None => {
212                    openai_request["function_call"] = serde_json::Value::String("none".to_string());
213                }
214                crate::types::function_call::FunctionCallPolicy::Auto(name) => {
215                    if name == "auto" {
216                        openai_request["function_call"] =
217                            serde_json::Value::String("auto".to_string());
218                    } else {
219                        openai_request["function_call"] = serde_json::Value::String(name.clone());
220                    }
221                }
222            }
223        }
224
225        Ok(openai_request)
226    }
227
228    // Note: provider-specific upload helpers were removed to avoid blocking the async
229    // runtime. Use `crate::provider::utils::upload_file_to_provider` (async) if provider
230    // upload behavior is desired; it will be integrated in a future change.
231
232    fn parse_response(
233        &self,
234        response: serde_json::Value,
235    ) -> Result<ChatCompletionResponse, AiLibError> {
236        let choices = response["choices"]
237            .as_array()
238            .ok_or_else(|| {
239                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
240            })?
241            .iter()
242            .enumerate()
243            .map(|(index, choice)| {
244                let message = choice["message"].as_object().ok_or_else(|| {
245                    AiLibError::ProviderError("Invalid choice format".to_string())
246                })?;
247
248                let role = match message["role"].as_str().unwrap_or("user") {
249                    "system" => Role::System,
250                    "assistant" => Role::Assistant,
251                    _ => Role::User,
252                };
253
254                let content = message["content"].as_str().unwrap_or("").to_string();
255
256                // Build the Message and try to populate a typed FunctionCall if provided by the provider
257                let mut msg_obj = Message {
258                    role,
259                    content: crate::types::common::Content::Text(content.clone()),
260                    function_call: None,
261                };
262
263                if let Some(fc_val) = message.get("function_call").cloned() {
264                    // Try direct deserialization into our typed FunctionCall first
265                    match serde_json::from_value::<crate::types::function_call::FunctionCall>(
266                        fc_val.clone(),
267                    ) {
268                        Ok(fc) => {
269                            msg_obj.function_call = Some(fc);
270                        }
271                        Err(_) => {
272                            // Fallback: some providers return `arguments` as a JSON-encoded string.
273                            let name = fc_val
274                                .get("name")
275                                .and_then(|v| v.as_str())
276                                .unwrap_or_default()
277                                .to_string();
278                            let args_val = match fc_val.get("arguments") {
279                                Some(a) if a.is_string() => {
280                                    // Parse stringified JSON
281                                    a.as_str()
282                                        .and_then(|s| {
283                                            serde_json::from_str::<serde_json::Value>(s).ok()
284                                        })
285                                        .unwrap_or(serde_json::Value::Null)
286                                }
287                                Some(a) => a.clone(),
288                                None => serde_json::Value::Null,
289                            };
290                            msg_obj.function_call =
291                                Some(crate::types::function_call::FunctionCall {
292                                    name,
293                                    arguments: Some(args_val),
294                                });
295                        }
296                    }
297                }
298                Ok(Choice {
299                    index: index as u32,
300                    message: msg_obj,
301                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
302                })
303            })
304            .collect::<Result<Vec<_>, AiLibError>>()?;
305
306        let usage = response["usage"].as_object().ok_or_else(|| {
307            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
308        })?;
309
310        let usage = Usage {
311            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
312            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
313            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
314        };
315
316        Ok(ChatCompletionResponse {
317            id: response["id"].as_str().unwrap_or("").to_string(),
318            object: response["object"].as_str().unwrap_or("").to_string(),
319            created: response["created"].as_u64().unwrap_or(0),
320            model: response["model"].as_str().unwrap_or("").to_string(),
321            choices,
322            usage,
323        })
324    }
325}
326
327#[async_trait::async_trait]
328impl ChatApi for OpenAiAdapter {
329    async fn chat_completion(
330        &self,
331        request: ChatCompletionRequest,
332    ) -> Result<ChatCompletionResponse, AiLibError> {
333        // Use async converter which may perform provider uploads
334        let openai_request = self
335            .convert_request_async(&request)
336            .await
337            .unwrap_or(serde_json::json!({}));
338        let url = format!("{}/chat/completions", self.base_url);
339
340        // Record a request counter and start a timer using standardized keys
341        self.metrics.incr_counter("openai.requests", 1).await;
342        let timer = self.metrics.start_timer("openai.request_duration_ms").await;
343
344        let mut headers = HashMap::new();
345        headers.insert(
346            "Authorization".to_string(),
347            format!("Bearer {}", self.api_key),
348        );
349        headers.insert("Content-Type".to_string(), "application/json".to_string());
350
351        let response = match self
352            .transport
353            .post_json(&url, Some(headers), openai_request)
354            .await
355        {
356            Ok(v) => {
357                if let Some(t) = timer {
358                    t.stop();
359                }
360                v
361            }
362            Err(e) => {
363                if let Some(t) = timer {
364                    t.stop();
365                }
366                return Err(e);
367            }
368        };
369
370        self.parse_response(response)
371    }
372
373    async fn chat_completion_stream(
374        &self,
375        _request: ChatCompletionRequest,
376    ) -> Result<
377        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
378        AiLibError,
379    > {
380        let stream = stream::empty();
381        Ok(Box::new(Box::pin(stream)))
382    }
383
384    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
385        let url = format!("{}/models", self.base_url);
386        let mut headers = HashMap::new();
387        headers.insert(
388            "Authorization".to_string(),
389            format!("Bearer {}", self.api_key),
390        );
391
392        let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
393
394        Ok(response["data"]
395            .as_array()
396            .unwrap_or(&vec![])
397            .iter()
398            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
399            .collect())
400    }
401
402    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
403        Ok(ModelInfo {
404            id: model_id.to_string(),
405            object: "model".to_string(),
406            created: 0,
407            owned_by: "openai".to_string(),
408            permission: vec![ModelPermission {
409                id: "default".to_string(),
410                object: "model_permission".to_string(),
411                created: 0,
412                allow_create_engine: false,
413                allow_sampling: true,
414                allow_logprobs: false,
415                allow_search_indices: false,
416                allow_view: true,
417                allow_fine_tuning: false,
418                organization: "*".to_string(),
419                group: None,
420                is_blocking: false,
421            }],
422        })
423    }
424}