1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
6use serde::{de::DeserializeOwned, Serialize};
7
8use crate::{
9 config::{Config, OpenAIConfig},
10 error::{map_deserialization_error, OpenAIError, WrappedError},
11 file::Files,
12 image::Images,
13 moderation::Moderations,
14 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
15 Models, Projects, Threads, Users, VectorStores,
16};
17
18#[derive(Debug, Clone, Default)]
19pub struct Client<C: Config> {
22 http_client: reqwest::Client,
23 config: C,
24 backoff: backoff::ExponentialBackoff,
25}
26
27impl Client<OpenAIConfig> {
28 pub fn new() -> Self {
30 Self::default()
31 }
32}
33
34impl<C: Config> Client<C> {
35 pub fn build(
37 http_client: reqwest::Client,
38 config: C,
39 backoff: backoff::ExponentialBackoff,
40 ) -> Self {
41 Self {
42 http_client,
43 config,
44 backoff,
45 }
46 }
47
48 pub fn with_config(config: C) -> Self {
50 Self {
51 http_client: reqwest::Client::new(),
52 config,
53 backoff: Default::default(),
54 }
55 }
56
57 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
61 self.http_client = http_client;
62 self
63 }
64
65 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
67 self.backoff = backoff;
68 self
69 }
70
71 pub fn models(&self) -> Models<C> {
75 Models::new(self)
76 }
77
78 pub fn completions(&self) -> Completions<C> {
80 Completions::new(self)
81 }
82
83 pub fn chat(&self) -> Chat<C> {
85 Chat::new(self)
86 }
87
88 pub fn images(&self) -> Images<C> {
90 Images::new(self)
91 }
92
93 pub fn moderations(&self) -> Moderations<C> {
95 Moderations::new(self)
96 }
97
98 pub fn files(&self) -> Files<C> {
100 Files::new(self)
101 }
102
103 pub fn fine_tuning(&self) -> FineTuning<C> {
105 FineTuning::new(self)
106 }
107
108 pub fn embeddings(&self) -> Embeddings<C> {
110 Embeddings::new(self)
111 }
112
113 pub fn audio(&self) -> Audio<C> {
115 Audio::new(self)
116 }
117
118 pub fn assistants(&self) -> Assistants<C> {
120 Assistants::new(self)
121 }
122
123 pub fn threads(&self) -> Threads<C> {
125 Threads::new(self)
126 }
127
128 pub fn vector_stores(&self) -> VectorStores<C> {
130 VectorStores::new(self)
131 }
132
133 pub fn batches(&self) -> Batches<C> {
135 Batches::new(self)
136 }
137
138 pub fn audit_logs(&self) -> AuditLogs<C> {
140 AuditLogs::new(self)
141 }
142
143 pub fn invites(&self) -> Invites<C> {
145 Invites::new(self)
146 }
147
148 pub fn users(&self) -> Users<C> {
150 Users::new(self)
151 }
152
153 pub fn projects(&self) -> Projects<C> {
155 Projects::new(self)
156 }
157
158 pub fn config(&self) -> &C {
159 &self.config
160 }
161
162 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
164 where
165 O: DeserializeOwned,
166 {
167 let request_maker = || async {
168 Ok(self
169 .http_client
170 .get(self.config.url(path))
171 .query(&self.config.query())
172 .headers(self.config.headers())
173 .build()?)
174 };
175
176 self.execute(request_maker).await
177 }
178
179 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
181 where
182 O: DeserializeOwned,
183 Q: Serialize + ?Sized,
184 {
185 let request_maker = || async {
186 Ok(self
187 .http_client
188 .get(self.config.url(path))
189 .query(&self.config.query())
190 .query(query)
191 .headers(self.config.headers())
192 .build()?)
193 };
194
195 self.execute(request_maker).await
196 }
197
198 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
200 where
201 O: DeserializeOwned,
202 {
203 let request_maker = || async {
204 Ok(self
205 .http_client
206 .delete(self.config.url(path))
207 .query(&self.config.query())
208 .headers(self.config.headers())
209 .build()?)
210 };
211
212 self.execute(request_maker).await
213 }
214
215 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
217 let request_maker = || async {
218 Ok(self
219 .http_client
220 .get(self.config.url(path))
221 .query(&self.config.query())
222 .headers(self.config.headers())
223 .build()?)
224 };
225
226 self.execute_raw(request_maker).await
227 }
228
229 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
231 where
232 I: Serialize,
233 {
234 let request_maker = || async {
235 Ok(self
236 .http_client
237 .post(self.config.url(path))
238 .query(&self.config.query())
239 .headers(self.config.headers())
240 .json(&request)
241 .build()?)
242 };
243
244 self.execute_raw(request_maker).await
245 }
246
247 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
249 where
250 I: Serialize,
251 O: DeserializeOwned,
252 {
253 let request_maker = || async {
254 Ok(self
255 .http_client
256 .post(self.config.url(path))
257 .query(&self.config.query())
258 .headers(self.config.headers())
259 .json(&request)
260 .build()?)
261 };
262
263 self.execute(request_maker).await
264 }
265
266 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
268 where
269 reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
270 F: Clone,
271 {
272 let request_maker = || async {
273 Ok(self
274 .http_client
275 .post(self.config.url(path))
276 .query(&self.config.query())
277 .headers(self.config.headers())
278 .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
279 .build()?)
280 };
281
282 self.execute_raw(request_maker).await
283 }
284
285 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
287 where
288 O: DeserializeOwned,
289 reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
290 F: Clone,
291 {
292 let request_maker = || async {
293 Ok(self
294 .http_client
295 .post(self.config.url(path))
296 .query(&self.config.query())
297 .headers(self.config.headers())
298 .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
299 .build()?)
300 };
301
302 self.execute(request_maker).await
303 }
304
305 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
311 where
312 M: Fn() -> Fut,
313 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
314 {
315 let client = self.http_client.clone();
316
317 backoff::future::retry(self.backoff.clone(), || async {
318 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
319 let response = client
320 .execute(request)
321 .await
322 .map_err(OpenAIError::Reqwest)
323 .map_err(backoff::Error::Permanent)?;
324
325 let status = response.status();
326 let bytes = response
327 .bytes()
328 .await
329 .map_err(OpenAIError::Reqwest)
330 .map_err(backoff::Error::Permanent)?;
331
332 if !status.is_success() {
334 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
335 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
336 .map_err(backoff::Error::Permanent)?;
337
338 if status.as_u16() == 429
339 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
342 {
343 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
345 return Err(backoff::Error::Transient {
346 err: OpenAIError::ApiError(wrapped_error.error),
347 retry_after: None,
348 });
349 } else {
350 return Err(backoff::Error::Permanent(OpenAIError::ApiError(
351 wrapped_error.error,
352 )));
353 }
354 }
355
356 Ok(bytes)
357 })
358 .await
359 }
360
361 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
367 where
368 O: DeserializeOwned,
369 M: Fn() -> Fut,
370 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
371 {
372 let bytes = self.execute_raw(request_maker).await?;
373
374 let response: O = serde_json::from_slice(bytes.as_ref())
375 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
376
377 Ok(response)
378 }
379
380 pub(crate) async fn post_stream<I, O>(
382 &self,
383 path: &str,
384 request: I,
385 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
386 where
387 I: Serialize,
388 O: DeserializeOwned + std::marker::Send + 'static,
389 {
390 let event_source = self
391 .http_client
392 .post(self.config.url(path))
393 .query(&self.config.query())
394 .headers(self.config.headers())
395 .json(&request)
396 .eventsource()
397 .unwrap();
398
399 stream(event_source).await
400 }
401
402 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
403 &self,
404 path: &str,
405 request: I,
406 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
407 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
408 where
409 I: Serialize,
410 O: DeserializeOwned + std::marker::Send + 'static,
411 {
412 let event_source = self
413 .http_client
414 .post(self.config.url(path))
415 .query(&self.config.query())
416 .headers(self.config.headers())
417 .json(&request)
418 .eventsource()
419 .unwrap();
420
421 stream_mapped_raw_events(event_source, event_mapper).await
422 }
423
424 pub(crate) async fn _get_stream<Q, O>(
426 &self,
427 path: &str,
428 query: &Q,
429 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
430 where
431 Q: Serialize + ?Sized,
432 O: DeserializeOwned + std::marker::Send + 'static,
433 {
434 let event_source = self
435 .http_client
436 .get(self.config.url(path))
437 .query(query)
438 .query(&self.config.query())
439 .headers(self.config.headers())
440 .eventsource()
441 .unwrap();
442
443 stream(event_source).await
444 }
445}
446
447pub(crate) async fn stream<O>(
450 mut event_source: EventSource,
451) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
452where
453 O: DeserializeOwned + std::marker::Send + 'static,
454{
455 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
456
457 tokio::spawn(async move {
458 while let Some(ev) = event_source.next().await {
459 match ev {
460 Err(e) => {
461 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
462 break;
464 }
465 }
466 Ok(event) => match event {
467 Event::Message(message) => {
468 if message.data == "[DONE]" {
469 break;
470 }
471
472 let response = match serde_json::from_str::<O>(&message.data) {
473 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
474 Ok(output) => Ok(output),
475 };
476
477 if let Err(_e) = tx.send(response) {
478 break;
480 }
481 }
482 Event::Open => continue,
483 },
484 }
485 }
486
487 event_source.close();
488 });
489
490 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
491}
492
493pub(crate) async fn stream_mapped_raw_events<O>(
494 mut event_source: EventSource,
495 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
496) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
497where
498 O: DeserializeOwned + std::marker::Send + 'static,
499{
500 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
501
502 tokio::spawn(async move {
503 while let Some(ev) = event_source.next().await {
504 match ev {
505 Err(e) => {
506 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
507 break;
509 }
510 }
511 Ok(event) => match event {
512 Event::Message(message) => {
513 let mut done = false;
514
515 if message.data == "[DONE]" {
516 done = true;
517 }
518
519 let response = event_mapper(message);
520
521 if let Err(_e) = tx.send(response) {
522 break;
524 }
525
526 if done {
527 break;
528 }
529 }
530 Event::Open => continue,
531 },
532 }
533 }
534
535 event_source.close();
536 });
537
538 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
539}