1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use future::Future;
8use futures::stream::Filter;
9use futures::{Stream, stream::StreamExt};
10use pin_project::pin_project;
11use reqwest::{Response, multipart::Form};
12use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
13use serde::{Serialize, de::DeserializeOwned};
14
15use crate::error::{ApiError, StreamError};
16use crate::{
17 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
18 Models, Projects, Responses, Threads, Uploads, Users, VectorStores, Videos,
19 config::{Config, OpenAIConfig},
20 error::{OpenAIError, WrappedError, map_deserialization_error},
21 file::Files,
22 image::Images,
23 moderation::Moderations,
24 traits::AsyncTryFrom,
25};
26
27#[derive(Debug, Clone)]
28pub struct Client<C: Config> {
31 http_client: reqwest::Client,
32 config: C,
33}
34
35impl Client<OpenAIConfig> {
36 pub fn new() -> Self {
38 Self {
39 http_client: reqwest::Client::new(),
40 config: OpenAIConfig::default(),
41 }
42 }
43}
44
45impl<C: Config> Client<C> {
46 pub fn build(http_client: reqwest::Client, config: C) -> Self {
48 Self {
49 http_client,
50 config,
51 }
52 }
53
54 pub fn with_config(config: C) -> Self {
56 Self {
57 http_client: reqwest::Client::new(),
58 config,
59 }
60 }
61
62 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
66 self.http_client = http_client;
67 self
68 }
69
70 pub fn models(&self) -> Models<'_, C> {
74 Models::new(self)
75 }
76
77 pub fn completions(&self) -> Completions<'_, C> {
79 Completions::new(self)
80 }
81
82 pub fn chat(&self) -> Chat<'_, C> {
84 Chat::new(self)
85 }
86
87 pub fn images(&self) -> Images<'_, C> {
89 Images::new(self)
90 }
91
92 pub fn moderations(&self) -> Moderations<'_, C> {
94 Moderations::new(self)
95 }
96
97 pub fn files(&self) -> Files<'_, C> {
99 Files::new(self)
100 }
101
102 pub fn uploads(&self) -> Uploads<'_, C> {
104 Uploads::new(self)
105 }
106
107 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
109 FineTuning::new(self)
110 }
111
112 pub fn embeddings(&self) -> Embeddings<'_, C> {
114 Embeddings::new(self)
115 }
116
117 pub fn audio(&self) -> Audio<'_, C> {
119 Audio::new(self)
120 }
121
122 pub fn videos(&self) -> Videos<'_, C> {
124 Videos::new(self)
125 }
126
127 pub fn assistants(&self) -> Assistants<'_, C> {
129 Assistants::new(self)
130 }
131
132 pub fn threads(&self) -> Threads<'_, C> {
134 Threads::new(self)
135 }
136
137 pub fn vector_stores(&self) -> VectorStores<'_, C> {
139 VectorStores::new(self)
140 }
141
142 pub fn batches(&self) -> Batches<'_, C> {
144 Batches::new(self)
145 }
146
147 pub fn audit_logs(&self) -> AuditLogs<'_, C> {
149 AuditLogs::new(self)
150 }
151
152 pub fn invites(&self) -> Invites<'_, C> {
154 Invites::new(self)
155 }
156
157 pub fn users(&self) -> Users<'_, C> {
159 Users::new(self)
160 }
161
162 pub fn projects(&self) -> Projects<'_, C> {
164 Projects::new(self)
165 }
166
167 pub fn responses(&self) -> Responses<'_, C> {
169 Responses::new(self)
170 }
171
172 pub fn config(&self) -> &C {
173 &self.config
174 }
175
176 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
178 where
179 O: DeserializeOwned,
180 {
181 self.execute(async {
182 Ok(self
183 .http_client
184 .get(self.config.url(path))
185 .query(&self.config.query())
186 .headers(self.config.headers())
187 .build()?)
188 })
189 .await
190 }
191
192 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
194 where
195 O: DeserializeOwned,
196 Q: Serialize + ?Sized,
197 {
198 self.execute(async {
199 Ok(self
200 .http_client
201 .get(self.config.url(path))
202 .query(&self.config.query())
203 .query(query)
204 .headers(self.config.headers())
205 .build()?)
206 })
207 .await
208 }
209
210 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
212 where
213 O: DeserializeOwned,
214 {
215 self.execute(async {
216 Ok(self
217 .http_client
218 .delete(self.config.url(path))
219 .query(&self.config.query())
220 .headers(self.config.headers())
221 .build()?)
222 })
223 .await
224 }
225
226 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
228 self.execute_raw(async {
229 Ok(self
230 .http_client
231 .get(self.config.url(path))
232 .query(&self.config.query())
233 .headers(self.config.headers())
234 .build()?)
235 })
236 .await
237 }
238
239 pub(crate) async fn get_raw_with_query<Q>(
240 &self,
241 path: &str,
242 query: &Q,
243 ) -> Result<Bytes, OpenAIError>
244 where
245 Q: Serialize + ?Sized,
246 {
247 self.execute_raw(async {
248 Ok(self
249 .http_client
250 .get(self.config.url(path))
251 .query(&self.config.query())
252 .query(query)
253 .headers(self.config.headers())
254 .build()?)
255 })
256 .await
257 }
258
259 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
261 where
262 I: Serialize,
263 {
264 self.execute_raw(async {
265 Ok(self
266 .http_client
267 .post(self.config.url(path))
268 .query(&self.config.query())
269 .headers(self.config.headers())
270 .json(&request)
271 .build()?)
272 })
273 .await
274 }
275
276 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
278 where
279 I: Serialize,
280 O: DeserializeOwned,
281 {
282 self.execute(async {
283 Ok(self
284 .http_client
285 .post(self.config.url(path))
286 .query(&self.config.query())
287 .headers(self.config.headers())
288 .json(&request)
289 .build()?)
290 })
291 .await
292 }
293
294 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
296 where
297 Form: AsyncTryFrom<F, Error = OpenAIError>,
298 {
299 self.execute_raw(async {
300 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
301 Ok(self
302 .http_client
303 .post(self.config.url(path))
304 .query(&self.config.query())
305 .headers(self.config.headers())
306 .multipart(form)
307 .build()?)
308 })
309 .await
310 }
311
312 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
314 where
315 O: DeserializeOwned,
316 Form: AsyncTryFrom<F, Error = OpenAIError>,
317 {
318 self.execute(async {
319 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
320 Ok(self
321 .http_client
322 .post(self.config.url(path))
323 .query(&self.config.query())
324 .headers(self.config.headers())
325 .multipart(form)
326 .build()?)
327 })
328 .await
329 }
330
331 async fn execute_raw(
333 &self,
334 request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
335 ) -> Result<Bytes, OpenAIError> {
336 let client = self.http_client.clone();
337 let request = request_future.await?;
338 let response = client
339 .execute(request)
340 .await
341 .map_err(OpenAIError::Reqwest)?;
342
343 let status = response.status();
344 match read_response(response).await {
345 Ok(bytes) => Ok(bytes),
346 Err(e) => match e {
347 OpenAIError::ApiError(api_error) => {
348 if status.as_u16() == 429
349 && api_error.r#type != Some("insufficient_quota".to_string())
350 {
351 tracing::warn!("Rate limited: {}", api_error.message);
353 }
354 Err(OpenAIError::ApiError(api_error))
355 }
356 _ => Err(e),
357 },
358 }
359 }
360
361 async fn execute<O>(
363 &self,
364 request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
365 ) -> Result<O, OpenAIError>
366 where
367 O: DeserializeOwned,
368 {
369 let bytes = self.execute_raw(request_future).await?;
370
371 let response: O = serde_json::from_slice(bytes.as_ref())
372 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
373
374 Ok(response)
375 }
376
377 pub(crate) async fn post_stream<I, O>(&self, path: &str, request: I) -> OpenAIEventStream<O>
379 where
380 I: Serialize,
381 O: DeserializeOwned + Send + 'static,
382 {
383 let event_source = self
384 .http_client
385 .post(self.config.url(path))
386 .query(&self.config.query())
387 .headers(self.config.headers())
388 .json(&request)
389 .eventsource()
390 .unwrap();
391
392 OpenAIEventStream::new(event_source)
393 }
394
395 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
396 &self,
397 path: &str,
398 request: I,
399 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
400 ) -> OpenAIEventStream<O>
401 where
402 I: Serialize,
403 O: DeserializeOwned + Send + 'static,
404 {
405 let event_source = self
406 .http_client
407 .post(self.config.url(path))
408 .query(&self.config.query())
409 .headers(self.config.headers())
410 .json(&request)
411 .eventsource()
412 .unwrap();
413
414 OpenAIEventStream::with_event_mapping(event_source, event_mapper)
415 }
416
417 pub(crate) async fn _get_stream<Q, O>(&self, path: &str, query: &Q) -> OpenAIEventStream<O>
419 where
420 Q: Serialize + ?Sized,
421 O: DeserializeOwned + Send + 'static,
422 {
423 let event_source = self
424 .http_client
425 .get(self.config.url(path))
426 .query(query)
427 .query(&self.config.query())
428 .headers(self.config.headers())
429 .eventsource()
430 .unwrap();
431
432 OpenAIEventStream::new(event_source)
433 }
434}
435
436#[pin_project]
440pub struct OpenAIEventStream<O>
441where
442 O: DeserializeOwned + Send + 'static,
443{
444 #[pin]
445 stream: Filter<
446 EventSource,
447 future::Ready<bool>,
448 fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
449 >,
450 event_mapper:
451 Option<Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>>,
452 done: bool,
453 _phantom_data: PhantomData<O>,
454}
455
456impl<O> OpenAIEventStream<O>
457where
458 O: DeserializeOwned + Send + 'static,
459{
460 pub(crate) fn with_event_mapping<M>(event_source: EventSource, event_mapper: M) -> Self
461 where
462 M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
463 {
464 Self {
465 stream: event_source.filter(|result|
466 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
468 done: false,
469 event_mapper: Some(Box::new(event_mapper)),
470 _phantom_data: PhantomData,
471 }
472 }
473
474 pub(crate) fn new(event_source: EventSource) -> Self {
475 Self {
476 stream: event_source.filter(|result|
477 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
479 done: false,
480 event_mapper: None,
481 _phantom_data: PhantomData,
482 }
483 }
484}
485
486impl<O> Stream for OpenAIEventStream<O>
487where
488 O: DeserializeOwned + Send + 'static,
489{
490 type Item = Result<O, OpenAIError>;
491
492 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
493 let this = self.project();
494 if *this.done {
495 return Poll::Ready(None);
496 }
497 let stream: Pin<&mut _> = this.stream;
498 match stream.poll_next(cx) {
499 Poll::Ready(response) => {
500 match response {
501 None => Poll::Ready(None), Some(result) => match result {
503 Ok(event) => match event {
504 Event::Open => unreachable!(), Event::Message(message) => {
506 if let Some(event_mapper) = this.event_mapper.as_ref() {
507 if message.data == "[DONE]" {
508 *this.done = true;
509 }
510 let response = event_mapper(message);
511 match response {
512 Ok(output) => Poll::Ready(Some(Ok(output))),
513 Err(_) => Poll::Ready(None),
514 }
515 } else {
516 if message.data == "[DONE]" {
517 *this.done = true;
518 Poll::Ready(None) } else {
520 match serde_json::from_str::<O>(&message.data) {
522 Err(e) => {
523 *this.done = true;
524 Poll::Ready(Some(Err(map_deserialization_error(
525 e,
526 message.data.as_bytes(),
527 ))))
528 }
529 Ok(output) => Poll::Ready(Some(Ok(output))),
530 }
531 }
532 }
533 }
534 },
535 Err(e) => {
536 *this.done = true;
537 Poll::Ready(Some(Err(OpenAIError::StreamError(
538 StreamError::ReqwestEventSource(e),
539 ))))
540 }
541 },
542 }
543 }
544 Poll::Pending => Poll::Pending,
545 }
546 }
547}
548
549async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
550 let status = response.status();
551 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
552
553 if status.is_server_error() {
554 let message: String = String::from_utf8_lossy(&bytes).into_owned();
556 tracing::warn!("Server error: {status} - {message}");
557 return Err(OpenAIError::ApiError(ApiError {
558 message,
559 r#type: None,
560 param: None,
561 code: None,
562 }));
563 }
564
565 if !status.is_success() {
567 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
568 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
569
570 return Err(OpenAIError::ApiError(wrapped_error.error));
571 }
572
573 Ok(bytes)
574}