Skip to main content

async_openai/
client.rs

1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14    config::{Config, OpenAIConfig},
15    error::{map_deserialization_error, ApiError, ApiErrorResponse, OpenAIError, WrappedError},
16    executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17    traits::AsyncTryFrom,
18    RequestOptions,
19};
20
21struct RequestParts {
22    request_client: reqwest::Client,
23    method: reqwest::Method,
24    url: String,
25    headers: HeaderMap,
26    query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30    fn build_request_builder(&self) -> reqwest::RequestBuilder {
31        self.request_client
32            .request(self.method.clone(), self.url.clone())
33            .query(&self.query)
34            .headers(self.headers.clone())
35    }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88/// Client is a container for config and HTTP execution
89/// used to make API calls.
90pub struct Client<C: Config> {
91    request_client: reqwest::Client,
92    executor: SharedExecutor,
93    config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98    C: Config + std::fmt::Debug,
99{
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("Client")
102            .field("request_client", &self.request_client)
103            .field("config", &self.config)
104            .finish()
105    }
106}
107
108impl<C: Config> Default for Client<C>
109where
110    C: Default,
111{
112    fn default() -> Self {
113        let request_client = reqwest::Client::new();
114        Self {
115            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116            request_client,
117            config: C::default(),
118        }
119    }
120}
121
122impl Client<OpenAIConfig> {
123    /// Client with default [OpenAIConfig]
124    pub fn new() -> Self {
125        Self::default()
126    }
127}
128
129impl<C: Config> Client<C> {
130    /// Create client with a custom HTTP client and config.
131    pub fn build(http_client: reqwest::Client, config: C) -> Self {
132        Self {
133            executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134            request_client: http_client,
135            config,
136        }
137    }
138
139    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
140    pub fn with_config(config: C) -> Self {
141        let request_client = reqwest::Client::new();
142        Self {
143            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144            request_client,
145            config,
146        }
147    }
148
149    /// Provide your own [client] to make HTTP requests with.
150    ///
151    /// [client]: reqwest::Client
152    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153        self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154        self.request_client = http_client;
155        self
156    }
157
158    /// Provide your own tower-compatible service to execute HTTP requests.
159    #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160    pub fn with_http_service<S>(mut self, service: S) -> Self
161    where
162        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163        S::Future: Send + 'static,
164        S::Error: Into<OpenAIError> + Send + Sync + 'static,
165    {
166        // This is the public middleware escape hatch. We erase the concrete
167        // tower stack here so the rest of the client does not become generic
168        // over the service type, which would otherwise leak through every API
169        // group and make the crate much harder to use.
170        self.executor = Arc::new(TowerExecutor::new(service));
171        self
172    }
173
174    /// Provide your own tower-compatible service to execute HTTP requests.
175    #[cfg(all(feature = "middleware", target_family = "wasm"))]
176    pub fn with_http_service<S>(mut self, service: S) -> Self
177    where
178        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179        S::Future: 'static,
180        S::Error: Into<OpenAIError> + 'static,
181    {
182        // wasm futures produced by reqwest are not `Send`, so the wasm version
183        // intentionally avoids native thread-safety bounds. Users are still
184        // responsible for choosing tower layers that work in their wasm
185        // runtime.
186        self.executor = Arc::new(TowerExecutor::new(service));
187        self
188    }
189
190    // API groups
191
192    /// To call [Models] group related APIs using this client.
193    #[cfg(feature = "model")]
194    pub fn models(&self) -> Models<'_, C> {
195        Models::new(self)
196    }
197
198    /// To call [Completions] group related APIs using this client.
199    #[cfg(feature = "completions")]
200    pub fn completions(&self) -> Completions<'_, C> {
201        Completions::new(self)
202    }
203
204    /// To call [Chat] group related APIs using this client.
205    #[cfg(feature = "chat-completion")]
206    pub fn chat(&self) -> Chat<'_, C> {
207        Chat::new(self)
208    }
209
210    /// To call [Images] group related APIs using this client.
211    #[cfg(feature = "image")]
212    pub fn images(&self) -> Images<'_, C> {
213        Images::new(self)
214    }
215
216    /// To call [Moderations] group related APIs using this client.
217    #[cfg(feature = "moderation")]
218    pub fn moderations(&self) -> Moderations<'_, C> {
219        Moderations::new(self)
220    }
221
222    /// To call [Files] group related APIs using this client.
223    #[cfg(feature = "file")]
224    pub fn files(&self) -> Files<'_, C> {
225        Files::new(self)
226    }
227
228    /// To call [Uploads] group related APIs using this client.
229    #[cfg(feature = "upload")]
230    pub fn uploads(&self) -> Uploads<'_, C> {
231        Uploads::new(self)
232    }
233
234    /// To call [FineTuning] group related APIs using this client.
235    #[cfg(feature = "finetuning")]
236    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237        FineTuning::new(self)
238    }
239
240    /// To call [Embeddings] group related APIs using this client.
241    #[cfg(feature = "embedding")]
242    pub fn embeddings(&self) -> Embeddings<'_, C> {
243        Embeddings::new(self)
244    }
245
246    /// To call [Audio] group related APIs using this client.
247    #[cfg(feature = "audio")]
248    pub fn audio(&self) -> Audio<'_, C> {
249        Audio::new(self)
250    }
251
252    /// To call [Videos] group related APIs using this client.
253    #[cfg(feature = "video")]
254    pub fn videos(&self) -> Videos<'_, C> {
255        Videos::new(self)
256    }
257
258    /// To call [Assistants] group related APIs using this client.
259    #[cfg(feature = "assistant")]
260    #[deprecated(
261        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262    )]
263    #[allow(deprecated)]
264    pub fn assistants(&self) -> Assistants<'_, C> {
265        Assistants::new(self)
266    }
267
268    /// To call [Threads] group related APIs using this client.
269    #[cfg(feature = "assistant")]
270    #[deprecated(
271        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272    )]
273    #[allow(deprecated)]
274    pub fn threads(&self) -> Threads<'_, C> {
275        Threads::new(self)
276    }
277
278    /// To call [VectorStores] group related APIs using this client.
279    #[cfg(feature = "vectorstore")]
280    pub fn vector_stores(&self) -> VectorStores<'_, C> {
281        VectorStores::new(self)
282    }
283
284    /// To call [Batches] group related APIs using this client.
285    #[cfg(feature = "batch")]
286    pub fn batches(&self) -> Batches<'_, C> {
287        Batches::new(self)
288    }
289
290    /// To call [Admin] group related APIs using this client.
291    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
292    #[cfg(feature = "administration")]
293    pub fn admin(&self) -> Admin<'_, C> {
294        Admin::new(self)
295    }
296
297    /// To call [Responses] group related APIs using this client.
298    #[cfg(feature = "responses")]
299    pub fn responses(&self) -> Responses<'_, C> {
300        Responses::new(self)
301    }
302
303    /// To call [Conversations] group related APIs using this client.
304    #[cfg(feature = "responses")]
305    pub fn conversations(&self) -> Conversations<'_, C> {
306        Conversations::new(self)
307    }
308
309    /// To call [Containers] group related APIs using this client.
310    #[cfg(feature = "container")]
311    pub fn containers(&self) -> Containers<'_, C> {
312        Containers::new(self)
313    }
314
315    /// To call [Skills] group related APIs using this client.
316    #[cfg(feature = "skill")]
317    pub fn skills(&self) -> Skills<'_, C> {
318        Skills::new(self)
319    }
320
321    /// To call [Evals] group related APIs using this client.
322    #[cfg(feature = "evals")]
323    pub fn evals(&self) -> Evals<'_, C> {
324        Evals::new(self)
325    }
326
327    #[cfg(feature = "chatkit")]
328    pub fn chatkit(&self) -> Chatkit<'_, C> {
329        Chatkit::new(self)
330    }
331
332    /// To call [Realtime] group related APIs using this client.
333    #[cfg(feature = "realtime")]
334    pub fn realtime(&self) -> Realtime<'_, C> {
335        Realtime::new(self)
336    }
337
338    pub fn config(&self) -> &C {
339        &self.config
340    }
341
342    fn build_request_parts(
343        &self,
344        method: reqwest::Method,
345        path: &str,
346        request_options: &RequestOptions,
347    ) -> Arc<RequestParts> {
348        let url = if let Some(path) = request_options.path() {
349            self.config.url(path.as_str())
350        } else {
351            self.config.url(path)
352        };
353        let mut headers = self.config.headers();
354        if let Some(request_headers) = request_options.headers() {
355            headers.extend(request_headers.clone());
356        }
357
358        let mut query = self
359            .config
360            .query()
361            .into_iter()
362            .map(|(key, value)| (key.to_string(), value.to_string()))
363            .collect::<Vec<_>>();
364        query.extend_from_slice(request_options.query());
365
366        Arc::new(RequestParts {
367            request_client: self.request_client.clone(),
368            method,
369            url,
370            headers,
371            query,
372        })
373    }
374
375    fn build_request_factory(
376        &self,
377        method: reqwest::Method,
378        path: &str,
379        request_options: &RequestOptions,
380    ) -> HttpRequestFactory {
381        let request_parts = self.build_request_parts(method, path, request_options);
382
383        HttpRequestFactory::new(move || {
384            let request_parts = request_parts.clone();
385
386            async move {
387                let request = request_parts.build_request_builder().build()?;
388                Ok(request)
389            }
390        })
391    }
392
393    fn build_request_factory_with_json<I>(
394        &self,
395        method: reqwest::Method,
396        path: &str,
397        request: I,
398        request_options: &RequestOptions,
399    ) -> Result<HttpRequestFactory, OpenAIError>
400    where
401        I: Serialize,
402    {
403        // JSON bodies are materialized once so the base BYOT path can keep
404        // accepting borrowed inputs.
405        let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
406            OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
407        })?);
408        let request_parts = self.build_request_parts(method, path, request_options);
409
410        Ok(HttpRequestFactory::new(move || {
411            let request_parts = request_parts.clone();
412            let request = request.clone();
413
414            async move {
415                let request_builder = request_parts
416                    .build_request_builder()
417                    .header(reqwest::header::CONTENT_TYPE, "application/json")
418                    .body(request.clone());
419
420                Ok(request_builder.build()?)
421            }
422        }))
423    }
424
425    fn build_request_factory_with_form<F>(
426        &self,
427        method: reqwest::Method,
428        path: &str,
429        form: F,
430        request_options: &RequestOptions,
431    ) -> Result<HttpRequestFactory, OpenAIError>
432    where
433        F: Clone + crate::traits::MaybeSend + 'static,
434        Form: AsyncTryFrom<F, Error = OpenAIError>,
435    {
436        // Multipart is the reason the factory exists.
437        //
438        // `Mutex` is only here to make the captured state `Sync` on native targets.
439        #[cfg(not(target_family = "wasm"))]
440        let form = Arc::new(Mutex::new(form));
441        let request_parts = self.build_request_parts(method, path, request_options);
442
443        Ok(HttpRequestFactory::new(move || {
444            let request_parts = request_parts.clone();
445            let form = form.clone();
446
447            async move {
448                #[cfg(not(target_family = "wasm"))]
449                let form = form
450                    .lock()
451                    .expect("multipart request factory mutex poisoned")
452                    .clone();
453                #[cfg(target_family = "wasm")]
454                let form = form.clone();
455                let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
456                let request_builder = request_parts.build_request_builder().multipart(form);
457
458                Ok(request_builder.build()?)
459            }
460        }))
461    }
462
463    /// Make a GET request to {path} and deserialize the response body
464    #[allow(unused)]
465    pub(crate) async fn get<O>(
466        &self,
467        path: &str,
468        request_options: &RequestOptions,
469    ) -> Result<O, OpenAIError>
470    where
471        O: DeserializeOwned,
472    {
473        let request_factory =
474            self.build_request_factory(reqwest::Method::GET, path, request_options);
475        self.execute(request_factory).await
476    }
477
478    /// Make a DELETE request to {path} and deserialize the response body
479    #[allow(unused)]
480    pub(crate) async fn delete<O>(
481        &self,
482        path: &str,
483        request_options: &RequestOptions,
484    ) -> Result<O, OpenAIError>
485    where
486        O: DeserializeOwned,
487    {
488        let request_factory =
489            self.build_request_factory(reqwest::Method::DELETE, path, request_options);
490        self.execute(request_factory).await
491    }
492
493    /// Make a GET request to {path} and return the response body
494    #[allow(unused)]
495    pub(crate) async fn get_raw(
496        &self,
497        path: &str,
498        request_options: &RequestOptions,
499    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
500        let request_factory =
501            self.build_request_factory(reqwest::Method::GET, path, request_options);
502        self.execute_raw(request_factory).await
503    }
504
505    /// Make a POST request to {path} and return the response body
506    #[allow(unused)]
507    pub(crate) async fn post_raw<I>(
508        &self,
509        path: &str,
510        request: I,
511        request_options: &RequestOptions,
512    ) -> Result<(Bytes, HeaderMap), OpenAIError>
513    where
514        I: Serialize,
515    {
516        let request_factory = self.build_request_factory_with_json(
517            reqwest::Method::POST,
518            path,
519            request,
520            request_options,
521        )?;
522        self.execute_raw(request_factory).await
523    }
524
525    /// Make a POST request to {path} and deserialize the response body
526    #[allow(unused)]
527    pub(crate) async fn post<I, O>(
528        &self,
529        path: &str,
530        request: I,
531        request_options: &RequestOptions,
532    ) -> Result<O, OpenAIError>
533    where
534        I: Serialize,
535        O: DeserializeOwned,
536    {
537        let request_factory = self.build_request_factory_with_json(
538            reqwest::Method::POST,
539            path,
540            request,
541            request_options,
542        )?;
543        self.execute(request_factory).await
544    }
545
546    /// POST a form at {path} and return the response body
547    #[allow(unused)]
548    pub(crate) async fn post_form_raw<F>(
549        &self,
550        path: &str,
551        form: F,
552        request_options: &RequestOptions,
553    ) -> Result<(Bytes, HeaderMap), OpenAIError>
554    where
555        F: Clone + crate::traits::MaybeSend + 'static,
556        Form: AsyncTryFrom<F, Error = OpenAIError>,
557    {
558        let request_factory = self.build_request_factory_with_form(
559            reqwest::Method::POST,
560            path,
561            form,
562            request_options,
563        )?;
564        self.execute_raw(request_factory).await
565    }
566
567    /// POST a form at {path} and deserialize the response body
568    #[allow(unused)]
569    pub(crate) async fn post_form<O, F>(
570        &self,
571        path: &str,
572        form: F,
573        request_options: &RequestOptions,
574    ) -> Result<O, OpenAIError>
575    where
576        O: DeserializeOwned,
577        F: Clone + crate::traits::MaybeSend + 'static,
578        Form: AsyncTryFrom<F, Error = OpenAIError>,
579    {
580        let request_factory = self.build_request_factory_with_form(
581            reqwest::Method::POST,
582            path,
583            form,
584            request_options,
585        )?;
586        self.execute(request_factory).await
587    }
588
589    #[allow(unused)]
590    pub(crate) async fn post_form_stream<O, F>(
591        &self,
592        path: &str,
593        form: F,
594        request_options: &RequestOptions,
595    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
596    where
597        F: Clone + crate::traits::MaybeSend + 'static,
598        Form: AsyncTryFrom<F, Error = OpenAIError>,
599        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
600    {
601        let request_factory = self.build_request_factory_with_form(
602            reqwest::Method::POST,
603            path,
604            form,
605            request_options,
606        )?;
607
608        self.execute_stream(request_factory).await
609    }
610
611    async fn execute_raw(
612        &self,
613        request_factory: HttpRequestFactory,
614    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
615        let response = self.execute_response(request_factory).await?;
616        read_response(response).await
617    }
618
619    async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
620    where
621        O: DeserializeOwned,
622    {
623        let (bytes, _headers) = self.execute_raw(request_factory).await?;
624
625        let response: O = serde_json::from_slice(bytes.as_ref())
626            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
627
628        Ok(response)
629    }
630
631    async fn execute_response(
632        &self,
633        request_factory: HttpRequestFactory,
634    ) -> Result<Response, OpenAIError> {
635        let response = self.executor.execute(request_factory).await?;
636        if !response.status().is_success() {
637            return Err(read_error_response(response).await);
638        }
639        Ok(response)
640    }
641
642    async fn execute_stream<O>(
643        &self,
644        request_factory: HttpRequestFactory,
645    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
646    where
647        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
648    {
649        let response = self.execute_response(request_factory).await?;
650        Ok(stream(response).await)
651    }
652
653    async fn execute_stream_mapped_raw_events<O>(
654        &self,
655        request_factory: HttpRequestFactory,
656        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
657            + crate::traits::MaybeSend
658            + 'static,
659    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
660    where
661        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
662    {
663        let response = self.execute_response(request_factory).await?;
664        Ok(stream_mapped_raw_events(response, event_mapper).await)
665    }
666
667    /// Make HTTP POST request to receive SSE
668    #[allow(unused)]
669    pub(crate) async fn post_stream<I, O>(
670        &self,
671        path: &str,
672        request: I,
673        request_options: &RequestOptions,
674    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
675    where
676        I: Serialize,
677        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
678    {
679        let request_factory = self.build_request_factory_with_json(
680            reqwest::Method::POST,
681            path,
682            request,
683            request_options,
684        )?;
685        // Stream setup is still request/response first. We only create the SSE
686        // stream after the HTTP layer has returned a response object.
687        self.execute_stream(request_factory).await
688    }
689
690    #[allow(unused)]
691    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
692        &self,
693        path: &str,
694        request: I,
695        request_options: &RequestOptions,
696        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
697            + crate::traits::MaybeSend
698            + 'static,
699    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
700    where
701        I: Serialize,
702        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
703    {
704        let request_factory = self.build_request_factory_with_json(
705            reqwest::Method::POST,
706            path,
707            request,
708            request_options,
709        )?;
710        self.execute_stream_mapped_raw_events(request_factory, event_mapper)
711            .await
712    }
713
714    /// Make HTTP GET request to receive SSE
715    #[allow(unused)]
716    pub(crate) async fn get_stream<O>(
717        &self,
718        path: &str,
719        request_options: &RequestOptions,
720    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
721    where
722        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
723    {
724        let request_factory =
725            self.build_request_factory(reqwest::Method::GET, path, request_options);
726        self.execute_stream(request_factory).await
727    }
728}
729
730async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
731    let headers = response.headers().clone();
732    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
733    Ok((bytes, headers))
734}
735
736async fn read_error_response(response: Response) -> OpenAIError {
737    let status = response.status();
738    let bytes = match response.bytes().await {
739        Ok(b) => b,
740        Err(e) => return OpenAIError::Reqwest(e),
741    };
742
743    if status.is_server_error() {
744        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
745        let message: String = String::from_utf8_lossy(&bytes).into_owned();
746        tracing::warn!("Server error: {status} - {message}");
747        return OpenAIError::ApiError(ApiErrorResponse {
748            status_code: status,
749            api_error: ApiError {
750                message,
751                r#type: None,
752                param: None,
753                code: None,
754            },
755        });
756    }
757
758    // Deserialize response body from the error object
759    match serde_json::from_slice::<WrappedError>(bytes.as_ref()) {
760        Ok(wrapped) => OpenAIError::ApiError(ApiErrorResponse {
761            status_code: status,
762            api_error: wrapped.error,
763        }),
764        Err(e) => map_deserialization_error(e, bytes.as_ref()),
765    }
766}
767
768/// Request which responds with SSE.
769/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
770pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
771where
772    O: DeserializeOwned + crate::traits::MaybeSend + 'static,
773{
774    stream_mapped_raw_events(response, |event| {
775        serde_json::from_str::<O>(&event.data)
776            .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
777    })
778    .await
779}
780
781#[cfg(target_family = "wasm")]
782pub(crate) async fn stream_mapped_raw_events<O>(
783    response: Response,
784    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
785) -> crate::types::stream::StreamResponse<O>
786where
787    O: DeserializeOwned + 'static,
788{
789    let byte_stream = response
790        .bytes_stream()
791        .map(|result| result.map_err(std::io::Error::other));
792    let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
793
794    Box::pin(futures::stream::unfold(
795        (event_stream, event_mapper),
796        |(mut event_stream, event_mapper)| async move {
797            loop {
798                let event = match event_stream.next().await {
799                    Some(Ok(event)) => event,
800                    Some(Err(error)) => {
801                        return Some((
802                            Err(OpenAIError::StreamError(Box::new(
803                                StreamError::EventStream(error.to_string()),
804                            ))),
805                            (event_stream, event_mapper),
806                        ));
807                    }
808                    None => return None,
809                };
810
811                if event.data == "[DONE]" {
812                    return None;
813                }
814
815                if event.event == "keepalive" {
816                    continue;
817                }
818
819                let response = event_mapper(event);
820                return Some((response, (event_stream, event_mapper)));
821            }
822        },
823    ))
824}
825
826#[cfg(not(target_family = "wasm"))]
827pub(crate) async fn stream_mapped_raw_events<O>(
828    response: Response,
829    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
830) -> crate::types::stream::StreamResponse<O>
831where
832    O: DeserializeOwned + std::marker::Send + 'static,
833{
834    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
835
836    tokio::spawn(async move {
837        let byte_stream = response
838            .bytes_stream()
839            .map(|r| r.map_err(std::io::Error::other));
840        let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
841
842        while let Some(ev) = event_stream.next().await {
843            let event = match ev {
844                Ok(e) => e,
845                Err(e) => {
846                    let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
847                        StreamError::EventStream(e.to_string()),
848                    ))));
849                    break;
850                }
851            };
852            if event.data == "[DONE]" {
853                break;
854            }
855
856            if event.event == "keepalive" {
857                continue;
858            }
859
860            let response = event_mapper(event);
861
862            if tx.send(response).is_err() {
863                break;
864            }
865        }
866    });
867
868    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
869}
870
871#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
872mod tests {
873    use std::sync::{
874        atomic::{AtomicUsize, Ordering},
875        Arc,
876    };
877
878    use futures::StreamExt;
879    use http::Response as HttpResponse;
880    use serde_json::json;
881    use tower::{service_fn, ServiceBuilder};
882
883    use super::Client;
884    use crate::{
885        config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
886        retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
887    };
888
889    #[tokio::test]
890    async fn unary_requests_dispatch_through_middleware_service() {
891        let request_count = Arc::new(AtomicUsize::new(0));
892        let service = {
893            let request_count = request_count.clone();
894            ServiceBuilder::new()
895                .concurrency_limit(1)
896                .service(service_fn(move |factory: HttpRequestFactory| {
897                    let request_count = request_count.clone();
898                    async move {
899                        let request = factory.build().await?;
900                        assert_eq!(request.url().path(), "/models");
901                        request_count.fetch_add(1, Ordering::SeqCst);
902                        Ok::<reqwest::Response, OpenAIError>(
903                            HttpResponse::builder()
904                                .status(200)
905                                .header("content-type", "application/json")
906                                .body(reqwest::Body::from(
907                                    "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
908                                ))
909                                .unwrap()
910                                .into(),
911                        )
912                    }
913                }))
914        };
915
916        let client = Client::with_config(
917            OpenAIConfig::new()
918                .with_api_base("http://example.test")
919                .with_api_key("test-key"),
920        )
921        .with_http_service(service);
922
923        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
924
925        assert_eq!(value["object"], "list");
926        assert_eq!(request_count.load(Ordering::SeqCst), 1);
927    }
928
929    #[tokio::test]
930    async fn stream_requests_open_through_middleware_service() {
931        let request_count = Arc::new(AtomicUsize::new(0));
932        let service = {
933            let request_count = request_count.clone();
934            ServiceBuilder::new()
935                .concurrency_limit(1)
936                .service(service_fn(move |factory: HttpRequestFactory| {
937                    let request_count = request_count.clone();
938                    async move {
939                        let request = factory.build().await?;
940                        assert_eq!(request.url().path(), "/responses");
941                        request_count.fetch_add(1, Ordering::SeqCst);
942                        Ok::<reqwest::Response, OpenAIError>(
943                            HttpResponse::builder()
944                                .status(200)
945                                .header("content-type", "text/event-stream")
946                                .body(reqwest::Body::from(
947                                    "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
948                                ))
949                                .unwrap()
950                                .into(),
951                        )
952                    }
953                }))
954        };
955
956        let client = Client::with_config(
957            OpenAIConfig::new()
958                .with_api_base("http://example.test")
959                .with_api_key("test-key"),
960        )
961        .with_http_service(service);
962
963        let mut stream = client
964            .post_stream::<_, serde_json::Value>(
965                "/responses",
966                json!({ "stream": true }),
967                &RequestOptions::new(),
968            )
969            .await
970            .unwrap();
971
972        let first = stream.next().await.unwrap().unwrap();
973
974        assert_eq!(first, json!({ "ok": true }));
975        assert_eq!(request_count.load(Ordering::SeqCst), 1);
976    }
977
978    #[tokio::test]
979    async fn middleware_retry_policy_retries_429_responses() {
980        let request_count = Arc::new(AtomicUsize::new(0));
981        let service = {
982            let request_count = request_count.clone();
983            ServiceBuilder::new()
984                .retry(SimpleRetryPolicy::default())
985                .service(service_fn(move |factory: HttpRequestFactory| {
986                    let request_count = request_count.clone();
987                    async move {
988                        let request = factory.build().await?;
989                        assert_eq!(request.url().path(), "/models");
990                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
991
992                        let response = if attempt == 0 {
993                            HttpResponse::builder()
994                                .status(429)
995                                .header("content-type", "application/json")
996                                .body(reqwest::Body::from(
997                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
998                                ))
999                                .unwrap()
1000                        } else {
1001                            HttpResponse::builder()
1002                                .status(200)
1003                                .header("content-type", "application/json")
1004                                .body(reqwest::Body::from(
1005                                    r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1006                                ))
1007                                .unwrap()
1008                        };
1009
1010                        Ok::<reqwest::Response, OpenAIError>(response.into())
1011                    }
1012                }))
1013        };
1014
1015        let client = Client::with_config(
1016            OpenAIConfig::new()
1017                .with_api_base("http://example.test")
1018                .with_api_key("test-key"),
1019        )
1020        .with_http_service(service);
1021
1022        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1023
1024        assert_eq!(value["data"][0]["id"], "retry-model");
1025        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1026    }
1027
1028    #[derive(Clone)]
1029    struct RetryableMultipartInput {
1030        conversions: Arc<AtomicUsize>,
1031    }
1032
1033    impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1034        type Error = OpenAIError;
1035
1036        async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1037            value.conversions.fetch_add(1, Ordering::SeqCst);
1038            Ok(reqwest::multipart::Form::new().text("field", "value"))
1039        }
1040    }
1041
1042    #[tokio::test]
1043    async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1044        let request_count = Arc::new(AtomicUsize::new(0));
1045        let conversion_count = Arc::new(AtomicUsize::new(0));
1046
1047        let service = {
1048            let request_count = request_count.clone();
1049            ServiceBuilder::new()
1050                .retry(SimpleRetryPolicy::default())
1051                .service(service_fn(move |factory: HttpRequestFactory| {
1052                    let request_count = request_count.clone();
1053                    async move {
1054                        let request = factory.build().await?;
1055                        assert_eq!(request.method(), reqwest::Method::POST);
1056                        assert_eq!(request.url().path(), "/files");
1057                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1058
1059                        let response = if attempt == 0 {
1060                            HttpResponse::builder()
1061                                .status(429)
1062                                .header("content-type", "application/json")
1063                                .body(reqwest::Body::from(
1064                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1065                                ))
1066                                .unwrap()
1067                        } else {
1068                            HttpResponse::builder()
1069                                .status(200)
1070                                .header("content-type", "application/json")
1071                                .body(reqwest::Body::from(r#"{"ok":true}"#))
1072                                .unwrap()
1073                        };
1074
1075                        Ok::<reqwest::Response, OpenAIError>(response.into())
1076                    }
1077                }))
1078        };
1079
1080        let client = Client::with_config(
1081            OpenAIConfig::new()
1082                .with_api_base("http://example.test")
1083                .with_api_key("test-key"),
1084        )
1085        .with_http_service(service);
1086
1087        let value: serde_json::Value = client
1088            .post_form(
1089                "/files",
1090                RetryableMultipartInput {
1091                    conversions: conversion_count.clone(),
1092                },
1093                &RequestOptions::new(),
1094            )
1095            .await
1096            .unwrap();
1097
1098        assert_eq!(value, json!({ "ok": true }));
1099        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1100        assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1101    }
1102}