1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 chatkit::Chatkit,
11 config::{Config, OpenAIConfig},
12 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
13 file::Files,
14 image::Images,
15 moderation::Moderations,
16 traits::AsyncTryFrom,
17 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Containers, Conversations,
18 Embeddings, Evals, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users,
19 VectorStores, Videos,
20};
21
22#[derive(Debug, Clone, Default)]
23pub struct Client<C: Config> {
26 http_client: reqwest::Client,
27 config: C,
28 backoff: backoff::ExponentialBackoff,
29}
30
31impl Client<OpenAIConfig> {
32 pub fn new() -> Self {
34 Self::default()
35 }
36}
37
38impl<C: Config> Client<C> {
39 pub fn build(
41 http_client: reqwest::Client,
42 config: C,
43 backoff: backoff::ExponentialBackoff,
44 ) -> Self {
45 Self {
46 http_client,
47 config,
48 backoff,
49 }
50 }
51
52 pub fn with_config(config: C) -> Self {
54 Self {
55 http_client: reqwest::Client::new(),
56 config,
57 backoff: Default::default(),
58 }
59 }
60
61 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
65 self.http_client = http_client;
66 self
67 }
68
69 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
71 self.backoff = backoff;
72 self
73 }
74
75 pub fn models(&self) -> Models<'_, C> {
79 Models::new(self)
80 }
81
82 pub fn completions(&self) -> Completions<'_, C> {
84 Completions::new(self)
85 }
86
87 pub fn chat(&self) -> Chat<'_, C> {
89 Chat::new(self)
90 }
91
92 pub fn images(&self) -> Images<'_, C> {
94 Images::new(self)
95 }
96
97 pub fn moderations(&self) -> Moderations<'_, C> {
99 Moderations::new(self)
100 }
101
102 pub fn files(&self) -> Files<'_, C> {
104 Files::new(self)
105 }
106
107 pub fn uploads(&self) -> Uploads<'_, C> {
109 Uploads::new(self)
110 }
111
112 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
114 FineTuning::new(self)
115 }
116
117 pub fn embeddings(&self) -> Embeddings<'_, C> {
119 Embeddings::new(self)
120 }
121
122 pub fn audio(&self) -> Audio<'_, C> {
124 Audio::new(self)
125 }
126
127 pub fn videos(&self) -> Videos<'_, C> {
129 Videos::new(self)
130 }
131
132 pub fn assistants(&self) -> Assistants<'_, C> {
134 Assistants::new(self)
135 }
136
137 pub fn threads(&self) -> Threads<'_, C> {
139 Threads::new(self)
140 }
141
142 pub fn vector_stores(&self) -> VectorStores<'_, C> {
144 VectorStores::new(self)
145 }
146
147 pub fn batches(&self) -> Batches<'_, C> {
149 Batches::new(self)
150 }
151
152 pub fn audit_logs(&self) -> AuditLogs<'_, C> {
154 AuditLogs::new(self)
155 }
156
157 pub fn invites(&self) -> Invites<'_, C> {
159 Invites::new(self)
160 }
161
162 pub fn users(&self) -> Users<'_, C> {
164 Users::new(self)
165 }
166
167 pub fn projects(&self) -> Projects<'_, C> {
169 Projects::new(self)
170 }
171
172 pub fn responses(&self) -> Responses<'_, C> {
174 Responses::new(self)
175 }
176
177 pub fn conversations(&self) -> Conversations<'_, C> {
179 Conversations::new(self)
180 }
181
182 pub fn containers(&self) -> Containers<'_, C> {
184 Containers::new(self)
185 }
186
187 pub fn evals(&self) -> Evals<'_, C> {
189 Evals::new(self)
190 }
191
192 pub fn chatkit(&self) -> Chatkit<'_, C> {
193 Chatkit::new(self)
194 }
195
196 pub fn config(&self) -> &C {
197 &self.config
198 }
199
200 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
202 where
203 O: DeserializeOwned,
204 {
205 let request_maker = || async {
206 Ok(self
207 .http_client
208 .get(self.config.url(path))
209 .query(&self.config.query())
210 .headers(self.config.headers())
211 .build()?)
212 };
213
214 self.execute(request_maker).await
215 }
216
217 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
219 where
220 O: DeserializeOwned,
221 Q: Serialize + ?Sized,
222 {
223 let request_maker = || async {
224 Ok(self
225 .http_client
226 .get(self.config.url(path))
227 .query(&self.config.query())
228 .query(query)
229 .headers(self.config.headers())
230 .build()?)
231 };
232
233 self.execute(request_maker).await
234 }
235
236 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
238 where
239 O: DeserializeOwned,
240 {
241 let request_maker = || async {
242 Ok(self
243 .http_client
244 .delete(self.config.url(path))
245 .query(&self.config.query())
246 .headers(self.config.headers())
247 .build()?)
248 };
249
250 self.execute(request_maker).await
251 }
252
253 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
255 let request_maker = || async {
256 Ok(self
257 .http_client
258 .get(self.config.url(path))
259 .query(&self.config.query())
260 .headers(self.config.headers())
261 .build()?)
262 };
263
264 self.execute_raw(request_maker).await
265 }
266
267 pub(crate) async fn get_raw_with_query<Q>(
268 &self,
269 path: &str,
270 query: &Q,
271 ) -> Result<Bytes, OpenAIError>
272 where
273 Q: Serialize + ?Sized,
274 {
275 let request_maker = || async {
276 Ok(self
277 .http_client
278 .get(self.config.url(path))
279 .query(&self.config.query())
280 .query(query)
281 .headers(self.config.headers())
282 .build()?)
283 };
284
285 self.execute_raw(request_maker).await
286 }
287
288 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
290 where
291 I: Serialize,
292 {
293 let request_maker = || async {
294 Ok(self
295 .http_client
296 .post(self.config.url(path))
297 .query(&self.config.query())
298 .headers(self.config.headers())
299 .json(&request)
300 .build()?)
301 };
302
303 self.execute_raw(request_maker).await
304 }
305
306 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
308 where
309 I: Serialize,
310 O: DeserializeOwned,
311 {
312 let request_maker = || async {
313 Ok(self
314 .http_client
315 .post(self.config.url(path))
316 .query(&self.config.query())
317 .headers(self.config.headers())
318 .json(&request)
319 .build()?)
320 };
321
322 self.execute(request_maker).await
323 }
324
325 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
327 where
328 Form: AsyncTryFrom<F, Error = OpenAIError>,
329 F: Clone,
330 {
331 let request_maker = || async {
332 Ok(self
333 .http_client
334 .post(self.config.url(path))
335 .query(&self.config.query())
336 .headers(self.config.headers())
337 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
338 .build()?)
339 };
340
341 self.execute_raw(request_maker).await
342 }
343
344 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
346 where
347 O: DeserializeOwned,
348 Form: AsyncTryFrom<F, Error = OpenAIError>,
349 F: Clone,
350 {
351 let request_maker = || async {
352 Ok(self
353 .http_client
354 .post(self.config.url(path))
355 .query(&self.config.query())
356 .headers(self.config.headers())
357 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
358 .build()?)
359 };
360
361 self.execute(request_maker).await
362 }
363
364 pub(crate) async fn post_form_stream<O, F>(
365 &self,
366 path: &str,
367 form: F,
368 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
369 where
370 F: Clone,
371 Form: AsyncTryFrom<F, Error = OpenAIError>,
372 O: DeserializeOwned + std::marker::Send + 'static,
373 {
374 let response = self
377 .http_client
378 .post(self.config.url(path))
379 .query(&self.config.query())
380 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
381 .headers(self.config.headers())
382 .send()
383 .await
384 .map_err(OpenAIError::Reqwest)?;
385
386 if !response.status().is_success() {
388 return Err(read_response(response).await.unwrap_err());
389 }
390
391 let stream = response
393 .bytes_stream()
394 .map(|result| result.map_err(std::io::Error::other));
395 let event_stream = eventsource_stream::EventStream::new(stream);
396
397 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
399
400 tokio::spawn(async move {
401 use futures::StreamExt;
402 let mut event_stream = std::pin::pin!(event_stream);
403
404 while let Some(event_result) = event_stream.next().await {
405 match event_result {
406 Err(e) => {
407 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
408 StreamError::EventStream(e.to_string()),
409 )))) {
410 break;
411 }
412 }
413 Ok(event) => {
414 if event.data == "[DONE]" {
416 break;
417 }
418
419 let response = match serde_json::from_str::<O>(&event.data) {
420 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
421 Ok(output) => Ok(output),
422 };
423
424 if let Err(_e) = tx.send(response) {
425 break;
426 }
427 }
428 }
429 }
430 });
431
432 Ok(Box::pin(
433 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
434 ))
435 }
436
437 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
443 where
444 M: Fn() -> Fut,
445 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
446 {
447 let client = self.http_client.clone();
448
449 backoff::future::retry(self.backoff.clone(), || async {
450 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
451 let response = client
452 .execute(request)
453 .await
454 .map_err(OpenAIError::Reqwest)
455 .map_err(backoff::Error::Permanent)?;
456
457 let status = response.status();
458
459 match read_response(response).await {
460 Ok(bytes) => Ok(bytes),
461 Err(e) => {
462 match e {
463 OpenAIError::ApiError(api_error) => {
464 if status.is_server_error() {
465 Err(backoff::Error::Transient {
466 err: OpenAIError::ApiError(api_error),
467 retry_after: None,
468 })
469 } else if status.as_u16() == 429
470 && api_error.r#type != Some("insufficient_quota".to_string())
471 {
472 tracing::warn!("Rate limited: {}", api_error.message);
474 Err(backoff::Error::Transient {
475 err: OpenAIError::ApiError(api_error),
476 retry_after: None,
477 })
478 } else {
479 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
480 }
481 }
482 _ => Err(backoff::Error::Permanent(e)),
483 }
484 }
485 }
486 })
487 .await
488 }
489
490 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
496 where
497 O: DeserializeOwned,
498 M: Fn() -> Fut,
499 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
500 {
501 let bytes = self.execute_raw(request_maker).await?;
502
503 let response: O = serde_json::from_slice(bytes.as_ref())
504 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
505
506 Ok(response)
507 }
508
509 pub(crate) async fn post_stream<I, O>(
511 &self,
512 path: &str,
513 request: I,
514 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
515 where
516 I: Serialize,
517 O: DeserializeOwned + std::marker::Send + 'static,
518 {
519 let event_source = self
520 .http_client
521 .post(self.config.url(path))
522 .query(&self.config.query())
523 .headers(self.config.headers())
524 .json(&request)
525 .eventsource()
526 .unwrap();
527
528 stream(event_source).await
529 }
530
531 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
532 &self,
533 path: &str,
534 request: I,
535 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
536 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
537 where
538 I: Serialize,
539 O: DeserializeOwned + std::marker::Send + 'static,
540 {
541 let event_source = self
542 .http_client
543 .post(self.config.url(path))
544 .query(&self.config.query())
545 .headers(self.config.headers())
546 .json(&request)
547 .eventsource()
548 .unwrap();
549
550 stream_mapped_raw_events(event_source, event_mapper).await
551 }
552
553 pub(crate) async fn _get_stream<Q, O>(
555 &self,
556 path: &str,
557 query: &Q,
558 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
559 where
560 Q: Serialize + ?Sized,
561 O: DeserializeOwned + std::marker::Send + 'static,
562 {
563 let event_source = self
564 .http_client
565 .get(self.config.url(path))
566 .query(query)
567 .query(&self.config.query())
568 .headers(self.config.headers())
569 .eventsource()
570 .unwrap();
571
572 stream(event_source).await
573 }
574}
575
576async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
577 let status = response.status();
578 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
579
580 if status.is_server_error() {
581 let message: String = String::from_utf8_lossy(&bytes).into_owned();
583 tracing::warn!("Server error: {status} - {message}");
584 return Err(OpenAIError::ApiError(ApiError {
585 message,
586 r#type: None,
587 param: None,
588 code: None,
589 }));
590 }
591
592 if !status.is_success() {
594 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
595 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
596
597 return Err(OpenAIError::ApiError(wrapped_error.error));
598 }
599
600 Ok(bytes)
601}
602
603async fn map_stream_error(value: EventSourceError) -> OpenAIError {
604 match value {
605 EventSourceError::InvalidStatusCode(status_code, response) => {
606 read_response(response).await.expect_err(&format!(
607 "Unreachable because read_response returns err when status_code {status_code} is invalid"
608 ))
609 }
610 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
611 }
612}
613
614pub(crate) async fn stream<O>(
617 mut event_source: EventSource,
618) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
619where
620 O: DeserializeOwned + std::marker::Send + 'static,
621{
622 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
623
624 tokio::spawn(async move {
625 while let Some(ev) = event_source.next().await {
626 match ev {
627 Err(e) => {
628 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
629 break;
631 }
632 }
633 Ok(event) => match event {
634 Event::Message(message) => {
635 if message.data == "[DONE]" {
636 break;
637 }
638
639 let response = match serde_json::from_str::<O>(&message.data) {
640 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
641 Ok(output) => Ok(output),
642 };
643
644 if let Err(_e) = tx.send(response) {
645 break;
647 }
648 }
649 Event::Open => continue,
650 },
651 }
652 }
653
654 event_source.close();
655 });
656
657 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
658}
659
660pub(crate) async fn stream_mapped_raw_events<O>(
661 mut event_source: EventSource,
662 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
663) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
664where
665 O: DeserializeOwned + std::marker::Send + 'static,
666{
667 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
668
669 tokio::spawn(async move {
670 while let Some(ev) = event_source.next().await {
671 match ev {
672 Err(e) => {
673 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
674 break;
676 }
677 }
678 Ok(event) => match event {
679 Event::Message(message) => {
680 let mut done = false;
681
682 if message.data == "[DONE]" {
683 done = true;
684 }
685
686 let response = event_mapper(message);
687
688 if let Err(_e) = tx.send(response) {
689 break;
691 }
692
693 if done {
694 break;
695 }
696 }
697 Event::Open => continue,
698 },
699 }
700 }
701
702 event_source.close();
703 });
704
705 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
706}