1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
12
13#[cfg(feature = "audio")]
14use crate::client::AudioGenerationClient;
15#[cfg(feature = "image")]
16use crate::client::ImageGenerationClient;
17
18use serde::Deserialize;
19
20const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
24
25#[derive(Clone)]
26pub struct Client {
27 base_url: String,
28 api_key: String,
29 http_client: reqwest::Client,
30}
31
32impl std::fmt::Debug for Client {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("Client")
35 .field("base_url", &self.base_url)
36 .field("http_client", &self.http_client)
37 .field("api_key", &"<REDACTED>")
38 .finish()
39 }
40}
41
42impl Client {
43 pub fn new(api_key: &str) -> Self {
45 Self::from_url(api_key, OPENAI_API_BASE_URL)
46 }
47
48 pub fn from_url(api_key: &str, base_url: &str) -> Self {
50 Self {
51 base_url: base_url.to_string(),
52 api_key: api_key.to_string(),
53 http_client: reqwest::Client::builder()
54 .build()
55 .expect("OpenAI reqwest client should build"),
56 }
57 }
58
59 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
62 self.http_client = client;
63
64 self
65 }
66
67 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
68 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69 self.http_client.post(url).bearer_auth(&self.api_key)
70 }
71}
72
73impl ProviderClient for Client {
74 fn from_env() -> Self {
77 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
78 Self::new(&api_key)
79 }
80}
81
82impl CompletionClient for Client {
83 type CompletionModel = super::responses_api::ResponsesCompletionModel;
84 fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
96 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
97 }
98}
99
100impl EmbeddingsClient for Client {
101 type EmbeddingModel = EmbeddingModel;
102 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
103 let ndims = match model {
104 TEXT_EMBEDDING_3_LARGE => 3072,
105 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
106 _ => 0,
107 };
108 EmbeddingModel::new(self.clone(), model, ndims)
109 }
110
111 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
112 EmbeddingModel::new(self.clone(), model, ndims)
113 }
114}
115
116impl TranscriptionClient for Client {
117 type TranscriptionModel = TranscriptionModel;
118 fn transcription_model(&self, model: &str) -> TranscriptionModel {
130 TranscriptionModel::new(self.clone(), model)
131 }
132}
133
134#[cfg(feature = "image")]
135impl ImageGenerationClient for Client {
136 type ImageGenerationModel = ImageGenerationModel;
137 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
149 ImageGenerationModel::new(self.clone(), model)
150 }
151}
152
153#[cfg(feature = "audio")]
154impl AudioGenerationClient for Client {
155 type AudioGenerationModel = AudioGenerationModel;
156 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
168 AudioGenerationModel::new(self.clone(), model)
169 }
170}
171
172#[derive(Debug, Deserialize)]
173pub struct ApiErrorResponse {
174 pub(crate) message: String,
175}
176
177#[derive(Debug, Deserialize)]
178#[serde(untagged)]
179pub(crate) enum ApiResponse<T> {
180 Ok(T),
181 Err(ApiErrorResponse),
182}
183
184#[cfg(test)]
185mod tests {
186 use crate::message::ImageDetail;
187 use crate::providers::openai::{
188 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
189 };
190 use crate::{OneOrMany, message};
191 use serde_path_to_error::deserialize;
192
193 #[test]
194 fn test_deserialize_message() {
195 let assistant_message_json = r#"
196 {
197 "role": "assistant",
198 "content": "\n\nHello there, how may I assist you today?"
199 }
200 "#;
201
202 let assistant_message_json2 = r#"
203 {
204 "role": "assistant",
205 "content": [
206 {
207 "type": "text",
208 "text": "\n\nHello there, how may I assist you today?"
209 }
210 ],
211 "tool_calls": null
212 }
213 "#;
214
215 let assistant_message_json3 = r#"
216 {
217 "role": "assistant",
218 "tool_calls": [
219 {
220 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
221 "type": "function",
222 "function": {
223 "name": "subtract",
224 "arguments": "{\"x\": 2, \"y\": 5}"
225 }
226 }
227 ],
228 "content": null,
229 "refusal": null
230 }
231 "#;
232
233 let user_message_json = r#"
234 {
235 "role": "user",
236 "content": [
237 {
238 "type": "text",
239 "text": "What's in this image?"
240 },
241 {
242 "type": "image_url",
243 "image_url": {
244 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
245 }
246 },
247 {
248 "type": "audio",
249 "input_audio": {
250 "data": "...",
251 "format": "mp3"
252 }
253 }
254 ]
255 }
256 "#;
257
258 let assistant_message: Message = {
259 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
260 deserialize(jd).unwrap_or_else(|err| {
261 panic!(
262 "Deserialization error at {} ({}:{}): {}",
263 err.path(),
264 err.inner().line(),
265 err.inner().column(),
266 err
267 );
268 })
269 };
270
271 let assistant_message2: Message = {
272 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
273 deserialize(jd).unwrap_or_else(|err| {
274 panic!(
275 "Deserialization error at {} ({}:{}): {}",
276 err.path(),
277 err.inner().line(),
278 err.inner().column(),
279 err
280 );
281 })
282 };
283
284 let assistant_message3: Message = {
285 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
286 &mut serde_json::Deserializer::from_str(assistant_message_json3);
287 deserialize(jd).unwrap_or_else(|err| {
288 panic!(
289 "Deserialization error at {} ({}:{}): {}",
290 err.path(),
291 err.inner().line(),
292 err.inner().column(),
293 err
294 );
295 })
296 };
297
298 let user_message: Message = {
299 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
300 deserialize(jd).unwrap_or_else(|err| {
301 panic!(
302 "Deserialization error at {} ({}:{}): {}",
303 err.path(),
304 err.inner().line(),
305 err.inner().column(),
306 err
307 );
308 })
309 };
310
311 match assistant_message {
312 Message::Assistant { content, .. } => {
313 assert_eq!(
314 content[0],
315 AssistantContent::Text {
316 text: "\n\nHello there, how may I assist you today?".to_string()
317 }
318 );
319 }
320 _ => panic!("Expected assistant message"),
321 }
322
323 match assistant_message2 {
324 Message::Assistant {
325 content,
326 tool_calls,
327 ..
328 } => {
329 assert_eq!(
330 content[0],
331 AssistantContent::Text {
332 text: "\n\nHello there, how may I assist you today?".to_string()
333 }
334 );
335
336 assert_eq!(tool_calls, vec![]);
337 }
338 _ => panic!("Expected assistant message"),
339 }
340
341 match assistant_message3 {
342 Message::Assistant {
343 content,
344 tool_calls,
345 refusal,
346 ..
347 } => {
348 assert!(content.is_empty());
349 assert!(refusal.is_none());
350 assert_eq!(
351 tool_calls[0],
352 ToolCall {
353 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
354 r#type: ToolType::Function,
355 function: Function {
356 name: "subtract".to_string(),
357 arguments: serde_json::json!({"x": 2, "y": 5}),
358 },
359 }
360 );
361 }
362 _ => panic!("Expected assistant message"),
363 }
364
365 match user_message {
366 Message::User { content, .. } => {
367 let (first, second) = {
368 let mut iter = content.into_iter();
369 (iter.next().unwrap(), iter.next().unwrap())
370 };
371 assert_eq!(
372 first,
373 UserContent::Text {
374 text: "What's in this image?".to_string()
375 }
376 );
377 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
378 }
379 _ => panic!("Expected user message"),
380 }
381 }
382
383 #[test]
384 fn test_message_to_message_conversion() {
385 let user_message = message::Message::User {
386 content: OneOrMany::one(message::UserContent::text("Hello")),
387 };
388
389 let assistant_message = message::Message::Assistant {
390 id: None,
391 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
392 };
393
394 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
395 let converted_assistant_message: Vec<Message> =
396 assistant_message.clone().try_into().unwrap();
397
398 match converted_user_message[0].clone() {
399 Message::User { content, .. } => {
400 assert_eq!(
401 content.first(),
402 UserContent::Text {
403 text: "Hello".to_string()
404 }
405 );
406 }
407 _ => panic!("Expected user message"),
408 }
409
410 match converted_assistant_message[0].clone() {
411 Message::Assistant { content, .. } => {
412 assert_eq!(
413 content[0].clone(),
414 AssistantContent::Text {
415 text: "Hi there!".to_string()
416 }
417 );
418 }
419 _ => panic!("Expected assistant message"),
420 }
421
422 let original_user_message: message::Message =
423 converted_user_message[0].clone().try_into().unwrap();
424 let original_assistant_message: message::Message =
425 converted_assistant_message[0].clone().try_into().unwrap();
426
427 assert_eq!(original_user_message, user_message);
428 assert_eq!(original_assistant_message, assistant_message);
429 }
430
431 #[test]
432 fn test_message_from_message_conversion() {
433 let user_message = Message::User {
434 content: OneOrMany::one(UserContent::Text {
435 text: "Hello".to_string(),
436 }),
437 name: None,
438 };
439
440 let assistant_message = Message::Assistant {
441 content: vec![AssistantContent::Text {
442 text: "Hi there!".to_string(),
443 }],
444 refusal: None,
445 audio: None,
446 name: None,
447 tool_calls: vec![],
448 };
449
450 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
451 let converted_assistant_message: message::Message =
452 assistant_message.clone().try_into().unwrap();
453
454 match converted_user_message.clone() {
455 message::Message::User { content } => {
456 assert_eq!(content.first(), message::UserContent::text("Hello"));
457 }
458 _ => panic!("Expected user message"),
459 }
460
461 match converted_assistant_message.clone() {
462 message::Message::Assistant { content, .. } => {
463 assert_eq!(
464 content.first(),
465 message::AssistantContent::text("Hi there!")
466 );
467 }
468 _ => panic!("Expected assistant message"),
469 }
470
471 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
472 let original_assistant_message: Vec<Message> =
473 converted_assistant_message.try_into().unwrap();
474
475 assert_eq!(original_user_message[0], user_message);
476 assert_eq!(original_assistant_message[0], assistant_message);
477 }
478}