Skip to main content

async_openai/
client.rs

1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14    config::{Config, OpenAIConfig},
15    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
16    executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17    traits::AsyncTryFrom,
18    RequestOptions,
19};
20
21struct RequestParts {
22    request_client: reqwest::Client,
23    method: reqwest::Method,
24    url: String,
25    headers: HeaderMap,
26    query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30    fn build_request_builder(&self) -> reqwest::RequestBuilder {
31        self.request_client
32            .request(self.method.clone(), self.url.clone())
33            .query(&self.query)
34            .headers(self.headers.clone())
35    }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88/// Client is a container for config and HTTP execution
89/// used to make API calls.
90pub struct Client<C: Config> {
91    request_client: reqwest::Client,
92    executor: SharedExecutor,
93    config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98    C: Config + std::fmt::Debug,
99{
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("Client")
102            .field("request_client", &self.request_client)
103            .field("config", &self.config)
104            .finish()
105    }
106}
107
108impl<C: Config> Default for Client<C>
109where
110    C: Default,
111{
112    fn default() -> Self {
113        let request_client = reqwest::Client::new();
114        Self {
115            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116            request_client,
117            config: C::default(),
118        }
119    }
120}
121
122impl Client<OpenAIConfig> {
123    /// Client with default [OpenAIConfig]
124    pub fn new() -> Self {
125        Self::default()
126    }
127}
128
129impl<C: Config> Client<C> {
130    /// Create client with a custom HTTP client and config.
131    pub fn build(http_client: reqwest::Client, config: C) -> Self {
132        Self {
133            executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134            request_client: http_client,
135            config,
136        }
137    }
138
139    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
140    pub fn with_config(config: C) -> Self {
141        let request_client = reqwest::Client::new();
142        Self {
143            executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144            request_client,
145            config,
146        }
147    }
148
149    /// Provide your own [client] to make HTTP requests with.
150    ///
151    /// [client]: reqwest::Client
152    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153        self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154        self.request_client = http_client;
155        self
156    }
157
158    /// Provide your own tower-compatible service to execute HTTP requests.
159    #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160    pub fn with_http_service<S>(mut self, service: S) -> Self
161    where
162        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163        S::Future: Send + 'static,
164        S::Error: Into<OpenAIError> + Send + Sync + 'static,
165    {
166        // This is the public middleware escape hatch. We erase the concrete
167        // tower stack here so the rest of the client does not become generic
168        // over the service type, which would otherwise leak through every API
169        // group and make the crate much harder to use.
170        self.executor = Arc::new(TowerExecutor::new(service));
171        self
172    }
173
174    /// Provide your own tower-compatible service to execute HTTP requests.
175    #[cfg(all(feature = "middleware", target_family = "wasm"))]
176    pub fn with_http_service<S>(mut self, service: S) -> Self
177    where
178        S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179        S::Future: 'static,
180        S::Error: Into<OpenAIError> + 'static,
181    {
182        // wasm futures produced by reqwest are not `Send`, so the wasm version
183        // intentionally avoids native thread-safety bounds. Users are still
184        // responsible for choosing tower layers that work in their wasm
185        // runtime.
186        self.executor = Arc::new(TowerExecutor::new(service));
187        self
188    }
189
190    // API groups
191
192    /// To call [Models] group related APIs using this client.
193    #[cfg(feature = "model")]
194    pub fn models(&self) -> Models<'_, C> {
195        Models::new(self)
196    }
197
198    /// To call [Completions] group related APIs using this client.
199    #[cfg(feature = "completions")]
200    pub fn completions(&self) -> Completions<'_, C> {
201        Completions::new(self)
202    }
203
204    /// To call [Chat] group related APIs using this client.
205    #[cfg(feature = "chat-completion")]
206    pub fn chat(&self) -> Chat<'_, C> {
207        Chat::new(self)
208    }
209
210    /// To call [Images] group related APIs using this client.
211    #[cfg(feature = "image")]
212    pub fn images(&self) -> Images<'_, C> {
213        Images::new(self)
214    }
215
216    /// To call [Moderations] group related APIs using this client.
217    #[cfg(feature = "moderation")]
218    pub fn moderations(&self) -> Moderations<'_, C> {
219        Moderations::new(self)
220    }
221
222    /// To call [Files] group related APIs using this client.
223    #[cfg(feature = "file")]
224    pub fn files(&self) -> Files<'_, C> {
225        Files::new(self)
226    }
227
228    /// To call [Uploads] group related APIs using this client.
229    #[cfg(feature = "upload")]
230    pub fn uploads(&self) -> Uploads<'_, C> {
231        Uploads::new(self)
232    }
233
234    /// To call [FineTuning] group related APIs using this client.
235    #[cfg(feature = "finetuning")]
236    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237        FineTuning::new(self)
238    }
239
240    /// To call [Embeddings] group related APIs using this client.
241    #[cfg(feature = "embedding")]
242    pub fn embeddings(&self) -> Embeddings<'_, C> {
243        Embeddings::new(self)
244    }
245
246    /// To call [Audio] group related APIs using this client.
247    #[cfg(feature = "audio")]
248    pub fn audio(&self) -> Audio<'_, C> {
249        Audio::new(self)
250    }
251
252    /// To call [Videos] group related APIs using this client.
253    #[cfg(feature = "video")]
254    pub fn videos(&self) -> Videos<'_, C> {
255        Videos::new(self)
256    }
257
258    /// To call [Assistants] group related APIs using this client.
259    #[cfg(feature = "assistant")]
260    #[deprecated(
261        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262    )]
263    #[allow(deprecated)]
264    pub fn assistants(&self) -> Assistants<'_, C> {
265        Assistants::new(self)
266    }
267
268    /// To call [Threads] group related APIs using this client.
269    #[cfg(feature = "assistant")]
270    #[deprecated(
271        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272    )]
273    #[allow(deprecated)]
274    pub fn threads(&self) -> Threads<'_, C> {
275        Threads::new(self)
276    }
277
278    /// To call [VectorStores] group related APIs using this client.
279    #[cfg(feature = "vectorstore")]
280    pub fn vector_stores(&self) -> VectorStores<'_, C> {
281        VectorStores::new(self)
282    }
283
284    /// To call [Batches] group related APIs using this client.
285    #[cfg(feature = "batch")]
286    pub fn batches(&self) -> Batches<'_, C> {
287        Batches::new(self)
288    }
289
290    /// To call [Admin] group related APIs using this client.
291    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
292    #[cfg(feature = "administration")]
293    pub fn admin(&self) -> Admin<'_, C> {
294        Admin::new(self)
295    }
296
297    /// To call [Responses] group related APIs using this client.
298    #[cfg(feature = "responses")]
299    pub fn responses(&self) -> Responses<'_, C> {
300        Responses::new(self)
301    }
302
303    /// To call [Conversations] group related APIs using this client.
304    #[cfg(feature = "responses")]
305    pub fn conversations(&self) -> Conversations<'_, C> {
306        Conversations::new(self)
307    }
308
309    /// To call [Containers] group related APIs using this client.
310    #[cfg(feature = "container")]
311    pub fn containers(&self) -> Containers<'_, C> {
312        Containers::new(self)
313    }
314
315    /// To call [Skills] group related APIs using this client.
316    #[cfg(feature = "skill")]
317    pub fn skills(&self) -> Skills<'_, C> {
318        Skills::new(self)
319    }
320
321    /// To call [Evals] group related APIs using this client.
322    #[cfg(feature = "evals")]
323    pub fn evals(&self) -> Evals<'_, C> {
324        Evals::new(self)
325    }
326
327    #[cfg(feature = "chatkit")]
328    pub fn chatkit(&self) -> Chatkit<'_, C> {
329        Chatkit::new(self)
330    }
331
332    /// To call [Realtime] group related APIs using this client.
333    #[cfg(feature = "realtime")]
334    pub fn realtime(&self) -> Realtime<'_, C> {
335        Realtime::new(self)
336    }
337
338    pub fn config(&self) -> &C {
339        &self.config
340    }
341
342    fn build_request_parts(
343        &self,
344        method: reqwest::Method,
345        path: &str,
346        request_options: &RequestOptions,
347    ) -> Arc<RequestParts> {
348        let url = if let Some(path) = request_options.path() {
349            self.config.url(path.as_str())
350        } else {
351            self.config.url(path)
352        };
353        let mut headers = self.config.headers();
354        if let Some(request_headers) = request_options.headers() {
355            headers.extend(request_headers.clone());
356        }
357
358        let mut query = self
359            .config
360            .query()
361            .into_iter()
362            .map(|(key, value)| (key.to_string(), value.to_string()))
363            .collect::<Vec<_>>();
364        query.extend_from_slice(request_options.query());
365
366        Arc::new(RequestParts {
367            request_client: self.request_client.clone(),
368            method,
369            url,
370            headers,
371            query,
372        })
373    }
374
375    fn build_request_factory(
376        &self,
377        method: reqwest::Method,
378        path: &str,
379        request_options: &RequestOptions,
380    ) -> HttpRequestFactory {
381        let request_parts = self.build_request_parts(method, path, request_options);
382
383        HttpRequestFactory::new(move || {
384            let request_parts = request_parts.clone();
385
386            async move {
387                let request = request_parts.build_request_builder().build()?;
388                Ok(request)
389            }
390        })
391    }
392
393    fn build_request_factory_with_json<I>(
394        &self,
395        method: reqwest::Method,
396        path: &str,
397        request: I,
398        request_options: &RequestOptions,
399    ) -> Result<HttpRequestFactory, OpenAIError>
400    where
401        I: Serialize,
402    {
403        // JSON bodies are materialized once so the base BYOT path can keep
404        // accepting borrowed inputs.
405        let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
406            OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
407        })?);
408        let request_parts = self.build_request_parts(method, path, request_options);
409
410        Ok(HttpRequestFactory::new(move || {
411            let request_parts = request_parts.clone();
412            let request = request.clone();
413
414            async move {
415                let request_builder = request_parts
416                    .build_request_builder()
417                    .header(reqwest::header::CONTENT_TYPE, "application/json")
418                    .body(request.clone());
419
420                Ok(request_builder.build()?)
421            }
422        }))
423    }
424
425    fn build_request_factory_with_form<F>(
426        &self,
427        method: reqwest::Method,
428        path: &str,
429        form: F,
430        request_options: &RequestOptions,
431    ) -> Result<HttpRequestFactory, OpenAIError>
432    where
433        F: Clone + crate::traits::MaybeSend + 'static,
434        Form: AsyncTryFrom<F, Error = OpenAIError>,
435    {
436        // Multipart is the reason the factory exists.
437        //
438        // `Mutex` is only here to make the captured state `Sync` on native targets.
439        #[cfg(not(target_family = "wasm"))]
440        let form = Arc::new(Mutex::new(form));
441        let request_parts = self.build_request_parts(method, path, request_options);
442
443        Ok(HttpRequestFactory::new(move || {
444            let request_parts = request_parts.clone();
445            let form = form.clone();
446
447            async move {
448                #[cfg(not(target_family = "wasm"))]
449                let form = form
450                    .lock()
451                    .expect("multipart request factory mutex poisoned")
452                    .clone();
453                #[cfg(target_family = "wasm")]
454                let form = form.clone();
455                let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
456                let request_builder = request_parts.build_request_builder().multipart(form);
457
458                Ok(request_builder.build()?)
459            }
460        }))
461    }
462
463    /// Make a GET request to {path} and deserialize the response body
464    #[allow(unused)]
465    pub(crate) async fn get<O>(
466        &self,
467        path: &str,
468        request_options: &RequestOptions,
469    ) -> Result<O, OpenAIError>
470    where
471        O: DeserializeOwned,
472    {
473        let request_factory =
474            self.build_request_factory(reqwest::Method::GET, path, request_options);
475        self.execute(request_factory).await
476    }
477
478    /// Make a DELETE request to {path} and deserialize the response body
479    #[allow(unused)]
480    pub(crate) async fn delete<O>(
481        &self,
482        path: &str,
483        request_options: &RequestOptions,
484    ) -> Result<O, OpenAIError>
485    where
486        O: DeserializeOwned,
487    {
488        let request_factory =
489            self.build_request_factory(reqwest::Method::DELETE, path, request_options);
490        self.execute(request_factory).await
491    }
492
493    /// Make a GET request to {path} and return the response body
494    #[allow(unused)]
495    pub(crate) async fn get_raw(
496        &self,
497        path: &str,
498        request_options: &RequestOptions,
499    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
500        let request_factory =
501            self.build_request_factory(reqwest::Method::GET, path, request_options);
502        self.execute_raw(request_factory).await
503    }
504
505    /// Make a POST request to {path} and return the response body
506    #[allow(unused)]
507    pub(crate) async fn post_raw<I>(
508        &self,
509        path: &str,
510        request: I,
511        request_options: &RequestOptions,
512    ) -> Result<(Bytes, HeaderMap), OpenAIError>
513    where
514        I: Serialize,
515    {
516        let request_factory = self.build_request_factory_with_json(
517            reqwest::Method::POST,
518            path,
519            request,
520            request_options,
521        )?;
522        self.execute_raw(request_factory).await
523    }
524
525    /// Make a POST request to {path} and deserialize the response body
526    #[allow(unused)]
527    pub(crate) async fn post<I, O>(
528        &self,
529        path: &str,
530        request: I,
531        request_options: &RequestOptions,
532    ) -> Result<O, OpenAIError>
533    where
534        I: Serialize,
535        O: DeserializeOwned,
536    {
537        let request_factory = self.build_request_factory_with_json(
538            reqwest::Method::POST,
539            path,
540            request,
541            request_options,
542        )?;
543        self.execute(request_factory).await
544    }
545
546    /// POST a form at {path} and return the response body
547    #[allow(unused)]
548    pub(crate) async fn post_form_raw<F>(
549        &self,
550        path: &str,
551        form: F,
552        request_options: &RequestOptions,
553    ) -> Result<(Bytes, HeaderMap), OpenAIError>
554    where
555        F: Clone + crate::traits::MaybeSend + 'static,
556        Form: AsyncTryFrom<F, Error = OpenAIError>,
557    {
558        let request_factory = self.build_request_factory_with_form(
559            reqwest::Method::POST,
560            path,
561            form,
562            request_options,
563        )?;
564        self.execute_raw(request_factory).await
565    }
566
567    /// POST a form at {path} and deserialize the response body
568    #[allow(unused)]
569    pub(crate) async fn post_form<O, F>(
570        &self,
571        path: &str,
572        form: F,
573        request_options: &RequestOptions,
574    ) -> Result<O, OpenAIError>
575    where
576        O: DeserializeOwned,
577        F: Clone + crate::traits::MaybeSend + 'static,
578        Form: AsyncTryFrom<F, Error = OpenAIError>,
579    {
580        let request_factory = self.build_request_factory_with_form(
581            reqwest::Method::POST,
582            path,
583            form,
584            request_options,
585        )?;
586        self.execute(request_factory).await
587    }
588
589    #[allow(unused)]
590    pub(crate) async fn post_form_stream<O, F>(
591        &self,
592        path: &str,
593        form: F,
594        request_options: &RequestOptions,
595    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
596    where
597        F: Clone + crate::traits::MaybeSend + 'static,
598        Form: AsyncTryFrom<F, Error = OpenAIError>,
599        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
600    {
601        let request_factory = self.build_request_factory_with_form(
602            reqwest::Method::POST,
603            path,
604            form,
605            request_options,
606        )?;
607
608        self.execute_stream(request_factory).await
609    }
610
611    async fn execute_raw(
612        &self,
613        request_factory: HttpRequestFactory,
614    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
615        let response = self.execute_response(request_factory).await?;
616        read_response(response).await
617    }
618
619    async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
620    where
621        O: DeserializeOwned,
622    {
623        let (bytes, _headers) = self.execute_raw(request_factory).await?;
624
625        let response: O = serde_json::from_slice(bytes.as_ref())
626            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
627
628        Ok(response)
629    }
630
631    async fn execute_response(
632        &self,
633        request_factory: HttpRequestFactory,
634    ) -> Result<Response, OpenAIError> {
635        self.executor.execute(request_factory).await
636    }
637
638    async fn execute_stream<O>(
639        &self,
640        request_factory: HttpRequestFactory,
641    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
642    where
643        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
644    {
645        let response = self.execute_response(request_factory).await?;
646        Ok(stream(response).await)
647    }
648
649    async fn execute_stream_mapped_raw_events<O>(
650        &self,
651        request_factory: HttpRequestFactory,
652        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
653            + crate::traits::MaybeSend
654            + 'static,
655    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
656    where
657        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
658    {
659        let response = self.execute_response(request_factory).await?;
660        Ok(stream_mapped_raw_events(response, event_mapper).await)
661    }
662
663    /// Make HTTP POST request to receive SSE
664    #[allow(unused)]
665    pub(crate) async fn post_stream<I, O>(
666        &self,
667        path: &str,
668        request: I,
669        request_options: &RequestOptions,
670    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
671    where
672        I: Serialize,
673        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
674    {
675        let request_factory = self.build_request_factory_with_json(
676            reqwest::Method::POST,
677            path,
678            request,
679            request_options,
680        )?;
681        // Stream setup is still request/response first. We only create the SSE
682        // stream after the HTTP layer has returned a response object.
683        self.execute_stream(request_factory).await
684    }
685
686    #[allow(unused)]
687    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
688        &self,
689        path: &str,
690        request: I,
691        request_options: &RequestOptions,
692        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
693            + crate::traits::MaybeSend
694            + 'static,
695    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
696    where
697        I: Serialize,
698        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
699    {
700        let request_factory = self.build_request_factory_with_json(
701            reqwest::Method::POST,
702            path,
703            request,
704            request_options,
705        )?;
706        self.execute_stream_mapped_raw_events(request_factory, event_mapper)
707            .await
708    }
709
710    /// Make HTTP GET request to receive SSE
711    #[allow(unused)]
712    pub(crate) async fn get_stream<O>(
713        &self,
714        path: &str,
715        request_options: &RequestOptions,
716    ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
717    where
718        O: DeserializeOwned + crate::traits::MaybeSend + 'static,
719    {
720        let request_factory =
721            self.build_request_factory(reqwest::Method::GET, path, request_options);
722        self.execute_stream(request_factory).await
723    }
724}
725
726async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
727    let status = response.status();
728    let headers = response.headers().clone();
729    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
730
731    if status.is_server_error() {
732        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
733        let message: String = String::from_utf8_lossy(&bytes).into_owned();
734        tracing::warn!("Server error: {status} - {message}");
735        return Err(OpenAIError::ApiError(ApiError {
736            message,
737            r#type: None,
738            param: None,
739            code: None,
740        }));
741    }
742
743    // Deserialize response body from either error object or actual response object
744    if !status.is_success() {
745        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
746            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
747
748        return Err(OpenAIError::ApiError(wrapped_error.error));
749    }
750
751    Ok((bytes, headers))
752}
753
754/// Request which responds with SSE.
755/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
756pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
757where
758    O: DeserializeOwned + crate::traits::MaybeSend + 'static,
759{
760    stream_mapped_raw_events(response, |event| {
761        serde_json::from_str::<O>(&event.data)
762            .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
763    })
764    .await
765}
766
767#[cfg(target_family = "wasm")]
768pub(crate) async fn stream_mapped_raw_events<O>(
769    response: Response,
770    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
771) -> crate::types::stream::StreamResponse<O>
772where
773    O: DeserializeOwned + 'static,
774{
775    if !response.status().is_success() {
776        return Box::pin(futures::stream::once(async move {
777            match read_response(response).await {
778                Ok(_) => Err(OpenAIError::InvalidArgument(
779                    "stream request failed without an error body".into(),
780                )),
781                Err(error) => Err(error),
782            }
783        }));
784    }
785
786    let byte_stream = response
787        .bytes_stream()
788        .map(|result| result.map_err(std::io::Error::other));
789    let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
790
791    Box::pin(futures::stream::unfold(
792        (event_stream, event_mapper),
793        |(mut event_stream, event_mapper)| async move {
794            loop {
795                let event = match event_stream.next().await {
796                    Some(Ok(event)) => event,
797                    Some(Err(error)) => {
798                        return Some((
799                            Err(OpenAIError::StreamError(Box::new(
800                                StreamError::EventStream(error.to_string()),
801                            ))),
802                            (event_stream, event_mapper),
803                        ));
804                    }
805                    None => return None,
806                };
807
808                if event.data == "[DONE]" {
809                    return None;
810                }
811
812                if event.event == "keepalive" {
813                    continue;
814                }
815
816                let response = event_mapper(event);
817                return Some((response, (event_stream, event_mapper)));
818            }
819        },
820    ))
821}
822
823#[cfg(not(target_family = "wasm"))]
824pub(crate) async fn stream_mapped_raw_events<O>(
825    response: Response,
826    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
827) -> crate::types::stream::StreamResponse<O>
828where
829    O: DeserializeOwned + std::marker::Send + 'static,
830{
831    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
832
833    tokio::spawn(async move {
834        if !response.status().is_success() {
835            if let Err(e) = read_response(response).await {
836                let _ = tx.send(Err(e));
837            }
838            return;
839        }
840        let byte_stream = response
841            .bytes_stream()
842            .map(|r| r.map_err(std::io::Error::other));
843        let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
844
845        while let Some(ev) = event_stream.next().await {
846            let event = match ev {
847                Ok(e) => e,
848                Err(e) => {
849                    let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
850                        StreamError::EventStream(e.to_string()),
851                    ))));
852                    break;
853                }
854            };
855            if event.data == "[DONE]" {
856                break;
857            }
858
859            if event.event == "keepalive" {
860                continue;
861            }
862
863            let response = event_mapper(event);
864
865            if tx.send(response).is_err() {
866                break;
867            }
868        }
869    });
870
871    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
872}
873
874#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
875mod tests {
876    use std::sync::{
877        atomic::{AtomicUsize, Ordering},
878        Arc,
879    };
880
881    use futures::StreamExt;
882    use http::Response as HttpResponse;
883    use serde_json::json;
884    use tower::{service_fn, ServiceBuilder};
885
886    use super::Client;
887    use crate::{
888        config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
889        retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
890    };
891
892    #[tokio::test]
893    async fn unary_requests_dispatch_through_middleware_service() {
894        let request_count = Arc::new(AtomicUsize::new(0));
895        let service = {
896            let request_count = request_count.clone();
897            ServiceBuilder::new()
898                .concurrency_limit(1)
899                .service(service_fn(move |factory: HttpRequestFactory| {
900                    let request_count = request_count.clone();
901                    async move {
902                        let request = factory.build().await?;
903                        assert_eq!(request.url().path(), "/models");
904                        request_count.fetch_add(1, Ordering::SeqCst);
905                        Ok::<reqwest::Response, OpenAIError>(
906                            HttpResponse::builder()
907                                .status(200)
908                                .header("content-type", "application/json")
909                                .body(reqwest::Body::from(
910                                    "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
911                                ))
912                                .unwrap()
913                                .into(),
914                        )
915                    }
916                }))
917        };
918
919        let client = Client::with_config(
920            OpenAIConfig::new()
921                .with_api_base("http://example.test")
922                .with_api_key("test-key"),
923        )
924        .with_http_service(service);
925
926        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
927
928        assert_eq!(value["object"], "list");
929        assert_eq!(request_count.load(Ordering::SeqCst), 1);
930    }
931
932    #[tokio::test]
933    async fn stream_requests_open_through_middleware_service() {
934        let request_count = Arc::new(AtomicUsize::new(0));
935        let service = {
936            let request_count = request_count.clone();
937            ServiceBuilder::new()
938                .concurrency_limit(1)
939                .service(service_fn(move |factory: HttpRequestFactory| {
940                    let request_count = request_count.clone();
941                    async move {
942                        let request = factory.build().await?;
943                        assert_eq!(request.url().path(), "/responses");
944                        request_count.fetch_add(1, Ordering::SeqCst);
945                        Ok::<reqwest::Response, OpenAIError>(
946                            HttpResponse::builder()
947                                .status(200)
948                                .header("content-type", "text/event-stream")
949                                .body(reqwest::Body::from(
950                                    "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
951                                ))
952                                .unwrap()
953                                .into(),
954                        )
955                    }
956                }))
957        };
958
959        let client = Client::with_config(
960            OpenAIConfig::new()
961                .with_api_base("http://example.test")
962                .with_api_key("test-key"),
963        )
964        .with_http_service(service);
965
966        let mut stream = client
967            .post_stream::<_, serde_json::Value>(
968                "/responses",
969                json!({ "stream": true }),
970                &RequestOptions::new(),
971            )
972            .await
973            .unwrap();
974
975        let first = stream.next().await.unwrap().unwrap();
976
977        assert_eq!(first, json!({ "ok": true }));
978        assert_eq!(request_count.load(Ordering::SeqCst), 1);
979    }
980
981    #[tokio::test]
982    async fn middleware_retry_policy_retries_429_responses() {
983        let request_count = Arc::new(AtomicUsize::new(0));
984        let service = {
985            let request_count = request_count.clone();
986            ServiceBuilder::new()
987                .retry(SimpleRetryPolicy::default())
988                .service(service_fn(move |factory: HttpRequestFactory| {
989                    let request_count = request_count.clone();
990                    async move {
991                        let request = factory.build().await?;
992                        assert_eq!(request.url().path(), "/models");
993                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
994
995                        let response = if attempt == 0 {
996                            HttpResponse::builder()
997                                .status(429)
998                                .header("content-type", "application/json")
999                                .body(reqwest::Body::from(
1000                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1001                                ))
1002                                .unwrap()
1003                        } else {
1004                            HttpResponse::builder()
1005                                .status(200)
1006                                .header("content-type", "application/json")
1007                                .body(reqwest::Body::from(
1008                                    r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1009                                ))
1010                                .unwrap()
1011                        };
1012
1013                        Ok::<reqwest::Response, OpenAIError>(response.into())
1014                    }
1015                }))
1016        };
1017
1018        let client = Client::with_config(
1019            OpenAIConfig::new()
1020                .with_api_base("http://example.test")
1021                .with_api_key("test-key"),
1022        )
1023        .with_http_service(service);
1024
1025        let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1026
1027        assert_eq!(value["data"][0]["id"], "retry-model");
1028        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1029    }
1030
1031    #[derive(Clone)]
1032    struct RetryableMultipartInput {
1033        conversions: Arc<AtomicUsize>,
1034    }
1035
1036    impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1037        type Error = OpenAIError;
1038
1039        async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1040            value.conversions.fetch_add(1, Ordering::SeqCst);
1041            Ok(reqwest::multipart::Form::new().text("field", "value"))
1042        }
1043    }
1044
1045    #[tokio::test]
1046    async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1047        let request_count = Arc::new(AtomicUsize::new(0));
1048        let conversion_count = Arc::new(AtomicUsize::new(0));
1049
1050        let service = {
1051            let request_count = request_count.clone();
1052            ServiceBuilder::new()
1053                .retry(SimpleRetryPolicy::default())
1054                .service(service_fn(move |factory: HttpRequestFactory| {
1055                    let request_count = request_count.clone();
1056                    async move {
1057                        let request = factory.build().await?;
1058                        assert_eq!(request.method(), reqwest::Method::POST);
1059                        assert_eq!(request.url().path(), "/files");
1060                        let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1061
1062                        let response = if attempt == 0 {
1063                            HttpResponse::builder()
1064                                .status(429)
1065                                .header("content-type", "application/json")
1066                                .body(reqwest::Body::from(
1067                                    r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1068                                ))
1069                                .unwrap()
1070                        } else {
1071                            HttpResponse::builder()
1072                                .status(200)
1073                                .header("content-type", "application/json")
1074                                .body(reqwest::Body::from(r#"{"ok":true}"#))
1075                                .unwrap()
1076                        };
1077
1078                        Ok::<reqwest::Response, OpenAIError>(response.into())
1079                    }
1080                }))
1081        };
1082
1083        let client = Client::with_config(
1084            OpenAIConfig::new()
1085                .with_api_base("http://example.test")
1086                .with_api_key("test-key"),
1087        )
1088        .with_http_service(service);
1089
1090        let value: serde_json::Value = client
1091            .post_form(
1092                "/files",
1093                RetryableMultipartInput {
1094                    conversions: conversion_count.clone(),
1095                },
1096                &RequestOptions::new(),
1097            )
1098            .await
1099            .unwrap();
1100
1101        assert_eq!(value, json!({ "ok": true }));
1102        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1103        assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1104    }
1105}