1use crate::api::{ChatCompletionChunk, ChatProvider, ModelInfo, ModelPermission};
2#[cfg(not(feature = "unified_sse"))]
3use crate::api::{ChoiceDelta, MessageDelta};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
8 UsageStatus,
9};
10use futures::stream::{Stream, StreamExt};
11use std::collections::HashMap;
12use std::env;
13use std::sync::Arc;
14#[cfg(feature = "unified_transport")]
15use std::time::Duration;
16
17pub struct OpenAiAdapter {
21 transport: DynHttpTransportRef,
22 api_key: String,
23 base_url: String,
24 metrics: Arc<dyn Metrics>,
25}
26
27impl OpenAiAdapter {
28 #[allow(dead_code)]
29 fn build_default_timeout_secs() -> u64 {
30 std::env::var("AI_HTTP_TIMEOUT_SECS")
31 .ok()
32 .and_then(|s| s.parse::<u64>().ok())
33 .unwrap_or(30)
34 }
35
36 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
37 #[cfg(feature = "unified_transport")]
38 {
39 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
40 let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
41 AiLibError::NetworkError(format!("Failed to build http client: {}", e))
42 })?;
43 let t = HttpTransport::with_reqwest_client(client, timeout);
44 Ok(t.boxed())
45 }
46 #[cfg(not(feature = "unified_transport"))]
47 {
48 let t = HttpTransport::new();
49 return Ok(t.boxed());
50 }
51 }
52
53 pub fn new() -> Result<Self, AiLibError> {
54 let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
55 AiLibError::AuthenticationError(
56 "OPENAI_API_KEY environment variable not set".to_string(),
57 )
58 })?;
59
60 Ok(Self {
61 transport: Self::build_default_transport()?,
62 api_key,
63 base_url: "https://api.openai.com/v1".to_string(),
64 metrics: Arc::new(NoopMetrics::new()),
65 })
66 }
67
68 pub fn new_with_overrides(
70 api_key: String,
71 base_url: Option<String>,
72 ) -> Result<Self, AiLibError> {
73 Ok(Self {
74 transport: Self::build_default_transport()?,
75 api_key,
76 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
77 metrics: Arc::new(NoopMetrics::new()),
78 })
79 }
80
81 pub fn with_transport_ref(
83 transport: DynHttpTransportRef,
84 api_key: String,
85 base_url: String,
86 ) -> Result<Self, AiLibError> {
87 Ok(Self {
88 transport,
89 api_key,
90 base_url,
91 metrics: Arc::new(NoopMetrics::new()),
92 })
93 }
94
95 pub fn with_transport_ref_and_metrics(
96 transport: DynHttpTransportRef,
97 api_key: String,
98 base_url: String,
99 metrics: Arc<dyn Metrics>,
100 ) -> Result<Self, AiLibError> {
101 Ok(Self {
102 transport,
103 api_key,
104 base_url,
105 metrics,
106 })
107 }
108
109 pub fn with_metrics(
110 api_key: String,
111 base_url: String,
112 metrics: Arc<dyn Metrics>,
113 ) -> Result<Self, AiLibError> {
114 Ok(Self {
115 transport: Self::build_default_transport()?,
116 api_key,
117 base_url,
118 metrics,
119 })
120 }
121
122 #[allow(dead_code)]
123 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
124 let mut openai_request = serde_json::json!({
126 "model": request.model,
127 "messages": serde_json::Value::Array(vec![])
128 });
129
130 let mut msgs: Vec<serde_json::Value> = Vec::new();
131 for msg in request.messages.iter() {
132 let role = match msg.role {
133 Role::System => "system",
134 Role::User => "user",
135 Role::Assistant => "assistant",
136 };
137 let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
138 msgs.push(serde_json::json!({"role": role, "content": content_val}));
139 }
140 openai_request["messages"] = serde_json::Value::Array(msgs);
141 request.apply_extensions(&mut openai_request);
142 openai_request
143 }
144
145 async fn convert_request_async(
147 &self,
148 request: &ChatCompletionRequest,
149 ) -> Result<serde_json::Value, AiLibError> {
150 let mut openai_request = serde_json::json!({
154 "model": request.model,
155 "messages": serde_json::Value::Array(vec![])
156 });
157
158 let mut msgs: Vec<serde_json::Value> = Vec::new();
159 for msg in request.messages.iter() {
160 let role = match msg.role {
161 Role::System => "system",
162 Role::User => "user",
163 Role::Assistant => "assistant",
164 };
165
166 let content_val = match &msg.content {
168 crate::types::common::Content::Image { url, mime: _, name } => {
169 if url.is_some() {
170 crate::provider::utils::content_to_provider_value(&msg.content)
171 } else if let Some(n) = name {
172 let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
174 match crate::provider::utils::upload_file_with_transport(
175 Some(self.transport.clone()),
176 &upload_url,
177 n,
178 "file",
179 )
180 .await
181 {
182 Ok(remote) => {
183 if remote.starts_with("http://")
185 || remote.starts_with("https://")
186 || remote.starts_with("data:")
187 {
188 serde_json::json!({"image": {"url": remote}})
189 } else {
190 serde_json::json!({"image": {"file_id": remote}})
192 }
193 }
194 Err(_) => {
195 crate::provider::utils::content_to_provider_value(&msg.content)
196 }
197 }
198 } else {
199 crate::provider::utils::content_to_provider_value(&msg.content)
200 }
201 }
202 _ => crate::provider::utils::content_to_provider_value(&msg.content),
203 };
204 msgs.push(serde_json::json!({"role": role, "content": content_val}));
205 }
206
207 openai_request["messages"] = serde_json::Value::Array(msgs);
208
209 if let Some(temp) = request.temperature {
211 openai_request["temperature"] =
212 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
213 }
214 if let Some(max_tokens) = request.max_tokens {
215 openai_request["max_tokens"] =
216 serde_json::Value::Number(serde_json::Number::from(max_tokens));
217 }
218 if let Some(top_p) = request.top_p {
219 openai_request["top_p"] =
220 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
221 }
222 if let Some(freq_penalty) = request.frequency_penalty {
223 openai_request["frequency_penalty"] = serde_json::Value::Number(
224 serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
225 );
226 }
227 if let Some(presence_penalty) = request.presence_penalty {
228 openai_request["presence_penalty"] = serde_json::Value::Number(
229 serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
230 );
231 }
232
233 if let Some(functions) = &request.functions {
235 openai_request["functions"] =
236 serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
237 }
238
239 if let Some(policy) = &request.function_call {
241 match policy {
242 crate::types::function_call::FunctionCallPolicy::None => {
243 openai_request["function_call"] = serde_json::Value::String("none".to_string());
244 }
245 crate::types::function_call::FunctionCallPolicy::Auto(name) => {
246 if name == "auto" {
247 openai_request["function_call"] =
248 serde_json::Value::String("auto".to_string());
249 } else {
250 openai_request["function_call"] = serde_json::Value::String(name.clone());
251 }
252 }
253 }
254 }
255
256 request.apply_extensions(&mut openai_request);
257
258 Ok(openai_request)
259 }
260
261 fn parse_response(
266 &self,
267 response: serde_json::Value,
268 ) -> Result<ChatCompletionResponse, AiLibError> {
269 let choices = response["choices"]
270 .as_array()
271 .ok_or_else(|| {
272 AiLibError::ProviderError("Invalid response format: choices not found".to_string())
273 })?
274 .iter()
275 .enumerate()
276 .map(|(index, choice)| {
277 let message = choice["message"].as_object().ok_or_else(|| {
278 AiLibError::ProviderError("Invalid choice format".to_string())
279 })?;
280
281 let role = match message["role"].as_str().unwrap_or("user") {
282 "system" => Role::System,
283 "assistant" => Role::Assistant,
284 _ => Role::User,
285 };
286
287 let content = message["content"].as_str().unwrap_or("").to_string();
288
289 let mut msg_obj = Message {
291 role,
292 content: crate::types::common::Content::Text(content.clone()),
293 function_call: None,
294 };
295
296 if let Some(fc_val) = message.get("function_call").cloned() {
297 match serde_json::from_value::<crate::types::function_call::FunctionCall>(
299 fc_val.clone(),
300 ) {
301 Ok(fc) => {
302 msg_obj.function_call = Some(fc);
303 }
304 Err(_) => {
305 let name = fc_val
307 .get("name")
308 .and_then(|v| v.as_str())
309 .unwrap_or_default()
310 .to_string();
311 let args_val = match fc_val.get("arguments") {
312 Some(a) if a.is_string() => {
313 a.as_str()
315 .and_then(|s| {
316 serde_json::from_str::<serde_json::Value>(s).ok()
317 })
318 .unwrap_or(serde_json::Value::Null)
319 }
320 Some(a) => a.clone(),
321 None => serde_json::Value::Null,
322 };
323 msg_obj.function_call =
324 Some(crate::types::function_call::FunctionCall {
325 name,
326 arguments: Some(args_val),
327 });
328 }
329 }
330 }
331 Ok(Choice {
332 index: index as u32,
333 message: msg_obj,
334 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
335 })
336 })
337 .collect::<Result<Vec<_>, AiLibError>>()?;
338
339 let usage = response["usage"].as_object().ok_or_else(|| {
340 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
341 })?;
342
343 let usage = Usage {
344 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
345 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
346 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
347 };
348
349 Ok(ChatCompletionResponse {
350 id: response["id"].as_str().unwrap_or("").to_string(),
351 object: response["object"].as_str().unwrap_or("").to_string(),
352 created: response["created"].as_u64().unwrap_or(0),
353 model: response["model"].as_str().unwrap_or("").to_string(),
354 choices,
355 usage,
356 usage_status: UsageStatus::Finalized, })
358 }
359}
360
361#[async_trait::async_trait]
362impl ChatProvider for OpenAiAdapter {
363 fn name(&self) -> &str {
364 "OpenAI"
365 }
366
367 async fn chat(
368 &self,
369 request: ChatCompletionRequest,
370 ) -> Result<ChatCompletionResponse, AiLibError> {
371 self.metrics.incr_counter("openai.requests", 1).await;
373 let timer = self.metrics.start_timer("openai.request_duration_ms").await;
374 let url = format!("{}/chat/completions", self.base_url);
375
376 let openai_request = self.convert_request_async(&request).await?;
378
379 let mut headers = HashMap::new();
381 headers.insert(
382 "Authorization".to_string(),
383 format!("Bearer {}", self.api_key),
384 );
385 headers.insert("Content-Type".to_string(), "application/json".to_string());
386
387 let response_json = self
388 .transport
389 .post_json(&url, Some(headers), openai_request)
390 .await
391 .map_err(|e| e.with_context(&format!("OpenAI chat request to {}", url)))?;
392
393 if let Some(t) = timer {
394 t.stop();
395 }
396
397 self.parse_response(response_json)
398 }
399
400 async fn stream(
401 &self,
402 request: ChatCompletionRequest,
403 ) -> Result<
404 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
405 AiLibError,
406 > {
407 let url = format!("{}/chat/completions", self.base_url);
408
409 let mut openai_request = self.convert_request_async(&request).await?;
411 openai_request["stream"] = serde_json::Value::Bool(true);
412
413 let mut headers = HashMap::new();
414 headers.insert(
415 "Authorization".to_string(),
416 format!("Bearer {}", self.api_key),
417 );
418 headers.insert("Content-Type".to_string(), "application/json".to_string());
419 headers.insert("Accept".to_string(), "text/event-stream".to_string());
420
421 let byte_stream_res = self
422 .transport
423 .post_stream(&url, Some(headers), openai_request)
424 .await;
425
426 match byte_stream_res {
427 Ok(mut byte_stream) => {
428 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
429 tokio::spawn(async move {
430 let mut buffer = Vec::new();
431 while let Some(result) = byte_stream.next().await {
432 match result {
433 Ok(bytes) => {
434 buffer.extend_from_slice(&bytes);
435 #[cfg(feature = "unified_sse")]
436 {
437 while let Some(event_end) =
438 crate::sse::parser::find_event_boundary(&buffer)
439 {
440 let event_bytes =
441 buffer.drain(..event_end).collect::<Vec<_>>();
442 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
443 if let Some(chunk) =
444 crate::sse::parser::parse_sse_event(event_text)
445 {
446 match chunk {
447 Ok(Some(c)) => {
448 if tx.send(Ok(c)).is_err() {
449 return;
450 }
451 }
452 Ok(None) => return, Err(e) => {
454 let _ = tx.send(Err(e));
455 return;
456 }
457 }
458 }
459 }
460 }
461 }
462 #[cfg(not(feature = "unified_sse"))]
463 {
464 while let Some(event_end) = find_sse_event_boundary(&buffer) {
465 let event_bytes =
466 buffer.drain(..event_end).collect::<Vec<_>>();
467 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
468 if let Some(chunk) = parse_openai_sse_event(event_text)
469 {
470 match chunk {
471 Ok(Some(c)) => {
472 if tx.send(Ok(c)).is_err() {
473 return;
474 }
475 }
476 Ok(None) => return, Err(e) => {
478 let _ = tx.send(Err(e));
479 return;
480 }
481 }
482 }
483 }
484 }
485 }
486 }
487 Err(e) => {
488 let _ = tx.send(Err(AiLibError::ProviderError(format!(
489 "Stream error: {}",
490 e
491 ))));
492 break;
493 }
494 }
495 }
496 });
497 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
498 Ok(Box::new(Box::pin(stream)))
499 }
500 Err(e) => Err(e),
501 }
502 }
503
504 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
505 let url = format!("{}/models", self.base_url);
506 let mut headers = HashMap::new();
507 headers.insert(
508 "Authorization".to_string(),
509 format!("Bearer {}", self.api_key),
510 );
511
512 let response = self.transport.get_json(&url, Some(headers)).await?;
513
514 Ok(response["data"]
515 .as_array()
516 .unwrap_or(&vec![])
517 .iter()
518 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
519 .collect())
520 }
521
522 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
523 Ok(ModelInfo {
524 id: model_id.to_string(),
525 object: "model".to_string(),
526 created: 0,
527 owned_by: "openai".to_string(),
528 permission: vec![ModelPermission {
529 id: "default".to_string(),
530 object: "model_permission".to_string(),
531 created: 0,
532 allow_create_engine: false,
533 allow_sampling: true,
534 allow_logprobs: false,
535 allow_search_indices: false,
536 allow_view: true,
537 allow_fine_tuning: false,
538 organization: "*".to_string(),
539 group: None,
540 is_blocking: false,
541 }],
542 })
543 }
544}
545
546#[cfg(not(feature = "unified_sse"))]
548fn find_sse_event_boundary(buffer: &[u8]) -> Option<usize> {
549 let mut i = 0;
550 while i + 1 < buffer.len() {
551 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
553 return Some(i + 2);
554 }
555 if i + 3 < buffer.len()
557 && buffer[i] == b'\r'
558 && buffer[i + 1] == b'\n'
559 && buffer[i + 2] == b'\r'
560 && buffer[i + 3] == b'\n'
561 {
562 return Some(i + 4);
563 }
564 i += 1;
565 }
566 None
567}
568
569#[cfg(not(feature = "unified_sse"))]
570fn parse_openai_sse_event(
571 event_text: &str,
572) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
573 for line in event_text.lines() {
574 let line = line.trim();
575 if let Some(data) = line.strip_prefix("data: ") {
576 if data == "[DONE]" {
577 return Some(Ok(None));
578 }
579 return Some(parse_openai_chunk_data(data));
580 }
581 }
582 None
583}
584
585#[cfg(not(feature = "unified_sse"))]
586fn parse_openai_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
587 let json: serde_json::Value = serde_json::from_str(data)
588 .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
589
590 let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
591 if let Some(arr) = json["choices"].as_array() {
592 for (index, choice) in arr.iter().enumerate() {
593 let delta = &choice["delta"];
594 let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
595 "assistant" => Role::Assistant,
596 "user" => Role::User,
597 "system" => Role::System,
598 _ => Role::Assistant,
599 });
600 let content = delta
601 .get("content")
602 .and_then(|v| v.as_str())
603 .map(|s| s.to_string());
604 let md = MessageDelta { role, content };
605 let cd = ChoiceDelta {
606 index: index as u32,
607 delta: md,
608 finish_reason: choice
609 .get("finish_reason")
610 .and_then(|v| v.as_str())
611 .map(|s| s.to_string()),
612 };
613 choices_vec.push(cd);
614 }
615 }
616
617 Ok(Some(ChatCompletionChunk {
618 id: json["id"].as_str().unwrap_or_default().to_string(),
619 object: json["object"]
620 .as_str()
621 .unwrap_or("chat.completion.chunk")
622 .to_string(),
623 created: json["created"].as_u64().unwrap_or(0),
624 model: json["model"].as_str().unwrap_or_default().to_string(),
625 choices: choices_vec,
626 }))
627}