1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Error, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
12 file::Files,
13 image::Images,
14 moderation::Moderations,
15 traits::AsyncTryFrom,
16 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17 Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21pub struct Client<C: Config> {
24 http_client: reqwest::Client,
25 config: C,
26 backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30 pub fn new() -> Self {
32 Self::default()
33 }
34}
35
36#[derive(Debug, Deserialize)]
37struct CustomError {
38 error: String
39}
40
41impl<C: Config> Client<C> {
42 pub fn build(
44 http_client: reqwest::Client,
45 config: C,
46 backoff: backoff::ExponentialBackoff,
47 ) -> Self {
48 Self {
49 http_client,
50 config,
51 backoff,
52 }
53 }
54
55 pub fn with_config(config: C) -> Self {
57 Self {
58 http_client: reqwest::Client::new(),
59 config,
60 backoff: Default::default(),
61 }
62 }
63
64 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68 self.http_client = http_client;
69 self
70 }
71
72 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74 self.backoff = backoff;
75 self
76 }
77
78 pub fn models(&self) -> Models<C> {
82 Models::new(self)
83 }
84
85 pub fn completions(&self) -> Completions<C> {
87 Completions::new(self)
88 }
89
90 pub fn chat(&self) -> Chat<C> {
92 Chat::new(self)
93 }
94
95 pub fn images(&self) -> Images<C> {
97 Images::new(self)
98 }
99
100 pub fn moderations(&self) -> Moderations<C> {
102 Moderations::new(self)
103 }
104
105 pub fn files(&self) -> Files<C> {
107 Files::new(self)
108 }
109
110 pub fn uploads(&self) -> Uploads<C> {
112 Uploads::new(self)
113 }
114
115 pub fn fine_tuning(&self) -> FineTuning<C> {
117 FineTuning::new(self)
118 }
119
120 pub fn embeddings(&self) -> Embeddings<C> {
122 Embeddings::new(self)
123 }
124
125 pub fn audio(&self) -> Audio<C> {
127 Audio::new(self)
128 }
129
130 pub fn assistants(&self) -> Assistants<C> {
132 Assistants::new(self)
133 }
134
135 pub fn threads(&self) -> Threads<C> {
137 Threads::new(self)
138 }
139
140 pub fn vector_stores(&self) -> VectorStores<C> {
142 VectorStores::new(self)
143 }
144
145 pub fn batches(&self) -> Batches<C> {
147 Batches::new(self)
148 }
149
150 pub fn audit_logs(&self) -> AuditLogs<C> {
152 AuditLogs::new(self)
153 }
154
155 pub fn invites(&self) -> Invites<C> {
157 Invites::new(self)
158 }
159
160 pub fn users(&self) -> Users<C> {
162 Users::new(self)
163 }
164
165 pub fn projects(&self) -> Projects<C> {
167 Projects::new(self)
168 }
169
170 pub fn responses(&self) -> Responses<C> {
172 Responses::new(self)
173 }
174
175 pub fn config(&self) -> &C {
176 &self.config
177 }
178
179 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
181 where
182 O: DeserializeOwned,
183 {
184 let request_maker = || async {
185 Ok(self
186 .http_client
187 .get(self.config.url(path))
188 .query(&self.config.query())
189 .headers(self.config.headers())
190 .build()?)
191 };
192
193 self.execute(request_maker).await
194 }
195
196 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
198 where
199 O: DeserializeOwned,
200 Q: Serialize + ?Sized,
201 {
202 let request_maker = || async {
203 Ok(self
204 .http_client
205 .get(self.config.url(path))
206 .query(&self.config.query())
207 .query(query)
208 .headers(self.config.headers())
209 .build()?)
210 };
211
212 self.execute(request_maker).await
213 }
214
215 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
217 where
218 O: DeserializeOwned,
219 {
220 let request_maker = || async {
221 Ok(self
222 .http_client
223 .delete(self.config.url(path))
224 .query(&self.config.query())
225 .headers(self.config.headers())
226 .build()?)
227 };
228
229 self.execute(request_maker).await
230 }
231
232 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
234 let request_maker = || async {
235 Ok(self
236 .http_client
237 .get(self.config.url(path))
238 .query(&self.config.query())
239 .headers(self.config.headers())
240 .build()?)
241 };
242
243 self.execute_raw(request_maker).await
244 }
245
246 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
248 where
249 I: Serialize,
250 {
251 let request_maker = || async {
252 Ok(self
253 .http_client
254 .post(self.config.url(path))
255 .query(&self.config.query())
256 .headers(self.config.headers())
257 .json(&request)
258 .build()?)
259 };
260
261 self.execute_raw(request_maker).await
262 }
263
264 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
266 where
267 I: Serialize,
268 O: DeserializeOwned,
269 {
270 let request_maker = || async {
271 Ok(self
272 .http_client
273 .post(self.config.url(path))
274 .query(&self.config.query())
275 .headers(self.config.headers())
276 .json(&request)
277 .build()?)
278 };
279
280 self.execute(request_maker).await
281 }
282
283 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
285 where
286 Form: AsyncTryFrom<F, Error = OpenAIError>,
287 F: Clone,
288 {
289 let request_maker = || async {
290 Ok(self
291 .http_client
292 .post(self.config.url(path))
293 .query(&self.config.query())
294 .headers(self.config.headers())
295 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
296 .build()?)
297 };
298
299 self.execute_raw(request_maker).await
300 }
301
302 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
304 where
305 O: DeserializeOwned,
306 Form: AsyncTryFrom<F, Error = OpenAIError>,
307 F: Clone,
308 {
309 let request_maker = || async {
310 Ok(self
311 .http_client
312 .post(self.config.url(path))
313 .query(&self.config.query())
314 .headers(self.config.headers())
315 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
316 .build()?)
317 };
318
319 self.execute(request_maker).await
320 }
321
322 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
328 where
329 M: Fn() -> Fut,
330 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
331 {
332 let client = self.http_client.clone();
333
334 backoff::future::retry(self.backoff.clone(), || async {
335 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
336 let response = client
337 .execute(request)
338 .await
339 .map_err(OpenAIError::Reqwest)
340 .map_err(backoff::Error::Permanent)?;
341
342 let status = response.status();
343 let bytes = response
344 .bytes()
345 .await
346 .map_err(OpenAIError::Reqwest)
347 .map_err(backoff::Error::Permanent)?;
348
349 if status.is_server_error() {
350 let message: String = String::from_utf8_lossy(&bytes).into_owned();
352 tracing::warn!("Server error: {status} - {message}");
353 return Err(backoff::Error::Transient {
354 err: OpenAIError::ApiError(ApiError {
355 message,
356 r#type: None,
357 param: None,
358 code: None,
359 }),
360 retry_after: None,
361 });
362 }
363
364 if !status.is_success() {
366 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
367 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
368 .map_err(backoff::Error::Permanent)?;
369
370 if status.as_u16() == 429
371 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
374 {
375 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
377 return Err(backoff::Error::Transient {
378 err: OpenAIError::ApiError(wrapped_error.error),
379 retry_after: None,
380 });
381 } else {
382 return Err(backoff::Error::Permanent(OpenAIError::ApiError(
383 wrapped_error.error,
384 )));
385 }
386 }
387
388 Ok(bytes)
389 })
390 .await
391 }
392
393 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
399 where
400 O: DeserializeOwned,
401 M: Fn() -> Fut,
402 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
403 {
404 let bytes = self.execute_raw(request_maker).await?;
405
406 let response: O = serde_json::from_slice(bytes.as_ref())
407 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
408
409 Ok(response)
410 }
411
412 pub(crate) async fn post_stream<I, O>(
414 &self,
415 path: &str,
416 request: I,
417 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
418 where
419 I: Serialize,
420 O: DeserializeOwned + std::marker::Send + 'static,
421 {
422 let event_source = self
423 .http_client
424 .post(self.config.url(path))
425 .query(&self.config.query())
426 .headers(self.config.headers())
427 .json(&request)
428 .eventsource()
429 .unwrap();
430
431 stream(event_source).await
432 }
433
434 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
435 &self,
436 path: &str,
437 request: I,
438 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
439 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
440 where
441 I: Serialize,
442 O: DeserializeOwned + std::marker::Send + 'static,
443 {
444 let event_source = self
445 .http_client
446 .post(self.config.url(path))
447 .query(&self.config.query())
448 .headers(self.config.headers())
449 .json(&request)
450 .eventsource()
451 .unwrap();
452
453 stream_mapped_raw_events(event_source, event_mapper).await
454 }
455
456 pub(crate) async fn _get_stream<Q, O>(
458 &self,
459 path: &str,
460 query: &Q,
461 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
462 where
463 Q: Serialize + ?Sized,
464 O: DeserializeOwned + std::marker::Send + 'static,
465 {
466 let event_source = self
467 .http_client
468 .get(self.config.url(path))
469 .query(query)
470 .query(&self.config.query())
471 .headers(self.config.headers())
472 .eventsource()
473 .unwrap();
474
475 stream(event_source).await
476 }
477}
478
479async fn handle_eventsource_error(e: Error) -> Result<(), OpenAIError> {
480 let error_text = e.to_string();
481 if let Error::InvalidStatusCode(code, response) = e {
482 if code.as_u16() == 401 {
483 return Err(OpenAIError::ApiError(ApiError {
484 message: "Unauthorized".to_string(),
485 r#type: None,
486 param: None,
487 code: None,
488 }));
489 }
490
491 if code.as_u16() == 429 {
492 return Err(OpenAIError::ApiError(ApiError {
493 message: "Rate limited by provider".to_string(),
494 r#type: None,
495 param: None,
496 code: None,
497 }));
498 }
499
500 if code.as_u16() == 408 {
501 return Err(OpenAIError::ApiError(ApiError {
502 message: "Request to provider timed out".to_string(),
503 r#type: None,
504 param: None,
505 code: None,
506 }));
507 }
508
509 if let Ok(text) = response.text().await {
510 if code.as_u16() == 400 {
511 let custom_error = serde_json::from_str::<CustomError>(&text);
512 if let Ok(error) = custom_error {
513 return Err(OpenAIError::ApiError(ApiError {
514 message: error.error,
515 r#type: None,
516 param: None,
517 code: None,
518 }));
519 }
520 }
521
522 let api_error = serde_json::from_str::<WrappedError>(&text);
523 if let Ok(e) = api_error {
524 return Err(OpenAIError::ApiError(e.error));
525 }
526 }
527 }
528
529 Err(OpenAIError::StreamError(error_text))
530}
531
532pub(crate) async fn stream<O>(
535 mut event_source: EventSource,
536) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
537where
538 O: DeserializeOwned + std::marker::Send + 'static,
539{
540 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
541
542 tokio::spawn(async move {
543 while let Some(ev) = event_source.next().await {
544 match ev {
545 Err(e) => {
546 if let Err(e) = handle_eventsource_error(e).await {
547 if let Err(_e) = tx.send(Err(e)) {
548 break;
550 }
551 }
552 }
553 Ok(event) => match event {
554 Event::Message(message) => {
555 if message.data == "[DONE]" {
556 break;
557 }
558
559 let response = match serde_json::from_str::<O>(&message.data) {
560 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
561 Ok(output) => Ok(output),
562 };
563
564 if let Err(_e) = tx.send(response) {
565 break;
567 }
568 }
569 Event::Open => continue,
570 },
571 }
572 }
573
574 event_source.close();
575 });
576
577 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
578}
579
580pub(crate) async fn stream_mapped_raw_events<O>(
581 mut event_source: EventSource,
582 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
583) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
584where
585 O: DeserializeOwned + std::marker::Send + 'static,
586{
587 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
588
589 tokio::spawn(async move {
590 while let Some(ev) = event_source.next().await {
591 match ev {
592 Err(e) => {
593 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
594 break;
596 }
597 }
598 Ok(event) => match event {
599 Event::Message(message) => {
600 let mut done = false;
601
602 if message.data == "[DONE]" {
603 done = true;
604 }
605
606 let response = event_mapper(message);
607
608 if let Err(_e) = tx.send(response) {
609 break;
611 }
612
613 if done {
614 break;
615 }
616 }
617 Event::Open => continue,
618 },
619 }
620 }
621
622 event_source.close();
623 });
624
625 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
626}