1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
12 file::Files,
13 image::Images,
14 moderation::Moderations,
15 traits::AsyncTryFrom,
16 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17 Models, Projects, Threads, Uploads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21pub struct Client<C: Config> {
24 http_client: reqwest::Client,
25 config: C,
26 backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30 pub fn new() -> Self {
32 Self::default()
33 }
34}
35
36impl<C: Config> Client<C> {
37 pub fn build(
39 http_client: reqwest::Client,
40 config: C,
41 backoff: backoff::ExponentialBackoff,
42 ) -> Self {
43 Self {
44 http_client,
45 config,
46 backoff,
47 }
48 }
49
50 pub fn with_config(config: C) -> Self {
52 Self {
53 http_client: reqwest::Client::new(),
54 config,
55 backoff: Default::default(),
56 }
57 }
58
59 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63 self.http_client = http_client;
64 self
65 }
66
67 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69 self.backoff = backoff;
70 self
71 }
72
73 pub fn models(&self) -> Models<C> {
77 Models::new(self)
78 }
79
80 pub fn completions(&self) -> Completions<C> {
82 Completions::new(self)
83 }
84
85 pub fn chat(&self) -> Chat<C> {
87 Chat::new(self)
88 }
89
90 pub fn images(&self) -> Images<C> {
92 Images::new(self)
93 }
94
95 pub fn moderations(&self) -> Moderations<C> {
97 Moderations::new(self)
98 }
99
100 pub fn files(&self) -> Files<C> {
102 Files::new(self)
103 }
104
105 pub fn uploads(&self) -> Uploads<C> {
107 Uploads::new(self)
108 }
109
110 pub fn fine_tuning(&self) -> FineTuning<C> {
112 FineTuning::new(self)
113 }
114
115 pub fn embeddings(&self) -> Embeddings<C> {
117 Embeddings::new(self)
118 }
119
120 pub fn audio(&self) -> Audio<C> {
122 Audio::new(self)
123 }
124
125 pub fn assistants(&self) -> Assistants<C> {
127 Assistants::new(self)
128 }
129
130 pub fn threads(&self) -> Threads<C> {
132 Threads::new(self)
133 }
134
135 pub fn vector_stores(&self) -> VectorStores<C> {
137 VectorStores::new(self)
138 }
139
140 pub fn batches(&self) -> Batches<C> {
142 Batches::new(self)
143 }
144
145 pub fn audit_logs(&self) -> AuditLogs<C> {
147 AuditLogs::new(self)
148 }
149
150 pub fn invites(&self) -> Invites<C> {
152 Invites::new(self)
153 }
154
155 pub fn users(&self) -> Users<C> {
157 Users::new(self)
158 }
159
160 pub fn projects(&self) -> Projects<C> {
162 Projects::new(self)
163 }
164
165 pub fn config(&self) -> &C {
166 &self.config
167 }
168
169 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
171 where
172 O: DeserializeOwned,
173 {
174 let request_maker = || async {
175 Ok(self
176 .http_client
177 .get(self.config.url(path))
178 .query(&self.config.query())
179 .headers(self.config.headers())
180 .build()?)
181 };
182
183 self.execute(request_maker).await
184 }
185
186 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
188 where
189 O: DeserializeOwned,
190 Q: Serialize + ?Sized,
191 {
192 let request_maker = || async {
193 Ok(self
194 .http_client
195 .get(self.config.url(path))
196 .query(&self.config.query())
197 .query(query)
198 .headers(self.config.headers())
199 .build()?)
200 };
201
202 self.execute(request_maker).await
203 }
204
205 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
207 where
208 O: DeserializeOwned,
209 {
210 let request_maker = || async {
211 Ok(self
212 .http_client
213 .delete(self.config.url(path))
214 .query(&self.config.query())
215 .headers(self.config.headers())
216 .build()?)
217 };
218
219 self.execute(request_maker).await
220 }
221
222 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
224 let request_maker = || async {
225 Ok(self
226 .http_client
227 .get(self.config.url(path))
228 .query(&self.config.query())
229 .headers(self.config.headers())
230 .build()?)
231 };
232
233 self.execute_raw(request_maker).await
234 }
235
236 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
238 where
239 I: Serialize,
240 {
241 let request_maker = || async {
242 Ok(self
243 .http_client
244 .post(self.config.url(path))
245 .query(&self.config.query())
246 .headers(self.config.headers())
247 .json(&request)
248 .build()?)
249 };
250
251 self.execute_raw(request_maker).await
252 }
253
254 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
256 where
257 I: Serialize,
258 O: DeserializeOwned,
259 {
260 let request_maker = || async {
261 Ok(self
262 .http_client
263 .post(self.config.url(path))
264 .query(&self.config.query())
265 .headers(self.config.headers())
266 .json(&request)
267 .build()?)
268 };
269
270 self.execute(request_maker).await
271 }
272
273 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
275 where
276 Form: AsyncTryFrom<F, Error = OpenAIError>,
277 F: Clone,
278 {
279 let request_maker = || async {
280 Ok(self
281 .http_client
282 .post(self.config.url(path))
283 .query(&self.config.query())
284 .headers(self.config.headers())
285 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
286 .build()?)
287 };
288
289 self.execute_raw(request_maker).await
290 }
291
292 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
294 where
295 O: DeserializeOwned,
296 Form: AsyncTryFrom<F, Error = OpenAIError>,
297 F: Clone,
298 {
299 let request_maker = || async {
300 Ok(self
301 .http_client
302 .post(self.config.url(path))
303 .query(&self.config.query())
304 .headers(self.config.headers())
305 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
306 .build()?)
307 };
308
309 self.execute(request_maker).await
310 }
311
312 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
318 where
319 M: Fn() -> Fut,
320 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
321 {
322 let client = self.http_client.clone();
323
324 backoff::future::retry(self.backoff.clone(), || async {
325 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
326 let response = client
327 .execute(request)
328 .await
329 .map_err(OpenAIError::Reqwest)
330 .map_err(backoff::Error::Permanent)?;
331
332 let status = response.status();
333 let bytes = response
334 .bytes()
335 .await
336 .map_err(OpenAIError::Reqwest)
337 .map_err(backoff::Error::Permanent)?;
338
339 if status.is_server_error() {
340 let message: String = String::from_utf8_lossy(&bytes).into_owned();
342 tracing::warn!("Server error: {status} - {message}");
343 return Err(backoff::Error::Transient {
344 err: OpenAIError::ApiError(ApiError { message, r#type: None, param: None, code: None }),
345 retry_after: None,
346 });
347 }
348
349 if !status.is_success() {
351 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
352 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
353 .map_err(backoff::Error::Permanent)?;
354
355 if status.as_u16() == 429
356 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
359 {
360 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
362 return Err(backoff::Error::Transient {
363 err: OpenAIError::ApiError(wrapped_error.error),
364 retry_after: None,
365 });
366 } else {
367 return Err(backoff::Error::Permanent(OpenAIError::ApiError(
368 wrapped_error.error,
369 )));
370 }
371 }
372
373 Ok(bytes)
374 })
375 .await
376 }
377
378 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
384 where
385 O: DeserializeOwned,
386 M: Fn() -> Fut,
387 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
388 {
389 let bytes = self.execute_raw(request_maker).await?;
390
391 let response: O = serde_json::from_slice(bytes.as_ref())
392 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
393
394 Ok(response)
395 }
396
397 pub(crate) async fn post_stream<I, O>(
399 &self,
400 path: &str,
401 request: I,
402 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
403 where
404 I: Serialize,
405 O: DeserializeOwned + std::marker::Send + 'static,
406 {
407 let event_source = self
408 .http_client
409 .post(self.config.url(path))
410 .query(&self.config.query())
411 .headers(self.config.headers())
412 .json(&request)
413 .eventsource()
414 .unwrap();
415
416 stream(event_source).await
417 }
418
419 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
420 &self,
421 path: &str,
422 request: I,
423 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
424 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
425 where
426 I: Serialize,
427 O: DeserializeOwned + std::marker::Send + 'static,
428 {
429 let event_source = self
430 .http_client
431 .post(self.config.url(path))
432 .query(&self.config.query())
433 .headers(self.config.headers())
434 .json(&request)
435 .eventsource()
436 .unwrap();
437
438 stream_mapped_raw_events(event_source, event_mapper).await
439 }
440
441 pub(crate) async fn _get_stream<Q, O>(
443 &self,
444 path: &str,
445 query: &Q,
446 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
447 where
448 Q: Serialize + ?Sized,
449 O: DeserializeOwned + std::marker::Send + 'static,
450 {
451 let event_source = self
452 .http_client
453 .get(self.config.url(path))
454 .query(query)
455 .query(&self.config.query())
456 .headers(self.config.headers())
457 .eventsource()
458 .unwrap();
459
460 stream(event_source).await
461 }
462}
463
464pub(crate) async fn stream<O>(
467 mut event_source: EventSource,
468) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
469where
470 O: DeserializeOwned + std::marker::Send + 'static,
471{
472 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
473
474 tokio::spawn(async move {
475 while let Some(ev) = event_source.next().await {
476 match ev {
477 Err(e) => {
478 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
479 break;
481 }
482 }
483 Ok(event) => match event {
484 Event::Message(message) => {
485 if message.data == "[DONE]" {
486 break;
487 }
488
489 let response = match serde_json::from_str::<O>(&message.data) {
490 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
491 Ok(output) => Ok(output),
492 };
493
494 if let Err(_e) = tx.send(response) {
495 break;
497 }
498 }
499 Event::Open => continue,
500 },
501 }
502 }
503
504 event_source.close();
505 });
506
507 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
508}
509
510pub(crate) async fn stream_mapped_raw_events<O>(
511 mut event_source: EventSource,
512 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
513) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
514where
515 O: DeserializeOwned + std::marker::Send + 'static,
516{
517 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
518
519 tokio::spawn(async move {
520 while let Some(ev) = event_source.next().await {
521 match ev {
522 Err(e) => {
523 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
524 break;
526 }
527 }
528 Ok(event) => match event {
529 Event::Message(message) => {
530 let mut done = false;
531
532 if message.data == "[DONE]" {
533 done = true;
534 }
535
536 let response = event_mapper(message);
537
538 if let Err(_e) = tx.send(response) {
539 break;
541 }
542
543 if done {
544 break;
545 }
546 }
547 Event::Open => continue,
548 },
549 }
550 }
551
552 event_source.close();
553 });
554
555 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
556}