1use crate::http_client::sse::BoxedStream;
2use bytes::Bytes;
3pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder};
4use http::{HeaderName, StatusCode};
5use reqwest::Body;
6pub mod multipart;
7pub mod retry;
8pub mod sse;
9use crate::wasm_compat::*;
10pub use multipart::MultipartForm;
11pub use reqwest::Client as ReqwestClient;
12use std::pin::Pin;
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error {
16 #[error("Http error: {0}")]
17 Protocol(#[from] http::Error),
18 #[error("Invalid status code: {0}")]
19 InvalidStatusCode(StatusCode),
20 #[error("Invalid status code {0} with message: {1}")]
21 InvalidStatusCodeWithMessage(StatusCode, String),
22 #[error("Header value outside of legal range: {0}")]
23 InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
24 #[error("Request in error state, cannot access headers")]
25 NoHeaders,
26 #[error("Stream ended")]
27 StreamEnded,
28 #[error("Invalid content type was returned: {0:?}")]
29 InvalidContentType(HeaderValue),
30 #[cfg(not(target_family = "wasm"))]
31 #[error("Http client error: {0}")]
32 Instance(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
33
34 #[cfg(target_family = "wasm")]
35 #[error("Http client error: {0}")]
36 Instance(#[from] Box<dyn std::error::Error + 'static>),
37}
38
39pub type Result<T> = std::result::Result<T, Error>;
40
41#[cfg(not(target_family = "wasm"))]
42pub(crate) fn instance_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
43 Error::Instance(error.into())
44}
45
46#[cfg(target_family = "wasm")]
47fn instance_error<E: std::error::Error + 'static>(error: E) -> Error {
48 Error::Instance(error.into())
49}
50
51async fn non_success_status_error(response: reqwest::Response) -> Error {
52 let status = response.status();
53 let message = response
54 .text()
55 .await
56 .unwrap_or_else(|error| format!("failed to read error response body: {error}"));
57 Error::InvalidStatusCodeWithMessage(status, message)
58}
59
60pub type LazyBytes = WasmBoxedFuture<'static, Result<Bytes>>;
61pub type LazyBody<T> = WasmBoxedFuture<'static, Result<T>>;
62
63pub type StreamingResponse = Response<BoxedStream>;
64
65#[derive(Debug, Clone, Copy)]
66pub struct NoBody;
67
68impl From<NoBody> for Bytes {
69 fn from(_: NoBody) -> Self {
70 Bytes::new()
71 }
72}
73
74impl From<NoBody> for Body {
75 fn from(_: NoBody) -> Self {
76 reqwest::Body::default()
77 }
78}
79
80pub async fn text(response: Response<LazyBody<Vec<u8>>>) -> Result<String> {
81 let text = response.into_body().await?;
82 Ok(String::from(String::from_utf8_lossy(&text)))
83}
84
85pub fn make_auth_header(key: impl AsRef<str>) -> Result<(HeaderName, HeaderValue)> {
86 Ok((
87 http::header::AUTHORIZATION,
88 HeaderValue::from_str(&format!("Bearer {}", key.as_ref()))?,
89 ))
90}
91
92pub fn bearer_auth_header(headers: &mut HeaderMap, key: impl AsRef<str>) -> Result<()> {
93 let (k, v) = make_auth_header(key)?;
94
95 headers.insert(k, v);
96
97 Ok(())
98}
99
100pub fn with_bearer_auth(mut req: Builder, auth: &str) -> Result<Builder> {
101 bearer_auth_header(req.headers_mut().ok_or(Error::NoHeaders)?, auth)?;
102
103 Ok(req)
104}
105
106pub trait HttpClientExt: WasmCompatSend + WasmCompatSync {
108 fn send<T, U>(
110 &self,
111 req: Request<T>,
112 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
113 where
114 T: Into<Bytes>,
115 T: WasmCompatSend,
116 U: From<Bytes>,
117 U: WasmCompatSend + 'static;
118
119 fn send_multipart<U>(
121 &self,
122 req: Request<MultipartForm>,
123 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
124 where
125 U: From<Bytes>,
126 U: WasmCompatSend + 'static;
127
128 fn send_streaming<T>(
130 &self,
131 req: Request<T>,
132 ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
133 where
134 T: Into<Bytes> + WasmCompatSend;
135}
136
137impl HttpClientExt for reqwest::Client {
138 fn send<T, U>(
139 &self,
140 req: Request<T>,
141 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
142 where
143 T: Into<Bytes>,
144 U: From<Bytes> + WasmCompatSend,
145 {
146 let (parts, body) = req.into_parts();
147 let req = self
148 .request(parts.method, parts.uri.to_string())
149 .headers(parts.headers)
150 .body(body.into());
151
152 async move {
153 let response = req.send().await.map_err(instance_error)?;
154 if !response.status().is_success() {
155 return Err(non_success_status_error(response).await);
156 }
157
158 let mut res = Response::builder().status(response.status());
159
160 if let Some(hs) = res.headers_mut() {
161 *hs = response.headers().clone();
162 }
163
164 let body: LazyBody<U> = Box::pin(async {
165 let bytes = response
166 .bytes()
167 .await
168 .map_err(|e| Error::Instance(e.into()))?;
169
170 let body = U::from(bytes);
171 Ok(body)
172 });
173
174 res.body(body).map_err(Error::Protocol)
175 }
176 }
177
178 fn send_multipart<U>(
179 &self,
180 req: Request<MultipartForm>,
181 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
182 where
183 U: From<Bytes>,
184 U: WasmCompatSend + 'static,
185 {
186 let (parts, body) = req.into_parts();
187 let body = reqwest::multipart::Form::from(body);
188
189 let req = self
190 .request(parts.method, parts.uri.to_string())
191 .headers(parts.headers)
192 .multipart(body);
193
194 async move {
195 let response = req.send().await.map_err(instance_error)?;
196 if !response.status().is_success() {
197 return Err(non_success_status_error(response).await);
198 }
199
200 let mut res = Response::builder().status(response.status());
201
202 if let Some(hs) = res.headers_mut() {
203 *hs = response.headers().clone();
204 }
205
206 let body: LazyBody<U> = Box::pin(async {
207 let bytes = response
208 .bytes()
209 .await
210 .map_err(|e| Error::Instance(e.into()))?;
211
212 let body = U::from(bytes);
213 Ok(body)
214 });
215
216 res.body(body).map_err(Error::Protocol)
217 }
218 }
219
220 fn send_streaming<T>(
221 &self,
222 req: Request<T>,
223 ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
224 where
225 T: Into<Bytes> + WasmCompatSend,
226 {
227 let (parts, body) = req.into_parts();
228
229 let client = self.clone();
230
231 async move {
232 let req = self
233 .request(parts.method, parts.uri.to_string())
234 .headers(parts.headers)
235 .body(body.into())
236 .build()
237 .map_err(|error| Error::Instance(error.into()))?;
238 let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
239 if !response.status().is_success() {
240 return Err(non_success_status_error(response).await);
241 }
242
243 #[cfg(not(target_family = "wasm"))]
244 let mut res = Response::builder()
245 .status(response.status())
246 .version(response.version());
247
248 #[cfg(target_family = "wasm")]
249 let mut res = Response::builder().status(response.status());
250
251 if let Some(hs) = res.headers_mut() {
252 *hs = response.headers().clone();
253 }
254
255 use futures::StreamExt;
256
257 let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
258 Box::pin(
259 response
260 .bytes_stream()
261 .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
262 );
263
264 res.body(mapped_stream).map_err(Error::Protocol)
265 }
266 }
267}
268
269#[cfg(feature = "reqwest-middleware")]
270#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-middleware")))]
271impl HttpClientExt for reqwest_middleware::ClientWithMiddleware {
272 fn send<T, U>(
273 &self,
274 req: Request<T>,
275 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
276 where
277 T: Into<Bytes>,
278 U: From<Bytes> + WasmCompatSend,
279 {
280 let (parts, body) = req.into_parts();
281 let req = self
282 .request(parts.method, parts.uri.to_string())
283 .headers(parts.headers)
284 .body(body.into());
285
286 async move {
287 let response = req.send().await.map_err(instance_error)?;
288 if !response.status().is_success() {
289 return Err(non_success_status_error(response).await);
290 }
291
292 let mut res = Response::builder().status(response.status());
293
294 if let Some(hs) = res.headers_mut() {
295 *hs = response.headers().clone();
296 }
297
298 let body: LazyBody<U> = Box::pin(async {
299 let bytes = response
300 .bytes()
301 .await
302 .map_err(|e| Error::Instance(e.into()))?;
303
304 let body = U::from(bytes);
305 Ok(body)
306 });
307
308 res.body(body).map_err(Error::Protocol)
309 }
310 }
311
312 fn send_multipart<U>(
313 &self,
314 req: Request<MultipartForm>,
315 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
316 where
317 U: From<Bytes>,
318 U: WasmCompatSend + 'static,
319 {
320 let (parts, body) = req.into_parts();
321 let body = reqwest::multipart::Form::from(body);
322
323 let req = self
324 .request(parts.method, parts.uri.to_string())
325 .headers(parts.headers)
326 .multipart(body);
327
328 async move {
329 let response = req.send().await.map_err(instance_error)?;
330 if !response.status().is_success() {
331 return Err(non_success_status_error(response).await);
332 }
333
334 let mut res = Response::builder().status(response.status());
335
336 if let Some(hs) = res.headers_mut() {
337 *hs = response.headers().clone();
338 }
339
340 let body: LazyBody<U> = Box::pin(async {
341 let bytes = response
342 .bytes()
343 .await
344 .map_err(|e| Error::Instance(e.into()))?;
345
346 let body = U::from(bytes);
347 Ok(body)
348 });
349
350 res.body(body).map_err(Error::Protocol)
351 }
352 }
353
354 fn send_streaming<T>(
355 &self,
356 req: Request<T>,
357 ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
358 where
359 T: Into<Bytes> + WasmCompatSend,
360 {
361 let (parts, body) = req.into_parts();
362
363 let client = self.clone();
364
365 async move {
366 let req = self
367 .request(parts.method, parts.uri.to_string())
368 .headers(parts.headers)
369 .body(body.into())
370 .build()
371 .map_err(|error| Error::Instance(error.into()))?;
372 let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?;
373 if !response.status().is_success() {
374 return Err(non_success_status_error(response).await);
375 }
376
377 #[cfg(not(target_family = "wasm"))]
378 let mut res = Response::builder()
379 .status(response.status())
380 .version(response.version());
381
382 #[cfg(target_family = "wasm")]
383 let mut res = Response::builder().status(response.status());
384
385 if let Some(hs) = res.headers_mut() {
386 *hs = response.headers().clone();
387 }
388
389 use futures::StreamExt;
390
391 let mapped_stream: Pin<Box<dyn WasmCompatSendStream<InnerItem = Result<Bytes>>>> =
392 Box::pin(
393 response
394 .bytes_stream()
395 .map(|chunk| chunk.map_err(|e| Error::Instance(Box::new(e)))),
396 );
397
398 res.body(mapped_stream).map_err(Error::Protocol)
399 }
400 }
401}
402
403#[cfg(test)]
405pub(crate) mod mock {
406 use super::*;
407 use bytes::Bytes;
408
409 #[derive(Clone)]
413 pub struct MockStreamingClient {
414 pub sse_bytes: Bytes,
415 }
416
417 impl HttpClientExt for MockStreamingClient {
418 fn send<T, U>(
419 &self,
420 _req: Request<T>,
421 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
422 where
423 T: Into<Bytes>,
424 T: WasmCompatSend,
425 U: From<Bytes>,
426 U: WasmCompatSend + 'static,
427 {
428 std::future::ready(Err(Error::InvalidStatusCode(
429 http::StatusCode::NOT_IMPLEMENTED,
430 )))
431 }
432
433 fn send_multipart<U>(
434 &self,
435 _req: Request<MultipartForm>,
436 ) -> impl Future<Output = Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
437 where
438 U: From<Bytes>,
439 U: WasmCompatSend + 'static,
440 {
441 std::future::ready(Err(Error::InvalidStatusCode(
442 http::StatusCode::NOT_IMPLEMENTED,
443 )))
444 }
445
446 fn send_streaming<T>(
447 &self,
448 _req: Request<T>,
449 ) -> impl Future<Output = Result<StreamingResponse>> + WasmCompatSend
450 where
451 T: Into<Bytes> + WasmCompatSend,
452 {
453 let sse_bytes = self.sse_bytes.clone();
454 async move {
455 let byte_stream = futures::stream::iter(vec![Ok::<Bytes, Error>(sse_bytes)]);
456 let boxed_stream: sse::BoxedStream = Box::pin(byte_stream);
457
458 Response::builder()
459 .status(http::StatusCode::OK)
460 .header(http::header::CONTENT_TYPE, "text/event-stream")
461 .body(boxed_stream)
462 .map_err(Error::Protocol)
463 }
464 }
465 }
466}