ai_lib/provider/
openai.rs

1use crate::api::{ChatCompletionChunk, ChatProvider, ModelInfo, ModelPermission};
2#[cfg(not(feature = "unified_sse"))]
3use crate::api::{ChoiceDelta, MessageDelta};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
8    UsageStatus,
9};
10use futures::stream::{Stream, StreamExt};
11use std::collections::HashMap;
12use std::env;
13use std::sync::Arc;
14#[cfg(feature = "unified_transport")]
15use std::time::Duration;
16
17/// OpenAI adapter, supporting GPT series models
18///
19/// OpenAI adapter supporting GPT series models
20pub struct OpenAiAdapter {
21    transport: DynHttpTransportRef,
22    api_key: String,
23    base_url: String,
24    metrics: Arc<dyn Metrics>,
25}
26
27impl OpenAiAdapter {
28    #[allow(dead_code)]
29    fn build_default_timeout_secs() -> u64 {
30        std::env::var("AI_HTTP_TIMEOUT_SECS")
31            .ok()
32            .and_then(|s| s.parse::<u64>().ok())
33            .unwrap_or(30)
34    }
35
36    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
37        #[cfg(feature = "unified_transport")]
38        {
39            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
40            let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
41                AiLibError::NetworkError(format!("Failed to build http client: {}", e))
42            })?;
43            let t = HttpTransport::with_reqwest_client(client, timeout);
44            Ok(t.boxed())
45        }
46        #[cfg(not(feature = "unified_transport"))]
47        {
48            let t = HttpTransport::new();
49            return Ok(t.boxed());
50        }
51    }
52
53    pub fn new() -> Result<Self, AiLibError> {
54        let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
55            AiLibError::AuthenticationError(
56                "OPENAI_API_KEY environment variable not set".to_string(),
57            )
58        })?;
59
60        Ok(Self {
61            transport: Self::build_default_transport()?,
62            api_key,
63            base_url: "https://api.openai.com/v1".to_string(),
64            metrics: Arc::new(NoopMetrics::new()),
65        })
66    }
67
68    /// Explicit API key override (takes precedence over env var) + optional base_url override.
69    pub fn new_with_overrides(
70        api_key: String,
71        base_url: Option<String>,
72    ) -> Result<Self, AiLibError> {
73        Ok(Self {
74            transport: Self::build_default_transport()?,
75            api_key,
76            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
77            metrics: Arc::new(NoopMetrics::new()),
78        })
79    }
80
81    /// Construct with an injected object-safe transport reference
82    pub fn with_transport_ref(
83        transport: DynHttpTransportRef,
84        api_key: String,
85        base_url: String,
86    ) -> Result<Self, AiLibError> {
87        Ok(Self {
88            transport,
89            api_key,
90            base_url,
91            metrics: Arc::new(NoopMetrics::new()),
92        })
93    }
94
95    pub fn with_transport_ref_and_metrics(
96        transport: DynHttpTransportRef,
97        api_key: String,
98        base_url: String,
99        metrics: Arc<dyn Metrics>,
100    ) -> Result<Self, AiLibError> {
101        Ok(Self {
102            transport,
103            api_key,
104            base_url,
105            metrics,
106        })
107    }
108
109    pub fn with_metrics(
110        api_key: String,
111        base_url: String,
112        metrics: Arc<dyn Metrics>,
113    ) -> Result<Self, AiLibError> {
114        Ok(Self {
115            transport: Self::build_default_transport()?,
116            api_key,
117            base_url,
118            metrics,
119        })
120    }
121
122    #[allow(dead_code)]
123    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
124        // Synchronous converter: do not perform provider uploads, just inline content
125        let mut openai_request = serde_json::json!({
126            "model": request.model,
127            "messages": serde_json::Value::Array(vec![])
128        });
129
130        let mut msgs: Vec<serde_json::Value> = Vec::new();
131        for msg in request.messages.iter() {
132            let role = match msg.role {
133                Role::System => "system",
134                Role::User => "user",
135                Role::Assistant => "assistant",
136            };
137            let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
138            msgs.push(serde_json::json!({"role": role, "content": content_val}));
139        }
140        openai_request["messages"] = serde_json::Value::Array(msgs);
141        request.apply_extensions(&mut openai_request);
142        openai_request
143    }
144
145    /// Async version that can upload local files to OpenAI before constructing the request
146    async fn convert_request_async(
147        &self,
148        request: &ChatCompletionRequest,
149    ) -> Result<serde_json::Value, AiLibError> {
150        // Build the OpenAI-compatible request JSON. For now we avoid provider-specific
151        // upload flows here and rely on the generic provider utils (which may inline files)
152        // to produce content JSON values.
153        let mut openai_request = serde_json::json!({
154            "model": request.model,
155            "messages": serde_json::Value::Array(vec![])
156        });
157
158        let mut msgs: Vec<serde_json::Value> = Vec::new();
159        for msg in request.messages.iter() {
160            let role = match msg.role {
161                Role::System => "system",
162                Role::User => "user",
163                Role::Assistant => "assistant",
164            };
165
166            // If it's an Image with no URL but has a local `name`, attempt async upload to OpenAI
167            let content_val = match &msg.content {
168                crate::types::common::Content::Image { url, mime: _, name } => {
169                    if url.is_some() {
170                        crate::provider::utils::content_to_provider_value(&msg.content)
171                    } else if let Some(n) = name {
172                        // Try provider upload; fall back to inline behavior on error
173                        let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
174                        match crate::provider::utils::upload_file_with_transport(
175                            Some(self.transport.clone()),
176                            &upload_url,
177                            n,
178                            "file",
179                        )
180                        .await
181                        {
182                            Ok(remote) => {
183                                // remote may be a full URL, a data: URL, or a provider file id.
184                                if remote.starts_with("http://")
185                                    || remote.starts_with("https://")
186                                    || remote.starts_with("data:")
187                                {
188                                    serde_json::json!({"image": {"url": remote}})
189                                } else {
190                                    // Treat as provider file id
191                                    serde_json::json!({"image": {"file_id": remote}})
192                                }
193                            }
194                            Err(_) => {
195                                crate::provider::utils::content_to_provider_value(&msg.content)
196                            }
197                        }
198                    } else {
199                        crate::provider::utils::content_to_provider_value(&msg.content)
200                    }
201                }
202                _ => crate::provider::utils::content_to_provider_value(&msg.content),
203            };
204            msgs.push(serde_json::json!({"role": role, "content": content_val}));
205        }
206
207        openai_request["messages"] = serde_json::Value::Array(msgs);
208
209        // Optional params
210        if let Some(temp) = request.temperature {
211            openai_request["temperature"] =
212                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
213        }
214        if let Some(max_tokens) = request.max_tokens {
215            openai_request["max_tokens"] =
216                serde_json::Value::Number(serde_json::Number::from(max_tokens));
217        }
218        if let Some(top_p) = request.top_p {
219            openai_request["top_p"] =
220                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
221        }
222        if let Some(freq_penalty) = request.frequency_penalty {
223            openai_request["frequency_penalty"] = serde_json::Value::Number(
224                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
225            );
226        }
227        if let Some(presence_penalty) = request.presence_penalty {
228            openai_request["presence_penalty"] = serde_json::Value::Number(
229                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
230            );
231        }
232
233        // Add function calling definitions if provided
234        if let Some(functions) = &request.functions {
235            openai_request["functions"] =
236                serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
237        }
238
239        // function_call policy may be set to control OpenAI behavior
240        if let Some(policy) = &request.function_call {
241            match policy {
242                crate::types::function_call::FunctionCallPolicy::None => {
243                    openai_request["function_call"] = serde_json::Value::String("none".to_string());
244                }
245                crate::types::function_call::FunctionCallPolicy::Auto(name) => {
246                    if name == "auto" {
247                        openai_request["function_call"] =
248                            serde_json::Value::String("auto".to_string());
249                    } else {
250                        openai_request["function_call"] = serde_json::Value::String(name.clone());
251                    }
252                }
253            }
254        }
255
256        request.apply_extensions(&mut openai_request);
257
258        Ok(openai_request)
259    }
260
261    // Note: provider-specific upload helpers were removed to avoid blocking the async
262    // runtime. Use `crate::provider::utils::upload_file_to_provider` (async) if provider
263    // upload behavior is desired; it will be integrated in a future change.
264
265    fn parse_response(
266        &self,
267        response: serde_json::Value,
268    ) -> Result<ChatCompletionResponse, AiLibError> {
269        let choices = response["choices"]
270            .as_array()
271            .ok_or_else(|| {
272                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
273            })?
274            .iter()
275            .enumerate()
276            .map(|(index, choice)| {
277                let message = choice["message"].as_object().ok_or_else(|| {
278                    AiLibError::ProviderError("Invalid choice format".to_string())
279                })?;
280
281                let role = match message["role"].as_str().unwrap_or("user") {
282                    "system" => Role::System,
283                    "assistant" => Role::Assistant,
284                    _ => Role::User,
285                };
286
287                let content = message["content"].as_str().unwrap_or("").to_string();
288
289                // Build the Message and try to populate a typed FunctionCall if provided by the provider
290                let mut msg_obj = Message {
291                    role,
292                    content: crate::types::common::Content::Text(content.clone()),
293                    function_call: None,
294                };
295
296                if let Some(fc_val) = message.get("function_call").cloned() {
297                    // Try direct deserialization into our typed FunctionCall first
298                    match serde_json::from_value::<crate::types::function_call::FunctionCall>(
299                        fc_val.clone(),
300                    ) {
301                        Ok(fc) => {
302                            msg_obj.function_call = Some(fc);
303                        }
304                        Err(_) => {
305                            // Fallback: some providers return `arguments` as a JSON-encoded string.
306                            let name = fc_val
307                                .get("name")
308                                .and_then(|v| v.as_str())
309                                .unwrap_or_default()
310                                .to_string();
311                            let args_val = match fc_val.get("arguments") {
312                                Some(a) if a.is_string() => {
313                                    // Parse stringified JSON
314                                    a.as_str()
315                                        .and_then(|s| {
316                                            serde_json::from_str::<serde_json::Value>(s).ok()
317                                        })
318                                        .unwrap_or(serde_json::Value::Null)
319                                }
320                                Some(a) => a.clone(),
321                                None => serde_json::Value::Null,
322                            };
323                            msg_obj.function_call =
324                                Some(crate::types::function_call::FunctionCall {
325                                    name,
326                                    arguments: Some(args_val),
327                                });
328                        }
329                    }
330                }
331                Ok(Choice {
332                    index: index as u32,
333                    message: msg_obj,
334                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
335                })
336            })
337            .collect::<Result<Vec<_>, AiLibError>>()?;
338
339        let usage = response["usage"].as_object().ok_or_else(|| {
340            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
341        })?;
342
343        let usage = Usage {
344            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
345            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
346            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
347        };
348
349        Ok(ChatCompletionResponse {
350            id: response["id"].as_str().unwrap_or("").to_string(),
351            object: response["object"].as_str().unwrap_or("").to_string(),
352            created: response["created"].as_u64().unwrap_or(0),
353            model: response["model"].as_str().unwrap_or("").to_string(),
354            choices,
355            usage,
356            usage_status: UsageStatus::Finalized, // OpenAI provides accurate usage data
357        })
358    }
359}
360
361#[async_trait::async_trait]
362impl ChatProvider for OpenAiAdapter {
363    fn name(&self) -> &str {
364        "OpenAI"
365    }
366
367    async fn chat(
368        &self,
369        request: ChatCompletionRequest,
370    ) -> Result<ChatCompletionResponse, AiLibError> {
371        // Record a request counter and start a timer using standardized keys
372        self.metrics.incr_counter("openai.requests", 1).await;
373        let timer = self.metrics.start_timer("openai.request_duration_ms").await;
374        let url = format!("{}/chat/completions", self.base_url);
375
376        // Build request body via converter
377        let openai_request = self.convert_request_async(&request).await?;
378
379        // Use unified transport
380        let mut headers = HashMap::new();
381        headers.insert(
382            "Authorization".to_string(),
383            format!("Bearer {}", self.api_key),
384        );
385        headers.insert("Content-Type".to_string(), "application/json".to_string());
386
387        let response_json = self
388            .transport
389            .post_json(&url, Some(headers), openai_request)
390            .await
391            .map_err(|e| e.with_context(&format!("OpenAI chat request to {}", url)))?;
392
393        if let Some(t) = timer {
394            t.stop();
395        }
396
397        self.parse_response(response_json)
398    }
399
400    async fn stream(
401        &self,
402        request: ChatCompletionRequest,
403    ) -> Result<
404        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
405        AiLibError,
406    > {
407        let url = format!("{}/chat/completions", self.base_url);
408
409        // Build request body with stream=true
410        let mut openai_request = self.convert_request_async(&request).await?;
411        openai_request["stream"] = serde_json::Value::Bool(true);
412
413        let mut headers = HashMap::new();
414        headers.insert(
415            "Authorization".to_string(),
416            format!("Bearer {}", self.api_key),
417        );
418        headers.insert("Content-Type".to_string(), "application/json".to_string());
419        headers.insert("Accept".to_string(), "text/event-stream".to_string());
420
421        let byte_stream_res = self
422            .transport
423            .post_stream(&url, Some(headers), openai_request)
424            .await;
425
426        match byte_stream_res {
427            Ok(mut byte_stream) => {
428                let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
429                tokio::spawn(async move {
430                    let mut buffer = Vec::new();
431                    while let Some(result) = byte_stream.next().await {
432                        match result {
433                            Ok(bytes) => {
434                                buffer.extend_from_slice(&bytes);
435                                #[cfg(feature = "unified_sse")]
436                                {
437                                    while let Some(event_end) =
438                                        crate::sse::parser::find_event_boundary(&buffer)
439                                    {
440                                        let event_bytes =
441                                            buffer.drain(..event_end).collect::<Vec<_>>();
442                                        if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
443                                            if let Some(chunk) =
444                                                crate::sse::parser::parse_sse_event(event_text)
445                                            {
446                                                match chunk {
447                                                    Ok(Some(c)) => {
448                                                        if tx.send(Ok(c)).is_err() {
449                                                            return;
450                                                        }
451                                                    }
452                                                    Ok(None) => return, // [DONE] signal
453                                                    Err(e) => {
454                                                        let _ = tx.send(Err(e));
455                                                        return;
456                                                    }
457                                                }
458                                            }
459                                        }
460                                    }
461                                }
462                                #[cfg(not(feature = "unified_sse"))]
463                                {
464                                    while let Some(event_end) = find_sse_event_boundary(&buffer) {
465                                        let event_bytes =
466                                            buffer.drain(..event_end).collect::<Vec<_>>();
467                                        if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
468                                            if let Some(chunk) = parse_openai_sse_event(event_text)
469                                            {
470                                                match chunk {
471                                                    Ok(Some(c)) => {
472                                                        if tx.send(Ok(c)).is_err() {
473                                                            return;
474                                                        }
475                                                    }
476                                                    Ok(None) => return, // [DONE] signal
477                                                    Err(e) => {
478                                                        let _ = tx.send(Err(e));
479                                                        return;
480                                                    }
481                                                }
482                                            }
483                                        }
484                                    }
485                                }
486                            }
487                            Err(e) => {
488                                let _ = tx.send(Err(AiLibError::ProviderError(format!(
489                                    "Stream error: {}",
490                                    e
491                                ))));
492                                break;
493                            }
494                        }
495                    }
496                });
497                let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
498                Ok(Box::new(Box::pin(stream)))
499            }
500            Err(e) => Err(e),
501        }
502    }
503
504    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
505        let url = format!("{}/models", self.base_url);
506        let mut headers = HashMap::new();
507        headers.insert(
508            "Authorization".to_string(),
509            format!("Bearer {}", self.api_key),
510        );
511
512        let response = self.transport.get_json(&url, Some(headers)).await?;
513
514        Ok(response["data"]
515            .as_array()
516            .unwrap_or(&vec![])
517            .iter()
518            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
519            .collect())
520    }
521
522    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
523        Ok(ModelInfo {
524            id: model_id.to_string(),
525            object: "model".to_string(),
526            created: 0,
527            owned_by: "openai".to_string(),
528            permission: vec![ModelPermission {
529                id: "default".to_string(),
530                object: "model_permission".to_string(),
531                created: 0,
532                allow_create_engine: false,
533                allow_sampling: true,
534                allow_logprobs: false,
535                allow_search_indices: false,
536                allow_view: true,
537                allow_fine_tuning: false,
538                organization: "*".to_string(),
539                group: None,
540                is_blocking: false,
541            }],
542        })
543    }
544}
545
546// Local SSE parsing functions for when unified_sse feature is not enabled
547#[cfg(not(feature = "unified_sse"))]
548fn find_sse_event_boundary(buffer: &[u8]) -> Option<usize> {
549    let mut i = 0;
550    while i + 1 < buffer.len() {
551        // LF LF
552        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
553            return Some(i + 2);
554        }
555        // CR LF CR LF
556        if i + 3 < buffer.len()
557            && buffer[i] == b'\r'
558            && buffer[i + 1] == b'\n'
559            && buffer[i + 2] == b'\r'
560            && buffer[i + 3] == b'\n'
561        {
562            return Some(i + 4);
563        }
564        i += 1;
565    }
566    None
567}
568
569#[cfg(not(feature = "unified_sse"))]
570fn parse_openai_sse_event(
571    event_text: &str,
572) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
573    for line in event_text.lines() {
574        let line = line.trim();
575        if let Some(data) = line.strip_prefix("data: ") {
576            if data == "[DONE]" {
577                return Some(Ok(None));
578            }
579            return Some(parse_openai_chunk_data(data));
580        }
581    }
582    None
583}
584
585#[cfg(not(feature = "unified_sse"))]
586fn parse_openai_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
587    let json: serde_json::Value = serde_json::from_str(data)
588        .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
589
590    let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
591    if let Some(arr) = json["choices"].as_array() {
592        for (index, choice) in arr.iter().enumerate() {
593            let delta = &choice["delta"];
594            let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
595                "assistant" => Role::Assistant,
596                "user" => Role::User,
597                "system" => Role::System,
598                _ => Role::Assistant,
599            });
600            let content = delta
601                .get("content")
602                .and_then(|v| v.as_str())
603                .map(|s| s.to_string());
604            let md = MessageDelta { role, content };
605            let cd = ChoiceDelta {
606                index: index as u32,
607                delta: md,
608                finish_reason: choice
609                    .get("finish_reason")
610                    .and_then(|v| v.as_str())
611                    .map(|s| s.to_string()),
612            };
613            choices_vec.push(cd);
614        }
615    }
616
617    Ok(Some(ChatCompletionChunk {
618        id: json["id"].as_str().unwrap_or_default().to_string(),
619        object: json["object"]
620            .as_str()
621            .unwrap_or("chat.completion.chunk")
622            .to_string(),
623        created: json["created"].as_u64().unwrap_or(0),
624        model: json["model"].as_str().unwrap_or_default().to_string(),
625        choices: choices_vec,
626    }))
627}