rig/http_client/
mod.rs

1use crate::http_client::sse::BoxedStream;
2use bytes::Bytes;
3pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder};
4use http::{HeaderName, StatusCode};
5use reqwest::Body;
6
7pub mod multipart;
8pub mod retry;
9pub mod sse;
10
11pub use multipart::MultipartForm;
12
13use std::pin::Pin;
14
15use crate::wasm_compat::*;
16
17pub use reqwest::Client as ReqwestClient;
18
19#[derive(Debug, thiserror::Error)]
20pub enum Error {
21    #[error("Http error: {0}")]
22    Protocol(#[from] http::Error),
23    #[error("Invalid status code: {0}")]
24    InvalidStatusCode(StatusCode),
25    #[error("Invalid status code {0} with message: {1}")]
26    InvalidStatusCodeWithMessage(StatusCode, String),
27    #[error("Header value outside of legal range: {0}")]
28    InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
29    #[error("Request in error state, cannot access headers")]
30    NoHeaders,
31    #[error("Stream ended")]
32    StreamEnded,
33    #[error("Invalid content type was returned: {0:?}")]
34    InvalidContentType(HeaderValue),
35    #[cfg(not(target_family = "wasm"))]
36    #[error("Http client error: {0}")]
37    Instance(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
38
39    #[cfg(target_family = "wasm")]
40    #[error("Http client error: {0}")]
41    Instance(#[from] Box<dyn std::error::Error + 'static>),
42}
43
44pub type Result<T> = std::result::Result<T, Error>;
45
46#[cfg(not(target_family = "wasm"))]
47pub(crate) fn instance_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
48    Error::Instance(error.into())
49}
50
51#[cfg(target_family = "wasm")]
52fn instance_error<E: std::error::Error + 'static>(error: E) -> Error {
53    Error::Instance(error.into())
54}
55
56pub type LazyBytes = WasmBoxedFuture<'static, Result<Bytes>>;
57pub type LazyBody<T> = WasmBoxedFuture<'static, Result<T>>;
58
59pub type StreamingResponse = Response<BoxedStream>;
60
61#[derive(Debug, Clone, Copy)]
62pub struct NoBody;
63
64impl From<NoBody> for Bytes {
65    fn from(_: NoBody) -> Self {
66        Bytes::new()
67    }
68}
69
70impl From<NoBody> for Body {
71    fn from(_: NoBody) -> Self {
72        reqwest::Body::default()
73    }
74}
75
76pub async fn text(response: Response<LazyBody<Vec<u8>>>) -> Result<String> {
77    let text = response.into_body().await?;
78    Ok(String::from(String::from_utf8_lossy(&text)))
79}
80
81pub fn make_auth_header(key: impl AsRef<str>) -> Result<(HeaderName, HeaderValue)> {
82    Ok((
83        http::header::AUTHORIZATION,
84        HeaderValue::from_str(&format!("Bearer {}", key.as_ref()))?,
85    ))
86}
87
88pub fn bearer_auth_header(headers: &mut HeaderMap, key: impl AsRef<str>) -> Result<()> {
89    let (k, v) = make_auth_header(key)?;
90
91    headers.insert(k, v);
92
93    Ok(())
94}
95
96pub fn with_bearer_auth(mut req: Builder, auth: &str) -> Result<Builder> {
97    bearer_auth_header(req.headers_mut().ok_or(Error::NoHeaders)?, auth)?;
98
99    Ok(req)
100}
101
102/// A helper trait to make generic requests (both regular and SSE) possible.
103pub trait HttpClientExt: WasmCompatSend + WasmCompatSync {
104    /// Send a HTTP request, get a response back (as bytes). Response must be able to be turned back into Bytes.
105    fn send<T, U>(
106        &self,
107        req: Request<T>,
108    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
109    where
110        T: Into<Bytes>,
111        T: WasmCompatSend,
112        U: From<Bytes>,
113        U: WasmCompatSend + 'static;
114
115    /// Send a HTTP request with a multipart body, get a response back (as bytes). Response must be able to be turned back into Bytes (although usually for the response, you will probably want to specify Bytes anyway).
116    fn send_multipart<U>(
117        &self,
118        req: Request<MultipartForm>,
119    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
120    where
121        U: From<Bytes>,
122        U: WasmCompatSend + 'static;
123
124    /// Send a HTTP request, get a streamed response back (as a stream of [`bytes::Bytes`].)
125    fn send_streaming<T>(
126        &self,
127        req: Request<T>,
128    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
129    where
130        T: Into<Bytes>;
131}
132
133impl HttpClientExt for reqwest::Client {
134    fn send<T, U>(
135        &self,
136        req: Request<T>,
137    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
138    where
139        T: Into<Bytes>,
140        U: From<Bytes> + WasmCompatSend,
141    {
142        let (parts, body) = req.into_parts();
143        let req = self
144            .request(parts.method, parts.uri.to_string())
145            .headers(parts.headers)
146            .body(body.into());
147
148        async move {
149            let response = req.send().await.map_err(instance_error)?;
150            if !response.status().is_success() {
151                return Err(Error::InvalidStatusCodeWithMessage(
152                    response.status(),
153                    response.text().await.unwrap(),
154                ));
155            }
156
157            let mut res = Response::builder().status(response.status());
158
159            if let Some(hs) = res.headers_mut() {
160                *hs = response.headers().clone();
161            }
162
163            let body: LazyBody<U> = Box::pin(async {
164                let bytes = response
165                    .bytes()
166                    .await
167                    .map_err(|e| Error::Instance(e.into()))?;
168
169                let body = U::from(bytes);
170                Ok(body)
171            });
172
173            res.body(body).map_err(Error::Protocol)
174        }
175    }
176
177    fn send_multipart<U>(
178        &self,
179        req: Request<MultipartForm>,
180    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
181    where
182        U: From<Bytes>,
183        U: WasmCompatSend + 'static,
184    {
185        let (parts, body) = req.into_parts();
186        let body = reqwest::multipart::Form::from(body);
187
188        let req = self
189            .request(parts.method, parts.uri.to_string())
190            .headers(parts.headers)
191            .multipart(body);
192
193        async move {
194            let response = req.send().await.map_err(instance_error)?;
195            if !response.status().is_success() {
196                return Err(Error::InvalidStatusCodeWithMessage(
197                    response.status(),
198                    response.text().await.unwrap(),
199                ));
200            }
201
202            let mut res = Response::builder().status(response.status());
203
204            if let Some(hs) = res.headers_mut() {
205                *hs = response.headers().clone();
206            }
207
208            let body: LazyBody<U> = Box::pin(async {
209                let bytes = response
210                    .bytes()
211                    .await
212                    .map_err(|e| Error::Instance(e.into()))?;
213
214                let body = U::from(bytes);
215                Ok(body)
216            });
217
218            res.body(body).map_err(Error::Protocol)
219        }
220    }
221
222    fn send_streaming<T>(
223        &self,
224        req: Request<T>,
225    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
226    where
227        T: Into<Bytes>,
228    {
229        let (parts, body) = req.into_parts();
230
231        let req = self
232            .request(parts.method, parts.uri.to_string())
233            .headers(parts.headers)
234            .body(body.into())
235            .build()
236            .map_err(|x| Error::Instance(x.into()))
237            .unwrap();
238
239        let client = self.clone();
240
241        async move {
242            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
243            if !response.status().is_success() {
244                return Err(Error::InvalidStatusCodeWithMessage(
245                    response.status(),
246                    response.text().await.unwrap(),
247                ));
248            }
249
250            #[cfg(not(target_family = "wasm"))]
251            let mut res = Response::builder()
252                .status(response.status())
253                .version(response.version());
254
255            #[cfg(target_family = "wasm")]
256            let mut res = Response::builder().status(response.status());
257
258            if let Some(hs) = res.headers_mut() {
259                *hs = response.headers().clone();
260            }
261
262            use futures::StreamExt;
263
264            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
265                Box::pin(
266                    response
267                        .bytes_stream()
268                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
269                );
270
271            res.body(mapped_stream).map_err(Error::Protocol)
272        }
273    }
274}
275
276#[cfg(feature = "reqwest-middleware")]
277#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-middleware")))]
278impl HttpClientExt for reqwest_middleware::ClientWithMiddleware {
279    fn send<T, U>(
280        &self,
281        req: Request<T>,
282    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
283    where
284        T: Into<Bytes>,
285        U: From<Bytes> + WasmCompatSend,
286    {
287        let (parts, body) = req.into_parts();
288        let req = self
289            .request(parts.method, parts.uri.to_string())
290            .headers(parts.headers)
291            .body(body.into());
292
293        async move {
294            let response = req.send().await.map_err(instance_error)?;
295            if !response.status().is_success() {
296                return Err(Error::InvalidStatusCodeWithMessage(
297                    response.status(),
298                    response.text().await.unwrap(),
299                ));
300            }
301
302            let mut res = Response::builder().status(response.status());
303
304            if let Some(hs) = res.headers_mut() {
305                *hs = response.headers().clone();
306            }
307
308            let body: LazyBody<U> = Box::pin(async {
309                let bytes = response
310                    .bytes()
311                    .await
312                    .map_err(|e| Error::Instance(e.into()))?;
313
314                let body = U::from(bytes);
315                Ok(body)
316            });
317
318            res.body(body).map_err(Error::Protocol)
319        }
320    }
321
322    fn send_multipart<U>(
323        &self,
324        req: Request<MultipartForm>,
325    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
326    where
327        U: From<Bytes>,
328        U: WasmCompatSend + 'static,
329    {
330        let (parts, body) = req.into_parts();
331        let body = reqwest::multipart::Form::from(body);
332
333        let req = self
334            .request(parts.method, parts.uri.to_string())
335            .headers(parts.headers)
336            .multipart(body);
337
338        async move {
339            let response = req.send().await.map_err(instance_error)?;
340            if !response.status().is_success() {
341                return Err(Error::InvalidStatusCodeWithMessage(
342                    response.status(),
343                    response.text().await.unwrap(),
344                ));
345            }
346
347            let mut res = Response::builder().status(response.status());
348
349            if let Some(hs) = res.headers_mut() {
350                *hs = response.headers().clone();
351            }
352
353            let body: LazyBody<U> = Box::pin(async {
354                let bytes = response
355                    .bytes()
356                    .await
357                    .map_err(|e| Error::Instance(e.into()))?;
358
359                let body = U::from(bytes);
360                Ok(body)
361            });
362
363            res.body(body).map_err(Error::Protocol)
364        }
365    }
366
367    fn send_streaming<T>(
368        &self,
369        req: Request<T>,
370    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
371    where
372        T: Into<Bytes>,
373    {
374        let (parts, body) = req.into_parts();
375
376        let req = self
377            .request(parts.method, parts.uri.to_string())
378            .headers(parts.headers)
379            .body(body.into())
380            .build()
381            .map_err(|x| Error::Instance(x.into()))
382            .unwrap();
383
384        let client = self.clone();
385
386        async move {
387            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
388            if !response.status().is_success() {
389                return Err(Error::InvalidStatusCodeWithMessage(
390                    response.status(),
391                    response.text().await.unwrap(),
392                ));
393            }
394
395            #[cfg(not(target_family = "wasm"))]
396            let mut res = Response::builder()
397                .status(response.status())
398                .version(response.version());
399
400            #[cfg(target_family = "wasm")]
401            let mut res = Response::builder().status(response.status());
402
403            if let Some(hs) = res.headers_mut() {
404                *hs = response.headers().clone();
405            }
406
407            use futures::StreamExt;
408
409            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
410                Box::pin(
411                    response
412                        .bytes_stream()
413                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
414                );
415
416            res.body(mapped_stream).map_err(Error::Protocol)
417        }
418    }
419}