1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::Arc;
5
6use crate::util::ok;
7use bytes::Buf;
8use bytes::Bytes;
9use http::header::CONTENT_TYPE;
10pub use http::response::Builder;
11pub use http::response::Response;
12use http::HeaderValue;
13use http_body_util::BodyDataStream;
14use http_body_util::BodyExt;
15#[cfg(feature = "serde")]
16use serde::de::DeserializeOwned;
17use std::str::FromStr;
18
19pub trait ResponseExt<B>: Sized {
21 #[cfg(feature = "json")]
22 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
23 fn json<T: DeserializeOwned>(
24 self,
25 ) -> impl Future<Output = Result<Response<T>, ResponseError>> + Send;
26 fn text(self) -> impl Future<Output = Result<Response<String>, ResponseError>> + Send;
27 fn bytes(self) -> impl Future<Output = Result<Response<Bytes>, ResponseError>> + Send;
28 fn data_stream(self) -> Response<BodyDataStream<B>>;
29 fn buffer(self) -> impl Future<Output = Result<Response<impl Buf>, ResponseError>> + Send;
30 #[cfg(feature = "hyper")]
31 #[cfg_attr(docsrs, doc(cfg(feature = "hyper")))]
32 fn hyper_upgrade(
33 self,
34 ) -> impl Future<Output = Result<hyper::upgrade::Upgraded, ResponseError>> + Send;
35}
36
37pub type TextDecodeFn = fn(Vec<u8>) -> Result<String, Box<dyn std::error::Error + Send>>;
38
39#[derive(Debug, thiserror::Error)]
40pub enum ResponseError {
41 #[error("collect body error: {0}")]
42 CollectBody(#[source] Box<dyn std::error::Error + Send>),
43 #[cfg(feature = "json")]
44 #[error("json deserialize error: {0}")]
45 JsonDeserialize(#[from] serde_json::Error),
46 #[error("text decode error for charset {charset}: {error}")]
47 TextDecode {
48 #[source]
49 error: Box<dyn std::error::Error + Send>,
50 charset: String,
51 },
52}
53#[derive(Debug, Default, Clone)]
55pub struct Decoders {
56 inner: Arc<HashMap<Cow<'static, str>, TextDecodeFn>>,
57}
58
59impl Decoders {
60 pub fn new(map: HashMap<Cow<'static, str>, TextDecodeFn>) -> Self {
61 Decoders {
62 inner: Arc::new(map),
63 }
64 }
65}
66
67impl<B> ResponseExt<B> for Response<B>
68where
69 B: http_body::Body + Send,
70 B::Data: Send,
71 B::Error: std::error::Error + Send + 'static,
72{
73 #[cfg(feature = "json")]
75 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
76 async fn json<T: DeserializeOwned>(self) -> Result<Response<T>, ResponseError> {
77 use bytes::Buf;
78 let (parts, body) = self.into_parts();
79 let body = body
80 .collect()
81 .await
82 .map_err(|e| ResponseError::CollectBody(Box::new(e)))?
83 .aggregate();
84 let body = serde_json::from_reader::<_, T>(body.reader())
85 .map_err(ResponseError::JsonDeserialize)?;
86 Ok(Response::from_parts(parts, body))
87 }
88
89 async fn text(self) -> Result<Response<String>, ResponseError> {
95 use mime::Mime;
96 let (parts, body) = self.into_parts();
97 let body = body
98 .collect()
99 .await
100 .map_err(|e| ResponseError::CollectBody(Box::new(e)))?
101 .to_bytes();
102 let mut string_body: Option<String> = None;
103 'decode: {
104 if let Some(mime_type) = parts
105 .headers
106 .get(CONTENT_TYPE)
107 .and_then(ok(HeaderValue::to_str))
108 .and_then(ok(Mime::from_str))
109 {
110 let charset = mime_type.get_param(mime::CHARSET);
111 let custom_charset = match charset {
112 Some(mime::UTF_8) | None => break 'decode,
113 Some(custom_charset) => custom_charset,
114 };
115 #[cfg(feature = "charset")]
116 {
117 use encoding_rs::Encoding;
118 if let Some(encoding) = Encoding::for_label(custom_charset.as_str().as_bytes())
119 {
120 string_body.replace(encoding.decode(&body).0.to_string());
121 break 'decode;
122 }
123 }
124 let Some(decoders) = parts.extensions.get::<Decoders>() else {
125 break 'decode;
126 };
127 let Some(decoder_fn) = decoders.inner.get(custom_charset.as_str()) else {
128 break 'decode;
129 };
130 string_body = Some((decoder_fn)(body.to_vec()).map_err(|error| {
131 ResponseError::TextDecode {
132 error,
133 charset: custom_charset.to_string(),
134 }
135 })?);
136 }
137 }
138
139 let string_body = match string_body {
140 Some(string_body) => string_body,
141 None => {
142 String::from_utf8(body.to_vec()).map_err(|error| ResponseError::TextDecode {
143 error: Box::new(error),
144 charset: mime::TEXT_PLAIN_UTF_8.to_string(),
145 })?
146 }
147 };
148
149 Ok(Response::from_parts(parts, string_body))
150 }
151
152 #[inline]
154 fn data_stream(self) -> Response<BodyDataStream<B>> {
155 let (parts, body) = self.into_parts();
156 let body = BodyDataStream::new(body);
157 Response::from_parts(parts, body)
158 }
159
160 async fn bytes(self) -> Result<Response<Bytes>, ResponseError> {
162 let (parts, body) = self.into_parts();
163 let body = body
164 .collect()
165 .await
166 .map_err(|error| ResponseError::CollectBody(Box::new(error)))?
167 .to_bytes();
168 Ok(Response::from_parts(parts, body))
169 }
170
171 async fn buffer(self) -> Result<Response<impl Buf>, ResponseError> {
175 let (parts, body) = self.into_parts();
176 let body = body
177 .collect()
178 .await
179 .map_err(|error| ResponseError::CollectBody(Box::new(error)))?
180 .aggregate();
181 Ok(Response::from_parts(parts, body))
182 }
183
184 #[cfg(feature = "hyper")]
185 #[cfg_attr(docsrs, doc(cfg(feature = "hyper")))]
186 async fn hyper_upgrade(self) -> Result<hyper::upgrade::Upgraded, ResponseError> {
190 hyper::upgrade::on(self)
191 .await
192 .map_err(|error| ResponseError::CollectBody(Box::new(error)))
193 }
194}