Skip to main content

async_openai/
client.rs

1#[cfg(not(target_family = "wasm"))]
2use std::pin::Pin;
3
4use bytes::Bytes;
5#[cfg(not(target_family = "wasm"))]
6use futures::{stream::StreamExt, Stream};
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8#[cfg(not(target_family = "wasm"))]
9use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
10use serde::{de::DeserializeOwned, Serialize};
11
12#[cfg(not(target_family = "wasm"))]
13use crate::error::StreamError;
14use crate::{
15    config::{Config, OpenAIConfig},
16    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
17    traits::AsyncTryFrom,
18    RequestOptions,
19};
20
21#[cfg(feature = "administration")]
22use crate::admin::Admin;
23#[cfg(feature = "chatkit")]
24use crate::chatkit::Chatkit;
25#[cfg(feature = "file")]
26use crate::file::Files;
27#[cfg(feature = "image")]
28use crate::image::Images;
29#[cfg(feature = "moderation")]
30use crate::moderation::Moderations;
31#[cfg(feature = "assistant")]
32#[allow(deprecated)]
33use crate::Assistants;
34#[cfg(feature = "audio")]
35use crate::Audio;
36#[cfg(feature = "batch")]
37use crate::Batches;
38#[cfg(feature = "chat-completion")]
39use crate::Chat;
40#[cfg(feature = "completions")]
41use crate::Completions;
42#[cfg(feature = "container")]
43use crate::Containers;
44#[cfg(feature = "responses")]
45use crate::Conversations;
46#[cfg(feature = "embedding")]
47use crate::Embeddings;
48#[cfg(feature = "evals")]
49use crate::Evals;
50#[cfg(feature = "finetuning")]
51use crate::FineTuning;
52#[cfg(feature = "model")]
53use crate::Models;
54#[cfg(feature = "realtime")]
55use crate::Realtime;
56#[cfg(feature = "responses")]
57use crate::Responses;
58#[cfg(feature = "skill")]
59use crate::Skills;
60#[cfg(feature = "assistant")]
61#[allow(deprecated)]
62use crate::Threads;
63#[cfg(feature = "upload")]
64use crate::Uploads;
65#[cfg(feature = "vectorstore")]
66use crate::VectorStores;
67#[cfg(feature = "video")]
68use crate::Videos;
69
70#[derive(Debug, Clone)]
71/// Client is a container for config, backoff and http_client
72/// used to make API calls.
73pub struct Client<C: Config> {
74    http_client: reqwest::Client,
75    config: C,
76    #[cfg(not(target_family = "wasm"))]
77    backoff: backoff::ExponentialBackoff,
78}
79
80impl<C: Config> Default for Client<C>
81where
82    C: Default,
83{
84    fn default() -> Self {
85        Self {
86            http_client: reqwest::Client::new(),
87            config: C::default(),
88            #[cfg(not(target_family = "wasm"))]
89            backoff: Default::default(),
90        }
91    }
92}
93
94impl Client<OpenAIConfig> {
95    /// Client with default [OpenAIConfig]
96    pub fn new() -> Self {
97        Self::default()
98    }
99}
100
101impl<C: Config> Client<C> {
102    /// Create client with a custom HTTP client, OpenAI config, and backoff.
103    #[cfg(not(target_family = "wasm"))]
104    pub fn build(
105        http_client: reqwest::Client,
106        config: C,
107        backoff: backoff::ExponentialBackoff,
108    ) -> Self {
109        Self {
110            http_client,
111            config,
112            backoff,
113        }
114    }
115
116    /// Create client with a custom HTTP client and config (WASM version without backoff).
117    #[cfg(target_family = "wasm")]
118    pub fn build(http_client: reqwest::Client, config: C) -> Self {
119        Self {
120            http_client,
121            config,
122        }
123    }
124
125    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
126    pub fn with_config(config: C) -> Self {
127        Self {
128            http_client: reqwest::Client::new(),
129            config,
130            #[cfg(not(target_family = "wasm"))]
131            backoff: Default::default(),
132        }
133    }
134
135    /// Provide your own [client] to make HTTP requests with.
136    ///
137    /// [client]: reqwest::Client
138    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
139        self.http_client = http_client;
140        self
141    }
142
143    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
144    #[cfg(not(target_family = "wasm"))]
145    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
146        self.backoff = backoff;
147        self
148    }
149
150    // API groups
151
152    /// To call [Models] group related APIs using this client.
153    #[cfg(feature = "model")]
154    pub fn models(&self) -> Models<'_, C> {
155        Models::new(self)
156    }
157
158    /// To call [Completions] group related APIs using this client.
159    #[cfg(feature = "completions")]
160    pub fn completions(&self) -> Completions<'_, C> {
161        Completions::new(self)
162    }
163
164    /// To call [Chat] group related APIs using this client.
165    #[cfg(feature = "chat-completion")]
166    pub fn chat(&self) -> Chat<'_, C> {
167        Chat::new(self)
168    }
169
170    /// To call [Images] group related APIs using this client.
171    #[cfg(feature = "image")]
172    pub fn images(&self) -> Images<'_, C> {
173        Images::new(self)
174    }
175
176    /// To call [Moderations] group related APIs using this client.
177    #[cfg(feature = "moderation")]
178    pub fn moderations(&self) -> Moderations<'_, C> {
179        Moderations::new(self)
180    }
181
182    /// To call [Files] group related APIs using this client.
183    #[cfg(feature = "file")]
184    pub fn files(&self) -> Files<'_, C> {
185        Files::new(self)
186    }
187
188    /// To call [Uploads] group related APIs using this client.
189    #[cfg(feature = "upload")]
190    pub fn uploads(&self) -> Uploads<'_, C> {
191        Uploads::new(self)
192    }
193
194    /// To call [FineTuning] group related APIs using this client.
195    #[cfg(feature = "finetuning")]
196    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
197        FineTuning::new(self)
198    }
199
200    /// To call [Embeddings] group related APIs using this client.
201    #[cfg(feature = "embedding")]
202    pub fn embeddings(&self) -> Embeddings<'_, C> {
203        Embeddings::new(self)
204    }
205
206    /// To call [Audio] group related APIs using this client.
207    #[cfg(feature = "audio")]
208    pub fn audio(&self) -> Audio<'_, C> {
209        Audio::new(self)
210    }
211
212    /// To call [Videos] group related APIs using this client.
213    #[cfg(feature = "video")]
214    pub fn videos(&self) -> Videos<'_, C> {
215        Videos::new(self)
216    }
217
218    /// To call [Assistants] group related APIs using this client.
219    #[cfg(feature = "assistant")]
220    #[deprecated(
221        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
222    )]
223    #[allow(deprecated)]
224    pub fn assistants(&self) -> Assistants<'_, C> {
225        Assistants::new(self)
226    }
227
228    /// To call [Threads] group related APIs using this client.
229    #[cfg(feature = "assistant")]
230    #[deprecated(
231        note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
232    )]
233    #[allow(deprecated)]
234    pub fn threads(&self) -> Threads<'_, C> {
235        Threads::new(self)
236    }
237
238    /// To call [VectorStores] group related APIs using this client.
239    #[cfg(feature = "vectorstore")]
240    pub fn vector_stores(&self) -> VectorStores<'_, C> {
241        VectorStores::new(self)
242    }
243
244    /// To call [Batches] group related APIs using this client.
245    #[cfg(feature = "batch")]
246    pub fn batches(&self) -> Batches<'_, C> {
247        Batches::new(self)
248    }
249
250    /// To call [Admin] group related APIs using this client.
251    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
252    #[cfg(feature = "administration")]
253    pub fn admin(&self) -> Admin<'_, C> {
254        Admin::new(self)
255    }
256
257    /// To call [Responses] group related APIs using this client.
258    #[cfg(feature = "responses")]
259    pub fn responses(&self) -> Responses<'_, C> {
260        Responses::new(self)
261    }
262
263    /// To call [Conversations] group related APIs using this client.
264    #[cfg(feature = "responses")]
265    pub fn conversations(&self) -> Conversations<'_, C> {
266        Conversations::new(self)
267    }
268
269    /// To call [Containers] group related APIs using this client.
270    #[cfg(feature = "container")]
271    pub fn containers(&self) -> Containers<'_, C> {
272        Containers::new(self)
273    }
274
275    /// To call [Skills] group related APIs using this client.
276    #[cfg(feature = "skill")]
277    pub fn skills(&self) -> Skills<'_, C> {
278        Skills::new(self)
279    }
280
281    /// To call [Evals] group related APIs using this client.
282    #[cfg(feature = "evals")]
283    pub fn evals(&self) -> Evals<'_, C> {
284        Evals::new(self)
285    }
286
287    #[cfg(feature = "chatkit")]
288    pub fn chatkit(&self) -> Chatkit<'_, C> {
289        Chatkit::new(self)
290    }
291
292    /// To call [Realtime] group related APIs using this client.
293    #[cfg(feature = "realtime")]
294    pub fn realtime(&self) -> Realtime<'_, C> {
295        Realtime::new(self)
296    }
297
298    pub fn config(&self) -> &C {
299        &self.config
300    }
301
302    /// Helper function to build a request builder with common configuration
303    fn build_request_builder(
304        &self,
305        method: reqwest::Method,
306        path: &str,
307        request_options: &RequestOptions,
308    ) -> reqwest::RequestBuilder {
309        let mut request_builder = if let Some(path) = request_options.path() {
310            self.http_client
311                .request(method, self.config.url(path.as_str()))
312        } else {
313            self.http_client.request(method, self.config.url(path))
314        };
315
316        request_builder = request_builder
317            .query(&self.config.query())
318            .headers(self.config.headers());
319
320        if let Some(headers) = request_options.headers() {
321            request_builder = request_builder.headers(headers.clone());
322        }
323
324        if !request_options.query().is_empty() {
325            request_builder = request_builder.query(request_options.query());
326        }
327
328        request_builder
329    }
330
331    /// Make a GET request to {path} and deserialize the response body
332    #[allow(unused)]
333    pub(crate) async fn get<O>(
334        &self,
335        path: &str,
336        request_options: &RequestOptions,
337    ) -> Result<O, OpenAIError>
338    where
339        O: DeserializeOwned,
340    {
341        let request_maker = || async {
342            Ok(self
343                .build_request_builder(reqwest::Method::GET, path, request_options)
344                .build()?)
345        };
346
347        self.execute(request_maker).await
348    }
349
350    /// Make a DELETE request to {path} and deserialize the response body
351    #[allow(unused)]
352    pub(crate) async fn delete<O>(
353        &self,
354        path: &str,
355        request_options: &RequestOptions,
356    ) -> Result<O, OpenAIError>
357    where
358        O: DeserializeOwned,
359    {
360        let request_maker = || async {
361            Ok(self
362                .build_request_builder(reqwest::Method::DELETE, path, request_options)
363                .build()?)
364        };
365
366        self.execute(request_maker).await
367    }
368
369    /// Make a GET request to {path} and return the response body
370    #[allow(unused)]
371    pub(crate) async fn get_raw(
372        &self,
373        path: &str,
374        request_options: &RequestOptions,
375    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
376        let request_maker = || async {
377            Ok(self
378                .build_request_builder(reqwest::Method::GET, path, request_options)
379                .build()?)
380        };
381
382        self.execute_raw(request_maker).await
383    }
384
385    /// Make a POST request to {path} and return the response body
386    #[allow(unused)]
387    pub(crate) async fn post_raw<I>(
388        &self,
389        path: &str,
390        request: I,
391        request_options: &RequestOptions,
392    ) -> Result<(Bytes, HeaderMap), OpenAIError>
393    where
394        I: Serialize,
395    {
396        let request_maker = || async {
397            Ok(self
398                .build_request_builder(reqwest::Method::POST, path, request_options)
399                .json(&request)
400                .build()?)
401        };
402
403        self.execute_raw(request_maker).await
404    }
405
406    /// Make a POST request to {path} and deserialize the response body
407    #[allow(unused)]
408    pub(crate) async fn post<I, O>(
409        &self,
410        path: &str,
411        request: I,
412        request_options: &RequestOptions,
413    ) -> Result<O, OpenAIError>
414    where
415        I: Serialize,
416        O: DeserializeOwned,
417    {
418        let request_maker = || async {
419            Ok(self
420                .build_request_builder(reqwest::Method::POST, path, request_options)
421                .json(&request)
422                .build()?)
423        };
424
425        self.execute(request_maker).await
426    }
427
428    /// POST a form at {path} and return the response body
429    #[allow(unused)]
430    pub(crate) async fn post_form_raw<F>(
431        &self,
432        path: &str,
433        form: F,
434        request_options: &RequestOptions,
435    ) -> Result<(Bytes, HeaderMap), OpenAIError>
436    where
437        Form: AsyncTryFrom<F, Error = OpenAIError>,
438        F: Clone,
439    {
440        let request_maker = || async {
441            Ok(self
442                .build_request_builder(reqwest::Method::POST, path, request_options)
443                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
444                .build()?)
445        };
446
447        self.execute_raw(request_maker).await
448    }
449
450    /// POST a form at {path} and deserialize the response body
451    #[allow(unused)]
452    pub(crate) async fn post_form<O, F>(
453        &self,
454        path: &str,
455        form: F,
456        request_options: &RequestOptions,
457    ) -> Result<O, OpenAIError>
458    where
459        O: DeserializeOwned,
460        Form: AsyncTryFrom<F, Error = OpenAIError>,
461        F: Clone,
462    {
463        let request_maker = || async {
464            Ok(self
465                .build_request_builder(reqwest::Method::POST, path, request_options)
466                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
467                .build()?)
468        };
469
470        self.execute(request_maker).await
471    }
472
473    #[allow(unused)]
474    #[cfg(not(target_family = "wasm"))]
475    pub(crate) async fn post_form_stream<O, F>(
476        &self,
477        path: &str,
478        form: F,
479        request_options: &RequestOptions,
480    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
481    where
482        F: Clone,
483        Form: AsyncTryFrom<F, Error = OpenAIError>,
484        O: DeserializeOwned + std::marker::Send + 'static,
485    {
486        // Build and execute request manually since multipart::Form is not Clone
487        // and .eventsource() requires cloneability
488        let request_builder = self
489            .build_request_builder(reqwest::Method::POST, path, request_options)
490            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
491
492        let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
493
494        // Check for error status
495        if !response.status().is_success() {
496            return Err(read_response(response).await.unwrap_err());
497        }
498
499        // Convert response body to EventSource stream
500        let stream = response
501            .bytes_stream()
502            .map(|result| result.map_err(std::io::Error::other));
503        let event_stream = eventsource_stream::EventStream::new(stream);
504
505        // Convert EventSource stream to our expected format
506        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
507
508        tokio::spawn(async move {
509            use futures::StreamExt;
510            let mut event_stream = std::pin::pin!(event_stream);
511
512            while let Some(event_result) = event_stream.next().await {
513                match event_result {
514                    Err(e) => {
515                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
516                            StreamError::EventStream(e.to_string()),
517                        )))) {
518                            break;
519                        }
520                    }
521                    Ok(event) => {
522                        // eventsource_stream::Event is a struct with data field
523                        if event.data == "[DONE]" {
524                            break;
525                        }
526
527                        let response = match serde_json::from_str::<O>(&event.data) {
528                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
529                            Ok(output) => Ok(output),
530                        };
531
532                        if let Err(_e) = tx.send(response) {
533                            break;
534                        }
535                    }
536                }
537            }
538        });
539
540        Ok(Box::pin(
541            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
542        ))
543    }
544
545    /// Execute a HTTP request and retry on rate limit (non-WASM version with backoff)
546    ///
547    /// request_maker serves one purpose: to be able to create request again
548    /// to retry API call after getting rate limited. request_maker is async because
549    /// reqwest::multipart::Form is created by async calls to read files for uploads.
550    #[cfg(not(target_family = "wasm"))]
551    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
552    where
553        M: Fn() -> Fut,
554        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
555    {
556        let client = self.http_client.clone();
557
558        backoff::future::retry(self.backoff.clone(), || async {
559            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
560            let response = client
561                .execute(request)
562                .await
563                .map_err(OpenAIError::Reqwest)
564                .map_err(backoff::Error::Permanent)?;
565
566            let status = response.status();
567
568            match read_response(response).await {
569                Ok((bytes, headers)) => Ok((bytes, headers)),
570                Err(e) => {
571                    match e {
572                        OpenAIError::ApiError(api_error) => {
573                            if status.is_server_error() {
574                                Err(backoff::Error::Transient {
575                                    err: OpenAIError::ApiError(api_error),
576                                    retry_after: None,
577                                })
578                            } else if status.as_u16() == 429
579                                && api_error.r#type != Some("insufficient_quota".to_string())
580                            {
581                                // Rate limited retry...
582                                tracing::warn!("Rate limited: {}", api_error.message);
583                                Err(backoff::Error::Transient {
584                                    err: OpenAIError::ApiError(api_error),
585                                    retry_after: None,
586                                })
587                            } else {
588                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
589                            }
590                        }
591                        _ => Err(backoff::Error::Permanent(e)),
592                    }
593                }
594            }
595        })
596        .await
597    }
598
599    /// Execute a HTTP request (WASM version - single attempt, no retry)
600    #[cfg(target_family = "wasm")]
601    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
602    where
603        M: Fn() -> Fut,
604        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
605    {
606        let request = request_maker().await?;
607        let response = self
608            .http_client
609            .execute(request)
610            .await
611            .map_err(OpenAIError::Reqwest)?;
612
613        read_response(response).await
614    }
615
616    /// Execute a HTTP request and retry on rate limit
617    ///
618    /// request_maker serves one purpose: to be able to create request again
619    /// to retry API call after getting rate limited. request_maker is async because
620    /// reqwest::multipart::Form is created by async calls to read files for uploads.
621    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
622    where
623        O: DeserializeOwned,
624        M: Fn() -> Fut,
625        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
626    {
627        let (bytes, _headers) = self.execute_raw(request_maker).await?;
628
629        let response: O = serde_json::from_slice(bytes.as_ref())
630            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
631
632        Ok(response)
633    }
634
635    /// Make HTTP POST request to receive SSE
636    #[allow(unused)]
637    #[cfg(not(target_family = "wasm"))]
638    pub(crate) async fn post_stream<I, O>(
639        &self,
640        path: &str,
641        request: I,
642        request_options: &RequestOptions,
643    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
644    where
645        I: Serialize,
646        O: DeserializeOwned + std::marker::Send + 'static,
647    {
648        let request_builder = self
649            .build_request_builder(reqwest::Method::POST, path, request_options)
650            .json(&request);
651
652        let event_source = request_builder.eventsource().unwrap();
653
654        stream(event_source).await
655    }
656
657    #[allow(unused)]
658    #[cfg(not(target_family = "wasm"))]
659    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
660        &self,
661        path: &str,
662        request: I,
663        request_options: &RequestOptions,
664        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
665    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
666    where
667        I: Serialize,
668        O: DeserializeOwned + std::marker::Send + 'static,
669    {
670        let request_builder = self
671            .build_request_builder(reqwest::Method::POST, path, request_options)
672            .json(&request);
673
674        let event_source = request_builder.eventsource().unwrap();
675
676        stream_mapped_raw_events(event_source, event_mapper).await
677    }
678
679    /// Make HTTP GET request to receive SSE
680    #[allow(unused)]
681    #[cfg(not(target_family = "wasm"))]
682    pub(crate) async fn get_stream<O>(
683        &self,
684        path: &str,
685        request_options: &RequestOptions,
686    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
687    where
688        O: DeserializeOwned + std::marker::Send + 'static,
689    {
690        let request_builder =
691            self.build_request_builder(reqwest::Method::GET, path, request_options);
692
693        let event_source = request_builder.eventsource().unwrap();
694
695        stream(event_source).await
696    }
697}
698
699async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
700    let status = response.status();
701    let headers = response.headers().clone();
702    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
703
704    if status.is_server_error() {
705        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
706        let message: String = String::from_utf8_lossy(&bytes).into_owned();
707        tracing::warn!("Server error: {status} - {message}");
708        return Err(OpenAIError::ApiError(ApiError {
709            message,
710            r#type: None,
711            param: None,
712            code: None,
713        }));
714    }
715
716    // Deserialize response body from either error object or actual response object
717    if !status.is_success() {
718        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
719            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
720
721        return Err(OpenAIError::ApiError(wrapped_error.error));
722    }
723
724    Ok((bytes, headers))
725}
726
727#[cfg(not(target_family = "wasm"))]
728async fn map_stream_error(value: EventSourceError) -> OpenAIError {
729    match value {
730        EventSourceError::InvalidStatusCode(status_code, response) => {
731            read_response(response).await.expect_err(&format!(
732                "Unreachable because read_response returns err when status_code {status_code} is invalid"
733            ))
734        }
735        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
736    }
737}
738
739/// Request which responds with SSE.
740/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
741#[cfg(not(target_family = "wasm"))]
742pub(crate) async fn stream<O>(
743    mut event_source: EventSource,
744) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
745where
746    O: DeserializeOwned + std::marker::Send + 'static,
747{
748    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
749
750    tokio::spawn(async move {
751        while let Some(ev) = event_source.next().await {
752            match ev {
753                Err(e) => {
754                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
755                    // https://github.com/64bit/async-openai/issues/456
756                    match &e {
757                        EventSourceError::StreamEnded => {
758                            break;
759                        }
760                        _ => {
761                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
762                                // rx dropped
763                                break;
764                            }
765                        }
766                    }
767                }
768                Ok(event) => match event {
769                    Event::Message(message) => {
770                        if message.data == "[DONE]" {
771                            break;
772                        }
773
774                        if message.event == "keepalive" {
775                            continue;
776                        }
777
778                        let response = match serde_json::from_str::<O>(&message.data) {
779                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
780                            Ok(output) => Ok(output),
781                        };
782
783                        if let Err(_e) = tx.send(response) {
784                            // rx dropped
785                            break;
786                        }
787                    }
788                    Event::Open => continue,
789                },
790            }
791        }
792
793        event_source.close();
794    });
795
796    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
797}
798
799#[cfg(not(target_family = "wasm"))]
800pub(crate) async fn stream_mapped_raw_events<O>(
801    mut event_source: EventSource,
802    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
803) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
804where
805    O: DeserializeOwned + std::marker::Send + 'static,
806{
807    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
808
809    tokio::spawn(async move {
810        while let Some(ev) = event_source.next().await {
811            match ev {
812                Err(e) => {
813                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
814                    // https://github.com/64bit/async-openai/issues/456
815                    match &e {
816                        EventSourceError::StreamEnded => {
817                            break;
818                        }
819                        _ => {
820                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
821                                // rx dropped
822                                break;
823                            }
824                        }
825                    }
826                }
827                Ok(event) => match event {
828                    Event::Message(message) => {
829                        let mut done = false;
830
831                        if message.data == "[DONE]" {
832                            done = true;
833                        }
834
835                        if message.event == "keepalive" {
836                            continue;
837                        }
838
839                        let response = event_mapper(message);
840
841                        if let Err(_e) = tx.send(response) {
842                            // rx dropped
843                            break;
844                        }
845
846                        if done {
847                            break;
848                        }
849                    }
850                    Event::Open => continue,
851                },
852            }
853        }
854
855        event_source.close();
856    });
857
858    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
859}