1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::completion::CompletionModel;
4use super::embedding::{
5 EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
6};
7
8#[cfg(feature = "image")]
9use super::image_generation::ImageGenerationModel;
10use super::transcription::TranscriptionModel;
11
12use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
13
14#[cfg(feature = "audio")]
15use crate::client::AudioGenerationClient;
16#[cfg(feature = "image")]
17use crate::client::ImageGenerationClient;
18
19use serde::Deserialize;
20
21const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
25#[derive(Clone, Debug)]
26pub struct Client {
27 base_url: String,
28 http_client: reqwest::Client,
29}
30
31impl Client {
32 pub fn new(api_key: &str) -> Self {
34 Self::from_url(api_key, OPENAI_API_BASE_URL)
35 }
36
37 pub fn from_url(api_key: &str, base_url: &str) -> Self {
39 Self {
40 base_url: base_url.to_string(),
41 http_client: reqwest::Client::builder()
42 .default_headers({
43 let mut headers = reqwest::header::HeaderMap::new();
44 headers.insert(
45 "Authorization",
46 format!("Bearer {api_key}")
47 .parse()
48 .expect("Bearer token should parse"),
49 );
50 headers
51 })
52 .build()
53 .expect("OpenAI reqwest client should build"),
54 }
55 }
56
57 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
58 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
59 self.http_client.post(url)
60 }
61}
62
63impl ProviderClient for Client {
64 fn from_env() -> Self {
67 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
68 Self::new(&api_key)
69 }
70}
71
72impl CompletionClient for Client {
73 type CompletionModel = CompletionModel;
74 fn completion_model(&self, model: &str) -> CompletionModel {
86 CompletionModel::new(self.clone(), model)
87 }
88}
89
90impl EmbeddingsClient for Client {
91 type EmbeddingModel = EmbeddingModel;
92 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
93 let ndims = match model {
94 TEXT_EMBEDDING_3_LARGE => 3072,
95 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
96 _ => 0,
97 };
98 EmbeddingModel::new(self.clone(), model, ndims)
99 }
100
101 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
102 EmbeddingModel::new(self.clone(), model, ndims)
103 }
104}
105
106impl TranscriptionClient for Client {
107 type TranscriptionModel = TranscriptionModel;
108 fn transcription_model(&self, model: &str) -> TranscriptionModel {
120 TranscriptionModel::new(self.clone(), model)
121 }
122}
123
124#[cfg(feature = "image")]
125impl ImageGenerationClient for Client {
126 type ImageGenerationModel = ImageGenerationModel;
127 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
139 ImageGenerationModel::new(self.clone(), model)
140 }
141}
142
143#[cfg(feature = "audio")]
144impl AudioGenerationClient for Client {
145 type AudioGenerationModel = AudioGenerationModel;
146 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
158 AudioGenerationModel::new(self.clone(), model)
159 }
160}
161
162#[derive(Clone, Debug, Deserialize)]
163pub struct Usage {
164 pub prompt_tokens: usize,
165 pub total_tokens: usize,
166}
167
168impl std::fmt::Display for Usage {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 write!(
171 f,
172 "Prompt tokens: {} Total tokens: {}",
173 self.prompt_tokens, self.total_tokens
174 )
175 }
176}
177
178#[derive(Debug, Deserialize)]
179pub struct ApiErrorResponse {
180 pub(crate) message: String,
181}
182
183#[derive(Debug, Deserialize)]
184#[serde(untagged)]
185pub(crate) enum ApiResponse<T> {
186 Ok(T),
187 Err(ApiErrorResponse),
188}
189#[cfg(test)]
190mod tests {
191 use crate::message::ImageDetail;
192 use crate::providers::openai::{
193 AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
194 };
195 use crate::{message, OneOrMany};
196 use serde_path_to_error::deserialize;
197
198 #[test]
199 fn test_deserialize_message() {
200 let assistant_message_json = r#"
201 {
202 "role": "assistant",
203 "content": "\n\nHello there, how may I assist you today?"
204 }
205 "#;
206
207 let assistant_message_json2 = r#"
208 {
209 "role": "assistant",
210 "content": [
211 {
212 "type": "text",
213 "text": "\n\nHello there, how may I assist you today?"
214 }
215 ],
216 "tool_calls": null
217 }
218 "#;
219
220 let assistant_message_json3 = r#"
221 {
222 "role": "assistant",
223 "tool_calls": [
224 {
225 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
226 "type": "function",
227 "function": {
228 "name": "subtract",
229 "arguments": "{\"x\": 2, \"y\": 5}"
230 }
231 }
232 ],
233 "content": null,
234 "refusal": null
235 }
236 "#;
237
238 let user_message_json = r#"
239 {
240 "role": "user",
241 "content": [
242 {
243 "type": "text",
244 "text": "What's in this image?"
245 },
246 {
247 "type": "image_url",
248 "image_url": {
249 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
250 }
251 },
252 {
253 "type": "audio",
254 "input_audio": {
255 "data": "...",
256 "format": "mp3"
257 }
258 }
259 ]
260 }
261 "#;
262
263 let assistant_message: Message = {
264 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
265 deserialize(jd).unwrap_or_else(|err| {
266 panic!(
267 "Deserialization error at {} ({}:{}): {}",
268 err.path(),
269 err.inner().line(),
270 err.inner().column(),
271 err
272 );
273 })
274 };
275
276 let assistant_message2: Message = {
277 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
278 deserialize(jd).unwrap_or_else(|err| {
279 panic!(
280 "Deserialization error at {} ({}:{}): {}",
281 err.path(),
282 err.inner().line(),
283 err.inner().column(),
284 err
285 );
286 })
287 };
288
289 let assistant_message3: Message = {
290 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
291 &mut serde_json::Deserializer::from_str(assistant_message_json3);
292 deserialize(jd).unwrap_or_else(|err| {
293 panic!(
294 "Deserialization error at {} ({}:{}): {}",
295 err.path(),
296 err.inner().line(),
297 err.inner().column(),
298 err
299 );
300 })
301 };
302
303 let user_message: Message = {
304 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
305 deserialize(jd).unwrap_or_else(|err| {
306 panic!(
307 "Deserialization error at {} ({}:{}): {}",
308 err.path(),
309 err.inner().line(),
310 err.inner().column(),
311 err
312 );
313 })
314 };
315
316 match assistant_message {
317 Message::Assistant { content, .. } => {
318 assert_eq!(
319 content[0],
320 AssistantContent::Text {
321 text: "\n\nHello there, how may I assist you today?".to_string()
322 }
323 );
324 }
325 _ => panic!("Expected assistant message"),
326 }
327
328 match assistant_message2 {
329 Message::Assistant {
330 content,
331 tool_calls,
332 ..
333 } => {
334 assert_eq!(
335 content[0],
336 AssistantContent::Text {
337 text: "\n\nHello there, how may I assist you today?".to_string()
338 }
339 );
340
341 assert_eq!(tool_calls, vec![]);
342 }
343 _ => panic!("Expected assistant message"),
344 }
345
346 match assistant_message3 {
347 Message::Assistant {
348 content,
349 tool_calls,
350 refusal,
351 ..
352 } => {
353 assert!(content.is_empty());
354 assert!(refusal.is_none());
355 assert_eq!(
356 tool_calls[0],
357 ToolCall {
358 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
359 r#type: ToolType::Function,
360 function: Function {
361 name: "subtract".to_string(),
362 arguments: serde_json::json!({"x": 2, "y": 5}),
363 },
364 }
365 );
366 }
367 _ => panic!("Expected assistant message"),
368 }
369
370 match user_message {
371 Message::User { content, .. } => {
372 let (first, second) = {
373 let mut iter = content.into_iter();
374 (iter.next().unwrap(), iter.next().unwrap())
375 };
376 assert_eq!(
377 first,
378 UserContent::Text {
379 text: "What's in this image?".to_string()
380 }
381 );
382 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
383 }
384 _ => panic!("Expected user message"),
385 }
386 }
387
388 #[test]
389 fn test_message_to_message_conversion() {
390 let user_message = message::Message::User {
391 content: OneOrMany::one(message::UserContent::text("Hello")),
392 };
393
394 let assistant_message = message::Message::Assistant {
395 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
396 };
397
398 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
399 let converted_assistant_message: Vec<Message> =
400 assistant_message.clone().try_into().unwrap();
401
402 match converted_user_message[0].clone() {
403 Message::User { content, .. } => {
404 assert_eq!(
405 content.first(),
406 UserContent::Text {
407 text: "Hello".to_string()
408 }
409 );
410 }
411 _ => panic!("Expected user message"),
412 }
413
414 match converted_assistant_message[0].clone() {
415 Message::Assistant { content, .. } => {
416 assert_eq!(
417 content[0].clone(),
418 AssistantContent::Text {
419 text: "Hi there!".to_string()
420 }
421 );
422 }
423 _ => panic!("Expected assistant message"),
424 }
425
426 let original_user_message: message::Message =
427 converted_user_message[0].clone().try_into().unwrap();
428 let original_assistant_message: message::Message =
429 converted_assistant_message[0].clone().try_into().unwrap();
430
431 assert_eq!(original_user_message, user_message);
432 assert_eq!(original_assistant_message, assistant_message);
433 }
434
435 #[test]
436 fn test_message_from_message_conversion() {
437 let user_message = Message::User {
438 content: OneOrMany::one(UserContent::Text {
439 text: "Hello".to_string(),
440 }),
441 name: None,
442 };
443
444 let assistant_message = Message::Assistant {
445 content: vec![AssistantContent::Text {
446 text: "Hi there!".to_string(),
447 }],
448 refusal: None,
449 audio: None,
450 name: None,
451 tool_calls: vec![],
452 };
453
454 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
455 let converted_assistant_message: message::Message =
456 assistant_message.clone().try_into().unwrap();
457
458 match converted_user_message.clone() {
459 message::Message::User { content } => {
460 assert_eq!(content.first(), message::UserContent::text("Hello"));
461 }
462 _ => panic!("Expected user message"),
463 }
464
465 match converted_assistant_message.clone() {
466 message::Message::Assistant { content } => {
467 assert_eq!(
468 content.first(),
469 message::AssistantContent::text("Hi there!")
470 );
471 }
472 _ => panic!("Expected assistant message"),
473 }
474
475 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
476 let original_assistant_message: Vec<Message> =
477 converted_assistant_message.try_into().unwrap();
478
479 assert_eq!(original_user_message[0], user_message);
480 assert_eq!(original_assistant_message[0], assistant_message);
481 }
482}