ai_lib/provider/
openai.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9#[cfg(feature = "unified_transport")]
10use std::time::Duration;
11use std::env;
12use std::sync::Arc;
13
14/// OpenAI adapter, supporting GPT series models
15///
16/// OpenAI adapter supporting GPT series models
17pub struct OpenAiAdapter {
18    transport: DynHttpTransportRef,
19    api_key: String,
20    base_url: String,
21    metrics: Arc<dyn Metrics>,
22}
23
24impl OpenAiAdapter {
25    fn build_default_timeout_secs() -> u64 {
26        std::env::var("AI_HTTP_TIMEOUT_SECS")
27            .ok()
28            .and_then(|s| s.parse::<u64>().ok())
29            .unwrap_or(30)
30    }
31
32    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
33        #[cfg(feature = "unified_transport")]
34        {
35            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
36            let client = crate::transport::client_factory::build_shared_client()
37                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
38            let t = HttpTransport::with_reqwest_client(client, timeout);
39            return Ok(t.boxed());
40        }
41        #[cfg(not(feature = "unified_transport"))]
42        {
43            let t = HttpTransport::new();
44            return Ok(t.boxed());
45        }
46    }
47
48    pub fn new() -> Result<Self, AiLibError> {
49        let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
50            AiLibError::AuthenticationError(
51                "OPENAI_API_KEY environment variable not set".to_string(),
52            )
53        })?;
54
55        Ok(Self {
56            transport: Self::build_default_transport()?,
57            api_key,
58            base_url: "https://api.openai.com/v1".to_string(),
59            metrics: Arc::new(NoopMetrics::new()),
60        })
61    }
62
63    /// Explicit API key override (takes precedence over env var) + optional base_url override.
64    pub fn new_with_overrides(
65        api_key: String,
66        base_url: Option<String>,
67    ) -> Result<Self, AiLibError> {
68        Ok(Self {
69            transport: Self::build_default_transport()?,
70            api_key,
71            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
72            metrics: Arc::new(NoopMetrics::new()),
73        })
74    }
75
76    /// Construct with an injected object-safe transport reference
77    pub fn with_transport_ref(
78        transport: DynHttpTransportRef,
79        api_key: String,
80        base_url: String,
81    ) -> Result<Self, AiLibError> {
82        Ok(Self {
83            transport,
84            api_key,
85            base_url,
86            metrics: Arc::new(NoopMetrics::new()),
87        })
88    }
89
90    pub fn with_transport_ref_and_metrics(
91        transport: DynHttpTransportRef,
92        api_key: String,
93        base_url: String,
94        metrics: Arc<dyn Metrics>,
95    ) -> Result<Self, AiLibError> {
96        Ok(Self {
97            transport,
98            api_key,
99            base_url,
100            metrics,
101        })
102    }
103
104    pub fn with_metrics(
105        api_key: String,
106        base_url: String,
107        metrics: Arc<dyn Metrics>,
108    ) -> Result<Self, AiLibError> {
109        Ok(Self {
110            transport: Self::build_default_transport()?,
111            api_key,
112            base_url,
113            metrics,
114        })
115    }
116
117    #[allow(dead_code)]
118    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
119        // Synchronous converter: do not perform provider uploads, just inline content
120        let mut openai_request = serde_json::json!({
121            "model": request.model,
122            "messages": serde_json::Value::Array(vec![])
123        });
124
125        let mut msgs: Vec<serde_json::Value> = Vec::new();
126        for msg in request.messages.iter() {
127            let role = match msg.role {
128                Role::System => "system",
129                Role::User => "user",
130                Role::Assistant => "assistant",
131            };
132            let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
133            msgs.push(serde_json::json!({"role": role, "content": content_val}));
134        }
135        openai_request["messages"] = serde_json::Value::Array(msgs);
136        openai_request
137    }
138
139    /// Async version that can upload local files to OpenAI before constructing the request
140    async fn convert_request_async(
141        &self,
142        request: &ChatCompletionRequest,
143    ) -> Result<serde_json::Value, AiLibError> {
144        // Build the OpenAI-compatible request JSON. For now we avoid provider-specific
145        // upload flows here and rely on the generic provider utils (which may inline files)
146        // to produce content JSON values.
147        let mut openai_request = serde_json::json!({
148            "model": request.model,
149            "messages": serde_json::Value::Array(vec![])
150        });
151
152        let mut msgs: Vec<serde_json::Value> = Vec::new();
153        for msg in request.messages.iter() {
154            let role = match msg.role {
155                Role::System => "system",
156                Role::User => "user",
157                Role::Assistant => "assistant",
158            };
159
160            // If it's an Image with no URL but has a local `name`, attempt async upload to OpenAI
161            let content_val = match &msg.content {
162                crate::types::common::Content::Image { url, mime: _, name } => {
163                    if url.is_some() {
164                        crate::provider::utils::content_to_provider_value(&msg.content)
165                    } else if let Some(n) = name {
166                        // Try provider upload; fall back to inline behavior on error
167                        let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
168                        match crate::provider::utils::upload_file_with_transport(
169                            Some(self.transport.clone()),
170                            &upload_url,
171                            n,
172                            "file",
173                        )
174                        .await
175                        {
176                            Ok(remote) => {
177                                // remote may be a full URL, a data: URL, or a provider file id.
178                                if remote.starts_with("http://")
179                                    || remote.starts_with("https://")
180                                    || remote.starts_with("data:")
181                                {
182                                    serde_json::json!({"image": {"url": remote}})
183                                } else {
184                                    // Treat as provider file id
185                                    serde_json::json!({"image": {"file_id": remote}})
186                                }
187                            }
188                            Err(_) => {
189                                crate::provider::utils::content_to_provider_value(&msg.content)
190                            }
191                        }
192                    } else {
193                        crate::provider::utils::content_to_provider_value(&msg.content)
194                    }
195                }
196                _ => crate::provider::utils::content_to_provider_value(&msg.content),
197            };
198            msgs.push(serde_json::json!({"role": role, "content": content_val}));
199        }
200
201        openai_request["messages"] = serde_json::Value::Array(msgs);
202
203        // Optional params
204        if let Some(temp) = request.temperature {
205            openai_request["temperature"] =
206                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
207        }
208        if let Some(max_tokens) = request.max_tokens {
209            openai_request["max_tokens"] =
210                serde_json::Value::Number(serde_json::Number::from(max_tokens));
211        }
212        if let Some(top_p) = request.top_p {
213            openai_request["top_p"] =
214                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
215        }
216        if let Some(freq_penalty) = request.frequency_penalty {
217            openai_request["frequency_penalty"] = serde_json::Value::Number(
218                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
219            );
220        }
221        if let Some(presence_penalty) = request.presence_penalty {
222            openai_request["presence_penalty"] = serde_json::Value::Number(
223                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
224            );
225        }
226
227        // Add function calling definitions if provided
228        if let Some(functions) = &request.functions {
229            openai_request["functions"] =
230                serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
231        }
232
233        // function_call policy may be set to control OpenAI behavior
234        if let Some(policy) = &request.function_call {
235            match policy {
236                crate::types::function_call::FunctionCallPolicy::None => {
237                    openai_request["function_call"] = serde_json::Value::String("none".to_string());
238                }
239                crate::types::function_call::FunctionCallPolicy::Auto(name) => {
240                    if name == "auto" {
241                        openai_request["function_call"] =
242                            serde_json::Value::String("auto".to_string());
243                    } else {
244                        openai_request["function_call"] = serde_json::Value::String(name.clone());
245                    }
246                }
247            }
248        }
249
250        Ok(openai_request)
251    }
252
253    // Note: provider-specific upload helpers were removed to avoid blocking the async
254    // runtime. Use `crate::provider::utils::upload_file_to_provider` (async) if provider
255    // upload behavior is desired; it will be integrated in a future change.
256
257    fn parse_response(
258        &self,
259        response: serde_json::Value,
260    ) -> Result<ChatCompletionResponse, AiLibError> {
261        let choices = response["choices"]
262            .as_array()
263            .ok_or_else(|| {
264                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
265            })?
266            .iter()
267            .enumerate()
268            .map(|(index, choice)| {
269                let message = choice["message"].as_object().ok_or_else(|| {
270                    AiLibError::ProviderError("Invalid choice format".to_string())
271                })?;
272
273                let role = match message["role"].as_str().unwrap_or("user") {
274                    "system" => Role::System,
275                    "assistant" => Role::Assistant,
276                    _ => Role::User,
277                };
278
279                let content = message["content"].as_str().unwrap_or("").to_string();
280
281                // Build the Message and try to populate a typed FunctionCall if provided by the provider
282                let mut msg_obj = Message {
283                    role,
284                    content: crate::types::common::Content::Text(content.clone()),
285                    function_call: None,
286                };
287
288                if let Some(fc_val) = message.get("function_call").cloned() {
289                    // Try direct deserialization into our typed FunctionCall first
290                    match serde_json::from_value::<crate::types::function_call::FunctionCall>(
291                        fc_val.clone(),
292                    ) {
293                        Ok(fc) => {
294                            msg_obj.function_call = Some(fc);
295                        }
296                        Err(_) => {
297                            // Fallback: some providers return `arguments` as a JSON-encoded string.
298                            let name = fc_val
299                                .get("name")
300                                .and_then(|v| v.as_str())
301                                .unwrap_or_default()
302                                .to_string();
303                            let args_val = match fc_val.get("arguments") {
304                                Some(a) if a.is_string() => {
305                                    // Parse stringified JSON
306                                    a.as_str()
307                                        .and_then(|s| {
308                                            serde_json::from_str::<serde_json::Value>(s).ok()
309                                        })
310                                        .unwrap_or(serde_json::Value::Null)
311                                }
312                                Some(a) => a.clone(),
313                                None => serde_json::Value::Null,
314                            };
315                            msg_obj.function_call =
316                                Some(crate::types::function_call::FunctionCall {
317                                    name,
318                                    arguments: Some(args_val),
319                                });
320                        }
321                    }
322                }
323                Ok(Choice {
324                    index: index as u32,
325                    message: msg_obj,
326                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
327                })
328            })
329            .collect::<Result<Vec<_>, AiLibError>>()?;
330
331        let usage = response["usage"].as_object().ok_or_else(|| {
332            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
333        })?;
334
335        let usage = Usage {
336            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
337            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
338            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
339        };
340
341        Ok(ChatCompletionResponse {
342            id: response["id"].as_str().unwrap_or("").to_string(),
343            object: response["object"].as_str().unwrap_or("").to_string(),
344            created: response["created"].as_u64().unwrap_or(0),
345            model: response["model"].as_str().unwrap_or("").to_string(),
346            choices,
347            usage,
348        })
349    }
350}
351
352#[async_trait::async_trait]
353impl ChatApi for OpenAiAdapter {
354    async fn chat_completion(
355        &self,
356        request: ChatCompletionRequest,
357    ) -> Result<ChatCompletionResponse, AiLibError> {
358        // Record a request counter and start a timer using standardized keys
359        self.metrics.incr_counter("openai.requests", 1).await;
360        let timer = self.metrics.start_timer("openai.request_duration_ms").await;
361        let url = format!("{}/chat/completions", self.base_url);
362
363        // Build request body via converter
364        let openai_request = self.convert_request_async(&request).await?;
365
366        // Use unified transport
367        let mut headers = HashMap::new();
368        headers.insert(
369            "Authorization".to_string(),
370            format!("Bearer {}", self.api_key),
371        );
372        headers.insert("Content-Type".to_string(), "application/json".to_string());
373
374        let response_json = self
375            .transport
376            .post_json(&url, Some(headers), openai_request)
377            .await?;
378
379        if let Some(t) = timer {
380            t.stop();
381        }
382
383        self.parse_response(response_json)
384    }
385
386    async fn chat_completion_stream(
387        &self,
388        _request: ChatCompletionRequest,
389    ) -> Result<
390        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
391        AiLibError,
392    > {
393        let stream = stream::empty();
394        Ok(Box::new(Box::pin(stream)))
395    }
396
397    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
398        let url = format!("{}/models", self.base_url);
399        let mut headers = HashMap::new();
400        headers.insert(
401            "Authorization".to_string(),
402            format!("Bearer {}", self.api_key),
403        );
404
405        let response = self.transport.get_json(&url, Some(headers)).await?;
406
407        Ok(response["data"]
408            .as_array()
409            .unwrap_or(&vec![])
410            .iter()
411            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
412            .collect())
413    }
414
415    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
416        Ok(ModelInfo {
417            id: model_id.to_string(),
418            object: "model".to_string(),
419            created: 0,
420            owned_by: "openai".to_string(),
421            permission: vec![ModelPermission {
422                id: "default".to_string(),
423                object: "model_permission".to_string(),
424                created: 0,
425                allow_create_engine: false,
426                allow_sampling: true,
427                allow_logprobs: false,
428                allow_search_indices: false,
429                allow_view: true,
430                allow_fine_tuning: false,
431                organization: "*".to_string(),
432                group: None,
433                is_blocking: false,
434            }],
435        })
436    }
437}