client_util/
response.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::Arc;
5
6use crate::util::ok;
7use bytes::Buf;
8use bytes::Bytes;
9use http::header::CONTENT_TYPE;
10pub use http::response::Builder;
11pub use http::response::Response;
12use http::HeaderValue;
13use http_body_util::BodyDataStream;
14use http_body_util::BodyExt;
15#[cfg(feature = "serde")]
16use serde::de::DeserializeOwned;
17use std::str::FromStr;
18
19/// Extension trait for [`http::Response`].
20pub trait ResponseExt<B>: Sized {
21    #[cfg(feature = "json")]
22    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
23    fn json<T: DeserializeOwned>(
24        self,
25    ) -> impl Future<Output = Result<Response<T>, ResponseError>> + Send;
26    fn text(self) -> impl Future<Output = Result<Response<String>, ResponseError>> + Send;
27    fn bytes(self) -> impl Future<Output = Result<Response<Bytes>, ResponseError>> + Send;
28    fn data_stream(self) -> Response<BodyDataStream<B>>;
29    fn buffer(self) -> impl Future<Output = Result<Response<impl Buf>, ResponseError>> + Send;
30    #[cfg(feature = "hyper")]
31    #[cfg_attr(docsrs, doc(cfg(feature = "hyper")))]
32    fn hyper_upgrade(
33        self,
34    ) -> impl Future<Output = Result<hyper::upgrade::Upgraded, ResponseError>> + Send;
35}
36
37pub type TextDecodeFn = fn(Vec<u8>) -> Result<String, Box<dyn std::error::Error + Send>>;
38
39#[derive(Debug, thiserror::Error)]
40pub enum ResponseError {
41    #[error("collect body error: {0}")]
42    CollectBody(#[source] Box<dyn std::error::Error + Send>),
43    #[cfg(feature = "json")]
44    #[error("json deserialize error: {0}")]
45    JsonDeserialize(#[from] serde_json::Error),
46    #[error("text decode error for charset {charset}: {error}")]
47    TextDecode {
48        #[source]
49        error: Box<dyn std::error::Error + Send>,
50        charset: String,
51    },
52}
53/// A collection of text decoders.
54#[derive(Debug, Default, Clone)]
55pub struct Decoders {
56    inner: Arc<HashMap<Cow<'static, str>, TextDecodeFn>>,
57}
58
59impl Decoders {
60    pub fn new(map: HashMap<Cow<'static, str>, TextDecodeFn>) -> Self {
61        Decoders {
62            inner: Arc::new(map),
63        }
64    }
65}
66
67impl<B> ResponseExt<B> for Response<B>
68where
69    B: http_body::Body + Send,
70    B::Data: Send,
71    B::Error: std::error::Error + Send + 'static,
72{
73    /// Deserialize the response body as json.
74    #[cfg(feature = "json")]
75    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
76    async fn json<T: DeserializeOwned>(self) -> Result<Response<T>, ResponseError> {
77        use bytes::Buf;
78        let (parts, body) = self.into_parts();
79        let body = body
80            .collect()
81            .await
82            .map_err(|e| ResponseError::CollectBody(Box::new(e)))?
83            .aggregate();
84        let body = serde_json::from_reader::<_, T>(body.reader())
85            .map_err(ResponseError::JsonDeserialize)?;
86        Ok(Response::from_parts(parts, body))
87    }
88
89    /// Deserialize the response body as text.
90    ///
91    /// This function will try to decode the body with the charset specified in the `Content-Type` header.
92    ///
93    /// In most cases, the charset is `utf-8`. If the charset is not `utf-8`, you should enable the `charset` feature.
94    async fn text(self) -> Result<Response<String>, ResponseError> {
95        use mime::Mime;
96        let (parts, body) = self.into_parts();
97        let body = body
98            .collect()
99            .await
100            .map_err(|e| ResponseError::CollectBody(Box::new(e)))?
101            .to_bytes();
102        let mut string_body: Option<String> = None;
103        'decode: {
104            if let Some(mime_type) = parts
105                .headers
106                .get(CONTENT_TYPE)
107                .and_then(ok(HeaderValue::to_str))
108                .and_then(ok(Mime::from_str))
109            {
110                let charset = mime_type.get_param(mime::CHARSET);
111                let custom_charset = match charset {
112                    Some(mime::UTF_8) | None => break 'decode,
113                    Some(custom_charset) => custom_charset,
114                };
115                #[cfg(feature = "charset")]
116                {
117                    use encoding_rs::Encoding;
118                    if let Some(encoding) = Encoding::for_label(custom_charset.as_str().as_bytes())
119                    {
120                        string_body.replace(encoding.decode(&body).0.to_string());
121                        break 'decode;
122                    }
123                }
124                let Some(decoders) = parts.extensions.get::<Decoders>() else {
125                    break 'decode;
126                };
127                let Some(decoder_fn) = decoders.inner.get(custom_charset.as_str()) else {
128                    break 'decode;
129                };
130                string_body = Some((decoder_fn)(body.to_vec()).map_err(|error| {
131                    ResponseError::TextDecode {
132                        error,
133                        charset: custom_charset.to_string(),
134                    }
135                })?);
136            }
137        }
138
139        let string_body = match string_body {
140            Some(string_body) => string_body,
141            None => {
142                String::from_utf8(body.to_vec()).map_err(|error| ResponseError::TextDecode {
143                    error: Box::new(error),
144                    charset: mime::TEXT_PLAIN_UTF_8.to_string(),
145                })?
146            }
147        };
148
149        Ok(Response::from_parts(parts, string_body))
150    }
151
152    /// Wrap the response body as a data stream.
153    #[inline]
154    fn data_stream(self) -> Response<BodyDataStream<B>> {
155        let (parts, body) = self.into_parts();
156        let body = BodyDataStream::new(body);
157        Response::from_parts(parts, body)
158    }
159
160    /// Collect the response body as bytes.
161    async fn bytes(self) -> Result<Response<Bytes>, ResponseError> {
162        let (parts, body) = self.into_parts();
163        let body = body
164            .collect()
165            .await
166            .map_err(|error| ResponseError::CollectBody(Box::new(error)))?
167            .to_bytes();
168        Ok(Response::from_parts(parts, body))
169    }
170
171    /// Collect the response body as buffer.
172    ///
173    /// This function is useful when you want to deserialize the body in various ways.
174    async fn buffer(self) -> Result<Response<impl Buf>, ResponseError> {
175        let (parts, body) = self.into_parts();
176        let body = body
177            .collect()
178            .await
179            .map_err(|error| ResponseError::CollectBody(Box::new(error)))?
180            .aggregate();
181        Ok(Response::from_parts(parts, body))
182    }
183
184    #[cfg(feature = "hyper")]
185    #[cfg_attr(docsrs, doc(cfg(feature = "hyper")))]
186    /// Upgrade the connection to a different protocol with hyper.
187    ///
188    /// This function yield a asynchronous io. You can use this to create a websocket connection by using some websocket lib.
189    async fn hyper_upgrade(self) -> Result<hyper::upgrade::Upgraded, ResponseError> {
190        hyper::upgrade::on(self)
191            .await
192            .map_err(|error| ResponseError::CollectBody(Box::new(error)))
193    }
194}