1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::{
12 client::{
13 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient,
14 TranscriptionClient, VerifyClient, VerifyError,
15 },
16 extractor::ExtractorBuilder,
17 providers::openai::CompletionModel,
18};
19
20#[cfg(feature = "audio")]
21use crate::client::AudioGenerationClient;
22#[cfg(feature = "image")]
23use crate::client::ImageGenerationClient;
24
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27
28const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
32
33pub struct ClientBuilder<'a> {
34 api_key: &'a str,
35 base_url: &'a str,
36 http_client: Option<reqwest::Client>,
37}
38
39impl<'a> ClientBuilder<'a> {
40 pub fn new(api_key: &'a str) -> Self {
41 Self {
42 api_key,
43 base_url: OPENAI_API_BASE_URL,
44 http_client: None,
45 }
46 }
47
48 pub fn base_url(mut self, base_url: &'a str) -> Self {
49 self.base_url = base_url;
50 self
51 }
52
53 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
54 self.http_client = Some(client);
55 self
56 }
57
58 pub fn build(self) -> Result<Client, ClientBuilderError> {
59 let http_client = if let Some(http_client) = self.http_client {
60 http_client
61 } else {
62 reqwest::Client::builder().build()?
63 };
64
65 Ok(Client {
66 base_url: self.base_url.to_string(),
67 api_key: self.api_key.to_string(),
68 http_client,
69 })
70 }
71}
72
73#[derive(Clone)]
74pub struct Client {
75 base_url: String,
76 api_key: String,
77 http_client: reqwest::Client,
78}
79
80impl std::fmt::Debug for Client {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("Client")
83 .field("base_url", &self.base_url)
84 .field("http_client", &self.http_client)
85 .field("api_key", &"<REDACTED>")
86 .finish()
87 }
88}
89
90impl Client {
91 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
102 ClientBuilder::new(api_key)
103 }
104
105 pub fn new(api_key: &str) -> Self {
110 Self::builder(api_key)
111 .build()
112 .expect("OpenAI client should build")
113 }
114
115 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
116 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
117 self.http_client.post(url).bearer_auth(&self.api_key)
118 }
119
120 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
121 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
122 self.http_client.get(url).bearer_auth(&self.api_key)
123 }
124
125 pub fn extractor_completions_api<T>(&self, model: &str) -> ExtractorBuilder<CompletionModel, T>
129 where
130 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
131 {
132 ExtractorBuilder::new(self.completion_model(model).completions_api())
133 }
134}
135
136impl ProviderClient for Client {
137 fn from_env() -> Self {
140 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
141 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
142
143 match base_url {
144 Some(url) => Self::builder(&api_key).base_url(&url).build().unwrap(),
145 None => Self::new(&api_key),
146 }
147 }
148
149 fn from_val(input: crate::client::ProviderValue) -> Self {
150 let crate::client::ProviderValue::Simple(api_key) = input else {
151 panic!("Incorrect provider value type")
152 };
153 Self::new(&api_key)
154 }
155}
156
157impl CompletionClient for Client {
158 type CompletionModel = super::responses_api::ResponsesCompletionModel;
159 fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
171 super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
172 }
173}
174
175impl EmbeddingsClient for Client {
176 type EmbeddingModel = EmbeddingModel;
177 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
178 let ndims = match model {
179 TEXT_EMBEDDING_3_LARGE => 3072,
180 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
181 _ => 0,
182 };
183 EmbeddingModel::new(self.clone(), model, ndims)
184 }
185
186 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
187 EmbeddingModel::new(self.clone(), model, ndims)
188 }
189}
190
191impl TranscriptionClient for Client {
192 type TranscriptionModel = TranscriptionModel;
193 fn transcription_model(&self, model: &str) -> TranscriptionModel {
205 TranscriptionModel::new(self.clone(), model)
206 }
207}
208
209#[cfg(feature = "image")]
210impl ImageGenerationClient for Client {
211 type ImageGenerationModel = ImageGenerationModel;
212 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
224 ImageGenerationModel::new(self.clone(), model)
225 }
226}
227
228#[cfg(feature = "audio")]
229impl AudioGenerationClient for Client {
230 type AudioGenerationModel = AudioGenerationModel;
231 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
243 AudioGenerationModel::new(self.clone(), model)
244 }
245}
246
247impl VerifyClient for Client {
248 #[cfg_attr(feature = "worker", worker::send)]
249 async fn verify(&self) -> Result<(), VerifyError> {
250 let response = self.get("/models").send().await?;
251 match response.status() {
252 reqwest::StatusCode::OK => Ok(()),
253 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
254 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
255 Err(VerifyError::ProviderError(response.text().await?))
256 }
257 _ => {
258 response.error_for_status()?;
259 Ok(())
260 }
261 }
262 }
263}
264
265#[derive(Debug, Deserialize)]
266pub struct ApiErrorResponse {
267 pub(crate) message: String,
268}
269
270#[derive(Debug, Deserialize)]
271#[serde(untagged)]
272pub(crate) enum ApiResponse<T> {
273 Ok(T),
274 Err(ApiErrorResponse),
275}
276
277#[cfg(test)]
278mod tests {
279 use crate::message::ImageDetail;
280 use crate::providers::openai::{
281 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
282 };
283 use crate::{OneOrMany, message};
284 use serde_path_to_error::deserialize;
285
286 #[test]
287 fn test_deserialize_message() {
288 let assistant_message_json = r#"
289 {
290 "role": "assistant",
291 "content": "\n\nHello there, how may I assist you today?"
292 }
293 "#;
294
295 let assistant_message_json2 = r#"
296 {
297 "role": "assistant",
298 "content": [
299 {
300 "type": "text",
301 "text": "\n\nHello there, how may I assist you today?"
302 }
303 ],
304 "tool_calls": null
305 }
306 "#;
307
308 let assistant_message_json3 = r#"
309 {
310 "role": "assistant",
311 "tool_calls": [
312 {
313 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
314 "type": "function",
315 "function": {
316 "name": "subtract",
317 "arguments": "{\"x\": 2, \"y\": 5}"
318 }
319 }
320 ],
321 "content": null,
322 "refusal": null
323 }
324 "#;
325
326 let user_message_json = r#"
327 {
328 "role": "user",
329 "content": [
330 {
331 "type": "text",
332 "text": "What's in this image?"
333 },
334 {
335 "type": "image_url",
336 "image_url": {
337 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
338 }
339 },
340 {
341 "type": "audio",
342 "input_audio": {
343 "data": "...",
344 "format": "mp3"
345 }
346 }
347 ]
348 }
349 "#;
350
351 let assistant_message: Message = {
352 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
353 deserialize(jd).unwrap_or_else(|err| {
354 panic!(
355 "Deserialization error at {} ({}:{}): {}",
356 err.path(),
357 err.inner().line(),
358 err.inner().column(),
359 err
360 );
361 })
362 };
363
364 let assistant_message2: Message = {
365 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
366 deserialize(jd).unwrap_or_else(|err| {
367 panic!(
368 "Deserialization error at {} ({}:{}): {}",
369 err.path(),
370 err.inner().line(),
371 err.inner().column(),
372 err
373 );
374 })
375 };
376
377 let assistant_message3: Message = {
378 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
379 &mut serde_json::Deserializer::from_str(assistant_message_json3);
380 deserialize(jd).unwrap_or_else(|err| {
381 panic!(
382 "Deserialization error at {} ({}:{}): {}",
383 err.path(),
384 err.inner().line(),
385 err.inner().column(),
386 err
387 );
388 })
389 };
390
391 let user_message: Message = {
392 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
393 deserialize(jd).unwrap_or_else(|err| {
394 panic!(
395 "Deserialization error at {} ({}:{}): {}",
396 err.path(),
397 err.inner().line(),
398 err.inner().column(),
399 err
400 );
401 })
402 };
403
404 match assistant_message {
405 Message::Assistant { content, .. } => {
406 assert_eq!(
407 content[0],
408 AssistantContent::Text {
409 text: "\n\nHello there, how may I assist you today?".to_string()
410 }
411 );
412 }
413 _ => panic!("Expected assistant message"),
414 }
415
416 match assistant_message2 {
417 Message::Assistant {
418 content,
419 tool_calls,
420 ..
421 } => {
422 assert_eq!(
423 content[0],
424 AssistantContent::Text {
425 text: "\n\nHello there, how may I assist you today?".to_string()
426 }
427 );
428
429 assert_eq!(tool_calls, vec![]);
430 }
431 _ => panic!("Expected assistant message"),
432 }
433
434 match assistant_message3 {
435 Message::Assistant {
436 content,
437 tool_calls,
438 refusal,
439 ..
440 } => {
441 assert!(content.is_empty());
442 assert!(refusal.is_none());
443 assert_eq!(
444 tool_calls[0],
445 ToolCall {
446 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
447 r#type: ToolType::Function,
448 function: Function {
449 name: "subtract".to_string(),
450 arguments: serde_json::json!({"x": 2, "y": 5}),
451 },
452 }
453 );
454 }
455 _ => panic!("Expected assistant message"),
456 }
457
458 match user_message {
459 Message::User { content, .. } => {
460 let (first, second) = {
461 let mut iter = content.into_iter();
462 (iter.next().unwrap(), iter.next().unwrap())
463 };
464 assert_eq!(
465 first,
466 UserContent::Text {
467 text: "What's in this image?".to_string()
468 }
469 );
470 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
471 }
472 _ => panic!("Expected user message"),
473 }
474 }
475
476 #[test]
477 fn test_message_to_message_conversion() {
478 let user_message = message::Message::User {
479 content: OneOrMany::one(message::UserContent::text("Hello")),
480 };
481
482 let assistant_message = message::Message::Assistant {
483 id: None,
484 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
485 };
486
487 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
488 let converted_assistant_message: Vec<Message> =
489 assistant_message.clone().try_into().unwrap();
490
491 match converted_user_message[0].clone() {
492 Message::User { content, .. } => {
493 assert_eq!(
494 content.first(),
495 UserContent::Text {
496 text: "Hello".to_string()
497 }
498 );
499 }
500 _ => panic!("Expected user message"),
501 }
502
503 match converted_assistant_message[0].clone() {
504 Message::Assistant { content, .. } => {
505 assert_eq!(
506 content[0].clone(),
507 AssistantContent::Text {
508 text: "Hi there!".to_string()
509 }
510 );
511 }
512 _ => panic!("Expected assistant message"),
513 }
514
515 let original_user_message: message::Message =
516 converted_user_message[0].clone().try_into().unwrap();
517 let original_assistant_message: message::Message =
518 converted_assistant_message[0].clone().try_into().unwrap();
519
520 assert_eq!(original_user_message, user_message);
521 assert_eq!(original_assistant_message, assistant_message);
522 }
523
524 #[test]
525 fn test_message_from_message_conversion() {
526 let user_message = Message::User {
527 content: OneOrMany::one(UserContent::Text {
528 text: "Hello".to_string(),
529 }),
530 name: None,
531 };
532
533 let assistant_message = Message::Assistant {
534 content: vec![AssistantContent::Text {
535 text: "Hi there!".to_string(),
536 }],
537 refusal: None,
538 audio: None,
539 name: None,
540 tool_calls: vec![],
541 };
542
543 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
544 let converted_assistant_message: message::Message =
545 assistant_message.clone().try_into().unwrap();
546
547 match converted_user_message.clone() {
548 message::Message::User { content } => {
549 assert_eq!(content.first(), message::UserContent::text("Hello"));
550 }
551 _ => panic!("Expected user message"),
552 }
553
554 match converted_assistant_message.clone() {
555 message::Message::Assistant { content, .. } => {
556 assert_eq!(
557 content.first(),
558 message::AssistantContent::text("Hi there!")
559 );
560 }
561 _ => panic!("Expected assistant message"),
562 }
563
564 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
565 let original_assistant_message: Vec<Message> =
566 converted_assistant_message.try_into().unwrap();
567
568 assert_eq!(original_user_message[0], user_message);
569 assert_eq!(original_assistant_message[0], assistant_message);
570 }
571}