Skip to main content

async_openai/
client.rs

1#[cfg(not(target_family = "wasm"))]
2use std::pin::Pin;
3use std::sync::Arc;
4#[cfg(not(target_family = "wasm"))]
5use std::sync::Mutex;
6
7use bytes::Bytes;
8#[cfg(not(target_family = "wasm"))]
9use futures::{stream::StreamExt, Stream};
10use reqwest::{header::HeaderMap, multipart::Form, Response};
11use serde::{de::DeserializeOwned, Serialize};
12
13#[cfg(not(target_family = "wasm"))]
14use crate::error::StreamError;
15#[cfg(feature = "middleware")]
16use crate::executor::TowerExecutor;
17use crate::{
18    config::{Config, OpenAIConfig},
19    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
20    executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
21    traits::AsyncTryFrom,
22    RequestOptions,
23};
24
25struct RequestParts {
26    request_client: reqwest::Client,
27    method: reqwest::Method,
28    url: String,
29    headers: HeaderMap,
30    query: Vec<(String, String)>,
31}
32
33impl RequestParts {
34    fn build_request_builder(&self) -> reqwest::RequestBuilder {
35        self.request_client
36            .request(self.method.clone(), self.url.clone())
37            .query(&self.query)
38            .headers(self.headers.clone())
39    }
40}
41
42#[cfg(feature = "administration")]
43use crate::admin::Admin;
44#[cfg(feature = "chatkit")]
45use crate::chatkit::Chatkit;
46#[cfg(feature = "file")]
47use crate::file::Files;
48#[cfg(feature = "image")]
49use crate::image::Images;
50#[cfg(feature = "moderation")]
51use crate::moderation::Moderations;
52#[cfg(feature = "assistant")]
53#[allow(deprecated)]
54use crate::Assistants;
55#[cfg(feature = "audio")]
56use crate::Audio;
57#[cfg(feature = "batch")]
58use crate::Batches;
59#[cfg(feature = "chat-completion")]
60use crate::Chat;
61#[cfg(feature = "completions")]
62use crate::Completions;
63#[cfg(feature = "container")]
64use crate::Containers;
65#[cfg(feature = "responses")]
66use crate::Conversations;
67#[cfg(feature = "embedding")]
68use crate::Embeddings;
69#[cfg(feature = "evals")]
70use crate::Evals;
71#[cfg(feature = "finetuning")]
72use crate::FineTuning;
73#[cfg(feature = "model")]
74use crate::Models;
75#[cfg(feature = "realtime")]
76use crate::Realtime;
77#[cfg(feature = "responses")]
78use crate::Responses;
79#[cfg(feature = "skill")]
80use crate::Skills;
81#[cfg(feature = "assistant")]
82#[allow(deprecated)]
83use crate::Threads;
84#[cfg(feature = "upload")]
85use crate::Uploads;
86#[cfg(feature = "vectorstore")]
87use crate::VectorStores;
88#[cfg(feature = "video")]
89use crate::Videos;
90
91#[derive(Clone)]
92/// Client is a container for config and HTTP execution
93/// used to make API calls.
94pub struct Client<C: Config> {
95    request_client: reqwest::Client,
96    executor: SharedExecutor,
97    config: C,
98}
99
100impl<C> std::fmt::Debug for Client<C>
101where
102    C: Config + std::fmt::Debug,
103{
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("Client")
106            .field("request_client", &self.request_client)
107            .field("config", &self.config)
108            .finish()
109    }
110}
111
112impl<C: Config> Default for Client<C>
113where
114    C: Default,
115{
116    fn default() -> Self {
117        let request_client = reqwest::Client::new();
118        Self {
119            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
120            request_client,
121            config: C::default(),
122        }
123    }
124}
125
126impl Client<OpenAIConfig> {
127    /// Client with default [OpenAIConfig]
128    pub fn new() -> Self {
129        Self::default()
130    }
131}
132
133impl<C: Config> Client<C> {
134    /// Create client with a custom HTTP client and config.
135    pub fn build(http_client: reqwest::Client, config: C) -> Self {
136        Self {
137            executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
138            request_client: http_client,
139            config,
140        }
141    }
142
143    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
144    pub fn with_config(config: C) -> Self {
145        let request_client = reqwest::Client::new();
146        Self {
147            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
148            request_client,
149            config,
150        }
151    }
152
153    /// Provide your own [client] to make HTTP requests with.
154    ///
155    /// [client]: reqwest::Client
156    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
157        self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
158        self.request_client = http_client;
159        self
160    }
161
162    /// Provide your own tower-compatible service to execute HTTP requests.
163    #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
164    pub fn with_http_service<S>(mut self, service: S) -> Self
165    where
166        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
167        S::Future: Send + 'static,
168        S::Error: Into<OpenAIError> + Send + Sync + 'static,
169    {
170        // This is the public middleware escape hatch. We erase the concrete
171        // tower stack here so the rest of the client does not become generic
172        // over the service type, which would otherwise leak through every API
173        // group and make the crate much harder to use.
174        self.executor = Arc::new(TowerExecutor::new(service));
175        self
176    }
177
178    /// Provide your own tower-compatible service to execute HTTP requests.
179    #[cfg(all(feature = "middleware", target_family = "wasm"))]
180    pub fn with_http_service<S>(mut self, service: S) -> Self
181    where
182        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
183        S::Future: 'static,
184        S::Error: Into<OpenAIError> + 'static,
185    {
186        // wasm futures produced by reqwest are not `Send`, so the wasm version
187        // intentionally avoids native thread-safety bounds. Users are still
188        // responsible for choosing tower layers that work in their wasm
189        // runtime.
190        self.executor = Arc::new(TowerExecutor::new(service));
191        self
192    }
193
194    // API groups
195
196    /// To call [Models] group related APIs using this client.
197    #[cfg(feature = "model")]
198    pub fn models(&self) -> Models<'_, C> {
199        Models::new(self)
200    }
201
202    /// To call [Completions] group related APIs using this client.
203    #[cfg(feature = "completions")]
204    pub fn completions(&self) -> Completions<'_, C> {
205        Completions::new(self)
206    }
207
208    /// To call [Chat] group related APIs using this client.
209    #[cfg(feature = "chat-completion")]
210    pub fn chat(&self) -> Chat<'_, C> {
211        Chat::new(self)
212    }
213
214    /// To call [Images] group related APIs using this client.
215    #[cfg(feature = "image")]
216    pub fn images(&self) -> Images<'_, C> {
217        Images::new(self)
218    }
219
220    /// To call [Moderations] group related APIs using this client.
221    #[cfg(feature = "moderation")]
222    pub fn moderations(&self) -> Moderations<'_, C> {
223        Moderations::new(self)
224    }
225
226    /// To call [Files] group related APIs using this client.
227    #[cfg(feature = "file")]
228    pub fn files(&self) -> Files<'_, C> {
229        Files::new(self)
230    }
231
232    /// To call [Uploads] group related APIs using this client.
233    #[cfg(feature = "upload")]
234    pub fn uploads(&self) -> Uploads<'_, C> {
235        Uploads::new(self)
236    }
237
238    /// To call [FineTuning] group related APIs using this client.
239    #[cfg(feature = "finetuning")]
240    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
241        FineTuning::new(self)
242    }
243
244    /// To call [Embeddings] group related APIs using this client.
245    #[cfg(feature = "embedding")]
246    pub fn embeddings(&self) -> Embeddings<'_, C> {
247        Embeddings::new(self)
248    }
249
250    /// To call [Audio] group related APIs using this client.
251    #[cfg(feature = "audio")]
252    pub fn audio(&self) -> Audio<'_, C> {
253        Audio::new(self)
254    }
255
256    /// To call [Videos] group related APIs using this client.
257    #[cfg(feature = "video")]
258    pub fn videos(&self) -> Videos<'_, C> {
259        Videos::new(self)
260    }
261
262    /// To call [Assistants] group related APIs using this client.
263    #[cfg(feature = "assistant")]
264    #[deprecated(
265        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
266    )]
267    #[allow(deprecated)]
268    pub fn assistants(&self) -> Assistants<'_, C> {
269        Assistants::new(self)
270    }
271
272    /// To call [Threads] group related APIs using this client.
273    #[cfg(feature = "assistant")]
274    #[deprecated(
275        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
276    )]
277    #[allow(deprecated)]
278    pub fn threads(&self) -> Threads<'_, C> {
279        Threads::new(self)
280    }
281
282    /// To call [VectorStores] group related APIs using this client.
283    #[cfg(feature = "vectorstore")]
284    pub fn vector_stores(&self) -> VectorStores<'_, C> {
285        VectorStores::new(self)
286    }
287
288    /// To call [Batches] group related APIs using this client.
289    #[cfg(feature = "batch")]
290    pub fn batches(&self) -> Batches<'_, C> {
291        Batches::new(self)
292    }
293
294    /// To call [Admin] group related APIs using this client.
295    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
296    #[cfg(feature = "administration")]
297    pub fn admin(&self) -> Admin<'_, C> {
298        Admin::new(self)
299    }
300
301    /// To call [Responses] group related APIs using this client.
302    #[cfg(feature = "responses")]
303    pub fn responses(&self) -> Responses<'_, C> {
304        Responses::new(self)
305    }
306
307    /// To call [Conversations] group related APIs using this client.
308    #[cfg(feature = "responses")]
309    pub fn conversations(&self) -> Conversations<'_, C> {
310        Conversations::new(self)
311    }
312
313    /// To call [Containers] group related APIs using this client.
314    #[cfg(feature = "container")]
315    pub fn containers(&self) -> Containers<'_, C> {
316        Containers::new(self)
317    }
318
319    /// To call [Skills] group related APIs using this client.
320    #[cfg(feature = "skill")]
321    pub fn skills(&self) -> Skills<'_, C> {
322        Skills::new(self)
323    }
324
325    /// To call [Evals] group related APIs using this client.
326    #[cfg(feature = "evals")]
327    pub fn evals(&self) -> Evals<'_, C> {
328        Evals::new(self)
329    }
330
331    #[cfg(feature = "chatkit")]
332    pub fn chatkit(&self) -> Chatkit<'_, C> {
333        Chatkit::new(self)
334    }
335
336    /// To call [Realtime] group related APIs using this client.
337    #[cfg(feature = "realtime")]
338    pub fn realtime(&self) -> Realtime<'_, C> {
339        Realtime::new(self)
340    }
341
342    pub fn config(&self) -> &C {
343        &self.config
344    }
345
346    fn build_request_parts(
347        &self,
348        method: reqwest::Method,
349        path: &str,
350        request_options: &RequestOptions,
351    ) -> Arc<RequestParts> {
352        let url = if let Some(path) = request_options.path() {
353            self.config.url(path.as_str())
354        } else {
355            self.config.url(path)
356        };
357        let mut headers = self.config.headers();
358        if let Some(request_headers) = request_options.headers() {
359            headers.extend(request_headers.clone());
360        }
361
362        let mut query = self
363            .config
364            .query()
365            .into_iter()
366            .map(|(key, value)| (key.to_string(), value.to_string()))
367            .collect::<Vec<_>>();
368        query.extend_from_slice(request_options.query());
369
370        Arc::new(RequestParts {
371            request_client: self.request_client.clone(),
372            method,
373            url,
374            headers,
375            query,
376        })
377    }
378
379    fn build_request_factory(
380        &self,
381        method: reqwest::Method,
382        path: &str,
383        request_options: &RequestOptions,
384    ) -> HttpRequestFactory {
385        let request_parts = self.build_request_parts(method, path, request_options);
386
387        HttpRequestFactory::new(move || {
388            let request_parts = request_parts.clone();
389
390            async move {
391                let request = request_parts.build_request_builder().build()?;
392                Ok(request)
393            }
394        })
395    }
396
397    fn build_request_factory_with_json<I>(
398        &self,
399        method: reqwest::Method,
400        path: &str,
401        request: I,
402        request_options: &RequestOptions,
403    ) -> Result<HttpRequestFactory, OpenAIError>
404    where
405        I: Serialize,
406    {
407        // JSON bodies are materialized once so the base BYOT path can keep
408        // accepting borrowed inputs. Middleware-enabled BYOT still adds owned
409        // replay bounds in the macro, but the core client does not force those
410        // bounds onto non-middleware users.
411        let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
412            OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
413        })?);
414        let request_parts = self.build_request_parts(method, path, request_options);
415
416        Ok(HttpRequestFactory::new(move || {
417            let request_parts = request_parts.clone();
418            let request = request.clone();
419
420            async move {
421                let request_builder = request_parts
422                    .build_request_builder()
423                    .header(reqwest::header::CONTENT_TYPE, "application/json")
424                    .body(request.clone());
425
426                Ok(request_builder.build()?)
427            }
428        }))
429    }
430
431    #[cfg(not(target_family = "wasm"))]
432    fn build_request_factory_with_form<F>(
433        &self,
434        method: reqwest::Method,
435        path: &str,
436        form: F,
437        request_options: &RequestOptions,
438    ) -> Result<HttpRequestFactory, OpenAIError>
439    where
440        F: Clone + Send + 'static,
441        Form: AsyncTryFrom<F, Error = OpenAIError>,
442    {
443        // Multipart is the reason the factory exists.
444        //
445        // `Mutex` is only here to make the captured state `Sync`
446        let form = Arc::new(Mutex::new(form));
447        let request_parts = self.build_request_parts(method, path, request_options);
448
449        Ok(HttpRequestFactory::new(move || {
450            let request_parts = request_parts.clone();
451            let form = form.clone();
452
453            async move {
454                let form = form
455                    .lock()
456                    .expect("multipart request factory mutex poisoned")
457                    .clone();
458                let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
459                let request_builder = request_parts.build_request_builder().multipart(form);
460
461                Ok(request_builder.build()?)
462            }
463        }))
464    }
465
466    #[cfg(target_family = "wasm")]
467    fn build_request_factory_with_form<F>(
468        &self,
469        method: reqwest::Method,
470        path: &str,
471        form: F,
472        request_options: &RequestOptions,
473    ) -> Result<HttpRequestFactory, OpenAIError>
474    where
475        F: Clone + 'static,
476        Form: AsyncTryFrom<F, Error = OpenAIError>,
477    {
478        let request_parts = self.build_request_parts(method, path, request_options);
479
480        Ok(HttpRequestFactory::new(move || {
481            let request_parts = request_parts.clone();
482            let form = form.clone();
483
484            async move {
485                let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
486                let request_builder = request_parts.build_request_builder().multipart(form);
487
488                Ok(request_builder.build()?)
489            }
490        }))
491    }
492
493    /// Make a GET request to {path} and deserialize the response body
494    #[allow(unused)]
495    pub(crate) async fn get<O>(
496        &self,
497        path: &str,
498        request_options: &RequestOptions,
499    ) -> Result<O, OpenAIError>
500    where
501        O: DeserializeOwned,
502    {
503        let request_factory =
504            self.build_request_factory(reqwest::Method::GET, path, request_options);
505        self.execute(request_factory).await
506    }
507
508    /// Make a DELETE request to {path} and deserialize the response body
509    #[allow(unused)]
510    pub(crate) async fn delete<O>(
511        &self,
512        path: &str,
513        request_options: &RequestOptions,
514    ) -> Result<O, OpenAIError>
515    where
516        O: DeserializeOwned,
517    {
518        let request_factory =
519            self.build_request_factory(reqwest::Method::DELETE, path, request_options);
520        self.execute(request_factory).await
521    }
522
523    /// Make a GET request to {path} and return the response body
524    #[allow(unused)]
525    pub(crate) async fn get_raw(
526        &self,
527        path: &str,
528        request_options: &RequestOptions,
529    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
530        let request_factory =
531            self.build_request_factory(reqwest::Method::GET, path, request_options);
532        self.execute_raw(request_factory).await
533    }
534
535    /// Make a POST request to {path} and return the response body
536    #[allow(unused)]
537    pub(crate) async fn post_raw<I>(
538        &self,
539        path: &str,
540        request: I,
541        request_options: &RequestOptions,
542    ) -> Result<(Bytes, HeaderMap), OpenAIError>
543    where
544        I: Serialize,
545    {
546        let request_factory = self.build_request_factory_with_json(
547            reqwest::Method::POST,
548            path,
549            request,
550            request_options,
551        )?;
552        self.execute_raw(request_factory).await
553    }
554
555    /// Make a POST request to {path} and deserialize the response body
556    #[allow(unused)]
557    pub(crate) async fn post<I, O>(
558        &self,
559        path: &str,
560        request: I,
561        request_options: &RequestOptions,
562    ) -> Result<O, OpenAIError>
563    where
564        I: Serialize,
565        O: DeserializeOwned,
566    {
567        let request_factory = self.build_request_factory_with_json(
568            reqwest::Method::POST,
569            path,
570            request,
571            request_options,
572        )?;
573        self.execute(request_factory).await
574    }
575
576    /// POST a form at {path} and return the response body
577    #[allow(unused)]
578    #[cfg(not(target_family = "wasm"))]
579    pub(crate) async fn post_form_raw<F>(
580        &self,
581        path: &str,
582        form: F,
583        request_options: &RequestOptions,
584    ) -> Result<(Bytes, HeaderMap), OpenAIError>
585    where
586        F: Clone + Send + 'static,
587        Form: AsyncTryFrom<F, Error = OpenAIError>,
588    {
589        let request_factory = self.build_request_factory_with_form(
590            reqwest::Method::POST,
591            path,
592            form,
593            request_options,
594        )?;
595        self.execute_raw(request_factory).await
596    }
597
598    /// POST a form at {path} and return the response body
599    #[allow(unused)]
600    #[cfg(target_family = "wasm")]
601    pub(crate) async fn post_form_raw<F>(
602        &self,
603        path: &str,
604        form: F,
605        request_options: &RequestOptions,
606    ) -> Result<(Bytes, HeaderMap), OpenAIError>
607    where
608        F: Clone + 'static,
609        Form: AsyncTryFrom<F, Error = OpenAIError>,
610    {
611        let request_factory = self.build_request_factory_with_form(
612            reqwest::Method::POST,
613            path,
614            form,
615            request_options,
616        )?;
617        self.execute_raw(request_factory).await
618    }
619
620    /// POST a form at {path} and deserialize the response body
621    #[allow(unused)]
622    #[cfg(not(target_family = "wasm"))]
623    pub(crate) async fn post_form<O, F>(
624        &self,
625        path: &str,
626        form: F,
627        request_options: &RequestOptions,
628    ) -> Result<O, OpenAIError>
629    where
630        O: DeserializeOwned,
631        F: Clone + Send + 'static,
632        Form: AsyncTryFrom<F, Error = OpenAIError>,
633    {
634        let request_factory = self.build_request_factory_with_form(
635            reqwest::Method::POST,
636            path,
637            form,
638            request_options,
639        )?;
640        self.execute(request_factory).await
641    }
642
643    /// POST a form at {path} and deserialize the response body
644    #[allow(unused)]
645    #[cfg(target_family = "wasm")]
646    pub(crate) async fn post_form<O, F>(
647        &self,
648        path: &str,
649        form: F,
650        request_options: &RequestOptions,
651    ) -> Result<O, OpenAIError>
652    where
653        O: DeserializeOwned,
654        F: Clone + 'static,
655        Form: AsyncTryFrom<F, Error = OpenAIError>,
656    {
657        let request_factory = self.build_request_factory_with_form(
658            reqwest::Method::POST,
659            path,
660            form,
661            request_options,
662        )?;
663        self.execute(request_factory).await
664    }
665
666    #[allow(unused)]
667    #[cfg(not(target_family = "wasm"))]
668    pub(crate) async fn post_form_stream<O, F>(
669        &self,
670        path: &str,
671        form: F,
672        request_options: &RequestOptions,
673    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
674    where
675        F: Clone + Send + 'static,
676        Form: AsyncTryFrom<F, Error = OpenAIError>,
677        O: DeserializeOwned + std::marker::Send + 'static,
678    {
679        let request_factory = self.build_request_factory_with_form(
680            reqwest::Method::POST,
681            path,
682            form,
683            request_options,
684        )?;
685
686        self.execute_stream(request_factory).await
687    }
688
689    async fn execute_raw(
690        &self,
691        request_factory: HttpRequestFactory,
692    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
693        let response = self.execute_response(request_factory).await?;
694        read_response(response).await
695    }
696
697    async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
698    where
699        O: DeserializeOwned,
700    {
701        let (bytes, _headers) = self.execute_raw(request_factory).await?;
702
703        let response: O = serde_json::from_slice(bytes.as_ref())
704            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
705
706        Ok(response)
707    }
708
709    async fn execute_response(
710        &self,
711        request_factory: HttpRequestFactory,
712    ) -> Result<Response, OpenAIError> {
713        self.executor.execute(request_factory).await
714    }
715
716    #[cfg(not(target_family = "wasm"))]
717    async fn execute_stream<O>(
718        &self,
719        request_factory: HttpRequestFactory,
720    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
721    where
722        O: DeserializeOwned + std::marker::Send + 'static,
723    {
724        let response = self.execute_response(request_factory).await?;
725        Ok(stream(response).await)
726    }
727
728    #[cfg(not(target_family = "wasm"))]
729    async fn execute_stream_mapped_raw_events<O>(
730        &self,
731        request_factory: HttpRequestFactory,
732        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
733    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
734    where
735        O: DeserializeOwned + std::marker::Send + 'static,
736    {
737        let response = self.execute_response(request_factory).await?;
738        Ok(stream_mapped_raw_events(response, event_mapper).await)
739    }
740
741    /// Make HTTP POST request to receive SSE
742    #[allow(unused)]
743    #[cfg(not(target_family = "wasm"))]
744    pub(crate) async fn post_stream<I, O>(
745        &self,
746        path: &str,
747        request: I,
748        request_options: &RequestOptions,
749    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
750    where
751        I: Serialize,
752        O: DeserializeOwned + std::marker::Send + 'static,
753    {
754        let request_factory = self.build_request_factory_with_json(
755            reqwest::Method::POST,
756            path,
757            request,
758            request_options,
759        )?;
760        // Stream setup is still request/response first. We only create the SSE
761        // stream after the HTTP layer has returned a response object.
762        self.execute_stream(request_factory).await
763    }
764
765    #[allow(unused)]
766    #[cfg(not(target_family = "wasm"))]
767    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
768        &self,
769        path: &str,
770        request: I,
771        request_options: &RequestOptions,
772        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
773    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
774    where
775        I: Serialize,
776        O: DeserializeOwned + std::marker::Send + 'static,
777    {
778        let request_factory = self.build_request_factory_with_json(
779            reqwest::Method::POST,
780            path,
781            request,
782            request_options,
783        )?;
784        self.execute_stream_mapped_raw_events(request_factory, event_mapper)
785            .await
786    }
787
788    /// Make HTTP GET request to receive SSE
789    #[allow(unused)]
790    #[cfg(not(target_family = "wasm"))]
791    pub(crate) async fn get_stream<O>(
792        &self,
793        path: &str,
794        request_options: &RequestOptions,
795    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
796    where
797        O: DeserializeOwned + std::marker::Send + 'static,
798    {
799        let request_factory =
800            self.build_request_factory(reqwest::Method::GET, path, request_options);
801        self.execute_stream(request_factory).await
802    }
803}
804
805async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
806    let status = response.status();
807    let headers = response.headers().clone();
808    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
809
810    if status.is_server_error() {
811        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
812        let message: String = String::from_utf8_lossy(&bytes).into_owned();
813        tracing::warn!("Server error: {status} - {message}");
814        return Err(OpenAIError::ApiError(ApiError {
815            message,
816            r#type: None,
817            param: None,
818            code: None,
819        }));
820    }
821
822    // Deserialize response body from either error object or actual response object
823    if !status.is_success() {
824        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
825            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
826
827        return Err(OpenAIError::ApiError(wrapped_error.error));
828    }
829
830    Ok((bytes, headers))
831}
832
833/// Request which responds with SSE.
834/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
835#[cfg(not(target_family = "wasm"))]
836pub(crate) async fn stream<O>(
837    response: Response,
838) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
839where
840    O: DeserializeOwned + std::marker::Send + 'static,
841{
842    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
843
844    tokio::spawn(async move {
845        if !response.status().is_success() {
846            if let Err(e) = read_response(response).await {
847                let _ = tx.send(Err(e));
848            }
849            return;
850        }
851        let byte_stream = response
852            .bytes_stream()
853            .map(|r| r.map_err(std::io::Error::other));
854        let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
855
856        while let Some(ev) = event_stream.next().await {
857            let event = match ev {
858                Ok(e) => e,
859                Err(e) => {
860                    let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
861                        StreamError::EventStream(e.to_string()),
862                    ))));
863                    break;
864                }
865            };
866            if event.data == "[DONE]" {
867                break;
868            }
869            if event.event == "keepalive" {
870                continue;
871            }
872
873            let response = serde_json::from_str::<O>(&event.data)
874                .map_err(|e| map_deserialization_error(e, event.data.as_bytes()));
875
876            if tx.send(response).is_err() {
877                break;
878            }
879        }
880    });
881
882    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
883}
884
885#[cfg(not(target_family = "wasm"))]
886pub(crate) async fn stream_mapped_raw_events<O>(
887    response: Response,
888    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
889) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
890where
891    O: DeserializeOwned + std::marker::Send + 'static,
892{
893    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
894
895    tokio::spawn(async move {
896        if !response.status().is_success() {
897            if let Err(e) = read_response(response).await {
898                let _ = tx.send(Err(e));
899            }
900            return;
901        }
902        let byte_stream = response
903            .bytes_stream()
904            .map(|r| r.map_err(std::io::Error::other));
905        let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
906
907        while let Some(ev) = event_stream.next().await {
908            let event = match ev {
909                Ok(e) => e,
910                Err(e) => {
911                    let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
912                        StreamError::EventStream(e.to_string()),
913                    ))));
914                    break;
915                }
916            };
917            let done = event.data == "[DONE]";
918
919            if event.event == "keepalive" {
920                continue;
921            }
922
923            let response = event_mapper(event);
924
925            if tx.send(response).is_err() {
926                break;
927            }
928
929            if done {
930                break;
931            }
932        }
933    });
934
935    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
936}
937
938#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
939mod tests {
940    use std::sync::{
941        atomic::{AtomicUsize, Ordering},
942        Arc,
943    };
944
945    use futures::StreamExt;
946    use http::Response as HttpResponse;
947    use serde_json::json;
948    use tower::{service_fn, ServiceBuilder};
949
950    use super::Client;
951    use crate::{
952        config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
953        retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
954    };
955
956    #[tokio::test]
957    async fn unary_requests_dispatch_through_middleware_service() {
958        let request_count = Arc::new(AtomicUsize::new(0));
959        let service = {
960            let request_count = request_count.clone();
961            ServiceBuilder::new()
962                .concurrency_limit(1)
963                .service(service_fn(move |factory: HttpRequestFactory| {
964                    let request_count = request_count.clone();
965                    async move {
966                        let request = factory.build().await?;
967                        assert_eq!(request.url().path(), "/models");
968                        request_count.fetch_add(1, Ordering::SeqCst);
969                        Ok::<reqwest::Response, OpenAIError>(
970                            HttpResponse::builder()
971                                .status(200)
972                                .header("content-type", "application/json")
973                                .body(reqwest::Body::from(
974                                    "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
975                                ))
976                                .unwrap()
977                                .into(),
978                        )
979                    }
980                }))
981        };
982
983        let client = Client::with_config(
984            OpenAIConfig::new()
985                .with_api_base("http://example.test")
986                .with_api_key("test-key"),
987        )
988        .with_http_service(service);
989
990        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
991
992        assert_eq!(value["object"], "list");
993        assert_eq!(request_count.load(Ordering::SeqCst), 1);
994    }
995
996    #[tokio::test]
997    async fn stream_requests_open_through_middleware_service() {
998        let request_count = Arc::new(AtomicUsize::new(0));
999        let service = {
1000            let request_count = request_count.clone();
1001            ServiceBuilder::new()
1002                .concurrency_limit(1)
1003                .service(service_fn(move |factory: HttpRequestFactory| {
1004                    let request_count = request_count.clone();
1005                    async move {
1006                        let request = factory.build().await?;
1007                        assert_eq!(request.url().path(), "/responses");
1008                        request_count.fetch_add(1, Ordering::SeqCst);
1009                        Ok::<reqwest::Response, OpenAIError>(
1010                            HttpResponse::builder()
1011                                .status(200)
1012                                .header("content-type", "text/event-stream")
1013                                .body(reqwest::Body::from(
1014                                    "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
1015                                ))
1016                                .unwrap()
1017                                .into(),
1018                        )
1019                    }
1020                }))
1021        };
1022
1023        let client = Client::with_config(
1024            OpenAIConfig::new()
1025                .with_api_base("http://example.test")
1026                .with_api_key("test-key"),
1027        )
1028        .with_http_service(service);
1029
1030        let mut stream = client
1031            .post_stream::<_, serde_json::Value>(
1032                "/responses",
1033                json!({ "stream": true }),
1034                &RequestOptions::new(),
1035            )
1036            .await
1037            .unwrap();
1038
1039        let first = stream.next().await.unwrap().unwrap();
1040
1041        assert_eq!(first, json!({ "ok": true }));
1042        assert_eq!(request_count.load(Ordering::SeqCst), 1);
1043    }
1044
1045    #[tokio::test]
1046    async fn middleware_retry_policy_retries_429_responses() {
1047        let request_count = Arc::new(AtomicUsize::new(0));
1048        let service = {
1049            let request_count = request_count.clone();
1050            ServiceBuilder::new()
1051                .retry(SimpleRetryPolicy::default())
1052                .service(service_fn(move |factory: HttpRequestFactory| {
1053                    let request_count = request_count.clone();
1054                    async move {
1055                        let request = factory.build().await?;
1056                        assert_eq!(request.url().path(), "/models");
1057                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1058
1059                        let response = if attempt == 0 {
1060                            HttpResponse::builder()
1061                                .status(429)
1062                                .header("content-type", "application/json")
1063                                .body(reqwest::Body::from(
1064                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1065                                ))
1066                                .unwrap()
1067                        } else {
1068                            HttpResponse::builder()
1069                                .status(200)
1070                                .header("content-type", "application/json")
1071                                .body(reqwest::Body::from(
1072                                    r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1073                                ))
1074                                .unwrap()
1075                        };
1076
1077                        Ok::<reqwest::Response, OpenAIError>(response.into())
1078                    }
1079                }))
1080        };
1081
1082        let client = Client::with_config(
1083            OpenAIConfig::new()
1084                .with_api_base("http://example.test")
1085                .with_api_key("test-key"),
1086        )
1087        .with_http_service(service);
1088
1089        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1090
1091        assert_eq!(value["data"][0]["id"], "retry-model");
1092        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1093    }
1094
1095    #[derive(Clone)]
1096    struct RetryableMultipartInput {
1097        conversions: Arc<AtomicUsize>,
1098    }
1099
1100    impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1101        type Error = OpenAIError;
1102
1103        async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1104            value.conversions.fetch_add(1, Ordering::SeqCst);
1105            Ok(reqwest::multipart::Form::new().text("field", "value"))
1106        }
1107    }
1108
1109    #[tokio::test]
1110    async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1111        let request_count = Arc::new(AtomicUsize::new(0));
1112        let conversion_count = Arc::new(AtomicUsize::new(0));
1113
1114        let service = {
1115            let request_count = request_count.clone();
1116            ServiceBuilder::new()
1117                .retry(SimpleRetryPolicy::default())
1118                .service(service_fn(move |factory: HttpRequestFactory| {
1119                    let request_count = request_count.clone();
1120                    async move {
1121                        let request = factory.build().await?;
1122                        assert_eq!(request.method(), reqwest::Method::POST);
1123                        assert_eq!(request.url().path(), "/files");
1124                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1125
1126                        let response = if attempt == 0 {
1127                            HttpResponse::builder()
1128                                .status(429)
1129                                .header("content-type", "application/json")
1130                                .body(reqwest::Body::from(
1131                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1132                                ))
1133                                .unwrap()
1134                        } else {
1135                            HttpResponse::builder()
1136                                .status(200)
1137                                .header("content-type", "application/json")
1138                                .body(reqwest::Body::from(r#"{"ok":true}"#))
1139                                .unwrap()
1140                        };
1141
1142                        Ok::<reqwest::Response, OpenAIError>(response.into())
1143                    }
1144                }))
1145        };
1146
1147        let client = Client::with_config(
1148            OpenAIConfig::new()
1149                .with_api_base("http://example.test")
1150                .with_api_key("test-key"),
1151        )
1152        .with_http_service(service);
1153
1154        let value: serde_json::Value = client
1155            .post_form(
1156                "/files",
1157                RetryableMultipartInput {
1158                    conversions: conversion_count.clone(),
1159                },
1160                &RequestOptions::new(),
1161            )
1162            .await
1163            .unwrap();
1164
1165        assert_eq!(value, json!({ "ok": true }));
1166        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1167        assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1168    }
1169}