1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12pub struct OpenAiAdapter {
16 transport: DynHttpTransportRef,
17 api_key: String,
18 base_url: String,
19 metrics: Arc<dyn Metrics>,
20}
21
22impl OpenAiAdapter {
23 pub fn new() -> Result<Self, AiLibError> {
24 let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
25 AiLibError::AuthenticationError(
26 "OPENAI_API_KEY environment variable not set".to_string(),
27 )
28 })?;
29
30 Ok(Self {
31 transport: HttpTransport::new().boxed(),
32 api_key,
33 base_url: "https://api.openai.com/v1".to_string(),
34 metrics: Arc::new(NoopMetrics::new()),
35 })
36 }
37
38 pub fn with_transport_ref(
40 transport: DynHttpTransportRef,
41 api_key: String,
42 base_url: String,
43 ) -> Result<Self, AiLibError> {
44 Ok(Self {
45 transport,
46 api_key,
47 base_url,
48 metrics: Arc::new(NoopMetrics::new()),
49 })
50 }
51
52 pub fn with_transport_ref_and_metrics(
53 transport: DynHttpTransportRef,
54 api_key: String,
55 base_url: String,
56 metrics: Arc<dyn Metrics>,
57 ) -> Result<Self, AiLibError> {
58 Ok(Self {
59 transport,
60 api_key,
61 base_url,
62 metrics,
63 })
64 }
65
66 pub fn with_metrics(
67 api_key: String,
68 base_url: String,
69 metrics: Arc<dyn Metrics>,
70 ) -> Result<Self, AiLibError> {
71 Ok(Self {
72 transport: HttpTransport::new().boxed(),
73 api_key,
74 base_url,
75 metrics,
76 })
77 }
78
79 #[allow(dead_code)]
80 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
81 let mut openai_request = serde_json::json!({
83 "model": request.model,
84 "messages": serde_json::Value::Array(vec![])
85 });
86
87 let mut msgs: Vec<serde_json::Value> = Vec::new();
88 for msg in request.messages.iter() {
89 let role = match msg.role {
90 Role::System => "system",
91 Role::User => "user",
92 Role::Assistant => "assistant",
93 };
94 let content_val = crate::provider::utils::content_to_provider_value(&msg.content);
95 msgs.push(serde_json::json!({"role": role, "content": content_val}));
96 }
97 openai_request["messages"] = serde_json::Value::Array(msgs);
98 openai_request
99 }
100
101 async fn convert_request_async(
103 &self,
104 request: &ChatCompletionRequest,
105 ) -> Result<serde_json::Value, AiLibError> {
106 let mut openai_request = serde_json::json!({
110 "model": request.model,
111 "messages": serde_json::Value::Array(vec![])
112 });
113
114 let mut msgs: Vec<serde_json::Value> = Vec::new();
115 for msg in request.messages.iter() {
116 let role = match msg.role {
117 Role::System => "system",
118 Role::User => "user",
119 Role::Assistant => "assistant",
120 };
121
122 let content_val = match &msg.content {
124 crate::types::common::Content::Image { url, mime: _, name } => {
125 if url.is_some() {
126 crate::provider::utils::content_to_provider_value(&msg.content)
127 } else if let Some(n) = name {
128 let upload_url = format!("{}/files", self.base_url.trim_end_matches('/'));
130 match crate::provider::utils::upload_file_with_transport(
131 Some(self.transport.clone()),
132 &upload_url,
133 n,
134 "file",
135 )
136 .await
137 {
138 Ok(remote) => {
139 if remote.starts_with("http://")
141 || remote.starts_with("https://")
142 || remote.starts_with("data:")
143 {
144 serde_json::json!({"image": {"url": remote}})
145 } else {
146 serde_json::json!({"image": {"file_id": remote}})
148 }
149 }
150 Err(_) => {
151 crate::provider::utils::content_to_provider_value(&msg.content)
152 }
153 }
154 } else {
155 crate::provider::utils::content_to_provider_value(&msg.content)
156 }
157 }
158 _ => crate::provider::utils::content_to_provider_value(&msg.content),
159 };
160 msgs.push(serde_json::json!({"role": role, "content": content_val}));
161 }
162
163 openai_request["messages"] = serde_json::Value::Array(msgs);
164
165 if let Some(temp) = request.temperature {
167 openai_request["temperature"] =
168 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
169 }
170 if let Some(max_tokens) = request.max_tokens {
171 openai_request["max_tokens"] =
172 serde_json::Value::Number(serde_json::Number::from(max_tokens));
173 }
174 if let Some(top_p) = request.top_p {
175 openai_request["top_p"] =
176 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
177 }
178 if let Some(freq_penalty) = request.frequency_penalty {
179 openai_request["frequency_penalty"] = serde_json::Value::Number(
180 serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
181 );
182 }
183 if let Some(presence_penalty) = request.presence_penalty {
184 openai_request["presence_penalty"] = serde_json::Value::Number(
185 serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
186 );
187 }
188
189 if let Some(functions) = &request.functions {
191 openai_request["functions"] =
192 serde_json::to_value(functions).unwrap_or(serde_json::Value::Null);
193 }
194
195 if let Some(policy) = &request.function_call {
197 match policy {
198 crate::types::function_call::FunctionCallPolicy::None => {
199 openai_request["function_call"] = serde_json::Value::String("none".to_string());
200 }
201 crate::types::function_call::FunctionCallPolicy::Auto(name) => {
202 if name == "auto" {
203 openai_request["function_call"] =
204 serde_json::Value::String("auto".to_string());
205 } else {
206 openai_request["function_call"] = serde_json::Value::String(name.clone());
207 }
208 }
209 }
210 }
211
212 Ok(openai_request)
213 }
214
215 fn parse_response(
220 &self,
221 response: serde_json::Value,
222 ) -> Result<ChatCompletionResponse, AiLibError> {
223 let choices = response["choices"]
224 .as_array()
225 .ok_or_else(|| {
226 AiLibError::ProviderError("Invalid response format: choices not found".to_string())
227 })?
228 .iter()
229 .enumerate()
230 .map(|(index, choice)| {
231 let message = choice["message"].as_object().ok_or_else(|| {
232 AiLibError::ProviderError("Invalid choice format".to_string())
233 })?;
234
235 let role = match message["role"].as_str().unwrap_or("user") {
236 "system" => Role::System,
237 "assistant" => Role::Assistant,
238 _ => Role::User,
239 };
240
241 let content = message["content"].as_str().unwrap_or("").to_string();
242
243 let mut msg_obj = Message {
245 role,
246 content: crate::types::common::Content::Text(content.clone()),
247 function_call: None,
248 };
249
250 if let Some(fc_val) = message.get("function_call").cloned() {
251 match serde_json::from_value::<crate::types::function_call::FunctionCall>(
253 fc_val.clone(),
254 ) {
255 Ok(fc) => {
256 msg_obj.function_call = Some(fc);
257 }
258 Err(_) => {
259 let name = fc_val
261 .get("name")
262 .and_then(|v| v.as_str())
263 .unwrap_or_default()
264 .to_string();
265 let args_val = match fc_val.get("arguments") {
266 Some(a) if a.is_string() => {
267 a.as_str()
269 .and_then(|s| {
270 serde_json::from_str::<serde_json::Value>(s).ok()
271 })
272 .unwrap_or(serde_json::Value::Null)
273 }
274 Some(a) => a.clone(),
275 None => serde_json::Value::Null,
276 };
277 msg_obj.function_call =
278 Some(crate::types::function_call::FunctionCall {
279 name,
280 arguments: Some(args_val),
281 });
282 }
283 }
284 }
285 Ok(Choice {
286 index: index as u32,
287 message: msg_obj,
288 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
289 })
290 })
291 .collect::<Result<Vec<_>, AiLibError>>()?;
292
293 let usage = response["usage"].as_object().ok_or_else(|| {
294 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
295 })?;
296
297 let usage = Usage {
298 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
299 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
300 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
301 };
302
303 Ok(ChatCompletionResponse {
304 id: response["id"].as_str().unwrap_or("").to_string(),
305 object: response["object"].as_str().unwrap_or("").to_string(),
306 created: response["created"].as_u64().unwrap_or(0),
307 model: response["model"].as_str().unwrap_or("").to_string(),
308 choices,
309 usage,
310 })
311 }
312}
313
314#[async_trait::async_trait]
315impl ChatApi for OpenAiAdapter {
316 async fn chat_completion(
317 &self,
318 request: ChatCompletionRequest,
319 ) -> Result<ChatCompletionResponse, AiLibError> {
320 let openai_request = self
322 .convert_request_async(&request)
323 .await
324 .unwrap_or(serde_json::json!({}));
325 let url = format!("{}/chat/completions", self.base_url);
326
327 self.metrics.incr_counter("openai.requests", 1).await;
329 let timer = self.metrics.start_timer("openai.request_duration_ms").await;
330
331 let mut headers = HashMap::new();
332 headers.insert(
333 "Authorization".to_string(),
334 format!("Bearer {}", self.api_key),
335 );
336 headers.insert("Content-Type".to_string(), "application/json".to_string());
337
338 let response = match self
339 .transport
340 .post_json(&url, Some(headers), openai_request)
341 .await
342 {
343 Ok(v) => {
344 if let Some(t) = timer {
345 t.stop();
346 }
347 v
348 }
349 Err(e) => {
350 if let Some(t) = timer {
351 t.stop();
352 }
353 return Err(e);
354 }
355 };
356
357 self.parse_response(response)
358 }
359
360 async fn chat_completion_stream(
361 &self,
362 _request: ChatCompletionRequest,
363 ) -> Result<
364 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
365 AiLibError,
366 > {
367 let stream = stream::empty();
368 Ok(Box::new(Box::pin(stream)))
369 }
370
371 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
372 let url = format!("{}/models", self.base_url);
373 let mut headers = HashMap::new();
374 headers.insert(
375 "Authorization".to_string(),
376 format!("Bearer {}", self.api_key),
377 );
378
379 let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
380
381 Ok(response["data"]
382 .as_array()
383 .unwrap_or(&vec![])
384 .iter()
385 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
386 .collect())
387 }
388
389 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
390 Ok(ModelInfo {
391 id: model_id.to_string(),
392 object: "model".to_string(),
393 created: 0,
394 owned_by: "openai".to_string(),
395 permission: vec![ModelPermission {
396 id: "default".to_string(),
397 object: "model_permission".to_string(),
398 created: 0,
399 allow_create_engine: false,
400 allow_sampling: true,
401 allow_logprobs: false,
402 allow_search_indices: false,
403 allow_view: true,
404 allow_fine_tuning: false,
405 organization: "*".to_string(),
406 group: None,
407 is_blocking: false,
408 }],
409 })
410 }
411}