Skip to main content

rig/http_client/
mod.rs

1use crate::http_client::sse::BoxedStream;
2use bytes::Bytes;
3pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder};
4use http::{HeaderName, StatusCode};
5use reqwest::Body;
6pub mod multipart;
7pub mod retry;
8pub mod sse;
9use crate::wasm_compat::*;
10pub use multipart::MultipartForm;
11pub use reqwest::Client as ReqwestClient;
12use std::pin::Pin;
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error {
16    #[error("Http error: {0}")]
17    Protocol(#[from] http::Error),
18    #[error("Invalid status code: {0}")]
19    InvalidStatusCode(StatusCode),
20    #[error("Invalid status code {0} with message: {1}")]
21    InvalidStatusCodeWithMessage(StatusCode, String),
22    #[error("Header value outside of legal range: {0}")]
23    InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
24    #[error("Request in error state, cannot access headers")]
25    NoHeaders,
26    #[error("Stream ended")]
27    StreamEnded,
28    #[error("Invalid content type was returned: {0:?}")]
29    InvalidContentType(HeaderValue),
30    #[cfg(not(target_family = "wasm"))]
31    #[error("Http client error: {0}")]
32    Instance(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
33
34    #[cfg(target_family = "wasm")]
35    #[error("Http client error: {0}")]
36    Instance(#[from] Box<dyn std::error::Error + 'static>),
37}
38
39pub type Result<T> = std::result::Result<T, Error>;
40
41#[cfg(not(target_family = "wasm"))]
42pub(crate) fn instance_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
43    Error::Instance(error.into())
44}
45
46#[cfg(target_family = "wasm")]
47fn instance_error<E: std::error::Error + 'static>(error: E) -> Error {
48    Error::Instance(error.into())
49}
50
51async fn non_success_status_error(response: reqwest::Response) -> Error {
52    let status = response.status();
53    let message = response
54        .text()
55        .await
56        .unwrap_or_else(|error| format!("failed to read error response body: {error}"));
57    Error::InvalidStatusCodeWithMessage(status, message)
58}
59
60pub type LazyBytes = WasmBoxedFuture<'static, Result<Bytes>>;
61pub type LazyBody<T> = WasmBoxedFuture<'static, Result<T>>;
62
63pub type StreamingResponse = Response<BoxedStream>;
64
65#[derive(Debug, Clone, Copy)]
66pub struct NoBody;
67
68impl From<NoBody> for Bytes {
69    fn from(_: NoBody) -> Self {
70        Bytes::new()
71    }
72}
73
74impl From<NoBody> for Body {
75    fn from(_: NoBody) -> Self {
76        reqwest::Body::default()
77    }
78}
79
80pub async fn text(response: Response<LazyBody<Vec<u8>>>) -> Result<String> {
81    let text = response.into_body().await?;
82    Ok(String::from(String::from_utf8_lossy(&text)))
83}
84
85pub fn make_auth_header(key: impl AsRef<str>) -> Result<(HeaderName, HeaderValue)> {
86    Ok((
87        http::header::AUTHORIZATION,
88        HeaderValue::from_str(&format!("Bearer {}", key.as_ref()))?,
89    ))
90}
91
92pub fn bearer_auth_header(headers: &mut HeaderMap, key: impl AsRef<str>) -> Result<()> {
93    let (k, v) = make_auth_header(key)?;
94
95    headers.insert(k, v);
96
97    Ok(())
98}
99
100pub fn with_bearer_auth(mut req: Builder, auth: &str) -> Result<Builder> {
101    bearer_auth_header(req.headers_mut().ok_or(Error::NoHeaders)?, auth)?;
102
103    Ok(req)
104}
105
106/// A helper trait to make generic requests (both regular and SSE) possible.
107pub trait HttpClientExt: WasmCompatSend + WasmCompatSync {
108    /// Send a HTTP request, get a response back (as bytes). Response must be able to be turned back into Bytes.
109    fn send<T, U>(
110        &self,
111        req: Request<T>,
112    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
113    where
114        T: Into<Bytes>,
115        T: WasmCompatSend,
116        U: From<Bytes>,
117        U: WasmCompatSend + 'static;
118
119    /// Send a HTTP request with a multipart body, get a response back (as bytes). Response must be able to be turned back into Bytes (although usually for the response, you will probably want to specify Bytes anyway).
120    fn send_multipart<U>(
121        &self,
122        req: Request<MultipartForm>,
123    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
124    where
125        U: From<Bytes>,
126        U: WasmCompatSend + 'static;
127
128    /// Send a HTTP request, get a streamed response back (as a stream of [`bytes::Bytes`].)
129    fn send_streaming<T>(
130        &self,
131        req: Request<T>,
132    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
133    where
134        T: Into<Bytes> + WasmCompatSend;
135}
136
137impl HttpClientExt for reqwest::Client {
138    fn send<T, U>(
139        &self,
140        req: Request<T>,
141    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
142    where
143        T: Into<Bytes>,
144        U: From<Bytes> + WasmCompatSend,
145    {
146        let (parts, body) = req.into_parts();
147        let req = self
148            .request(parts.method, parts.uri.to_string())
149            .headers(parts.headers)
150            .body(body.into());
151
152        async move {
153            let response = req.send().await.map_err(instance_error)?;
154            if !response.status().is_success() {
155                return Err(non_success_status_error(response).await);
156            }
157
158            let mut res = Response::builder().status(response.status());
159
160            if let Some(hs) = res.headers_mut() {
161                *hs = response.headers().clone();
162            }
163
164            let body: LazyBody<U> = Box::pin(async {
165                let bytes = response
166                    .bytes()
167                    .await
168                    .map_err(|e| Error::Instance(e.into()))?;
169
170                let body = U::from(bytes);
171                Ok(body)
172            });
173
174            res.body(body).map_err(Error::Protocol)
175        }
176    }
177
178    fn send_multipart<U>(
179        &self,
180        req: Request<MultipartForm>,
181    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
182    where
183        U: From<Bytes>,
184        U: WasmCompatSend + 'static,
185    {
186        let (parts, body) = req.into_parts();
187        let body = reqwest::multipart::Form::from(body);
188
189        let req = self
190            .request(parts.method, parts.uri.to_string())
191            .headers(parts.headers)
192            .multipart(body);
193
194        async move {
195            let response = req.send().await.map_err(instance_error)?;
196            if !response.status().is_success() {
197                return Err(non_success_status_error(response).await);
198            }
199
200            let mut res = Response::builder().status(response.status());
201
202            if let Some(hs) = res.headers_mut() {
203                *hs = response.headers().clone();
204            }
205
206            let body: LazyBody<U> = Box::pin(async {
207                let bytes = response
208                    .bytes()
209                    .await
210                    .map_err(|e| Error::Instance(e.into()))?;
211
212                let body = U::from(bytes);
213                Ok(body)
214            });
215
216            res.body(body).map_err(Error::Protocol)
217        }
218    }
219
220    fn send_streaming<T>(
221        &self,
222        req: Request<T>,
223    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
224    where
225        T: Into<Bytes> + WasmCompatSend,
226    {
227        let (parts, body) = req.into_parts();
228
229        let client = self.clone();
230
231        async move {
232            let req = self
233                .request(parts.method, parts.uri.to_string())
234                .headers(parts.headers)
235                .body(body.into())
236                .build()
237                .map_err(|error| Error::Instance(error.into()))?;
238            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
239            if !response.status().is_success() {
240                return Err(non_success_status_error(response).await);
241            }
242
243            #[cfg(not(target_family = "wasm"))]
244            let mut res = Response::builder()
245                .status(response.status())
246                .version(response.version());
247
248            #[cfg(target_family = "wasm")]
249            let mut res = Response::builder().status(response.status());
250
251            if let Some(hs) = res.headers_mut() {
252                *hs = response.headers().clone();
253            }
254
255            use futures::StreamExt;
256
257            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
258                Box::pin(
259                    response
260                        .bytes_stream()
261                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
262                );
263
264            res.body(mapped_stream).map_err(Error::Protocol)
265        }
266    }
267}
268
269#[cfg(feature = "reqwest-middleware")]
270#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-middleware")))]
271impl HttpClientExt for reqwest_middleware::ClientWithMiddleware {
272    fn send<T, U>(
273        &self,
274        req: Request<T>,
275    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
276    where
277        T: Into<Bytes>,
278        U: From<Bytes> + WasmCompatSend,
279    {
280        let (parts, body) = req.into_parts();
281        let req = self
282            .request(parts.method, parts.uri.to_string())
283            .headers(parts.headers)
284            .body(body.into());
285
286        async move {
287            let response = req.send().await.map_err(instance_error)?;
288            if !response.status().is_success() {
289                return Err(non_success_status_error(response).await);
290            }
291
292            let mut res = Response::builder().status(response.status());
293
294            if let Some(hs) = res.headers_mut() {
295                *hs = response.headers().clone();
296            }
297
298            let body: LazyBody<U> = Box::pin(async {
299                let bytes = response
300                    .bytes()
301                    .await
302                    .map_err(|e| Error::Instance(e.into()))?;
303
304                let body = U::from(bytes);
305                Ok(body)
306            });
307
308            res.body(body).map_err(Error::Protocol)
309        }
310    }
311
312    fn send_multipart<U>(
313        &self,
314        req: Request<MultipartForm>,
315    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
316    where
317        U: From<Bytes>,
318        U: WasmCompatSend + 'static,
319    {
320        let (parts, body) = req.into_parts();
321        let body = reqwest::multipart::Form::from(body);
322
323        let req = self
324            .request(parts.method, parts.uri.to_string())
325            .headers(parts.headers)
326            .multipart(body);
327
328        async move {
329            let response = req.send().await.map_err(instance_error)?;
330            if !response.status().is_success() {
331                return Err(non_success_status_error(response).await);
332            }
333
334            let mut res = Response::builder().status(response.status());
335
336            if let Some(hs) = res.headers_mut() {
337                *hs = response.headers().clone();
338            }
339
340            let body: LazyBody<U> = Box::pin(async {
341                let bytes = response
342                    .bytes()
343                    .await
344                    .map_err(|e| Error::Instance(e.into()))?;
345
346                let body = U::from(bytes);
347                Ok(body)
348            });
349
350            res.body(body).map_err(Error::Protocol)
351        }
352    }
353
354    fn send_streaming<T>(
355        &self,
356        req: Request<T>,
357    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
358    where
359        T: Into<Bytes> + WasmCompatSend,
360    {
361        let (parts, body) = req.into_parts();
362
363        let client = self.clone();
364
365        async move {
366            let req = self
367                .request(parts.method, parts.uri.to_string())
368                .headers(parts.headers)
369                .body(body.into())
370                .build()
371                .map_err(|error| Error::Instance(error.into()))?;
372            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
373            if !response.status().is_success() {
374                return Err(non_success_status_error(response).await);
375            }
376
377            #[cfg(not(target_family = "wasm"))]
378            let mut res = Response::builder()
379                .status(response.status())
380                .version(response.version());
381
382            #[cfg(target_family = "wasm")]
383            let mut res = Response::builder().status(response.status());
384
385            if let Some(hs) = res.headers_mut() {
386                *hs = response.headers().clone();
387            }
388
389            use futures::StreamExt;
390
391            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
392                Box::pin(
393                    response
394                        .bytes_stream()
395                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
396                );
397
398            res.body(mapped_stream).map_err(Error::Protocol)
399        }
400    }
401}
402
403/// Test utilities for mocking HTTP clients.
404#[cfg(test)]
405pub(crate) mod mock {
406    use super::*;
407    use bytes::Bytes;
408
409    /// A mock HTTP client that returns pre-built SSE bytes from `send_streaming`.
410    ///
411    /// `send` and `send_multipart` always return `NOT_IMPLEMENTED`.
412    #[derive(Clone)]
413    pub struct MockStreamingClient {
414        pub sse_bytes: Bytes,
415    }
416
417    impl HttpClientExt for MockStreamingClient {
418        fn send<T, U>(
419            &self,
420            _req: Request<T>,
421        ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
422        where
423            T: Into<Bytes>,
424            T: WasmCompatSend,
425            U: From<Bytes>,
426            U: WasmCompatSend + 'static,
427        {
428            std::future::ready(Err(Error::InvalidStatusCode(
429                http::StatusCode::NOT_IMPLEMENTED,
430            )))
431        }
432
433        fn send_multipart<U>(
434            &self,
435            _req: Request<MultipartForm>,
436        ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
437        where
438            U: From<Bytes>,
439            U: WasmCompatSend + 'static,
440        {
441            std::future::ready(Err(Error::InvalidStatusCode(
442                http::StatusCode::NOT_IMPLEMENTED,
443            )))
444        }
445
446        fn send_streaming<T>(
447            &self,
448            _req: Request<T>,
449        ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
450        where
451            T: Into<Bytes> + WasmCompatSend,
452        {
453            let sse_bytes = self.sse_bytes.clone();
454            async move {
455                let byte_stream = futures::stream::iter(vec![Ok::<Bytes, Error>(sse_bytes)]);
456                let boxed_stream: sse::BoxedStream = Box::pin(byte_stream);
457
458                Response::builder()
459                    .status(http::StatusCode::OK)
460                    .header(http::header::CONTENT_TYPE, "text/event-stream")
461                    .body(boxed_stream)
462                    .map_err(Error::Protocol)
463            }
464        }
465    }
466}