1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{
12 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
13};
14
15#[cfg(feature = "audio")]
16use crate::client::AudioGenerationClient;
17#[cfg(feature = "image")]
18use crate::client::ImageGenerationClient;
19
20use serde::Deserialize;
21
22const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
26
27pub struct ClientBuilder<'a> {
28 api_key: &'a str,
29 base_url: &'a str,
30 http_client: Option<reqwest::Client>,
31}
32
33impl<'a> ClientBuilder<'a> {
34 pub fn new(api_key: &'a str) -> Self {
35 Self {
36 api_key,
37 base_url: OPENAI_API_BASE_URL,
38 http_client: None,
39 }
40 }
41
42 pub fn base_url(mut self, base_url: &'a str) -> Self {
43 self.base_url = base_url;
44 self
45 }
46
47 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
48 self.http_client = Some(client);
49 self
50 }
51
52 pub fn build(self) -> Result<Client, ClientBuilderError> {
53 let http_client = if let Some(http_client) = self.http_client {
54 http_client
55 } else {
56 reqwest::Client::builder().build()?
57 };
58
59 Ok(Client {
60 base_url: self.base_url.to_string(),
61 api_key: self.api_key.to_string(),
62 http_client,
63 })
64 }
65}
66
67#[derive(Clone)]
68pub struct Client {
69 base_url: String,
70 api_key: String,
71 http_client: reqwest::Client,
72}
73
74impl std::fmt::Debug for Client {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("Client")
77 .field("base_url", &self.base_url)
78 .field("http_client", &self.http_client)
79 .field("api_key", &"<REDACTED>")
80 .finish()
81 }
82}
83
84impl Client {
85 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
96 ClientBuilder::new(api_key)
97 }
98
99 pub fn new(api_key: &str) -> Self {
104 Self::builder(api_key)
105 .build()
106 .expect("OpenAI client should build")
107 }
108
109 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
110 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
111 self.http_client.post(url).bearer_auth(&self.api_key)
112 }
113}
114
115impl ProviderClient for Client {
116 fn from_env() -> Self {
119 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
120 Self::new(&api_key)
121 }
122
123 fn from_val(input: crate::client::ProviderValue) -> Self {
124 let crate::client::ProviderValue::Simple(api_key) = input else {
125 panic!("Incorrect provider value type")
126 };
127 Self::new(&api_key)
128 }
129}
130
131impl CompletionClient for Client {
132 type CompletionModel = super::responses_api::ResponsesCompletionModel;
133 fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
145 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
146 }
147}
148
149impl EmbeddingsClient for Client {
150 type EmbeddingModel = EmbeddingModel;
151 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
152 let ndims = match model {
153 TEXT_EMBEDDING_3_LARGE => 3072,
154 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
155 _ => 0,
156 };
157 EmbeddingModel::new(self.clone(), model, ndims)
158 }
159
160 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
161 EmbeddingModel::new(self.clone(), model, ndims)
162 }
163}
164
165impl TranscriptionClient for Client {
166 type TranscriptionModel = TranscriptionModel;
167 fn transcription_model(&self, model: &str) -> TranscriptionModel {
179 TranscriptionModel::new(self.clone(), model)
180 }
181}
182
183#[cfg(feature = "image")]
184impl ImageGenerationClient for Client {
185 type ImageGenerationModel = ImageGenerationModel;
186 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
198 ImageGenerationModel::new(self.clone(), model)
199 }
200}
201
202#[cfg(feature = "audio")]
203impl AudioGenerationClient for Client {
204 type AudioGenerationModel = AudioGenerationModel;
205 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
217 AudioGenerationModel::new(self.clone(), model)
218 }
219}
220
221#[derive(Debug, Deserialize)]
222pub struct ApiErrorResponse {
223 pub(crate) message: String,
224}
225
226#[derive(Debug, Deserialize)]
227#[serde(untagged)]
228pub(crate) enum ApiResponse<T> {
229 Ok(T),
230 Err(ApiErrorResponse),
231}
232
233#[cfg(test)]
234mod tests {
235 use crate::message::ImageDetail;
236 use crate::providers::openai::{
237 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
238 };
239 use crate::{OneOrMany, message};
240 use serde_path_to_error::deserialize;
241
242 #[test]
243 fn test_deserialize_message() {
244 let assistant_message_json = r#"
245 {
246 "role": "assistant",
247 "content": "\n\nHello there, how may I assist you today?"
248 }
249 "#;
250
251 let assistant_message_json2 = r#"
252 {
253 "role": "assistant",
254 "content": [
255 {
256 "type": "text",
257 "text": "\n\nHello there, how may I assist you today?"
258 }
259 ],
260 "tool_calls": null
261 }
262 "#;
263
264 let assistant_message_json3 = r#"
265 {
266 "role": "assistant",
267 "tool_calls": [
268 {
269 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
270 "type": "function",
271 "function": {
272 "name": "subtract",
273 "arguments": "{\"x\": 2, \"y\": 5}"
274 }
275 }
276 ],
277 "content": null,
278 "refusal": null
279 }
280 "#;
281
282 let user_message_json = r#"
283 {
284 "role": "user",
285 "content": [
286 {
287 "type": "text",
288 "text": "What's in this image?"
289 },
290 {
291 "type": "image_url",
292 "image_url": {
293 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
294 }
295 },
296 {
297 "type": "audio",
298 "input_audio": {
299 "data": "...",
300 "format": "mp3"
301 }
302 }
303 ]
304 }
305 "#;
306
307 let assistant_message: Message = {
308 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
309 deserialize(jd).unwrap_or_else(|err| {
310 panic!(
311 "Deserialization error at {} ({}:{}): {}",
312 err.path(),
313 err.inner().line(),
314 err.inner().column(),
315 err
316 );
317 })
318 };
319
320 let assistant_message2: Message = {
321 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
322 deserialize(jd).unwrap_or_else(|err| {
323 panic!(
324 "Deserialization error at {} ({}:{}): {}",
325 err.path(),
326 err.inner().line(),
327 err.inner().column(),
328 err
329 );
330 })
331 };
332
333 let assistant_message3: Message = {
334 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
335 &mut serde_json::Deserializer::from_str(assistant_message_json3);
336 deserialize(jd).unwrap_or_else(|err| {
337 panic!(
338 "Deserialization error at {} ({}:{}): {}",
339 err.path(),
340 err.inner().line(),
341 err.inner().column(),
342 err
343 );
344 })
345 };
346
347 let user_message: Message = {
348 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
349 deserialize(jd).unwrap_or_else(|err| {
350 panic!(
351 "Deserialization error at {} ({}:{}): {}",
352 err.path(),
353 err.inner().line(),
354 err.inner().column(),
355 err
356 );
357 })
358 };
359
360 match assistant_message {
361 Message::Assistant { content, .. } => {
362 assert_eq!(
363 content[0],
364 AssistantContent::Text {
365 text: "\n\nHello there, how may I assist you today?".to_string()
366 }
367 );
368 }
369 _ => panic!("Expected assistant message"),
370 }
371
372 match assistant_message2 {
373 Message::Assistant {
374 content,
375 tool_calls,
376 ..
377 } => {
378 assert_eq!(
379 content[0],
380 AssistantContent::Text {
381 text: "\n\nHello there, how may I assist you today?".to_string()
382 }
383 );
384
385 assert_eq!(tool_calls, vec![]);
386 }
387 _ => panic!("Expected assistant message"),
388 }
389
390 match assistant_message3 {
391 Message::Assistant {
392 content,
393 tool_calls,
394 refusal,
395 ..
396 } => {
397 assert!(content.is_empty());
398 assert!(refusal.is_none());
399 assert_eq!(
400 tool_calls[0],
401 ToolCall {
402 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
403 r#type: ToolType::Function,
404 function: Function {
405 name: "subtract".to_string(),
406 arguments: serde_json::json!({"x": 2, "y": 5}),
407 },
408 }
409 );
410 }
411 _ => panic!("Expected assistant message"),
412 }
413
414 match user_message {
415 Message::User { content, .. } => {
416 let (first, second) = {
417 let mut iter = content.into_iter();
418 (iter.next().unwrap(), iter.next().unwrap())
419 };
420 assert_eq!(
421 first,
422 UserContent::Text {
423 text: "What's in this image?".to_string()
424 }
425 );
426 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
427 }
428 _ => panic!("Expected user message"),
429 }
430 }
431
432 #[test]
433 fn test_message_to_message_conversion() {
434 let user_message = message::Message::User {
435 content: OneOrMany::one(message::UserContent::text("Hello")),
436 };
437
438 let assistant_message = message::Message::Assistant {
439 id: None,
440 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
441 };
442
443 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
444 let converted_assistant_message: Vec<Message> =
445 assistant_message.clone().try_into().unwrap();
446
447 match converted_user_message[0].clone() {
448 Message::User { content, .. } => {
449 assert_eq!(
450 content.first(),
451 UserContent::Text {
452 text: "Hello".to_string()
453 }
454 );
455 }
456 _ => panic!("Expected user message"),
457 }
458
459 match converted_assistant_message[0].clone() {
460 Message::Assistant { content, .. } => {
461 assert_eq!(
462 content[0].clone(),
463 AssistantContent::Text {
464 text: "Hi there!".to_string()
465 }
466 );
467 }
468 _ => panic!("Expected assistant message"),
469 }
470
471 let original_user_message: message::Message =
472 converted_user_message[0].clone().try_into().unwrap();
473 let original_assistant_message: message::Message =
474 converted_assistant_message[0].clone().try_into().unwrap();
475
476 assert_eq!(original_user_message, user_message);
477 assert_eq!(original_assistant_message, assistant_message);
478 }
479
480 #[test]
481 fn test_message_from_message_conversion() {
482 let user_message = Message::User {
483 content: OneOrMany::one(UserContent::Text {
484 text: "Hello".to_string(),
485 }),
486 name: None,
487 };
488
489 let assistant_message = Message::Assistant {
490 content: vec![AssistantContent::Text {
491 text: "Hi there!".to_string(),
492 }],
493 refusal: None,
494 audio: None,
495 name: None,
496 tool_calls: vec![],
497 };
498
499 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
500 let converted_assistant_message: message::Message =
501 assistant_message.clone().try_into().unwrap();
502
503 match converted_user_message.clone() {
504 message::Message::User { content } => {
505 assert_eq!(content.first(), message::UserContent::text("Hello"));
506 }
507 _ => panic!("Expected user message"),
508 }
509
510 match converted_assistant_message.clone() {
511 message::Message::Assistant { content, .. } => {
512 assert_eq!(
513 content.first(),
514 message::AssistantContent::text("Hi there!")
515 );
516 }
517 _ => panic!("Expected assistant message"),
518 }
519
520 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
521 let original_assistant_message: Vec<Message> =
522 converted_assistant_message.try_into().unwrap();
523
524 assert_eq!(original_user_message[0], user_message);
525 assert_eq!(original_assistant_message[0], assistant_message);
526 }
527}