1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9#[cfg(feature = "unified_transport")]
10use std::time::Duration;
11use std::env;
12use std::sync::Arc;
13
14pub struct OpenAiAdapter {
18 transport: DynHttpTransportRef,
19 api_key: String,
20 base_url: String,
21 metrics: Arc<dyn Metrics>,
22}
23
24impl OpenAiAdapter {
25 fn build_default_timeout_secs() -> u64 {
26 std::env::var("AI_HTTP_TIMEOUT_SECS")
27 .ok()
28 .and_then(|s| s.parse::<u64>().ok())
29 .unwrap_or(30)
30 }
31
32 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
33 #[cfg(feature = "unified_transport")]
34 {
35 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
36 let client = crate::transport::client_factory::build_shared_client()
37 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
38 let t = HttpTransport::with_reqwest_client(client, timeout);
39 return Ok(t.boxed());
40 }
41 #[cfg(not(feature = "unified_transport"))]
42 {
43 let t = HttpTransport::new();
44 return Ok(t.boxed());
45 }
46 }
47
48 pub fn new() -> Result<Self, AiLibError> {
49 let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
50 AiLibError::AuthenticationError(
51 "OPENAI_API_KEY environment variable not set".to_string(),
52 )
53 })?;
54
55 Ok(Self {
56 transport: Self::build_default_transport()?,
57 api_key,
58 base_url: "https://api.openai.com/v1".to_string(),
59 metrics: Arc::new(NoopMetrics::new()),
60 })
61 }
62
63 pub fn new_with_overrides(
65 api_key: String,
66 base_url: Option<String>,
67 ) -> Result<Self, AiLibError> {
68 Ok(Self {
69 transport: Self::build_default_transport()?,
70 api_key,
71 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
72 metrics: Arc::new(NoopMetrics::new()),
73 })
74 }
75
76 pub fn with_transport_ref(
78 transport: DynHttpTransportRef,
79 api_key: String,
80 base_url: String,
81 ) -> Result<Self, AiLibError> {
82 Ok(Self {
83 transport,
84 api_key,
85 base_url,
86 metrics: Arc::new(NoopMetrics::new()),
87 })
88 }
89
90 pub fn with_transport_ref_and_metrics(
91 transport: DynHttpTransportRef,
92 api_key: String,
93 base_url: String,
94 metrics: Arc<dyn Metrics>,
95 ) -> Result<Self, AiLibError> {
96 Ok(Self {
97 transport,
98 api_key,
99 base_url,
100 metrics,
101 })
102 }
103
104 pub fn with_metrics(
105 api_key: String,
106 base_url: String,
107 metrics: Arc<dyn Metrics>,
108 ) -> Result<Self, AiLibError> {
109 Ok(Self {
110 transport: Self::build_default_transport()?,
111 api_key,
112 base_url,
113 metrics,
114 })
115 }
116
117 #[allow(dead_code)]
118 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
119 let mut openai_request = serde_json::json!({
121 "model": request.model,
122 "messages": serde_json::Value::Array(vec![])
123 });
124
125 let mut msgs: Vec<serde_json::Value> = Vec::new();
126 for msg in request.messages.iter() {
127 let role = match msg.role {
128 Role::System => "system",
129 Role::User => "user",
130 Role::Assistant => "assistant",
131 };
132 let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
133 msgs.push(serde_json::json!({"role": role, "content": content_val}));
134 }
135 openai_request["messages"] = serde_json::Value::Array(msgs);
136 openai_request
137 }
138
139 async fn convert_request_async(
141 &self,
142 request: &ChatCompletionRequest,
143 ) -> Result<serde_json::Value, AiLibError> {
144 let mut openai_request = serde_json::json!({
148 "model": request.model,
149 "messages": serde_json::Value::Array(vec![])
150 });
151
152 let mut msgs: Vec<serde_json::Value> = Vec::new();
153 for msg in request.messages.iter() {
154 let role = match msg.role {
155 Role::System => "system",
156 Role::User => "user",
157 Role::Assistant => "assistant",
158 };
159
160 let content_val = match &msg.content {
162 crate::types::common::Content::Image { url, mime: _, name } => {
163 if url.is_some() {
164 crate::provider::utils::content_to_provider_value(&msg.content)
165 } else if let Some(n) = name {
166 let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
168 match crate::provider::utils::upload_file_with_transport(
169 Some(self.transport.clone()),
170 &upload_url,
171 n,
172 "file",
173 )
174 .await
175 {
176 Ok(remote) => {
177 if remote.starts_with("http://")
179 || remote.starts_with("https://")
180 || remote.starts_with("data:")
181 {
182 serde_json::json!({"image": {"url": remote}})
183 } else {
184 serde_json::json!({"image": {"file_id": remote}})
186 }
187 }
188 Err(_) => {
189 crate::provider::utils::content_to_provider_value(&msg.content)
190 }
191 }
192 } else {
193 crate::provider::utils::content_to_provider_value(&msg.content)
194 }
195 }
196 _ => crate::provider::utils::content_to_provider_value(&msg.content),
197 };
198 msgs.push(serde_json::json!({"role": role, "content": content_val}));
199 }
200
201 openai_request["messages"] = serde_json::Value::Array(msgs);
202
203 if let Some(temp) = request.temperature {
205 openai_request["temperature"] =
206 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
207 }
208 if let Some(max_tokens) = request.max_tokens {
209 openai_request["max_tokens"] =
210 serde_json::Value::Number(serde_json::Number::from(max_tokens));
211 }
212 if let Some(top_p) = request.top_p {
213 openai_request["top_p"] =
214 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
215 }
216 if let Some(freq_penalty) = request.frequency_penalty {
217 openai_request["frequency_penalty"] = serde_json::Value::Number(
218 serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
219 );
220 }
221 if let Some(presence_penalty) = request.presence_penalty {
222 openai_request["presence_penalty"] = serde_json::Value::Number(
223 serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
224 );
225 }
226
227 if let Some(functions) = &request.functions {
229 openai_request["functions"] =
230 serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
231 }
232
233 if let Some(policy) = &request.function_call {
235 match policy {
236 crate::types::function_call::FunctionCallPolicy::None => {
237 openai_request["function_call"] = serde_json::Value::String("none".to_string());
238 }
239 crate::types::function_call::FunctionCallPolicy::Auto(name) => {
240 if name == "auto" {
241 openai_request["function_call"] =
242 serde_json::Value::String("auto".to_string());
243 } else {
244 openai_request["function_call"] = serde_json::Value::String(name.clone());
245 }
246 }
247 }
248 }
249
250 Ok(openai_request)
251 }
252
253 fn parse_response(
258 &self,
259 response: serde_json::Value,
260 ) -> Result<ChatCompletionResponse, AiLibError> {
261 let choices = response["choices"]
262 .as_array()
263 .ok_or_else(|| {
264 AiLibError::ProviderError("Invalid response format: choices not found".to_string())
265 })?
266 .iter()
267 .enumerate()
268 .map(|(index, choice)| {
269 let message = choice["message"].as_object().ok_or_else(|| {
270 AiLibError::ProviderError("Invalid choice format".to_string())
271 })?;
272
273 let role = match message["role"].as_str().unwrap_or("user") {
274 "system" => Role::System,
275 "assistant" => Role::Assistant,
276 _ => Role::User,
277 };
278
279 let content = message["content"].as_str().unwrap_or("").to_string();
280
281 let mut msg_obj = Message {
283 role,
284 content: crate::types::common::Content::Text(content.clone()),
285 function_call: None,
286 };
287
288 if let Some(fc_val) = message.get("function_call").cloned() {
289 match serde_json::from_value::<crate::types::function_call::FunctionCall>(
291 fc_val.clone(),
292 ) {
293 Ok(fc) => {
294 msg_obj.function_call = Some(fc);
295 }
296 Err(_) => {
297 let name = fc_val
299 .get("name")
300 .and_then(|v| v.as_str())
301 .unwrap_or_default()
302 .to_string();
303 let args_val = match fc_val.get("arguments") {
304 Some(a) if a.is_string() => {
305 a.as_str()
307 .and_then(|s| {
308 serde_json::from_str::<serde_json::Value>(s).ok()
309 })
310 .unwrap_or(serde_json::Value::Null)
311 }
312 Some(a) => a.clone(),
313 None => serde_json::Value::Null,
314 };
315 msg_obj.function_call =
316 Some(crate::types::function_call::FunctionCall {
317 name,
318 arguments: Some(args_val),
319 });
320 }
321 }
322 }
323 Ok(Choice {
324 index: index as u32,
325 message: msg_obj,
326 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
327 })
328 })
329 .collect::<Result<Vec<_>, AiLibError>>()?;
330
331 let usage = response["usage"].as_object().ok_or_else(|| {
332 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
333 })?;
334
335 let usage = Usage {
336 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
337 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
338 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
339 };
340
341 Ok(ChatCompletionResponse {
342 id: response["id"].as_str().unwrap_or("").to_string(),
343 object: response["object"].as_str().unwrap_or("").to_string(),
344 created: response["created"].as_u64().unwrap_or(0),
345 model: response["model"].as_str().unwrap_or("").to_string(),
346 choices,
347 usage,
348 })
349 }
350}
351
352#[async_trait::async_trait]
353impl ChatApi for OpenAiAdapter {
354 async fn chat_completion(
355 &self,
356 request: ChatCompletionRequest,
357 ) -> Result<ChatCompletionResponse, AiLibError> {
358 self.metrics.incr_counter("openai.requests", 1).await;
360 let timer = self.metrics.start_timer("openai.request_duration_ms").await;
361 let url = format!("{}/chat/completions", self.base_url);
362
363 let openai_request = self.convert_request_async(&request).await?;
365
366 let mut headers = HashMap::new();
368 headers.insert(
369 "Authorization".to_string(),
370 format!("Bearer {}", self.api_key),
371 );
372 headers.insert("Content-Type".to_string(), "application/json".to_string());
373
374 let response_json = self
375 .transport
376 .post_json(&url, Some(headers), openai_request)
377 .await?;
378
379 if let Some(t) = timer {
380 t.stop();
381 }
382
383 self.parse_response(response_json)
384 }
385
386 async fn chat_completion_stream(
387 &self,
388 _request: ChatCompletionRequest,
389 ) -> Result<
390 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
391 AiLibError,
392 > {
393 let stream = stream::empty();
394 Ok(Box::new(Box::pin(stream)))
395 }
396
397 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
398 let url = format!("{}/models", self.base_url);
399 let mut headers = HashMap::new();
400 headers.insert(
401 "Authorization".to_string(),
402 format!("Bearer {}", self.api_key),
403 );
404
405 let response = self.transport.get_json(&url, Some(headers)).await?;
406
407 Ok(response["data"]
408 .as_array()
409 .unwrap_or(&vec![])
410 .iter()
411 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
412 .collect())
413 }
414
415 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
416 Ok(ModelInfo {
417 id: model_id.to_string(),
418 object: "model".to_string(),
419 created: 0,
420 owned_by: "openai".to_string(),
421 permission: vec![ModelPermission {
422 id: "default".to_string(),
423 object: "model_permission".to_string(),
424 created: 0,
425 allow_create_engine: false,
426 allow_sampling: true,
427 allow_logprobs: false,
428 allow_search_indices: false,
429 allow_view: true,
430 allow_fine_tuning: false,
431 organization: "*".to_string(),
432 group: None,
433 is_blocking: false,
434 }],
435 })
436 }
437}