ai_lib/provider/
openai.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12/// OpenAI适配器,支持GPT系列模型
13///
14/// OpenAI adapter supporting GPT series models
15pub struct OpenAiAdapter {
16    transport: DynHttpTransportRef,
17    api_key: String,
18    base_url: String,
19    metrics: Arc<dyn Metrics>,
20}
21
22impl OpenAiAdapter {
23    pub fn new() -> Result<Self, AiLibError> {
24        let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
25            AiLibError::AuthenticationError(
26                "OPENAI_API_KEY environment variable not set".to_string(),
27            )
28        })?;
29
30        Ok(Self {
31            transport: HttpTransport::new().boxed(),
32            api_key,
33            base_url: "https://api.openai.com/v1".to_string(),
34            metrics: Arc::new(NoopMetrics::new()),
35        })
36    }
37
38    /// Construct with an injected object-safe transport reference
39    pub fn with_transport_ref(
40        transport: DynHttpTransportRef,
41        api_key: String,
42        base_url: String,
43    ) -> Result<Self, AiLibError> {
44        Ok(Self {
45            transport,
46            api_key,
47            base_url,
48            metrics: Arc::new(NoopMetrics::new()),
49        })
50    }
51
52    pub fn with_transport_ref_and_metrics(
53        transport: DynHttpTransportRef,
54        api_key: String,
55        base_url: String,
56        metrics: Arc<dyn Metrics>,
57    ) -> Result<Self, AiLibError> {
58        Ok(Self {
59            transport,
60            api_key,
61            base_url,
62            metrics,
63        })
64    }
65
66    pub fn with_metrics(
67        api_key: String,
68        base_url: String,
69        metrics: Arc<dyn Metrics>,
70    ) -> Result<Self, AiLibError> {
71        Ok(Self {
72            transport: HttpTransport::new().boxed(),
73            api_key,
74            base_url,
75            metrics,
76        })
77    }
78
79    #[allow(dead_code)]
80    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
81        // Synchronous converter: do not perform provider uploads, just inline content
82        let mut openai_request = serde_json::json!({
83            "model": request.model,
84            "messages": serde_json::Value::Array(vec![])
85        });
86
87        let mut msgs: Vec<serde_json::Value> = Vec::new();
88        for msg in request.messages.iter() {
89            let role = match msg.role {
90                Role::System => "system",
91                Role::User => "user",
92                Role::Assistant => "assistant",
93            };
94            let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
95            msgs.push(serde_json::json!({"role": role, "content": content_val}));
96        }
97        openai_request["messages"] = serde_json::Value::Array(msgs);
98        openai_request
99    }
100
101    /// Async version that can upload local files to OpenAI before constructing the request
102    async fn convert_request_async(
103        &self,
104        request: &ChatCompletionRequest,
105    ) -> Result<serde_json::Value, AiLibError> {
106        // Build the OpenAI-compatible request JSON. For now we avoid provider-specific
107        // upload flows here and rely on the generic provider utils (which may inline files)
108        // to produce content JSON values.
109        let mut openai_request = serde_json::json!({
110            "model": request.model,
111            "messages": serde_json::Value::Array(vec![])
112        });
113
114        let mut msgs: Vec<serde_json::Value> = Vec::new();
115        for msg in request.messages.iter() {
116            let role = match msg.role {
117                Role::System => "system",
118                Role::User => "user",
119                Role::Assistant => "assistant",
120            };
121
122            // If it's an Image with no URL but has a local `name`, attempt async upload to OpenAI
123            let content_val = match &msg.content {
124                crate::types::common::Content::Image { url, mime: _, name } => {
125                    if url.is_some() {
126                        crate::provider::utils::content_to_provider_value(&msg.content)
127                    } else if let Some(n) = name {
128                        // Try provider upload; fall back to inline behavior on error
129                        let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
130                        match crate::provider::utils::upload_file_with_transport(
131                            Some(self.transport.clone()),
132                            &upload_url,
133                            n,
134                            "file",
135                        )
136                        .await
137                        {
138                            Ok(remote) => {
139                                // remote may be a full URL, a data: URL, or a provider file id.
140                                if remote.starts_with("http://")
141                                    || remote.starts_with("https://")
142                                    || remote.starts_with("data:")
143                                {
144                                    serde_json::json!({"image": {"url": remote}})
145                                } else {
146                                    // Treat as provider file id
147                                    serde_json::json!({"image": {"file_id": remote}})
148                                }
149                            }
150                            Err(_) => {
151                                crate::provider::utils::content_to_provider_value(&msg.content)
152                            }
153                        }
154                    } else {
155                        crate::provider::utils::content_to_provider_value(&msg.content)
156                    }
157                }
158                _ => crate::provider::utils::content_to_provider_value(&msg.content),
159            };
160            msgs.push(serde_json::json!({"role": role, "content": content_val}));
161        }
162
163        openai_request["messages"] = serde_json::Value::Array(msgs);
164
165        // Optional params
166        if let Some(temp) = request.temperature {
167            openai_request["temperature"] =
168                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
169        }
170        if let Some(max_tokens) = request.max_tokens {
171            openai_request["max_tokens"] =
172                serde_json::Value::Number(serde_json::Number::from(max_tokens));
173        }
174        if let Some(top_p) = request.top_p {
175            openai_request["top_p"] =
176                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
177        }
178        if let Some(freq_penalty) = request.frequency_penalty {
179            openai_request["frequency_penalty"] = serde_json::Value::Number(
180                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
181            );
182        }
183        if let Some(presence_penalty) = request.presence_penalty {
184            openai_request["presence_penalty"] = serde_json::Value::Number(
185                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
186            );
187        }
188
189        // Add function calling definitions if provided
190        if let Some(functions) = &request.functions {
191            openai_request["functions"] =
192                serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
193        }
194
195        // function_call policy may be set to control OpenAI behavior
196        if let Some(policy) = &request.function_call {
197            match policy {
198                crate::types::function_call::FunctionCallPolicy::None => {
199                    openai_request["function_call"] = serde_json::Value::String("none".to_string());
200                }
201                crate::types::function_call::FunctionCallPolicy::Auto(name) => {
202                    if name == "auto" {
203                        openai_request["function_call"] =
204                            serde_json::Value::String("auto".to_string());
205                    } else {
206                        openai_request["function_call"] = serde_json::Value::String(name.clone());
207                    }
208                }
209            }
210        }
211
212        Ok(openai_request)
213    }
214
215    // Note: provider-specific upload helpers were removed to avoid blocking the async
216    // runtime. Use `crate::provider::utils::upload_file_to_provider` (async) if provider
217    // upload behavior is desired; it will be integrated in a future change.
218
219    fn parse_response(
220        &self,
221        response: serde_json::Value,
222    ) -> Result<ChatCompletionResponse, AiLibError> {
223        let choices = response["choices"]
224            .as_array()
225            .ok_or_else(|| {
226                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
227            })?
228            .iter()
229            .enumerate()
230            .map(|(index, choice)| {
231                let message = choice["message"].as_object().ok_or_else(|| {
232                    AiLibError::ProviderError("Invalid choice format".to_string())
233                })?;
234
235                let role = match message["role"].as_str().unwrap_or("user") {
236                    "system" => Role::System,
237                    "assistant" => Role::Assistant,
238                    _ => Role::User,
239                };
240
241                let content = message["content"].as_str().unwrap_or("").to_string();
242
243                // Build the Message and try to populate a typed FunctionCall if provided by the provider
244                let mut msg_obj = Message {
245                    role,
246                    content: crate::types::common::Content::Text(content.clone()),
247                    function_call: None,
248                };
249
250                if let Some(fc_val) = message.get("function_call").cloned() {
251                    // Try direct deserialization into our typed FunctionCall first
252                    match serde_json::from_value::<crate::types::function_call::FunctionCall>(
253                        fc_val.clone(),
254                    ) {
255                        Ok(fc) => {
256                            msg_obj.function_call = Some(fc);
257                        }
258                        Err(_) => {
259                            // Fallback: some providers return `arguments` as a JSON-encoded string.
260                            let name = fc_val
261                                .get("name")
262                                .and_then(|v| v.as_str())
263                                .unwrap_or_default()
264                                .to_string();
265                            let args_val = match fc_val.get("arguments") {
266                                Some(a) if a.is_string() => {
267                                    // Parse stringified JSON
268                                    a.as_str()
269                                        .and_then(|s| {
270                                            serde_json::from_str::<serde_json::Value>(s).ok()
271                                        })
272                                        .unwrap_or(serde_json::Value::Null)
273                                }
274                                Some(a) => a.clone(),
275                                None => serde_json::Value::Null,
276                            };
277                            msg_obj.function_call =
278                                Some(crate::types::function_call::FunctionCall {
279                                    name,
280                                    arguments: Some(args_val),
281                                });
282                        }
283                    }
284                }
285                Ok(Choice {
286                    index: index as u32,
287                    message: msg_obj,
288                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
289                })
290            })
291            .collect::<Result<Vec<_>, AiLibError>>()?;
292
293        let usage = response["usage"].as_object().ok_or_else(|| {
294            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
295        })?;
296
297        let usage = Usage {
298            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
299            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
300            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
301        };
302
303        Ok(ChatCompletionResponse {
304            id: response["id"].as_str().unwrap_or("").to_string(),
305            object: response["object"].as_str().unwrap_or("").to_string(),
306            created: response["created"].as_u64().unwrap_or(0),
307            model: response["model"].as_str().unwrap_or("").to_string(),
308            choices,
309            usage,
310        })
311    }
312}
313
314#[async_trait::async_trait]
315impl ChatApi for OpenAiAdapter {
316    async fn chat_completion(
317        &self,
318        request: ChatCompletionRequest,
319    ) -> Result<ChatCompletionResponse, AiLibError> {
320        // Use async converter which may perform provider uploads
321        let openai_request = self
322            .convert_request_async(&request)
323            .await
324            .unwrap_or(serde_json::json!({}));
325        let url = format!("{}/chat/completions", self.base_url);
326
327    // Record a request counter and start a timer using standardized keys
328    self.metrics.incr_counter("openai.requests", 1).await;
329    let timer = self.metrics.start_timer("openai.request_duration_ms").await;
330
331        let mut headers = HashMap::new();
332        headers.insert(
333            "Authorization".to_string(),
334            format!("Bearer {}", self.api_key),
335        );
336        headers.insert("Content-Type".to_string(), "application/json".to_string());
337
338        let response = match self
339            .transport
340            .post_json(&url, Some(headers), openai_request)
341            .await
342        {
343            Ok(v) => {
344                if let Some(t) = timer {
345                    t.stop();
346                }
347                v
348            }
349            Err(e) => {
350                if let Some(t) = timer {
351                    t.stop();
352                }
353                return Err(e);
354            }
355        };
356
357        self.parse_response(response)
358    }
359
360    async fn chat_completion_stream(
361        &self,
362        _request: ChatCompletionRequest,
363    ) -> Result<
364        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
365        AiLibError,
366    > {
367        let stream = stream::empty();
368        Ok(Box::new(Box::pin(stream)))
369    }
370
371    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
372        let url = format!("{}/models", self.base_url);
373        let mut headers = HashMap::new();
374        headers.insert(
375            "Authorization".to_string(),
376            format!("Bearer {}", self.api_key),
377        );
378
379        let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
380
381        Ok(response["data"]
382            .as_array()
383            .unwrap_or(&vec![])
384            .iter()
385            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
386            .collect())
387    }
388
389    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
390        Ok(ModelInfo {
391            id: model_id.to_string(),
392            object: "model".to_string(),
393            created: 0,
394            owned_by: "openai".to_string(),
395            permission: vec![ModelPermission {
396                id: "default".to_string(),
397                object: "model_permission".to_string(),
398                created: 0,
399                allow_create_engine: false,
400                allow_sampling: true,
401                allow_logprobs: false,
402                allow_search_indices: false,
403                allow_view: true,
404                allow_fine_tuning: false,
405                organization: "*".to_string(),
406                group: None,
407                is_blocking: false,
408            }],
409        })
410    }
411}