1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12 file::Files,
13 image::Images,
14 moderation::Moderations,
15 traits::AsyncTryFrom,
16 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17 Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21pub struct Client<C: Config> {
24 http_client: reqwest::Client,
25 config: C,
26 backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30 pub fn new() -> Self {
32 Self::default()
33 }
34}
35
36impl<C: Config> Client<C> {
37 pub fn build(
39 http_client: reqwest::Client,
40 config: C,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn with_config(config: C) -> Self {
52 Self {
53 http_client: reqwest::Client::new(),
54 config,
55 backoff: Default::default(),
56 }
57 }
58
59 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63 self.http_client = http_client;
64 self
65 }
66
67 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69 self.backoff = backoff;
70 self
71 }
72
73 pub fn models(&self) -> Models<C> {
77 Models::new(self)
78 }
79
80 pub fn completions(&self) -> Completions<C> {
82 Completions::new(self)
83 }
84
85 pub fn chat(&self) -> Chat<C> {
87 Chat::new(self)
88 }
89
90 pub fn images(&self) -> Images<C> {
92 Images::new(self)
93 }
94
95 pub fn moderations(&self) -> Moderations<C> {
97 Moderations::new(self)
98 }
99
100 pub fn files(&self) -> Files<C> {
102 Files::new(self)
103 }
104
105 pub fn uploads(&self) -> Uploads<C> {
107 Uploads::new(self)
108 }
109
110 pub fn fine_tuning(&self) -> FineTuning<C> {
112 FineTuning::new(self)
113 }
114
115 pub fn embeddings(&self) -> Embeddings<C> {
117 Embeddings::new(self)
118 }
119
120 pub fn audio(&self) -> Audio<C> {
122 Audio::new(self)
123 }
124
125 pub fn assistants(&self) -> Assistants<C> {
127 Assistants::new(self)
128 }
129
130 pub fn threads(&self) -> Threads<C> {
132 Threads::new(self)
133 }
134
135 pub fn vector_stores(&self) -> VectorStores<C> {
137 VectorStores::new(self)
138 }
139
140 pub fn batches(&self) -> Batches<C> {
142 Batches::new(self)
143 }
144
145 pub fn audit_logs(&self) -> AuditLogs<C> {
147 AuditLogs::new(self)
148 }
149
150 pub fn invites(&self) -> Invites<C> {
152 Invites::new(self)
153 }
154
155 pub fn users(&self) -> Users<C> {
157 Users::new(self)
158 }
159
160 pub fn projects(&self) -> Projects<C> {
162 Projects::new(self)
163 }
164
165 pub fn responses(&self) -> Responses<C> {
167 Responses::new(self)
168 }
169
170 pub fn config(&self) -> &C {
171 &self.config
172 }
173
174 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
176 where
177 O: DeserializeOwned,
178 {
179 let request_maker = || async {
180 Ok(self
181 .http_client
182 .get(self.config.url(path))
183 .query(&self.config.query())
184 .headers(self.config.headers())
185 .build()?)
186 };
187
188 self.execute(request_maker).await
189 }
190
191 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
193 where
194 O: DeserializeOwned,
195 Q: Serialize + ?Sized,
196 {
197 let request_maker = || async {
198 Ok(self
199 .http_client
200 .get(self.config.url(path))
201 .query(&self.config.query())
202 .query(query)
203 .headers(self.config.headers())
204 .build()?)
205 };
206
207 self.execute(request_maker).await
208 }
209
210 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
212 where
213 O: DeserializeOwned,
214 {
215 let request_maker = || async {
216 Ok(self
217 .http_client
218 .delete(self.config.url(path))
219 .query(&self.config.query())
220 .headers(self.config.headers())
221 .build()?)
222 };
223
224 self.execute(request_maker).await
225 }
226
227 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
229 let request_maker = || async {
230 Ok(self
231 .http_client
232 .get(self.config.url(path))
233 .query(&self.config.query())
234 .headers(self.config.headers())
235 .build()?)
236 };
237
238 self.execute_raw(request_maker).await
239 }
240
241 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
243 where
244 I: Serialize,
245 {
246 let request_maker = || async {
247 Ok(self
248 .http_client
249 .post(self.config.url(path))
250 .query(&self.config.query())
251 .headers(self.config.headers())
252 .json(&request)
253 .build()?)
254 };
255
256 self.execute_raw(request_maker).await
257 }
258
259 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
261 where
262 I: Serialize,
263 O: DeserializeOwned,
264 {
265 let request_maker = || async {
266 Ok(self
267 .http_client
268 .post(self.config.url(path))
269 .query(&self.config.query())
270 .headers(self.config.headers())
271 .json(&request)
272 .build()?)
273 };
274
275 self.execute(request_maker).await
276 }
277
278 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
280 where
281 Form: AsyncTryFrom<F, Error = OpenAIError>,
282 F: Clone,
283 {
284 let request_maker = || async {
285 Ok(self
286 .http_client
287 .post(self.config.url(path))
288 .query(&self.config.query())
289 .headers(self.config.headers())
290 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
291 .build()?)
292 };
293
294 self.execute_raw(request_maker).await
295 }
296
297 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
299 where
300 O: DeserializeOwned,
301 Form: AsyncTryFrom<F, Error = OpenAIError>,
302 F: Clone,
303 {
304 let request_maker = || async {
305 Ok(self
306 .http_client
307 .post(self.config.url(path))
308 .query(&self.config.query())
309 .headers(self.config.headers())
310 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
311 .build()?)
312 };
313
314 self.execute(request_maker).await
315 }
316
317 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
323 where
324 M: Fn() -> Fut,
325 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
326 {
327 let client = self.http_client.clone();
328
329 backoff::future::retry(self.backoff.clone(), || async {
330 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
331 let response = client
332 .execute(request)
333 .await
334 .map_err(OpenAIError::Reqwest)
335 .map_err(backoff::Error::Permanent)?;
336
337 let status = response.status();
338
339 match read_response(response).await {
340 Ok(bytes) => Ok(bytes),
341 Err(e) => {
342 match e {
343 OpenAIError::ApiError(api_error) => {
344 if status.is_server_error() {
345 Err(backoff::Error::Transient {
346 err: OpenAIError::ApiError(api_error),
347 retry_after: None,
348 })
349 } else if status.as_u16() == 429
350 && api_error.r#type != Some("insufficient_quota".to_string())
351 {
352 tracing::warn!("Rate limited: {}", api_error.message);
354 Err(backoff::Error::Transient {
355 err: OpenAIError::ApiError(api_error),
356 retry_after: None,
357 })
358 } else {
359 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
360 }
361 }
362 _ => Err(backoff::Error::Permanent(e)),
363 }
364 }
365 }
366 })
367 .await
368 }
369
370 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
376 where
377 O: DeserializeOwned,
378 M: Fn() -> Fut,
379 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
380 {
381 let bytes = self.execute_raw(request_maker).await?;
382
383 let response: O = serde_json::from_slice(bytes.as_ref())
384 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
385
386 Ok(response)
387 }
388
389 pub(crate) async fn post_stream<I, O>(
391 &self,
392 path: &str,
393 request: I,
394 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
395 where
396 I: Serialize,
397 O: DeserializeOwned + std::marker::Send + 'static,
398 {
399 let event_source = self
400 .http_client
401 .post(self.config.url(path))
402 .query(&self.config.query())
403 .headers(self.config.headers())
404 .json(&request)
405 .eventsource()
406 .unwrap();
407
408 stream(event_source).await
409 }
410
411 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
412 &self,
413 path: &str,
414 request: I,
415 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
416 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
417 where
418 I: Serialize,
419 O: DeserializeOwned + std::marker::Send + 'static,
420 {
421 let event_source = self
422 .http_client
423 .post(self.config.url(path))
424 .query(&self.config.query())
425 .headers(self.config.headers())
426 .json(&request)
427 .eventsource()
428 .unwrap();
429
430 stream_mapped_raw_events(event_source, event_mapper).await
431 }
432
433 pub(crate) async fn _get_stream<Q, O>(
435 &self,
436 path: &str,
437 query: &Q,
438 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
439 where
440 Q: Serialize + ?Sized,
441 O: DeserializeOwned + std::marker::Send + 'static,
442 {
443 let event_source = self
444 .http_client
445 .get(self.config.url(path))
446 .query(query)
447 .query(&self.config.query())
448 .headers(self.config.headers())
449 .eventsource()
450 .unwrap();
451
452 stream(event_source).await
453 }
454}
455
456async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
457 let status = response.status();
458 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
459
460 if status.is_server_error() {
461 let message: String = String::from_utf8_lossy(&bytes).into_owned();
463 tracing::warn!("Server error: {status} - {message}");
464 return Err(OpenAIError::ApiError(ApiError {
465 message,
466 r#type: None,
467 param: None,
468 code: None,
469 }));
470 }
471
472 if !status.is_success() {
474 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
475 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
476
477 return Err(OpenAIError::ApiError(wrapped_error.error));
478 }
479
480 Ok(bytes)
481}
482
483async fn map_stream_error(value: EventSourceError) -> OpenAIError {
484 match value {
485 EventSourceError::InvalidStatusCode(status_code, response) => {
486 read_response(response).await.expect_err(&format!(
487 "Unreachable because read_response returns err when status_code {status_code} is invalid"
488 ))
489 }
490 _ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value)),
491 }
492}
493
494pub(crate) async fn stream<O>(
497 mut event_source: EventSource,
498) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
499where
500 O: DeserializeOwned + std::marker::Send + 'static,
501{
502 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
503
504 tokio::spawn(async move {
505 while let Some(ev) = event_source.next().await {
506 match ev {
507 Err(e) => {
508 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
509 break;
511 }
512 }
513 Ok(event) => match event {
514 Event::Message(message) => {
515 if message.data == "[DONE]" {
516 break;
517 }
518
519 let response = match serde_json::from_str::<O>(&message.data) {
520 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
521 Ok(output) => Ok(output),
522 };
523
524 if let Err(_e) = tx.send(response) {
525 break;
527 }
528 }
529 Event::Open => continue,
530 },
531 }
532 }
533
534 event_source.close();
535 });
536
537 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
538}
539
540pub(crate) async fn stream_mapped_raw_events<O>(
541 mut event_source: EventSource,
542 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
543) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
544where
545 O: DeserializeOwned + std::marker::Send + 'static,
546{
547 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
548
549 tokio::spawn(async move {
550 while let Some(ev) = event_source.next().await {
551 match ev {
552 Err(e) => {
553 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
554 break;
556 }
557 }
558 Ok(event) => match event {
559 Event::Message(message) => {
560 let mut done = false;
561
562 if message.data == "[DONE]" {
563 done = true;
564 }
565
566 let response = event_mapper(message);
567
568 if let Err(_e) = tx.send(response) {
569 break;
571 }
572
573 if done {
574 break;
575 }
576 }
577 Event::Open => continue,
578 },
579 }
580 }
581
582 event_source.close();
583 });
584
585 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
586}