1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{
12 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
13};
14
15#[cfg(feature = "audio")]
16use crate::client::AudioGenerationClient;
17#[cfg(feature = "image")]
18use crate::client::ImageGenerationClient;
19
20use serde::Deserialize;
21
22const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
26
27pub struct ClientBuilder<'a> {
28 api_key: &'a str,
29 base_url: &'a str,
30 http_client: Option<reqwest::Client>,
31}
32
33impl<'a> ClientBuilder<'a> {
34 pub fn new(api_key: &'a str) -> Self {
35 Self {
36 api_key,
37 base_url: OPENAI_API_BASE_URL,
38 http_client: None,
39 }
40 }
41
42 pub fn base_url(mut self, base_url: &'a str) -> Self {
43 self.base_url = base_url;
44 self
45 }
46
47 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
48 self.http_client = Some(client);
49 self
50 }
51
52 pub fn build(self) -> Result<Client, ClientBuilderError> {
53 let http_client = if let Some(http_client) = self.http_client {
54 http_client
55 } else {
56 reqwest::Client::builder().build()?
57 };
58
59 Ok(Client {
60 base_url: self.base_url.to_string(),
61 api_key: self.api_key.to_string(),
62 http_client,
63 })
64 }
65}
66
67#[derive(Clone)]
68pub struct Client {
69 base_url: String,
70 api_key: String,
71 http_client: reqwest::Client,
72}
73
74impl std::fmt::Debug for Client {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("Client")
77 .field("base_url", &self.base_url)
78 .field("http_client", &self.http_client)
79 .field("api_key", &"<REDACTED>")
80 .finish()
81 }
82}
83
84impl Client {
85 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
96 ClientBuilder::new(api_key)
97 }
98
99 pub fn new(api_key: &str) -> Self {
104 Self::builder(api_key)
105 .build()
106 .expect("OpenAI client should build")
107 }
108
109 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
110 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
111 self.http_client.post(url).bearer_auth(&self.api_key)
112 }
113}
114
115impl ProviderClient for Client {
116 fn from_env() -> Self {
119 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
120 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
121
122 match base_url {
123 Some(url) => Self::builder(&api_key).base_url(&url).build().unwrap(),
124 None => Self::new(&api_key),
125 }
126 }
127
128 fn from_val(input: crate::client::ProviderValue) -> Self {
129 let crate::client::ProviderValue::Simple(api_key) = input else {
130 panic!("Incorrect provider value type")
131 };
132 Self::new(&api_key)
133 }
134}
135
136impl CompletionClient for Client {
137 type CompletionModel = super::responses_api::ResponsesCompletionModel;
138 fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
150 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
151 }
152}
153
154impl EmbeddingsClient for Client {
155 type EmbeddingModel = EmbeddingModel;
156 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
157 let ndims = match model {
158 TEXT_EMBEDDING_3_LARGE => 3072,
159 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
160 _ => 0,
161 };
162 EmbeddingModel::new(self.clone(), model, ndims)
163 }
164
165 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
166 EmbeddingModel::new(self.clone(), model, ndims)
167 }
168}
169
170impl TranscriptionClient for Client {
171 type TranscriptionModel = TranscriptionModel;
172 fn transcription_model(&self, model: &str) -> TranscriptionModel {
184 TranscriptionModel::new(self.clone(), model)
185 }
186}
187
188#[cfg(feature = "image")]
189impl ImageGenerationClient for Client {
190 type ImageGenerationModel = ImageGenerationModel;
191 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
203 ImageGenerationModel::new(self.clone(), model)
204 }
205}
206
207#[cfg(feature = "audio")]
208impl AudioGenerationClient for Client {
209 type AudioGenerationModel = AudioGenerationModel;
210 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
222 AudioGenerationModel::new(self.clone(), model)
223 }
224}
225
226#[derive(Debug, Deserialize)]
227pub struct ApiErrorResponse {
228 pub(crate) message: String,
229}
230
231#[derive(Debug, Deserialize)]
232#[serde(untagged)]
233pub(crate) enum ApiResponse<T> {
234 Ok(T),
235 Err(ApiErrorResponse),
236}
237
238#[cfg(test)]
239mod tests {
240 use crate::message::ImageDetail;
241 use crate::providers::openai::{
242 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
243 };
244 use crate::{OneOrMany, message};
245 use serde_path_to_error::deserialize;
246
247 #[test]
248 fn test_deserialize_message() {
249 let assistant_message_json = r#"
250 {
251 "role": "assistant",
252 "content": "\n\nHello there, how may I assist you today?"
253 }
254 "#;
255
256 let assistant_message_json2 = r#"
257 {
258 "role": "assistant",
259 "content": [
260 {
261 "type": "text",
262 "text": "\n\nHello there, how may I assist you today?"
263 }
264 ],
265 "tool_calls": null
266 }
267 "#;
268
269 let assistant_message_json3 = r#"
270 {
271 "role": "assistant",
272 "tool_calls": [
273 {
274 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
275 "type": "function",
276 "function": {
277 "name": "subtract",
278 "arguments": "{\"x\": 2, \"y\": 5}"
279 }
280 }
281 ],
282 "content": null,
283 "refusal": null
284 }
285 "#;
286
287 let user_message_json = r#"
288 {
289 "role": "user",
290 "content": [
291 {
292 "type": "text",
293 "text": "What's in this image?"
294 },
295 {
296 "type": "image_url",
297 "image_url": {
298 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
299 }
300 },
301 {
302 "type": "audio",
303 "input_audio": {
304 "data": "...",
305 "format": "mp3"
306 }
307 }
308 ]
309 }
310 "#;
311
312 let assistant_message: Message = {
313 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
314 deserialize(jd).unwrap_or_else(|err| {
315 panic!(
316 "Deserialization error at {} ({}:{}): {}",
317 err.path(),
318 err.inner().line(),
319 err.inner().column(),
320 err
321 );
322 })
323 };
324
325 let assistant_message2: Message = {
326 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
327 deserialize(jd).unwrap_or_else(|err| {
328 panic!(
329 "Deserialization error at {} ({}:{}): {}",
330 err.path(),
331 err.inner().line(),
332 err.inner().column(),
333 err
334 );
335 })
336 };
337
338 let assistant_message3: Message = {
339 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
340 &mut serde_json::Deserializer::from_str(assistant_message_json3);
341 deserialize(jd).unwrap_or_else(|err| {
342 panic!(
343 "Deserialization error at {} ({}:{}): {}",
344 err.path(),
345 err.inner().line(),
346 err.inner().column(),
347 err
348 );
349 })
350 };
351
352 let user_message: Message = {
353 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
354 deserialize(jd).unwrap_or_else(|err| {
355 panic!(
356 "Deserialization error at {} ({}:{}): {}",
357 err.path(),
358 err.inner().line(),
359 err.inner().column(),
360 err
361 );
362 })
363 };
364
365 match assistant_message {
366 Message::Assistant { content, .. } => {
367 assert_eq!(
368 content[0],
369 AssistantContent::Text {
370 text: "\n\nHello there, how may I assist you today?".to_string()
371 }
372 );
373 }
374 _ => panic!("Expected assistant message"),
375 }
376
377 match assistant_message2 {
378 Message::Assistant {
379 content,
380 tool_calls,
381 ..
382 } => {
383 assert_eq!(
384 content[0],
385 AssistantContent::Text {
386 text: "\n\nHello there, how may I assist you today?".to_string()
387 }
388 );
389
390 assert_eq!(tool_calls, vec![]);
391 }
392 _ => panic!("Expected assistant message"),
393 }
394
395 match assistant_message3 {
396 Message::Assistant {
397 content,
398 tool_calls,
399 refusal,
400 ..
401 } => {
402 assert!(content.is_empty());
403 assert!(refusal.is_none());
404 assert_eq!(
405 tool_calls[0],
406 ToolCall {
407 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
408 r#type: ToolType::Function,
409 function: Function {
410 name: "subtract".to_string(),
411 arguments: serde_json::json!({"x": 2, "y": 5}),
412 },
413 }
414 );
415 }
416 _ => panic!("Expected assistant message"),
417 }
418
419 match user_message {
420 Message::User { content, .. } => {
421 let (first, second) = {
422 let mut iter = content.into_iter();
423 (iter.next().unwrap(), iter.next().unwrap())
424 };
425 assert_eq!(
426 first,
427 UserContent::Text {
428 text: "What's in this image?".to_string()
429 }
430 );
431 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
432 }
433 _ => panic!("Expected user message"),
434 }
435 }
436
437 #[test]
438 fn test_message_to_message_conversion() {
439 let user_message = message::Message::User {
440 content: OneOrMany::one(message::UserContent::text("Hello")),
441 };
442
443 let assistant_message = message::Message::Assistant {
444 id: None,
445 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
446 };
447
448 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
449 let converted_assistant_message: Vec<Message> =
450 assistant_message.clone().try_into().unwrap();
451
452 match converted_user_message[0].clone() {
453 Message::User { content, .. } => {
454 assert_eq!(
455 content.first(),
456 UserContent::Text {
457 text: "Hello".to_string()
458 }
459 );
460 }
461 _ => panic!("Expected user message"),
462 }
463
464 match converted_assistant_message[0].clone() {
465 Message::Assistant { content, .. } => {
466 assert_eq!(
467 content[0].clone(),
468 AssistantContent::Text {
469 text: "Hi there!".to_string()
470 }
471 );
472 }
473 _ => panic!("Expected assistant message"),
474 }
475
476 let original_user_message: message::Message =
477 converted_user_message[0].clone().try_into().unwrap();
478 let original_assistant_message: message::Message =
479 converted_assistant_message[0].clone().try_into().unwrap();
480
481 assert_eq!(original_user_message, user_message);
482 assert_eq!(original_assistant_message, assistant_message);
483 }
484
485 #[test]
486 fn test_message_from_message_conversion() {
487 let user_message = Message::User {
488 content: OneOrMany::one(UserContent::Text {
489 text: "Hello".to_string(),
490 }),
491 name: None,
492 };
493
494 let assistant_message = Message::Assistant {
495 content: vec![AssistantContent::Text {
496 text: "Hi there!".to_string(),
497 }],
498 refusal: None,
499 audio: None,
500 name: None,
501 tool_calls: vec![],
502 };
503
504 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
505 let converted_assistant_message: message::Message =
506 assistant_message.clone().try_into().unwrap();
507
508 match converted_user_message.clone() {
509 message::Message::User { content } => {
510 assert_eq!(content.first(), message::UserContent::text("Hello"));
511 }
512 _ => panic!("Expected user message"),
513 }
514
515 match converted_assistant_message.clone() {
516 message::Message::Assistant { content, .. } => {
517 assert_eq!(
518 content.first(),
519 message::AssistantContent::text("Hi there!")
520 );
521 }
522 _ => panic!("Expected assistant message"),
523 }
524
525 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
526 let original_assistant_message: Vec<Message> =
527 converted_assistant_message.try_into().unwrap();
528
529 assert_eq!(original_user_message[0], user_message);
530 assert_eq!(original_assistant_message[0], assistant_message);
531 }
532}