1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9#[cfg(feature = "unified_transport")]
10use std::time::Duration;
11use std::env;
12use std::sync::Arc;
13
14pub struct OpenAiAdapter {
18 transport: DynHttpTransportRef,
19 api_key: String,
20 base_url: String,
21 metrics: Arc<dyn Metrics>,
22}
23
24impl OpenAiAdapter {
25 #[allow(dead_code)]
26 fn build_default_timeout_secs() -> u64 {
27 std::env::var("AI_HTTP_TIMEOUT_SECS")
28 .ok()
29 .and_then(|s| s.parse::<u64>().ok())
30 .unwrap_or(30)
31 }
32
33 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
34 #[cfg(feature = "unified_transport")]
35 {
36 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
37 let client = crate::transport::client_factory::build_shared_client()
38 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
39 let t = HttpTransport::with_reqwest_client(client, timeout);
40 return Ok(t.boxed());
41 }
42 #[cfg(not(feature = "unified_transport"))]
43 {
44 let t = HttpTransport::new();
45 return Ok(t.boxed());
46 }
47 }
48
49 pub fn new() -> Result<Self, AiLibError> {
50 let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
51 AiLibError::AuthenticationError(
52 "OPENAI_API_KEY environment variable not set".to_string(),
53 )
54 })?;
55
56 Ok(Self {
57 transport: Self::build_default_transport()?,
58 api_key,
59 base_url: "https://api.openai.com/v1".to_string(),
60 metrics: Arc::new(NoopMetrics::new()),
61 })
62 }
63
64 pub fn new_with_overrides(
66 api_key: String,
67 base_url: Option<String>,
68 ) -> Result<Self, AiLibError> {
69 Ok(Self {
70 transport: Self::build_default_transport()?,
71 api_key,
72 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
73 metrics: Arc::new(NoopMetrics::new()),
74 })
75 }
76
77 pub fn with_transport_ref(
79 transport: DynHttpTransportRef,
80 api_key: String,
81 base_url: String,
82 ) -> Result<Self, AiLibError> {
83 Ok(Self {
84 transport,
85 api_key,
86 base_url,
87 metrics: Arc::new(NoopMetrics::new()),
88 })
89 }
90
91 pub fn with_transport_ref_and_metrics(
92 transport: DynHttpTransportRef,
93 api_key: String,
94 base_url: String,
95 metrics: Arc<dyn Metrics>,
96 ) -> Result<Self, AiLibError> {
97 Ok(Self {
98 transport,
99 api_key,
100 base_url,
101 metrics,
102 })
103 }
104
105 pub fn with_metrics(
106 api_key: String,
107 base_url: String,
108 metrics: Arc<dyn Metrics>,
109 ) -> Result<Self, AiLibError> {
110 Ok(Self {
111 transport: Self::build_default_transport()?,
112 api_key,
113 base_url,
114 metrics,
115 })
116 }
117
118 #[allow(dead_code)]
119 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
120 let mut openai_request = serde_json::json!({
122 "model": request.model,
123 "messages": serde_json::Value::Array(vec![])
124 });
125
126 let mut msgs: Vec<serde_json::Value> = Vec::new();
127 for msg in request.messages.iter() {
128 let role = match msg.role {
129 Role::System => "system",
130 Role::User => "user",
131 Role::Assistant => "assistant",
132 };
133 let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
134 msgs.push(serde_json::json!({"role": role, "content": content_val}));
135 }
136 openai_request["messages"] = serde_json::Value::Array(msgs);
137 openai_request
138 }
139
140 async fn convert_request_async(
142 &self,
143 request: &ChatCompletionRequest,
144 ) -> Result<serde_json::Value, AiLibError> {
145 let mut openai_request = serde_json::json!({
149 "model": request.model,
150 "messages": serde_json::Value::Array(vec![])
151 });
152
153 let mut msgs: Vec<serde_json::Value> = Vec::new();
154 for msg in request.messages.iter() {
155 let role = match msg.role {
156 Role::System => "system",
157 Role::User => "user",
158 Role::Assistant => "assistant",
159 };
160
161 let content_val = match &msg.content {
163 crate::types::common::Content::Image { url, mime: _, name } => {
164 if url.is_some() {
165 crate::provider::utils::content_to_provider_value(&msg.content)
166 } else if let Some(n) = name {
167 let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
169 match crate::provider::utils::upload_file_with_transport(
170 Some(self.transport.clone()),
171 &upload_url,
172 n,
173 "file",
174 )
175 .await
176 {
177 Ok(remote) => {
178 if remote.starts_with("http://")
180 || remote.starts_with("https://")
181 || remote.starts_with("data:")
182 {
183 serde_json::json!({"image": {"url": remote}})
184 } else {
185 serde_json::json!({"image": {"file_id": remote}})
187 }
188 }
189 Err(_) => {
190 crate::provider::utils::content_to_provider_value(&msg.content)
191 }
192 }
193 } else {
194 crate::provider::utils::content_to_provider_value(&msg.content)
195 }
196 }
197 _ => crate::provider::utils::content_to_provider_value(&msg.content),
198 };
199 msgs.push(serde_json::json!({"role": role, "content": content_val}));
200 }
201
202 openai_request["messages"] = serde_json::Value::Array(msgs);
203
204 if let Some(temp) = request.temperature {
206 openai_request["temperature"] =
207 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
208 }
209 if let Some(max_tokens) = request.max_tokens {
210 openai_request["max_tokens"] =
211 serde_json::Value::Number(serde_json::Number::from(max_tokens));
212 }
213 if let Some(top_p) = request.top_p {
214 openai_request["top_p"] =
215 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
216 }
217 if let Some(freq_penalty) = request.frequency_penalty {
218 openai_request["frequency_penalty"] = serde_json::Value::Number(
219 serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
220 );
221 }
222 if let Some(presence_penalty) = request.presence_penalty {
223 openai_request["presence_penalty"] = serde_json::Value::Number(
224 serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
225 );
226 }
227
228 if let Some(functions) = &request.functions {
230 openai_request["functions"] =
231 serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
232 }
233
234 if let Some(policy) = &request.function_call {
236 match policy {
237 crate::types::function_call::FunctionCallPolicy::None => {
238 openai_request["function_call"] = serde_json::Value::String("none".to_string());
239 }
240 crate::types::function_call::FunctionCallPolicy::Auto(name) => {
241 if name == "auto" {
242 openai_request["function_call"] =
243 serde_json::Value::String("auto".to_string());
244 } else {
245 openai_request["function_call"] = serde_json::Value::String(name.clone());
246 }
247 }
248 }
249 }
250
251 Ok(openai_request)
252 }
253
254 fn parse_response(
259 &self,
260 response: serde_json::Value,
261 ) -> Result<ChatCompletionResponse, AiLibError> {
262 let choices = response["choices"]
263 .as_array()
264 .ok_or_else(|| {
265 AiLibError::ProviderError("Invalid response format: choices not found".to_string())
266 })?
267 .iter()
268 .enumerate()
269 .map(|(index, choice)| {
270 let message = choice["message"].as_object().ok_or_else(|| {
271 AiLibError::ProviderError("Invalid choice format".to_string())
272 })?;
273
274 let role = match message["role"].as_str().unwrap_or("user") {
275 "system" => Role::System,
276 "assistant" => Role::Assistant,
277 _ => Role::User,
278 };
279
280 let content = message["content"].as_str().unwrap_or("").to_string();
281
282 let mut msg_obj = Message {
284 role,
285 content: crate::types::common::Content::Text(content.clone()),
286 function_call: None,
287 };
288
289 if let Some(fc_val) = message.get("function_call").cloned() {
290 match serde_json::from_value::<crate::types::function_call::FunctionCall>(
292 fc_val.clone(),
293 ) {
294 Ok(fc) => {
295 msg_obj.function_call = Some(fc);
296 }
297 Err(_) => {
298 let name = fc_val
300 .get("name")
301 .and_then(|v| v.as_str())
302 .unwrap_or_default()
303 .to_string();
304 let args_val = match fc_val.get("arguments") {
305 Some(a) if a.is_string() => {
306 a.as_str()
308 .and_then(|s| {
309 serde_json::from_str::<serde_json::Value>(s).ok()
310 })
311 .unwrap_or(serde_json::Value::Null)
312 }
313 Some(a) => a.clone(),
314 None => serde_json::Value::Null,
315 };
316 msg_obj.function_call =
317 Some(crate::types::function_call::FunctionCall {
318 name,
319 arguments: Some(args_val),
320 });
321 }
322 }
323 }
324 Ok(Choice {
325 index: index as u32,
326 message: msg_obj,
327 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
328 })
329 })
330 .collect::<Result<Vec<_>, AiLibError>>()?;
331
332 let usage = response["usage"].as_object().ok_or_else(|| {
333 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
334 })?;
335
336 let usage = Usage {
337 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
338 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
339 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
340 };
341
342 Ok(ChatCompletionResponse {
343 id: response["id"].as_str().unwrap_or("").to_string(),
344 object: response["object"].as_str().unwrap_or("").to_string(),
345 created: response["created"].as_u64().unwrap_or(0),
346 model: response["model"].as_str().unwrap_or("").to_string(),
347 choices,
348 usage,
349 usage_status: UsageStatus::Finalized, })
351 }
352}
353
354#[async_trait::async_trait]
355impl ChatApi for OpenAiAdapter {
356 async fn chat_completion(
357 &self,
358 request: ChatCompletionRequest,
359 ) -> Result<ChatCompletionResponse, AiLibError> {
360 self.metrics.incr_counter("openai.requests", 1).await;
362 let timer = self.metrics.start_timer("openai.request_duration_ms").await;
363 let url = format!("{}/chat/completions", self.base_url);
364
365 let openai_request = self.convert_request_async(&request).await?;
367
368 let mut headers = HashMap::new();
370 headers.insert(
371 "Authorization".to_string(),
372 format!("Bearer {}", self.api_key),
373 );
374 headers.insert("Content-Type".to_string(), "application/json".to_string());
375
376 let response_json = self
377 .transport
378 .post_json(&url, Some(headers), openai_request)
379 .await?;
380
381 if let Some(t) = timer {
382 t.stop();
383 }
384
385 self.parse_response(response_json)
386 }
387
388 async fn chat_completion_stream(
389 &self,
390 _request: ChatCompletionRequest,
391 ) -> Result<
392 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
393 AiLibError,
394 > {
395 let stream = stream::empty();
396 Ok(Box::new(Box::pin(stream)))
397 }
398
399 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
400 let url = format!("{}/models", self.base_url);
401 let mut headers = HashMap::new();
402 headers.insert(
403 "Authorization".to_string(),
404 format!("Bearer {}", self.api_key),
405 );
406
407 let response = self.transport.get_json(&url, Some(headers)).await?;
408
409 Ok(response["data"]
410 .as_array()
411 .unwrap_or(&vec![])
412 .iter()
413 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
414 .collect())
415 }
416
417 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
418 Ok(ModelInfo {
419 id: model_id.to_string(),
420 object: "model".to_string(),
421 created: 0,
422 owned_by: "openai".to_string(),
423 permission: vec![ModelPermission {
424 id: "default".to_string(),
425 object: "model_permission".to_string(),
426 created: 0,
427 allow_create_engine: false,
428 allow_sampling: true,
429 allow_logprobs: false,
430 allow_search_indices: false,
431 allow_view: true,
432 allow_fine_tuning: false,
433 organization: "*".to_string(),
434 group: None,
435 is_blocking: false,
436 }],
437 })
438 }
439}