1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12pub struct OpenAiAdapter {
16 transport: DynHttpTransportRef,
17 api_key: String,
18 base_url: String,
19 metrics: Arc<dyn Metrics>,
20}
21
22impl OpenAiAdapter {
23 pub fn new() -> Result<Self, AiLibError> {
24 let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
25 AiLibError::AuthenticationError(
26 "OPENAI_API_KEY environment variable not set".to_string(),
27 )
28 })?;
29
30 Ok(Self {
31 transport: HttpTransport::new().boxed(),
32 api_key,
33 base_url: "https://api.openai.com/v1".to_string(),
34 metrics: Arc::new(NoopMetrics::new()),
35 })
36 }
37
38 pub fn new_with_overrides(
40 api_key: String,
41 base_url: Option<String>,
42 ) -> Result<Self, AiLibError> {
43 Ok(Self {
44 transport: HttpTransport::new().boxed(),
45 api_key,
46 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
47 metrics: Arc::new(NoopMetrics::new()),
48 })
49 }
50
51 pub fn with_transport_ref(
53 transport: DynHttpTransportRef,
54 api_key: String,
55 base_url: String,
56 ) -> Result<Self, AiLibError> {
57 Ok(Self {
58 transport,
59 api_key,
60 base_url,
61 metrics: Arc::new(NoopMetrics::new()),
62 })
63 }
64
65 pub fn with_transport_ref_and_metrics(
66 transport: DynHttpTransportRef,
67 api_key: String,
68 base_url: String,
69 metrics: Arc<dyn Metrics>,
70 ) -> Result<Self, AiLibError> {
71 Ok(Self {
72 transport,
73 api_key,
74 base_url,
75 metrics,
76 })
77 }
78
79 pub fn with_metrics(
80 api_key: String,
81 base_url: String,
82 metrics: Arc<dyn Metrics>,
83 ) -> Result<Self, AiLibError> {
84 Ok(Self {
85 transport: HttpTransport::new().boxed(),
86 api_key,
87 base_url,
88 metrics,
89 })
90 }
91
92 #[allow(dead_code)]
93 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
94 let mut openai_request = serde_json::json!({
96 "model": request.model,
97 "messages": serde_json::Value::Array(vec![])
98 });
99
100 let mut msgs: Vec<serde_json::Value> = Vec::new();
101 for msg in request.messages.iter() {
102 let role = match msg.role {
103 Role::System => "system",
104 Role::User => "user",
105 Role::Assistant => "assistant",
106 };
107 let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
108 msgs.push(serde_json::json!({"role": role, "content": content_val}));
109 }
110 openai_request["messages"] = serde_json::Value::Array(msgs);
111 openai_request
112 }
113
114 async fn convert_request_async(
116 &self,
117 request: &ChatCompletionRequest,
118 ) -> Result<serde_json::Value, AiLibError> {
119 let mut openai_request = serde_json::json!({
123 "model": request.model,
124 "messages": serde_json::Value::Array(vec![])
125 });
126
127 let mut msgs: Vec<serde_json::Value> = Vec::new();
128 for msg in request.messages.iter() {
129 let role = match msg.role {
130 Role::System => "system",
131 Role::User => "user",
132 Role::Assistant => "assistant",
133 };
134
135 let content_val = match &msg.content {
137 crate::types::common::Content::Image { url, mime: _, name } => {
138 if url.is_some() {
139 crate::provider::utils::content_to_provider_value(&msg.content)
140 } else if let Some(n) = name {
141 let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
143 match crate::provider::utils::upload_file_with_transport(
144 Some(self.transport.clone()),
145 &upload_url,
146 n,
147 "file",
148 )
149 .await
150 {
151 Ok(remote) => {
152 if remote.starts_with("http://")
154 || remote.starts_with("https://")
155 || remote.starts_with("data:")
156 {
157 serde_json::json!({"image": {"url": remote}})
158 } else {
159 serde_json::json!({"image": {"file_id": remote}})
161 }
162 }
163 Err(_) => {
164 crate::provider::utils::content_to_provider_value(&msg.content)
165 }
166 }
167 } else {
168 crate::provider::utils::content_to_provider_value(&msg.content)
169 }
170 }
171 _ => crate::provider::utils::content_to_provider_value(&msg.content),
172 };
173 msgs.push(serde_json::json!({"role": role, "content": content_val}));
174 }
175
176 openai_request["messages"] = serde_json::Value::Array(msgs);
177
178 if let Some(temp) = request.temperature {
180 openai_request["temperature"] =
181 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
182 }
183 if let Some(max_tokens) = request.max_tokens {
184 openai_request["max_tokens"] =
185 serde_json::Value::Number(serde_json::Number::from(max_tokens));
186 }
187 if let Some(top_p) = request.top_p {
188 openai_request["top_p"] =
189 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
190 }
191 if let Some(freq_penalty) = request.frequency_penalty {
192 openai_request["frequency_penalty"] = serde_json::Value::Number(
193 serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
194 );
195 }
196 if let Some(presence_penalty) = request.presence_penalty {
197 openai_request["presence_penalty"] = serde_json::Value::Number(
198 serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
199 );
200 }
201
202 if let Some(functions) = &request.functions {
204 openai_request["functions"] =
205 serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
206 }
207
208 if let Some(policy) = &request.function_call {
210 match policy {
211 crate::types::function_call::FunctionCallPolicy::None => {
212 openai_request["function_call"] = serde_json::Value::String("none".to_string());
213 }
214 crate::types::function_call::FunctionCallPolicy::Auto(name) => {
215 if name == "auto" {
216 openai_request["function_call"] =
217 serde_json::Value::String("auto".to_string());
218 } else {
219 openai_request["function_call"] = serde_json::Value::String(name.clone());
220 }
221 }
222 }
223 }
224
225 Ok(openai_request)
226 }
227
228 fn parse_response(
233 &self,
234 response: serde_json::Value,
235 ) -> Result<ChatCompletionResponse, AiLibError> {
236 let choices = response["choices"]
237 .as_array()
238 .ok_or_else(|| {
239 AiLibError::ProviderError("Invalid response format: choices not found".to_string())
240 })?
241 .iter()
242 .enumerate()
243 .map(|(index, choice)| {
244 let message = choice["message"].as_object().ok_or_else(|| {
245 AiLibError::ProviderError("Invalid choice format".to_string())
246 })?;
247
248 let role = match message["role"].as_str().unwrap_or("user") {
249 "system" => Role::System,
250 "assistant" => Role::Assistant,
251 _ => Role::User,
252 };
253
254 let content = message["content"].as_str().unwrap_or("").to_string();
255
256 let mut msg_obj = Message {
258 role,
259 content: crate::types::common::Content::Text(content.clone()),
260 function_call: None,
261 };
262
263 if let Some(fc_val) = message.get("function_call").cloned() {
264 match serde_json::from_value::<crate::types::function_call::FunctionCall>(
266 fc_val.clone(),
267 ) {
268 Ok(fc) => {
269 msg_obj.function_call = Some(fc);
270 }
271 Err(_) => {
272 let name = fc_val
274 .get("name")
275 .and_then(|v| v.as_str())
276 .unwrap_or_default()
277 .to_string();
278 let args_val = match fc_val.get("arguments") {
279 Some(a) if a.is_string() => {
280 a.as_str()
282 .and_then(|s| {
283 serde_json::from_str::<serde_json::Value>(s).ok()
284 })
285 .unwrap_or(serde_json::Value::Null)
286 }
287 Some(a) => a.clone(),
288 None => serde_json::Value::Null,
289 };
290 msg_obj.function_call =
291 Some(crate::types::function_call::FunctionCall {
292 name,
293 arguments: Some(args_val),
294 });
295 }
296 }
297 }
298 Ok(Choice {
299 index: index as u32,
300 message: msg_obj,
301 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
302 })
303 })
304 .collect::<Result<Vec<_>, AiLibError>>()?;
305
306 let usage = response["usage"].as_object().ok_or_else(|| {
307 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
308 })?;
309
310 let usage = Usage {
311 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
312 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
313 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
314 };
315
316 Ok(ChatCompletionResponse {
317 id: response["id"].as_str().unwrap_or("").to_string(),
318 object: response["object"].as_str().unwrap_or("").to_string(),
319 created: response["created"].as_u64().unwrap_or(0),
320 model: response["model"].as_str().unwrap_or("").to_string(),
321 choices,
322 usage,
323 })
324 }
325}
326
327#[async_trait::async_trait]
328impl ChatApi for OpenAiAdapter {
329 async fn chat_completion(
330 &self,
331 request: ChatCompletionRequest,
332 ) -> Result<ChatCompletionResponse, AiLibError> {
333 let openai_request = self
335 .convert_request_async(&request)
336 .await
337 .unwrap_or(serde_json::json!({}));
338 let url = format!("{}/chat/completions", self.base_url);
339
340 self.metrics.incr_counter("openai.requests", 1).await;
342 let timer = self.metrics.start_timer("openai.request_duration_ms").await;
343
344 let mut headers = HashMap::new();
345 headers.insert(
346 "Authorization".to_string(),
347 format!("Bearer {}", self.api_key),
348 );
349 headers.insert("Content-Type".to_string(), "application/json".to_string());
350
351 let response = match self
352 .transport
353 .post_json(&url, Some(headers), openai_request)
354 .await
355 {
356 Ok(v) => {
357 if let Some(t) = timer {
358 t.stop();
359 }
360 v
361 }
362 Err(e) => {
363 if let Some(t) = timer {
364 t.stop();
365 }
366 return Err(e);
367 }
368 };
369
370 self.parse_response(response)
371 }
372
373 async fn chat_completion_stream(
374 &self,
375 _request: ChatCompletionRequest,
376 ) -> Result<
377 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
378 AiLibError,
379 > {
380 let stream = stream::empty();
381 Ok(Box::new(Box::pin(stream)))
382 }
383
384 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
385 let url = format!("{}/models", self.base_url);
386 let mut headers = HashMap::new();
387 headers.insert(
388 "Authorization".to_string(),
389 format!("Bearer {}", self.api_key),
390 );
391
392 let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
393
394 Ok(response["data"]
395 .as_array()
396 .unwrap_or(&vec![])
397 .iter()
398 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
399 .collect())
400 }
401
402 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
403 Ok(ModelInfo {
404 id: model_id.to_string(),
405 object: "model".to_string(),
406 created: 0,
407 owned_by: "openai".to_string(),
408 permission: vec![ModelPermission {
409 id: "default".to_string(),
410 object: "model_permission".to_string(),
411 created: 0,
412 allow_create_engine: false,
413 allow_sampling: true,
414 allow_logprobs: false,
415 allow_search_indices: false,
416 allow_view: true,
417 allow_fine_tuning: false,
418 organization: "*".to_string(),
419 group: None,
420 is_blocking: false,
421 }],
422 })
423 }
424}