Skip to main content

rig/http_client/
mod.rs

1use crate::http_client::sse::BoxedStream;
2use bytes::Bytes;
3pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder};
4use http::{HeaderName, StatusCode};
5use reqwest::Body;
6pub mod multipart;
7pub mod retry;
8pub mod sse;
9use crate::wasm_compat::*;
10pub use multipart::MultipartForm;
11pub use reqwest::Client as ReqwestClient;
12use std::pin::Pin;
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error {
16    #[error("Http error: {0}")]
17    Protocol(#[from] http::Error),
18    #[error("Invalid status code: {0}")]
19    InvalidStatusCode(StatusCode),
20    #[error("Invalid status code {0} with message: {1}")]
21    InvalidStatusCodeWithMessage(StatusCode, String),
22    #[error("Header value outside of legal range: {0}")]
23    InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
24    #[error("Request in error state, cannot access headers")]
25    NoHeaders,
26    #[error("Stream ended")]
27    StreamEnded,
28    #[error("Invalid content type was returned: {0:?}")]
29    InvalidContentType(HeaderValue),
30    #[cfg(not(target_family = "wasm"))]
31    #[error("Http client error: {0}")]
32    Instance(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
33
34    #[cfg(target_family = "wasm")]
35    #[error("Http client error: {0}")]
36    Instance(#[from] Box<dyn std::error::Error + 'static>),
37}
38
39pub type Result<T> = std::result::Result<T, Error>;
40
41#[cfg(not(target_family = "wasm"))]
42pub(crate) fn instance_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
43    Error::Instance(error.into())
44}
45
46#[cfg(target_family = "wasm")]
47fn instance_error<E: std::error::Error + 'static>(error: E) -> Error {
48    Error::Instance(error.into())
49}
50
51pub type LazyBytes = WasmBoxedFuture<'static, Result<Bytes>>;
52pub type LazyBody<T> = WasmBoxedFuture<'static, Result<T>>;
53
54pub type StreamingResponse = Response<BoxedStream>;
55
56#[derive(Debug, Clone, Copy)]
57pub struct NoBody;
58
59impl From<NoBody> for Bytes {
60    fn from(_: NoBody) -> Self {
61        Bytes::new()
62    }
63}
64
65impl From<NoBody> for Body {
66    fn from(_: NoBody) -> Self {
67        reqwest::Body::default()
68    }
69}
70
71pub async fn text(response: Response<LazyBody<Vec<u8>>>) -> Result<String> {
72    let text = response.into_body().await?;
73    Ok(String::from(String::from_utf8_lossy(&text)))
74}
75
76pub fn make_auth_header(key: impl AsRef<str>) -> Result<(HeaderName, HeaderValue)> {
77    Ok((
78        http::header::AUTHORIZATION,
79        HeaderValue::from_str(&format!("Bearer {}", key.as_ref()))?,
80    ))
81}
82
83pub fn bearer_auth_header(headers: &mut HeaderMap, key: impl AsRef<str>) -> Result<()> {
84    let (k, v) = make_auth_header(key)?;
85
86    headers.insert(k, v);
87
88    Ok(())
89}
90
91pub fn with_bearer_auth(mut req: Builder, auth: &str) -> Result<Builder> {
92    bearer_auth_header(req.headers_mut().ok_or(Error::NoHeaders)?, auth)?;
93
94    Ok(req)
95}
96
97/// A helper trait to make generic requests (both regular and SSE) possible.
98pub trait HttpClientExt: WasmCompatSend + WasmCompatSync {
99    /// Send a HTTP request, get a response back (as bytes). Response must be able to be turned back into Bytes.
100    fn send<T, U>(
101        &self,
102        req: Request<T>,
103    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
104    where
105        T: Into<Bytes>,
106        T: WasmCompatSend,
107        U: From<Bytes>,
108        U: WasmCompatSend + 'static;
109
110    /// Send a HTTP request with a multipart body, get a response back (as bytes). Response must be able to be turned back into Bytes (although usually for the response, you will probably want to specify Bytes anyway).
111    fn send_multipart<U>(
112        &self,
113        req: Request<MultipartForm>,
114    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
115    where
116        U: From<Bytes>,
117        U: WasmCompatSend + 'static;
118
119    /// Send a HTTP request, get a streamed response back (as a stream of [`bytes::Bytes`].)
120    fn send_streaming<T>(
121        &self,
122        req: Request<T>,
123    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
124    where
125        T: Into<Bytes>;
126}
127
128impl HttpClientExt for reqwest::Client {
129    fn send<T, U>(
130        &self,
131        req: Request<T>,
132    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
133    where
134        T: Into<Bytes>,
135        U: From<Bytes> + WasmCompatSend,
136    {
137        let (parts, body) = req.into_parts();
138        let req = self
139            .request(parts.method, parts.uri.to_string())
140            .headers(parts.headers)
141            .body(body.into());
142
143        async move {
144            let response = req.send().await.map_err(instance_error)?;
145            if !response.status().is_success() {
146                return Err(Error::InvalidStatusCodeWithMessage(
147                    response.status(),
148                    response.text().await.unwrap(),
149                ));
150            }
151
152            let mut res = Response::builder().status(response.status());
153
154            if let Some(hs) = res.headers_mut() {
155                *hs = response.headers().clone();
156            }
157
158            let body: LazyBody<U> = Box::pin(async {
159                let bytes = response
160                    .bytes()
161                    .await
162                    .map_err(|e| Error::Instance(e.into()))?;
163
164                let body = U::from(bytes);
165                Ok(body)
166            });
167
168            res.body(body).map_err(Error::Protocol)
169        }
170    }
171
172    fn send_multipart<U>(
173        &self,
174        req: Request<MultipartForm>,
175    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
176    where
177        U: From<Bytes>,
178        U: WasmCompatSend + 'static,
179    {
180        let (parts, body) = req.into_parts();
181        let body = reqwest::multipart::Form::from(body);
182
183        let req = self
184            .request(parts.method, parts.uri.to_string())
185            .headers(parts.headers)
186            .multipart(body);
187
188        async move {
189            let response = req.send().await.map_err(instance_error)?;
190            if !response.status().is_success() {
191                return Err(Error::InvalidStatusCodeWithMessage(
192                    response.status(),
193                    response.text().await.unwrap(),
194                ));
195            }
196
197            let mut res = Response::builder().status(response.status());
198
199            if let Some(hs) = res.headers_mut() {
200                *hs = response.headers().clone();
201            }
202
203            let body: LazyBody<U> = Box::pin(async {
204                let bytes = response
205                    .bytes()
206                    .await
207                    .map_err(|e| Error::Instance(e.into()))?;
208
209                let body = U::from(bytes);
210                Ok(body)
211            });
212
213            res.body(body).map_err(Error::Protocol)
214        }
215    }
216
217    fn send_streaming<T>(
218        &self,
219        req: Request<T>,
220    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
221    where
222        T: Into<Bytes>,
223    {
224        let (parts, body) = req.into_parts();
225
226        let req = self
227            .request(parts.method, parts.uri.to_string())
228            .headers(parts.headers)
229            .body(body.into())
230            .build()
231            .map_err(|x| Error::Instance(x.into()))
232            .unwrap();
233
234        let client = self.clone();
235
236        async move {
237            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
238            if !response.status().is_success() {
239                return Err(Error::InvalidStatusCodeWithMessage(
240                    response.status(),
241                    response.text().await.unwrap(),
242                ));
243            }
244
245            #[cfg(not(target_family = "wasm"))]
246            let mut res = Response::builder()
247                .status(response.status())
248                .version(response.version());
249
250            #[cfg(target_family = "wasm")]
251            let mut res = Response::builder().status(response.status());
252
253            if let Some(hs) = res.headers_mut() {
254                *hs = response.headers().clone();
255            }
256
257            use futures::StreamExt;
258
259            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
260                Box::pin(
261                    response
262                        .bytes_stream()
263                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
264                );
265
266            res.body(mapped_stream).map_err(Error::Protocol)
267        }
268    }
269}
270
271#[cfg(feature = "reqwest-middleware")]
272#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-middleware")))]
273impl HttpClientExt for reqwest_middleware::ClientWithMiddleware {
274    fn send<T, U>(
275        &self,
276        req: Request<T>,
277    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
278    where
279        T: Into<Bytes>,
280        U: From<Bytes> + WasmCompatSend,
281    {
282        let (parts, body) = req.into_parts();
283        let req = self
284            .request(parts.method, parts.uri.to_string())
285            .headers(parts.headers)
286            .body(body.into());
287
288        async move {
289            let response = req.send().await.map_err(instance_error)?;
290            if !response.status().is_success() {
291                return Err(Error::InvalidStatusCodeWithMessage(
292                    response.status(),
293                    response.text().await.unwrap(),
294                ));
295            }
296
297            let mut res = Response::builder().status(response.status());
298
299            if let Some(hs) = res.headers_mut() {
300                *hs = response.headers().clone();
301            }
302
303            let body: LazyBody<U> = Box::pin(async {
304                let bytes = response
305                    .bytes()
306                    .await
307                    .map_err(|e| Error::Instance(e.into()))?;
308
309                let body = U::from(bytes);
310                Ok(body)
311            });
312
313            res.body(body).map_err(Error::Protocol)
314        }
315    }
316
317    fn send_multipart<U>(
318        &self,
319        req: Request<MultipartForm>,
320    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
321    where
322        U: From<Bytes>,
323        U: WasmCompatSend + 'static,
324    {
325        let (parts, body) = req.into_parts();
326        let body = reqwest::multipart::Form::from(body);
327
328        let req = self
329            .request(parts.method, parts.uri.to_string())
330            .headers(parts.headers)
331            .multipart(body);
332
333        async move {
334            let response = req.send().await.map_err(instance_error)?;
335            if !response.status().is_success() {
336                return Err(Error::InvalidStatusCodeWithMessage(
337                    response.status(),
338                    response.text().await.unwrap(),
339                ));
340            }
341
342            let mut res = Response::builder().status(response.status());
343
344            if let Some(hs) = res.headers_mut() {
345                *hs = response.headers().clone();
346            }
347
348            let body: LazyBody<U> = Box::pin(async {
349                let bytes = response
350                    .bytes()
351                    .await
352                    .map_err(|e| Error::Instance(e.into()))?;
353
354                let body = U::from(bytes);
355                Ok(body)
356            });
357
358            res.body(body).map_err(Error::Protocol)
359        }
360    }
361
362    fn send_streaming<T>(
363        &self,
364        req: Request<T>,
365    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
366    where
367        T: Into<Bytes>,
368    {
369        let (parts, body) = req.into_parts();
370
371        let req = self
372            .request(parts.method, parts.uri.to_string())
373            .headers(parts.headers)
374            .body(body.into())
375            .build()
376            .map_err(|x| Error::Instance(x.into()))
377            .unwrap();
378
379        let client = self.clone();
380
381        async move {
382            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
383            if !response.status().is_success() {
384                return Err(Error::InvalidStatusCodeWithMessage(
385                    response.status(),
386                    response.text().await.unwrap(),
387                ));
388            }
389
390            #[cfg(not(target_family = "wasm"))]
391            let mut res = Response::builder()
392                .status(response.status())
393                .version(response.version());
394
395            #[cfg(target_family = "wasm")]
396            let mut res = Response::builder().status(response.status());
397
398            if let Some(hs) = res.headers_mut() {
399                *hs = response.headers().clone();
400            }
401
402            use futures::StreamExt;
403
404            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
405                Box::pin(
406                    response
407                        .bytes_stream()
408                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
409                );
410
411            res.body(mapped_stream).map_err(Error::Protocol)
412        }
413    }
414}
415
416/// Test utilities for mocking HTTP clients.
417#[cfg(test)]
418pub(crate) mod mock {
419    use super::*;
420    use bytes::Bytes;
421
422    /// A mock HTTP client that returns pre-built SSE bytes from `send_streaming`.
423    ///
424    /// `send` and `send_multipart` always return `NOT_IMPLEMENTED`.
425    #[derive(Clone)]
426    pub struct MockStreamingClient {
427        pub sse_bytes: Bytes,
428    }
429
430    impl HttpClientExt for MockStreamingClient {
431        fn send<T, U>(
432            &self,
433            _req: Request<T>,
434        ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
435        where
436            T: Into<Bytes>,
437            T: WasmCompatSend,
438            U: From<Bytes>,
439            U: WasmCompatSend + 'static,
440        {
441            std::future::ready(Err(Error::InvalidStatusCode(
442                http::StatusCode::NOT_IMPLEMENTED,
443            )))
444        }
445
446        fn send_multipart<U>(
447            &self,
448            _req: Request<MultipartForm>,
449        ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
450        where
451            U: From<Bytes>,
452            U: WasmCompatSend + 'static,
453        {
454            std::future::ready(Err(Error::InvalidStatusCode(
455                http::StatusCode::NOT_IMPLEMENTED,
456            )))
457        }
458
459        fn send_streaming<T>(
460            &self,
461            _req: Request<T>,
462        ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
463        where
464            T: Into<Bytes>,
465        {
466            let sse_bytes = self.sse_bytes.clone();
467            async move {
468                let byte_stream = futures::stream::iter(vec![Ok::<Bytes, Error>(sse_bytes)]);
469                let boxed_stream: sse::BoxedStream = Box::pin(byte_stream);
470
471                Response::builder()
472                    .status(http::StatusCode::OK)
473                    .header(http::header::CONTENT_TYPE, "text/event-stream")
474                    .body(boxed_stream)
475                    .map_err(Error::Protocol)
476            }
477        }
478    }
479}