1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 extractor::ExtractorBuilder,
7 http_client::{self, HttpClientExt},
8 prelude::CompletionClient,
9 wasm_compat::{WasmCompatSend, WasmCompatSync},
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14
15const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
19
20#[derive(Debug, Default, Clone, Copy)]
24pub struct OpenAIResponsesExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27pub struct OpenAIResponsesExtBuilder;
28
29#[derive(Debug, Default, Clone, Copy)]
33pub struct OpenAICompletionsExt;
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct OpenAICompletionsExtBuilder;
37
38type OpenAIApiKey = BearerAuth;
39
40pub type Client<H = reqwest::Client> = client::Client<OpenAIResponsesExt, H>;
42pub type ClientBuilder<H = reqwest::Client> =
43 client::ClientBuilder<OpenAIResponsesExtBuilder, OpenAIApiKey, H>;
44
45pub type CompletionsClient<H = reqwest::Client> = client::Client<OpenAICompletionsExt, H>;
47pub type CompletionsClientBuilder<H = reqwest::Client> =
48 client::ClientBuilder<OpenAICompletionsExtBuilder, OpenAIApiKey, H>;
49
50impl Provider for OpenAIResponsesExt {
51 type Builder = OpenAIResponsesExtBuilder;
52
53 const VERIFY_PATH: &'static str = "/models";
54
55 fn build<H>(
56 _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
57 ) -> http_client::Result<Self> {
58 Ok(Self)
59 }
60}
61
62impl Provider for OpenAICompletionsExt {
63 type Builder = OpenAICompletionsExtBuilder;
64
65 const VERIFY_PATH: &'static str = "/models";
66
67 fn build<H>(
68 _: &crate::client::ClientBuilder<Self::Builder, OpenAIApiKey, H>,
69 ) -> http_client::Result<Self> {
70 Ok(Self)
71 }
72}
73
74impl<H> Capabilities<H> for OpenAIResponsesExt {
75 type Completion = Capable<super::responses_api::ResponsesCompletionModel<H>>;
76 type Embeddings = Capable<super::EmbeddingModel<H>>;
77 type Transcription = Capable<super::TranscriptionModel<H>>;
78 #[cfg(feature = "image")]
79 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
80 #[cfg(feature = "audio")]
81 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
82}
83
84impl<H> Capabilities<H> for OpenAICompletionsExt {
85 type Completion = Capable<super::completion::CompletionModel<H>>;
86 type Embeddings = Capable<super::EmbeddingModel<H>>;
87 type Transcription = Capable<super::TranscriptionModel<H>>;
88 #[cfg(feature = "image")]
89 type ImageGeneration = Capable<super::ImageGenerationModel<H>>;
90 #[cfg(feature = "audio")]
91 type AudioGeneration = Capable<super::audio_generation::AudioGenerationModel<H>>;
92}
93
94impl DebugExt for OpenAIResponsesExt {}
95
96impl DebugExt for OpenAICompletionsExt {}
97
98impl ProviderBuilder for OpenAIResponsesExtBuilder {
99 type Output = OpenAIResponsesExt;
100 type ApiKey = OpenAIApiKey;
101
102 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
103}
104
105impl ProviderBuilder for OpenAICompletionsExtBuilder {
106 type Output = OpenAICompletionsExt;
107 type ApiKey = OpenAIApiKey;
108
109 const BASE_URL: &'static str = OPENAI_API_BASE_URL;
110}
111
112impl<H> Client<H>
113where
114 H: HttpClientExt
115 + Clone
116 + std::fmt::Debug
117 + Default
118 + WasmCompatSend
119 + WasmCompatSync
120 + 'static,
121{
122 pub fn extractor<U>(
125 &self,
126 model: impl Into<String>,
127 ) -> ExtractorBuilder<super::responses_api::ResponsesCompletionModel<H>, U>
128 where
129 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
130 {
131 ExtractorBuilder::new(self.completion_model(model))
132 }
133
134 pub fn completions_api(self) -> CompletionsClient<H> {
137 client::Client::from_parts(
138 self.base_url().to_string(),
139 self.headers().clone(),
140 self.http_client().clone(),
141 OpenAICompletionsExt,
142 )
143 }
144}
145
146impl<H> CompletionsClient<H>
147where
148 H: HttpClientExt
149 + Clone
150 + std::fmt::Debug
151 + Default
152 + WasmCompatSend
153 + WasmCompatSync
154 + 'static,
155{
156 pub fn extractor<U>(
159 &self,
160 model: impl Into<String>,
161 ) -> ExtractorBuilder<super::completion::CompletionModel<H>, U>
162 where
163 U: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
164 {
165 ExtractorBuilder::new(self.completion_model(model))
166 }
167
168 pub fn responses_api(self) -> Client<H> {
171 client::Client::from_parts(
172 self.base_url().to_string(),
173 self.headers().clone(),
174 self.http_client().clone(),
175 OpenAIResponsesExt,
176 )
177 }
178}
179
180impl ProviderClient for Client {
181 type Input = OpenAIApiKey;
182
183 fn from_env() -> Self {
186 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
187 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
188
189 let mut builder = Client::builder().api_key(&api_key);
190
191 if let Some(base) = base_url {
192 builder = builder.base_url(&base);
193 }
194
195 builder.build().unwrap()
196 }
197
198 fn from_val(input: Self::Input) -> Self {
199 Self::new(input).unwrap()
200 }
201}
202
203impl ProviderClient for CompletionsClient {
204 type Input = OpenAIApiKey;
205
206 fn from_env() -> Self {
209 let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
210 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
211
212 let mut builder = CompletionsClient::builder().api_key(&api_key);
213
214 if let Some(base) = base_url {
215 builder = builder.base_url(&base);
216 }
217
218 builder.build().unwrap()
219 }
220
221 fn from_val(input: Self::Input) -> Self {
222 Self::new(input).unwrap()
223 }
224}
225
226#[derive(Debug, Deserialize)]
227pub struct ApiErrorResponse {
228 pub(crate) message: String,
229}
230
231#[derive(Debug, Deserialize)]
232#[serde(untagged)]
233pub(crate) enum ApiResponse<T> {
234 Ok(T),
235 Err(ApiErrorResponse),
236}
237
238#[cfg(test)]
239mod tests {
240 use crate::message::ImageDetail;
241 use crate::providers::openai::{
242 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
243 };
244 use crate::{OneOrMany, message};
245 use serde_path_to_error::deserialize;
246
247 #[test]
248 fn test_deserialize_message() {
249 let assistant_message_json = r#"
250 {
251 "role": "assistant",
252 "content": "\n\nHello there, how may I assist you today?"
253 }
254 "#;
255
256 let assistant_message_json2 = r#"
257 {
258 "role": "assistant",
259 "content": [
260 {
261 "type": "text",
262 "text": "\n\nHello there, how may I assist you today?"
263 }
264 ],
265 "tool_calls": null
266 }
267 "#;
268
269 let assistant_message_json3 = r#"
270 {
271 "role": "assistant",
272 "tool_calls": [
273 {
274 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
275 "type": "function",
276 "function": {
277 "name": "subtract",
278 "arguments": "{\"x\": 2, \"y\": 5}"
279 }
280 }
281 ],
282 "content": null,
283 "refusal": null
284 }
285 "#;
286
287 let user_message_json = r#"
288 {
289 "role": "user",
290 "content": [
291 {
292 "type": "text",
293 "text": "What's in this image?"
294 },
295 {
296 "type": "image_url",
297 "image_url": {
298 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
299 }
300 },
301 {
302 "type": "audio",
303 "input_audio": {
304 "data": "...",
305 "format": "mp3"
306 }
307 }
308 ]
309 }
310 "#;
311
312 let assistant_message: Message = {
313 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
314 deserialize(jd).unwrap_or_else(|err| {
315 panic!(
316 "Deserialization error at {} ({}:{}): {}",
317 err.path(),
318 err.inner().line(),
319 err.inner().column(),
320 err
321 );
322 })
323 };
324
325 let assistant_message2: Message = {
326 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
327 deserialize(jd).unwrap_or_else(|err| {
328 panic!(
329 "Deserialization error at {} ({}:{}): {}",
330 err.path(),
331 err.inner().line(),
332 err.inner().column(),
333 err
334 );
335 })
336 };
337
338 let assistant_message3: Message = {
339 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
340 &mut serde_json::Deserializer::from_str(assistant_message_json3);
341 deserialize(jd).unwrap_or_else(|err| {
342 panic!(
343 "Deserialization error at {} ({}:{}): {}",
344 err.path(),
345 err.inner().line(),
346 err.inner().column(),
347 err
348 );
349 })
350 };
351
352 let user_message: Message = {
353 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
354 deserialize(jd).unwrap_or_else(|err| {
355 panic!(
356 "Deserialization error at {} ({}:{}): {}",
357 err.path(),
358 err.inner().line(),
359 err.inner().column(),
360 err
361 );
362 })
363 };
364
365 match assistant_message {
366 Message::Assistant { content, .. } => {
367 assert_eq!(
368 content[0],
369 AssistantContent::Text {
370 text: "\n\nHello there, how may I assist you today?".to_string()
371 }
372 );
373 }
374 _ => panic!("Expected assistant message"),
375 }
376
377 match assistant_message2 {
378 Message::Assistant {
379 content,
380 tool_calls,
381 ..
382 } => {
383 assert_eq!(
384 content[0],
385 AssistantContent::Text {
386 text: "\n\nHello there, how may I assist you today?".to_string()
387 }
388 );
389
390 assert_eq!(tool_calls, vec![]);
391 }
392 _ => panic!("Expected assistant message"),
393 }
394
395 match assistant_message3 {
396 Message::Assistant {
397 content,
398 tool_calls,
399 refusal,
400 ..
401 } => {
402 assert!(content.is_empty());
403 assert!(refusal.is_none());
404 assert_eq!(
405 tool_calls[0],
406 ToolCall {
407 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
408 r#type: ToolType::Function,
409 function: Function {
410 name: "subtract".to_string(),
411 arguments: serde_json::json!({"x": 2, "y": 5}),
412 },
413 }
414 );
415 }
416 _ => panic!("Expected assistant message"),
417 }
418
419 match user_message {
420 Message::User { content, .. } => {
421 let (first, second) = {
422 let mut iter = content.into_iter();
423 (iter.next().unwrap(), iter.next().unwrap())
424 };
425 assert_eq!(
426 first,
427 UserContent::Text {
428 text: "What's in this image?".to_string()
429 }
430 );
431 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
432 }
433 _ => panic!("Expected user message"),
434 }
435 }
436
437 #[test]
438 fn test_message_to_message_conversion() {
439 let user_message = message::Message::User {
440 content: OneOrMany::one(message::UserContent::text("Hello")),
441 };
442
443 let assistant_message = message::Message::Assistant {
444 id: None,
445 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
446 };
447
448 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
449 let converted_assistant_message: Vec<Message> =
450 assistant_message.clone().try_into().unwrap();
451
452 match converted_user_message[0].clone() {
453 Message::User { content, .. } => {
454 assert_eq!(
455 content.first(),
456 UserContent::Text {
457 text: "Hello".to_string()
458 }
459 );
460 }
461 _ => panic!("Expected user message"),
462 }
463
464 match converted_assistant_message[0].clone() {
465 Message::Assistant { content, .. } => {
466 assert_eq!(
467 content[0].clone(),
468 AssistantContent::Text {
469 text: "Hi there!".to_string()
470 }
471 );
472 }
473 _ => panic!("Expected assistant message"),
474 }
475
476 let original_user_message: message::Message =
477 converted_user_message[0].clone().try_into().unwrap();
478 let original_assistant_message: message::Message =
479 converted_assistant_message[0].clone().try_into().unwrap();
480
481 assert_eq!(original_user_message, user_message);
482 assert_eq!(original_assistant_message, assistant_message);
483 }
484
485 #[test]
486 fn test_message_from_message_conversion() {
487 let user_message = Message::User {
488 content: OneOrMany::one(UserContent::Text {
489 text: "Hello".to_string(),
490 }),
491 name: None,
492 };
493
494 let assistant_message = Message::Assistant {
495 content: vec![AssistantContent::Text {
496 text: "Hi there!".to_string(),
497 }],
498 refusal: None,
499 audio: None,
500 name: None,
501 tool_calls: vec![],
502 };
503
504 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
505 let converted_assistant_message: message::Message =
506 assistant_message.clone().try_into().unwrap();
507
508 match converted_user_message.clone() {
509 message::Message::User { content } => {
510 assert_eq!(content.first(), message::UserContent::text("Hello"));
511 }
512 _ => panic!("Expected user message"),
513 }
514
515 match converted_assistant_message.clone() {
516 message::Message::Assistant { content, .. } => {
517 assert_eq!(
518 content.first(),
519 message::AssistantContent::text("Hi there!")
520 );
521 }
522 _ => panic!("Expected assistant message"),
523 }
524
525 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
526 let original_assistant_message: Vec<Message> =
527 converted_assistant_message.try_into().unwrap();
528
529 assert_eq!(original_user_message[0], user_message);
530 assert_eq!(original_assistant_message[0], assistant_message);
531 }
532}