rig/http_client/
mod.rs

1use crate::http_client::sse::BoxedStream;
2use bytes::Bytes;
3pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder};
4use http::{HeaderName, StatusCode};
5use reqwest::{Body, multipart::Form};
6
7pub mod retry;
8pub mod sse;
9
10use std::pin::Pin;
11
12use crate::wasm_compat::*;
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error {
16    #[error("Http error: {0}")]
17    Protocol(#[from] http::Error),
18    #[error("Invalid status code: {0}")]
19    InvalidStatusCode(StatusCode),
20    #[error("Invalid status code {0} with message: {1}")]
21    InvalidStatusCodeWithMessage(StatusCode, String),
22    #[error("Header value outside of legal range: {0}")]
23    InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
24    #[error("Request in error state, cannot access headers")]
25    NoHeaders,
26    #[error("Stream ended")]
27    StreamEnded,
28    #[error("Invalid content type was returned: {0:?}")]
29    InvalidContentType(HeaderValue),
30    #[cfg(not(target_family = "wasm"))]
31    #[error("Http client error: {0}")]
32    Instance(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
33
34    #[cfg(target_family = "wasm")]
35    #[error("Http client error: {0}")]
36    Instance(#[from] Box<dyn std::error::Error + 'static>),
37}
38
39pub type Result<T> = std::result::Result<T, Error>;
40
41#[cfg(not(target_family = "wasm"))]
42pub(crate) fn instance_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
43    Error::Instance(error.into())
44}
45
46#[cfg(target_family = "wasm")]
47fn instance_error<E: std::error::Error + 'static>(error: E) -> Error {
48    Error::Instance(error.into())
49}
50
51pub type LazyBytes = WasmBoxedFuture<'static, Result<Bytes>>;
52pub type LazyBody<T> = WasmBoxedFuture<'static, Result<T>>;
53
54pub type StreamingResponse = Response<BoxedStream>;
55
56#[derive(Debug, Clone, Copy)]
57pub struct NoBody;
58
59impl From<NoBody> for Bytes {
60    fn from(_: NoBody) -> Self {
61        Bytes::new()
62    }
63}
64
65impl From<NoBody> for Body {
66    fn from(_: NoBody) -> Self {
67        reqwest::Body::default()
68    }
69}
70
71pub async fn text(response: Response<LazyBody<Vec<u8>>>) -> Result<String> {
72    let text = response.into_body().await?;
73    Ok(String::from(String::from_utf8_lossy(&text)))
74}
75
76pub fn make_auth_header(key: impl AsRef<str>) -> Result<(HeaderName, HeaderValue)> {
77    Ok((
78        http::header::AUTHORIZATION,
79        HeaderValue::from_str(&format!("Bearer {}", key.as_ref()))?,
80    ))
81}
82
83pub fn bearer_auth_header(headers: &mut HeaderMap, key: impl AsRef<str>) -> Result<()> {
84    let (k, v) = make_auth_header(key)?;
85
86    headers.insert(k, v);
87
88    Ok(())
89}
90
91pub fn with_bearer_auth(mut req: Builder, auth: &str) -> Result<Builder> {
92    bearer_auth_header(req.headers_mut().ok_or(Error::NoHeaders)?, auth)?;
93
94    Ok(req)
95}
96
97/// A helper trait to make generic requests (both regular and SSE) possible.
98pub trait HttpClientExt: WasmCompatSend + WasmCompatSync {
99    /// Send a HTTP request, get a response back (as bytes). Response must be able to be turned back into Bytes.
100    fn send<T, U>(
101        &self,
102        req: Request<T>,
103    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
104    where
105        T: Into<Bytes>,
106        T: WasmCompatSend,
107        U: From<Bytes>,
108        U: WasmCompatSend + 'static;
109
110    /// Send a HTTP request with a multipart body, get a response back (as bytes). Response must be able to be turned back into Bytes (although usually for the response, you will probably want to specify Bytes anyway).
111    fn send_multipart<U>(
112        &self,
113        req: Request<Form>,
114    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
115    where
116        U: From<Bytes>,
117        U: WasmCompatSend + 'static;
118
119    /// Send a HTTP request, get a streamed response back (as a stream of [`bytes::Bytes`].)
120    fn send_streaming<T>(
121        &self,
122        req: Request<T>,
123    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
124    where
125        T: Into<Bytes>;
126}
127
128impl HttpClientExt for reqwest::Client {
129    fn send<T, U>(
130        &self,
131        req: Request<T>,
132    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
133    where
134        T: Into<Bytes>,
135        U: From<Bytes> + WasmCompatSend,
136    {
137        let (parts, body) = req.into_parts();
138        let req = self
139            .request(parts.method, parts.uri.to_string())
140            .headers(parts.headers)
141            .body(body.into());
142
143        async move {
144            let response = req.send().await.map_err(instance_error)?;
145            if !response.status().is_success() {
146                return Err(Error::InvalidStatusCodeWithMessage(
147                    response.status(),
148                    response.text().await.unwrap(),
149                ));
150            }
151
152            let mut res = Response::builder().status(response.status());
153
154            if let Some(hs) = res.headers_mut() {
155                *hs = response.headers().clone();
156            }
157
158            let body: LazyBody<U> = Box::pin(async {
159                let bytes = response
160                    .bytes()
161                    .await
162                    .map_err(|e| Error::Instance(e.into()))?;
163
164                let body = U::from(bytes);
165                Ok(body)
166            });
167
168            res.body(body).map_err(Error::Protocol)
169        }
170    }
171
172    fn send_multipart<U>(
173        &self,
174        req: Request<Form>,
175    ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
176    where
177        U: From<Bytes>,
178        U: WasmCompatSend + 'static,
179    {
180        let (parts, body) = req.into_parts();
181        let req = self
182            .request(parts.method, parts.uri.to_string())
183            .headers(parts.headers)
184            .multipart(body);
185
186        async move {
187            let response = req.send().await.map_err(instance_error)?;
188            if !response.status().is_success() {
189                return Err(Error::InvalidStatusCodeWithMessage(
190                    response.status(),
191                    response.text().await.unwrap(),
192                ));
193            }
194
195            let mut res = Response::builder().status(response.status());
196
197            if let Some(hs) = res.headers_mut() {
198                *hs = response.headers().clone();
199            }
200
201            let body: LazyBody<U> = Box::pin(async {
202                let bytes = response
203                    .bytes()
204                    .await
205                    .map_err(|e| Error::Instance(e.into()))?;
206
207                let body = U::from(bytes);
208                Ok(body)
209            });
210
211            res.body(body).map_err(Error::Protocol)
212        }
213    }
214
215    fn send_streaming<T>(
216        &self,
217        req: Request<T>,
218    ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
219    where
220        T: Into<Bytes>,
221    {
222        let (parts, body) = req.into_parts();
223
224        let req = self
225            .request(parts.method, parts.uri.to_string())
226            .headers(parts.headers)
227            .body(body.into())
228            .build()
229            .map_err(|x| Error::Instance(x.into()))
230            .unwrap();
231
232        let client = self.clone();
233
234        async move {
235            let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
236            if !response.status().is_success() {
237                return Err(Error::InvalidStatusCodeWithMessage(
238                    response.status(),
239                    response.text().await.unwrap(),
240                ));
241            }
242
243            #[cfg(not(target_family = "wasm"))]
244            let mut res = Response::builder()
245                .status(response.status())
246                .version(response.version());
247
248            #[cfg(target_family = "wasm")]
249            let mut res = Response::builder().status(response.status());
250
251            if let Some(hs) = res.headers_mut() {
252                *hs = response.headers().clone();
253            }
254
255            use futures::StreamExt;
256
257            let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
258                Box::pin(
259                    response
260                        .bytes_stream()
261                        .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
262                );
263
264            res.body(mapped_stream).map_err(Error::Protocol)
265        }
266    }
267}