1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
12
13#[cfg(feature = "audio")]
14use crate::client::AudioGenerationClient;
15#[cfg(feature = "image")]
16use crate::client::ImageGenerationClient;
17
18use serde::Deserialize;
19
20const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
24
25#[derive(Clone)]
26pub struct Client {
27 base_url: String,
28 api_key: String,
29 http_client: reqwest::Client,
30}
31
32impl std::fmt::Debug for Client {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("Client")
35 .field("base_url", &self.base_url)
36 .field("http_client", &self.http_client)
37 .field("api_key", &"<REDACTED>")
38 .finish()
39 }
40}
41
42impl Client {
43 pub fn new(api_key: &str) -> Self {
45 Self::from_url(api_key, OPENAI_API_BASE_URL)
46 }
47
48 pub fn from_url(api_key: &str, base_url: &str) -> Self {
50 Self {
51 base_url: base_url.to_string(),
52 api_key: api_key.to_string(),
53 http_client: reqwest::Client::builder()
54 .build()
55 .expect("OpenAI reqwest client should build"),
56 }
57 }
58
59 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
62 self.http_client = client;
63
64 self
65 }
66
67 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
68 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69 self.http_client.post(url).bearer_auth(&self.api_key)
70 }
71}
72
73impl ProviderClient for Client {
74 fn from_env() -> Self {
77 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
78 Self::new(&api_key)
79 }
80
81 fn from_val(input: crate::client::ProviderValue) -> Self {
82 let crate::client::ProviderValue::Simple(api_key) = input else {
83 panic!("Incorrect provider value type")
84 };
85 Self::new(&api_key)
86 }
87}
88
89impl CompletionClient for Client {
90 type CompletionModel = super::responses_api::ResponsesCompletionModel;
91 fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
103 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
104 }
105}
106
107impl EmbeddingsClient for Client {
108 type EmbeddingModel = EmbeddingModel;
109 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
110 let ndims = match model {
111 TEXT_EMBEDDING_3_LARGE => 3072,
112 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
113 _ => 0,
114 };
115 EmbeddingModel::new(self.clone(), model, ndims)
116 }
117
118 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
119 EmbeddingModel::new(self.clone(), model, ndims)
120 }
121}
122
123impl TranscriptionClient for Client {
124 type TranscriptionModel = TranscriptionModel;
125 fn transcription_model(&self, model: &str) -> TranscriptionModel {
137 TranscriptionModel::new(self.clone(), model)
138 }
139}
140
141#[cfg(feature = "image")]
142impl ImageGenerationClient for Client {
143 type ImageGenerationModel = ImageGenerationModel;
144 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
156 ImageGenerationModel::new(self.clone(), model)
157 }
158}
159
160#[cfg(feature = "audio")]
161impl AudioGenerationClient for Client {
162 type AudioGenerationModel = AudioGenerationModel;
163 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
175 AudioGenerationModel::new(self.clone(), model)
176 }
177}
178
179#[derive(Debug, Deserialize)]
180pub struct ApiErrorResponse {
181 pub(crate) message: String,
182}
183
184#[derive(Debug, Deserialize)]
185#[serde(untagged)]
186pub(crate) enum ApiResponse<T> {
187 Ok(T),
188 Err(ApiErrorResponse),
189}
190
191#[cfg(test)]
192mod tests {
193 use crate::message::ImageDetail;
194 use crate::providers::openai::{
195 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
196 };
197 use crate::{OneOrMany, message};
198 use serde_path_to_error::deserialize;
199
200 #[test]
201 fn test_deserialize_message() {
202 let assistant_message_json = r#"
203 {
204 "role": "assistant",
205 "content": "\n\nHello there, how may I assist you today?"
206 }
207 "#;
208
209 let assistant_message_json2 = r#"
210 {
211 "role": "assistant",
212 "content": [
213 {
214 "type": "text",
215 "text": "\n\nHello there, how may I assist you today?"
216 }
217 ],
218 "tool_calls": null
219 }
220 "#;
221
222 let assistant_message_json3 = r#"
223 {
224 "role": "assistant",
225 "tool_calls": [
226 {
227 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
228 "type": "function",
229 "function": {
230 "name": "subtract",
231 "arguments": "{\"x\": 2, \"y\": 5}"
232 }
233 }
234 ],
235 "content": null,
236 "refusal": null
237 }
238 "#;
239
240 let user_message_json = r#"
241 {
242 "role": "user",
243 "content": [
244 {
245 "type": "text",
246 "text": "What's in this image?"
247 },
248 {
249 "type": "image_url",
250 "image_url": {
251 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
252 }
253 },
254 {
255 "type": "audio",
256 "input_audio": {
257 "data": "...",
258 "format": "mp3"
259 }
260 }
261 ]
262 }
263 "#;
264
265 let assistant_message: Message = {
266 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
267 deserialize(jd).unwrap_or_else(|err| {
268 panic!(
269 "Deserialization error at {} ({}:{}): {}",
270 err.path(),
271 err.inner().line(),
272 err.inner().column(),
273 err
274 );
275 })
276 };
277
278 let assistant_message2: Message = {
279 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
280 deserialize(jd).unwrap_or_else(|err| {
281 panic!(
282 "Deserialization error at {} ({}:{}): {}",
283 err.path(),
284 err.inner().line(),
285 err.inner().column(),
286 err
287 );
288 })
289 };
290
291 let assistant_message3: Message = {
292 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
293 &mut serde_json::Deserializer::from_str(assistant_message_json3);
294 deserialize(jd).unwrap_or_else(|err| {
295 panic!(
296 "Deserialization error at {} ({}:{}): {}",
297 err.path(),
298 err.inner().line(),
299 err.inner().column(),
300 err
301 );
302 })
303 };
304
305 let user_message: Message = {
306 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
307 deserialize(jd).unwrap_or_else(|err| {
308 panic!(
309 "Deserialization error at {} ({}:{}): {}",
310 err.path(),
311 err.inner().line(),
312 err.inner().column(),
313 err
314 );
315 })
316 };
317
318 match assistant_message {
319 Message::Assistant { content, .. } => {
320 assert_eq!(
321 content[0],
322 AssistantContent::Text {
323 text: "\n\nHello there, how may I assist you today?".to_string()
324 }
325 );
326 }
327 _ => panic!("Expected assistant message"),
328 }
329
330 match assistant_message2 {
331 Message::Assistant {
332 content,
333 tool_calls,
334 ..
335 } => {
336 assert_eq!(
337 content[0],
338 AssistantContent::Text {
339 text: "\n\nHello there, how may I assist you today?".to_string()
340 }
341 );
342
343 assert_eq!(tool_calls, vec![]);
344 }
345 _ => panic!("Expected assistant message"),
346 }
347
348 match assistant_message3 {
349 Message::Assistant {
350 content,
351 tool_calls,
352 refusal,
353 ..
354 } => {
355 assert!(content.is_empty());
356 assert!(refusal.is_none());
357 assert_eq!(
358 tool_calls[0],
359 ToolCall {
360 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
361 r#type: ToolType::Function,
362 function: Function {
363 name: "subtract".to_string(),
364 arguments: serde_json::json!({"x": 2, "y": 5}),
365 },
366 }
367 );
368 }
369 _ => panic!("Expected assistant message"),
370 }
371
372 match user_message {
373 Message::User { content, .. } => {
374 let (first, second) = {
375 let mut iter = content.into_iter();
376 (iter.next().unwrap(), iter.next().unwrap())
377 };
378 assert_eq!(
379 first,
380 UserContent::Text {
381 text: "What's in this image?".to_string()
382 }
383 );
384 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
385 }
386 _ => panic!("Expected user message"),
387 }
388 }
389
390 #[test]
391 fn test_message_to_message_conversion() {
392 let user_message = message::Message::User {
393 content: OneOrMany::one(message::UserContent::text("Hello")),
394 };
395
396 let assistant_message = message::Message::Assistant {
397 id: None,
398 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
399 };
400
401 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
402 let converted_assistant_message: Vec<Message> =
403 assistant_message.clone().try_into().unwrap();
404
405 match converted_user_message[0].clone() {
406 Message::User { content, .. } => {
407 assert_eq!(
408 content.first(),
409 UserContent::Text {
410 text: "Hello".to_string()
411 }
412 );
413 }
414 _ => panic!("Expected user message"),
415 }
416
417 match converted_assistant_message[0].clone() {
418 Message::Assistant { content, .. } => {
419 assert_eq!(
420 content[0].clone(),
421 AssistantContent::Text {
422 text: "Hi there!".to_string()
423 }
424 );
425 }
426 _ => panic!("Expected assistant message"),
427 }
428
429 let original_user_message: message::Message =
430 converted_user_message[0].clone().try_into().unwrap();
431 let original_assistant_message: message::Message =
432 converted_assistant_message[0].clone().try_into().unwrap();
433
434 assert_eq!(original_user_message, user_message);
435 assert_eq!(original_assistant_message, assistant_message);
436 }
437
438 #[test]
439 fn test_message_from_message_conversion() {
440 let user_message = Message::User {
441 content: OneOrMany::one(UserContent::Text {
442 text: "Hello".to_string(),
443 }),
444 name: None,
445 };
446
447 let assistant_message = Message::Assistant {
448 content: vec![AssistantContent::Text {
449 text: "Hi there!".to_string(),
450 }],
451 refusal: None,
452 audio: None,
453 name: None,
454 tool_calls: vec![],
455 };
456
457 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
458 let converted_assistant_message: message::Message =
459 assistant_message.clone().try_into().unwrap();
460
461 match converted_user_message.clone() {
462 message::Message::User { content } => {
463 assert_eq!(content.first(), message::UserContent::text("Hello"));
464 }
465 _ => panic!("Expected user message"),
466 }
467
468 match converted_assistant_message.clone() {
469 message::Message::Assistant { content, .. } => {
470 assert_eq!(
471 content.first(),
472 message::AssistantContent::text("Hi there!")
473 );
474 }
475 _ => panic!("Expected assistant message"),
476 }
477
478 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
479 let original_assistant_message: Vec<Message> =
480 converted_assistant_message.try_into().unwrap();
481
482 assert_eq!(original_user_message[0], user_message);
483 assert_eq!(original_assistant_message[0], assistant_message);
484 }
485}