rig/http_client/
mod.rs

1use crate::http_client::sse::BoxedStream;
2use bytes::Bytes;
3pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder};
4use http::{HeaderName, StatusCode};
5use reqwest::Body;
6
7pub mod multipart;
8pub mod retry;
9pub mod sse;
10
11pub use multipart::MultipartForm;
12
13use std::pin::Pin;
14
15use crate::wasm_compat::*;
16
17#[derive(Debug, thiserror::Error)]
18pub enum Error {
19    #[error("Http error: {0}")]
20    Protocol(#[from] http::Error),
21    #[error("Invalid status code: {0}")]
22    InvalidStatusCode(StatusCode),
23    #[error("Invalid status code {0} with message: {1}")]
24    InvalidStatusCodeWithMessage(StatusCode, String),
25    #[error("Header value outside of legal range: {0}")]
26    InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
27    #[error("Request in error state, cannot access headers")]
28    NoHeaders,
29    #[error("Stream ended")]
30    StreamEnded,
31    #[error("Invalid content type was returned: {0:?}")]
32    InvalidContentType(HeaderValue),
33    #[cfg(not(target_family = "wasm"))]
34    #[error("Http client error: {0}")]
35    Instance(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
36
37    #[cfg(target_family = "wasm")]
38    #[error("Http client error: {0}")]
39    Instance(#[from] Box<dyn std::error::Error + 'static>),
40}
41
42pub type Result<T> = std::result::Result<T, Error>;
43
44#[cfg(not(target_family = "wasm"))]
45pub(crate) fn instance_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
46    Error::Instance(error.into())
47}
48
49#[cfg(target_family = "wasm")]
50fn instance_error<E: std::error::Error + 'static>(error: E) -> Error {
51    Error::Instance(error.into())
52}
53
54pub type LazyBytes = WasmBoxedFuture<'static, Result<Bytes>>;
55pub type LazyBody<T> = WasmBoxedFuture<'static, Result<T>>;
56
57pub type StreamingResponse = Response<BoxedStream>;
58
59#[derive(Debug, Clone, Copy)]
60pub struct NoBody;
61
62impl From<NoBody> for Bytes {
63    fn from(_: NoBody) -> Self {
64        Bytes::new()
65    }
66}
67
68impl From<NoBody> for Body {
69    fn from(_: NoBody) -> Self {
70        reqwest::Body::default()
71    }
72}
73
74pub async fn text(response: Response<LazyBody<Vec<u8>>>) -> Result<String> {
75    let text = response.into_body().await?;
76    Ok(String::from(String::from_utf8_lossy(&text)))
77}
78
79pub fn make_auth_header(key: impl AsRef<str>) -> Result<(HeaderName, HeaderValue)> {
80    Ok((
81        http::header::AUTHORIZATION,
82        HeaderValue::from_str(&format!("Bearer {}", key.as_ref()))?,
83    ))
84}
85
86pub fn bearer_auth_header(headers: &mut HeaderMap, key: impl AsRef<str>) -> Result<()> {
87    let (k, v) = make_auth_header(key)?;
88
89    headers.insert(k, v);
90
91    Ok(())
92}
93
94pub fn with_bearer_auth(mut req: Builder, auth: &str) -> Result<Builder> {
95    bearer_auth_header(req.headers_mut().ok_or(Error::NoHeaders)?, auth)?;
96
97    Ok(req)
98}
99
100/// A helper trait to make generic requests (both regular and SSE) possible.
101pub trait HttpClientExt: WasmCompatSend + WasmCompatSync {
102    /// Send a HTTP request, get a response back (as bytes). Response must be able to be turned back into Bytes.
103    fn send<T, U>(
104        &self,
105        req: Request<T>,
106    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
107    where
108        T: Into<Bytes>,
109        T: WasmCompatSend,
110        U: From<Bytes>,
111        U: WasmCompatSend + 'static;
112
113    /// Send a HTTP request with a multipart body, get a response back (as bytes). Response must be able to be turned back into Bytes (although usually for the response, you will probably want to specify Bytes anyway).
114    fn send_multipart<U>(
115        &self,
116        req: Request<MultipartForm>,
117    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
118    where
119        U: From<Bytes>,
120        U: WasmCompatSend + 'static;
121
122    /// Send a HTTP request, get a streamed response back (as a stream of [`bytes::Bytes`].)
123    fn send_streaming<T>(
124        &self,
125        req: Request<T>,
126    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
127    where
128        T: Into<Bytes>;
129}
130
131impl HttpClientExt for reqwest::Client {
132    fn send<T, U>(
133        &self,
134        req: Request<T>,
135    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
136    where
137        T: Into<Bytes>,
138        U: From<Bytes> + WasmCompatSend,
139    {
140        let (parts, body) = req.into_parts();
141        let req = self
142            .request(parts.method, parts.uri.to_string())
143            .headers(parts.headers)
144            .body(body.into());
145
146        async move {
147            let response = req.send().await.map_err(instance_error)?;
148            if !response.status().is_success() {
149                return Err(Error::InvalidStatusCodeWithMessage(
150                    response.status(),
151                    response.text().await.unwrap(),
152                ));
153            }
154
155            let mut res = Response::builder().status(response.status());
156
157            if let Some(hs) = res.headers_mut() {
158                *hs = response.headers().clone();
159            }
160
161            let body: LazyBody<U> = Box::pin(async {
162                let bytes = response
163                    .bytes()
164                    .await
165                    .map_err(|e| Error::Instance(e.into()))?;
166
167                let body = U::from(bytes);
168                Ok(body)
169            });
170
171            res.body(body).map_err(Error::Protocol)
172        }
173    }
174
175    fn send_multipart<U>(
176        &self,
177        req: Request<MultipartForm>,
178    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
179    where
180        U: From<Bytes>,
181        U: WasmCompatSend + 'static,
182    {
183        let (parts, body) = req.into_parts();
184        let body = reqwest::multipart::Form::from(body);
185
186        let req = self
187            .request(parts.method, parts.uri.to_string())
188            .headers(parts.headers)
189            .multipart(body);
190
191        async move {
192            let response = req.send().await.map_err(instance_error)?;
193            if !response.status().is_success() {
194                return Err(Error::InvalidStatusCodeWithMessage(
195                    response.status(),
196                    response.text().await.unwrap(),
197                ));
198            }
199
200            let mut res = Response::builder().status(response.status());
201
202            if let Some(hs) = res.headers_mut() {
203                *hs = response.headers().clone();
204            }
205
206            let body: LazyBody<U> = Box::pin(async {
207                let bytes = response
208                    .bytes()
209                    .await
210                    .map_err(|e| Error::Instance(e.into()))?;
211
212                let body = U::from(bytes);
213                Ok(body)
214            });
215
216            res.body(body).map_err(Error::Protocol)
217        }
218    }
219
220    fn send_streaming<T>(
221        &self,
222        req: Request<T>,
223    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
224    where
225        T: Into<Bytes>,
226    {
227        let (parts, body) = req.into_parts();
228
229        let req = self
230            .request(parts.method, parts.uri.to_string())
231            .headers(parts.headers)
232            .body(body.into())
233            .build()
234            .map_err(|x| Error::Instance(x.into()))
235            .unwrap();
236
237        let client = self.clone();
238
239        async move {
240            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
241            if !response.status().is_success() {
242                return Err(Error::InvalidStatusCodeWithMessage(
243                    response.status(),
244                    response.text().await.unwrap(),
245                ));
246            }
247
248            #[cfg(not(target_family = "wasm"))]
249            let mut res = Response::builder()
250                .status(response.status())
251                .version(response.version());
252
253            #[cfg(target_family = "wasm")]
254            let mut res = Response::builder().status(response.status());
255
256            if let Some(hs) = res.headers_mut() {
257                *hs = response.headers().clone();
258            }
259
260            use futures::StreamExt;
261
262            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
263                Box::pin(
264                    response
265                        .bytes_stream()
266                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
267                );
268
269            res.body(mapped_stream).map_err(Error::Protocol)
270        }
271    }
272}
273
274#[cfg(feature = "reqwest-middleware")]
275#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-middleware")))]
276impl HttpClientExt for reqwest_middleware::ClientWithMiddleware {
277    fn send<T, U>(
278        &self,
279        req: Request<T>,
280    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
281    where
282        T: Into<Bytes>,
283        U: From<Bytes> + WasmCompatSend,
284    {
285        let (parts, body) = req.into_parts();
286        let req = self
287            .request(parts.method, parts.uri.to_string())
288            .headers(parts.headers)
289            .body(body.into());
290
291        async move {
292            let response = req.send().await.map_err(instance_error)?;
293            if !response.status().is_success() {
294                return Err(Error::InvalidStatusCodeWithMessage(
295                    response.status(),
296                    response.text().await.unwrap(),
297                ));
298            }
299
300            let mut res = Response::builder().status(response.status());
301
302            if let Some(hs) = res.headers_mut() {
303                *hs = response.headers().clone();
304            }
305
306            let body: LazyBody<U> = Box::pin(async {
307                let bytes = response
308                    .bytes()
309                    .await
310                    .map_err(|e| Error::Instance(e.into()))?;
311
312                let body = U::from(bytes);
313                Ok(body)
314            });
315
316            res.body(body).map_err(Error::Protocol)
317        }
318    }
319
320    fn send_multipart<U>(
321        &self,
322        req: Request<MultipartForm>,
323    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
324    where
325        U: From<Bytes>,
326        U: WasmCompatSend + 'static,
327    {
328        let (parts, body) = req.into_parts();
329        let body = reqwest::multipart::Form::from(body);
330
331        let req = self
332            .request(parts.method, parts.uri.to_string())
333            .headers(parts.headers)
334            .multipart(body);
335
336        async move {
337            let response = req.send().await.map_err(instance_error)?;
338            if !response.status().is_success() {
339                return Err(Error::InvalidStatusCodeWithMessage(
340                    response.status(),
341                    response.text().await.unwrap(),
342                ));
343            }
344
345            let mut res = Response::builder().status(response.status());
346
347            if let Some(hs) = res.headers_mut() {
348                *hs = response.headers().clone();
349            }
350
351            let body: LazyBody<U> = Box::pin(async {
352                let bytes = response
353                    .bytes()
354                    .await
355                    .map_err(|e| Error::Instance(e.into()))?;
356
357                let body = U::from(bytes);
358                Ok(body)
359            });
360
361            res.body(body).map_err(Error::Protocol)
362        }
363    }
364
365    fn send_streaming<T>(
366        &self,
367        req: Request<T>,
368    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
369    where
370        T: Into<Bytes>,
371    {
372        let (parts, body) = req.into_parts();
373
374        let req = self
375            .request(parts.method, parts.uri.to_string())
376            .headers(parts.headers)
377            .body(body.into())
378            .build()
379            .map_err(|x| Error::Instance(x.into()))
380            .unwrap();
381
382        let client = self.clone();
383
384        async move {
385            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
386            if !response.status().is_success() {
387                return Err(Error::InvalidStatusCodeWithMessage(
388                    response.status(),
389                    response.text().await.unwrap(),
390                ));
391            }
392
393            #[cfg(not(target_family = "wasm"))]
394            let mut res = Response::builder()
395                .status(response.status())
396                .version(response.version());
397
398            #[cfg(target_family = "wasm")]
399            let mut res = Response::builder().status(response.status());
400
401            if let Some(hs) = res.headers_mut() {
402                *hs = response.headers().clone();
403            }
404
405            use futures::StreamExt;
406
407            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
408                Box::pin(
409                    response
410                        .bytes_stream()
411                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
412                );
413
414            res.body(mapped_stream).map_err(Error::Protocol)
415        }
416    }
417}