1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{
12 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
13 VerifyClient, VerifyError,
14};
15
16#[cfg(feature = "audio")]
17use crate::client::AudioGenerationClient;
18#[cfg(feature = "image")]
19use crate::client::ImageGenerationClient;
20
21use serde::Deserialize;
22
23const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
27
28pub struct ClientBuilder<'a> {
29 api_key: &'a str,
30 base_url: &'a str,
31 http_client: Option<reqwest::Client>,
32}
33
34impl<'a> ClientBuilder<'a> {
35 pub fn new(api_key: &'a str) -> Self {
36 Self {
37 api_key,
38 base_url: OPENAI_API_BASE_URL,
39 http_client: None,
40 }
41 }
42
43 pub fn base_url(mut self, base_url: &'a str) -> Self {
44 self.base_url = base_url;
45 self
46 }
47
48 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
49 self.http_client = Some(client);
50 self
51 }
52
53 pub fn build(self) -> Result<Client, ClientBuilderError> {
54 let http_client = if let Some(http_client) = self.http_client {
55 http_client
56 } else {
57 reqwest::Client::builder().build()?
58 };
59
60 Ok(Client {
61 base_url: self.base_url.to_string(),
62 api_key: self.api_key.to_string(),
63 http_client,
64 })
65 }
66}
67
68#[derive(Clone)]
69pub struct Client {
70 base_url: String,
71 api_key: String,
72 http_client: reqwest::Client,
73}
74
75impl std::fmt::Debug for Client {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("Client")
78 .field("base_url", &self.base_url)
79 .field("http_client", &self.http_client)
80 .field("api_key", &"<REDACTED>")
81 .finish()
82 }
83}
84
85impl Client {
86 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
97 ClientBuilder::new(api_key)
98 }
99
100 pub fn new(api_key: &str) -> Self {
105 Self::builder(api_key)
106 .build()
107 .expect("OpenAI client should build")
108 }
109
110 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
111 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
112 self.http_client.post(url).bearer_auth(&self.api_key)
113 }
114
115 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
116 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
117 self.http_client.get(url).bearer_auth(&self.api_key)
118 }
119}
120
121impl ProviderClient for Client {
122 fn from_env() -> Self {
125 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
126 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
127
128 match base_url {
129 Some(url) => Self::builder(&api_key).base_url(&url).build().unwrap(),
130 None => Self::new(&api_key),
131 }
132 }
133
134 fn from_val(input: crate::client::ProviderValue) -> Self {
135 let crate::client::ProviderValue::Simple(api_key) = input else {
136 panic!("Incorrect provider value type")
137 };
138 Self::new(&api_key)
139 }
140}
141
142impl CompletionClient for Client {
143 type CompletionModel = super::responses_api::ResponsesCompletionModel;
144 fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
156 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
157 }
158}
159
160impl EmbeddingsClient for Client {
161 type EmbeddingModel = EmbeddingModel;
162 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
163 let ndims = match model {
164 TEXT_EMBEDDING_3_LARGE => 3072,
165 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
166 _ => 0,
167 };
168 EmbeddingModel::new(self.clone(), model, ndims)
169 }
170
171 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
172 EmbeddingModel::new(self.clone(), model, ndims)
173 }
174}
175
176impl TranscriptionClient for Client {
177 type TranscriptionModel = TranscriptionModel;
178 fn transcription_model(&self, model: &str) -> TranscriptionModel {
190 TranscriptionModel::new(self.clone(), model)
191 }
192}
193
194#[cfg(feature = "image")]
195impl ImageGenerationClient for Client {
196 type ImageGenerationModel = ImageGenerationModel;
197 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
209 ImageGenerationModel::new(self.clone(), model)
210 }
211}
212
213#[cfg(feature = "audio")]
214impl AudioGenerationClient for Client {
215 type AudioGenerationModel = AudioGenerationModel;
216 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
228 AudioGenerationModel::new(self.clone(), model)
229 }
230}
231
232impl VerifyClient for Client {
233 #[cfg_attr(feature = "worker", worker::send)]
234 async fn verify(&self) -> Result<(), VerifyError> {
235 let response = self.get("/models").send().await?;
236 match response.status() {
237 reqwest::StatusCode::OK => Ok(()),
238 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
239 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
240 Err(VerifyError::ProviderError(response.text().await?))
241 }
242 _ => {
243 response.error_for_status()?;
244 Ok(())
245 }
246 }
247 }
248}
249
250#[derive(Debug, Deserialize)]
251pub struct ApiErrorResponse {
252 pub(crate) message: String,
253}
254
255#[derive(Debug, Deserialize)]
256#[serde(untagged)]
257pub(crate) enum ApiResponse<T> {
258 Ok(T),
259 Err(ApiErrorResponse),
260}
261
262#[cfg(test)]
263mod tests {
264 use crate::message::ImageDetail;
265 use crate::providers::openai::{
266 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
267 };
268 use crate::{OneOrMany, message};
269 use serde_path_to_error::deserialize;
270
271 #[test]
272 fn test_deserialize_message() {
273 let assistant_message_json = r#"
274 {
275 "role": "assistant",
276 "content": "\n\nHello there, how may I assist you today?"
277 }
278 "#;
279
280 let assistant_message_json2 = r#"
281 {
282 "role": "assistant",
283 "content": [
284 {
285 "type": "text",
286 "text": "\n\nHello there, how may I assist you today?"
287 }
288 ],
289 "tool_calls": null
290 }
291 "#;
292
293 let assistant_message_json3 = r#"
294 {
295 "role": "assistant",
296 "tool_calls": [
297 {
298 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
299 "type": "function",
300 "function": {
301 "name": "subtract",
302 "arguments": "{\"x\": 2, \"y\": 5}"
303 }
304 }
305 ],
306 "content": null,
307 "refusal": null
308 }
309 "#;
310
311 let user_message_json = r#"
312 {
313 "role": "user",
314 "content": [
315 {
316 "type": "text",
317 "text": "What's in this image?"
318 },
319 {
320 "type": "image_url",
321 "image_url": {
322 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
323 }
324 },
325 {
326 "type": "audio",
327 "input_audio": {
328 "data": "...",
329 "format": "mp3"
330 }
331 }
332 ]
333 }
334 "#;
335
336 let assistant_message: Message = {
337 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
338 deserialize(jd).unwrap_or_else(|err| {
339 panic!(
340 "Deserialization error at {} ({}:{}): {}",
341 err.path(),
342 err.inner().line(),
343 err.inner().column(),
344 err
345 );
346 })
347 };
348
349 let assistant_message2: Message = {
350 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
351 deserialize(jd).unwrap_or_else(|err| {
352 panic!(
353 "Deserialization error at {} ({}:{}): {}",
354 err.path(),
355 err.inner().line(),
356 err.inner().column(),
357 err
358 );
359 })
360 };
361
362 let assistant_message3: Message = {
363 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
364 &mut serde_json::Deserializer::from_str(assistant_message_json3);
365 deserialize(jd).unwrap_or_else(|err| {
366 panic!(
367 "Deserialization error at {} ({}:{}): {}",
368 err.path(),
369 err.inner().line(),
370 err.inner().column(),
371 err
372 );
373 })
374 };
375
376 let user_message: Message = {
377 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
378 deserialize(jd).unwrap_or_else(|err| {
379 panic!(
380 "Deserialization error at {} ({}:{}): {}",
381 err.path(),
382 err.inner().line(),
383 err.inner().column(),
384 err
385 );
386 })
387 };
388
389 match assistant_message {
390 Message::Assistant { content, .. } => {
391 assert_eq!(
392 content[0],
393 AssistantContent::Text {
394 text: "\n\nHello there, how may I assist you today?".to_string()
395 }
396 );
397 }
398 _ => panic!("Expected assistant message"),
399 }
400
401 match assistant_message2 {
402 Message::Assistant {
403 content,
404 tool_calls,
405 ..
406 } => {
407 assert_eq!(
408 content[0],
409 AssistantContent::Text {
410 text: "\n\nHello there, how may I assist you today?".to_string()
411 }
412 );
413
414 assert_eq!(tool_calls, vec![]);
415 }
416 _ => panic!("Expected assistant message"),
417 }
418
419 match assistant_message3 {
420 Message::Assistant {
421 content,
422 tool_calls,
423 refusal,
424 ..
425 } => {
426 assert!(content.is_empty());
427 assert!(refusal.is_none());
428 assert_eq!(
429 tool_calls[0],
430 ToolCall {
431 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
432 r#type: ToolType::Function,
433 function: Function {
434 name: "subtract".to_string(),
435 arguments: serde_json::json!({"x": 2, "y": 5}),
436 },
437 }
438 );
439 }
440 _ => panic!("Expected assistant message"),
441 }
442
443 match user_message {
444 Message::User { content, .. } => {
445 let (first, second) = {
446 let mut iter = content.into_iter();
447 (iter.next().unwrap(), iter.next().unwrap())
448 };
449 assert_eq!(
450 first,
451 UserContent::Text {
452 text: "What's in this image?".to_string()
453 }
454 );
455 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
456 }
457 _ => panic!("Expected user message"),
458 }
459 }
460
461 #[test]
462 fn test_message_to_message_conversion() {
463 let user_message = message::Message::User {
464 content: OneOrMany::one(message::UserContent::text("Hello")),
465 };
466
467 let assistant_message = message::Message::Assistant {
468 id: None,
469 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
470 };
471
472 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
473 let converted_assistant_message: Vec<Message> =
474 assistant_message.clone().try_into().unwrap();
475
476 match converted_user_message[0].clone() {
477 Message::User { content, .. } => {
478 assert_eq!(
479 content.first(),
480 UserContent::Text {
481 text: "Hello".to_string()
482 }
483 );
484 }
485 _ => panic!("Expected user message"),
486 }
487
488 match converted_assistant_message[0].clone() {
489 Message::Assistant { content, .. } => {
490 assert_eq!(
491 content[0].clone(),
492 AssistantContent::Text {
493 text: "Hi there!".to_string()
494 }
495 );
496 }
497 _ => panic!("Expected assistant message"),
498 }
499
500 let original_user_message: message::Message =
501 converted_user_message[0].clone().try_into().unwrap();
502 let original_assistant_message: message::Message =
503 converted_assistant_message[0].clone().try_into().unwrap();
504
505 assert_eq!(original_user_message, user_message);
506 assert_eq!(original_assistant_message, assistant_message);
507 }
508
509 #[test]
510 fn test_message_from_message_conversion() {
511 let user_message = Message::User {
512 content: OneOrMany::one(UserContent::Text {
513 text: "Hello".to_string(),
514 }),
515 name: None,
516 };
517
518 let assistant_message = Message::Assistant {
519 content: vec![AssistantContent::Text {
520 text: "Hi there!".to_string(),
521 }],
522 refusal: None,
523 audio: None,
524 name: None,
525 tool_calls: vec![],
526 };
527
528 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
529 let converted_assistant_message: message::Message =
530 assistant_message.clone().try_into().unwrap();
531
532 match converted_user_message.clone() {
533 message::Message::User { content } => {
534 assert_eq!(content.first(), message::UserContent::text("Hello"));
535 }
536 _ => panic!("Expected user message"),
537 }
538
539 match converted_assistant_message.clone() {
540 message::Message::Assistant { content, .. } => {
541 assert_eq!(
542 content.first(),
543 message::AssistantContent::text("Hi there!")
544 );
545 }
546 _ => panic!("Expected assistant message"),
547 }
548
549 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
550 let original_assistant_message: Vec<Message> =
551 converted_assistant_message.try_into().unwrap();
552
553 assert_eq!(original_user_message[0], user_message);
554 assert_eq!(original_assistant_message[0], assistant_message);
555 }
556}