athene/
request.rs

1#[cfg(feature = "multipart")]
2use crate::multipart::FilePart;
3use crate::{
4    body::HttpBody,
5    error::Error,
6    responder::{Form, FromBytes, Json, Query},
7};
8#[cfg(feature = "cookie")]
9use cookie::{Cookie, CookieJar};
10use headers::{Header, HeaderMapExt};
11use http_body_util::BodyExt;
12use hyper::body::Bytes;
13#[cfg(feature = "multipart")]
14use multimap::MultiMap;
15use serde::de::DeserializeOwned;
16use std::{
17    collections::HashMap,
18    net::SocketAddr,
19    ops::{Deref, DerefMut},
20};
21
22#[derive(Debug)]
23pub struct Request {
24    #[doc(hidden)]
25    inner: hyper::http::Request<HttpBody>,
26    #[doc(hidden)]
27    params: HashMap<String, String>,
28    #[doc(hidden)]
29    addr: SocketAddr,
30}
31
32impl Request {
33    /// Get inner Request
34    #[inline]
35    pub fn request(&mut self) -> &mut hyper::Request<HttpBody> {
36        &mut self.inner
37    }
38
39    #[doc(hidden)]
40    #[inline]
41    pub(crate) fn new(inner: hyper::http::Request<HttpBody>, addr: SocketAddr) -> Self {
42        Request {
43            inner,
44            params: HashMap::new(),
45            addr,
46        }
47    }
48
49    /// Get `Cookie` from Cookie header
50    #[cfg(feature = "cookie")]
51    #[inline]
52    pub fn cookie(&mut self, name: &str) -> Option<Cookie<'static>> {
53        if let Some(mut cookie_iter) = self
54            .inner
55            .headers()
56            .get("Cookie")
57            .and_then(|cookies| cookies.to_str().ok())
58            .map(|cookies_str| cookies_str.split(';').map(|s| s.trim()))
59            .map(|cookie_iter| {
60                cookie_iter.filter_map(|cookie_s| Cookie::parse(cookie_s.to_string()).ok())
61            })
62        {
63            cookie_iter.find(|cookie| cookie.name() == name)
64        } else {
65            None
66        }
67    }
68
69    /// Get Cookies from the Cookie header
70    #[cfg(feature = "cookie")]
71    #[inline]
72    pub fn cookies(&mut self) -> CookieJar {
73        let mut jar = CookieJar::new();
74        if let Some(cookie_iter) = self
75            .inner
76            .headers()
77            .get("Cookie")
78            .and_then(|cookies| cookies.to_str().ok())
79            .map(|cookies_str| cookies_str.split(';').map(|s| s.trim()))
80            .map(|cookie_iter| {
81                cookie_iter.filter_map(|cookie_s| Cookie::parse(cookie_s.to_string()).ok())
82            })
83        {
84            cookie_iter.for_each(|c| jar.add_original(c))
85        }
86        jar
87    }
88
89    /// Get accept.
90    #[inline]
91    pub fn accept(&self) -> Vec<mime::Mime> {
92        let mut list: Vec<mime::Mime> = vec![];
93        if let Some(accept) = self
94            .inner
95            .headers()
96            .get("accept")
97            .and_then(|h| h.to_str().ok())
98        {
99            let parts: Vec<&str> = accept.split(',').collect();
100            for part in parts {
101                if let Ok(mt) = part.parse() {
102                    list.push(mt);
103                }
104            }
105        }
106        list
107    }
108
109    /// Get first accept.
110    #[inline]
111    pub fn first_accept(&self) -> Option<mime::Mime> {
112        let mut accept = self.accept();
113        if !accept.is_empty() {
114            Some(accept.remove(0))
115        } else {
116            None
117        }
118    }
119
120    #[inline]
121    pub fn header<T: Header>(&self) -> Option<T> {
122        self.inner.headers().typed_get()
123    }
124
125    // peer_addr
126    #[inline]
127    pub fn remote_addr(self) -> SocketAddr {
128        self.addr
129    }
130
131    #[inline]
132    pub fn params(&self) -> &HashMap<String, String> {
133        &self.params
134    }
135
136    #[inline]
137    pub fn params_mut(&mut self) -> &mut HashMap<String, String> {
138        &mut self.params
139    }
140
141    #[inline]
142    pub fn param<T: std::str::FromStr>(&mut self, key: &str) -> Result<T, Error> {
143        let value = self
144            .params_mut()
145            .remove(key)
146            .ok_or_else(|| Error::MissingParameter(key.to_string(), false))?;
147        Ok(value
148            .parse::<T>()
149            .map_err(|_| Error::InvalidParameter(key.to_string(), false))?)
150    }
151
152    #[inline]
153    pub fn query<'de, B>(&'de self) -> Result<B, Error>
154    where
155        B: serde::Deserialize<'de>,
156    {
157        let query = self.inner.uri().query().unwrap_or("");
158        serde_urlencoded::from_str::<B>(query).map_err(Error::SerdeUrlDe)
159    }
160
161    #[inline]
162    pub fn content_type(&self) -> Option<&str> {
163        let content_type = self.inner.headers().get("content-type")?;
164        let content_type = content_type.to_str().ok()?;
165        Some(content_type)
166    }
167
168    #[cfg(feature = "multipart")]
169    #[inline]
170    async fn file_part(&mut self) -> Result<MultiMap<String, FilePart>, Error> {
171        let c_type = self.content_type().expect("bad request");
172        let boundary = multer::parse_boundary(c_type)?;
173        let mut multipart =
174            multer::Multipart::new(std::mem::take(self.inner.body_mut()), &boundary);
175        let mut file_parts = MultiMap::new();
176        while let Some(mut field) = multipart.next_field().await? {
177            if let Some(name) = field.name().map(|s| s.to_owned()) {
178                if field.headers().get("content-type").is_some() {
179                    file_parts.insert(name, FilePart::new(&mut field).await?);
180                }
181            }
182        }
183        Ok(file_parts)
184    }
185
186    #[cfg(feature = "multipart")]
187    #[inline]
188    pub async fn file(&mut self, key: &str) -> Result<FilePart, Error> {
189        let file_part = self.file_part().await?;
190        let file_part = file_part.get(key).unwrap();
191        Ok(file_part.to_owned())
192    }
193
194    #[cfg(feature = "multipart")]
195    #[inline]
196    pub async fn files(&mut self, key: &str) -> Result<Vec<FilePart>, Error> {
197        let file_part = self.file_part().await?;
198        let file_part = file_part.get_vec(key).unwrap();
199        Ok(file_part.to_owned())
200    }
201
202    #[cfg(feature = "multipart")]
203    #[inline]
204    pub async fn upload(&mut self, key: &str, save_path: &str) -> Result<u64, Error> {
205        let file = self.file(key).await?;
206        std::fs::create_dir_all(save_path)?;
207        let dest = format!("{}/{}", save_path, file.name().unwrap());
208        Ok(std::fs::copy(file.path(), std::path::Path::new(&dest))?)
209    }
210
211    #[cfg(feature = "multipart")]
212    #[inline]
213    pub async fn uploads(&mut self, key: &str, save_path: &str) -> Result<String, Error> {
214        let files = self.files(key).await?;
215        std::fs::create_dir_all(save_path)?;
216        let mut msgs = Vec::with_capacity(files.len());
217        for file in files {
218            let dest = format!("{}/{}", save_path, file.name().unwrap());
219            if let Err(e) = std::fs::copy(file.path(), std::path::Path::new(&dest)) {
220                return Ok(format!("file not found in request: {e}"));
221            } else {
222                msgs.push(dest);
223            }
224        }
225        Ok(format!("Files uploaded:\n\n{}", msgs.join("\n")))
226    }
227
228    #[inline]
229    pub async fn parse<T>(&mut self) -> Result<T, Error>
230    where
231        T: DeserializeOwned,
232    {
233        let body = self.inner.body_mut().collect().await?.to_bytes();
234        let essence = self.content_type();
235        match essence {
236            Some("application/json") => serde_json::from_slice(&body).map_err(Error::SerdeJson),
237            Some("application/x-www-form-urlencoded") => {
238                serde_urlencoded::from_bytes(&body).map_err(Error::SerdeUrlDe)
239            }
240            #[cfg(feature = "cbor")]
241            Some("application/cbor") => {
242                ciborium::de::from_reader(&body[..]).map_err(|e| Error::Other(e.to_string()))
243            }
244            #[cfg(feature = "msgpack")]
245            Some("application/msgpack") => {
246                rmp_serde::from_slice(&body).map_err(Error::MsgpackDeserialization)
247            }
248            _ => Err(Error::Other(String::from("Invalid Context-Type"))),
249        }
250    }
251
252    #[inline]
253    pub async fn parse_body<T: FromBytes>(&mut self) -> Result<T::Output, Error> {
254        let bytes = self.inner.body_mut().collect().await?.to_bytes();
255        let value = T::from_bytes(bytes)?;
256        Ok(value)
257    }
258
259    #[inline]
260    pub async fn parse_query<T: FromBytes>(&mut self) -> Result<Query<T::Output>, Error> {
261        let query = String::from(self.uri().query().unwrap_or(""));
262        let bytes = Bytes::from(query);
263        let value = T::from_bytes(bytes)?;
264        Ok(Query(value))
265    }
266
267    #[inline]
268    pub async fn parse_json<T: FromBytes>(&mut self) -> Result<Json<T::Output>, Error> {
269        let bytes = self.inner.body_mut().collect().await?.to_bytes();
270        let value = T::from_bytes(bytes)?;
271        Ok(Json(value))
272    }
273
274    #[inline]
275    pub async fn parse_form<T: FromBytes>(&mut self) -> Result<Form<T::Output>, Error> {
276        let bytes = self.inner.body_mut().collect().await?.to_bytes();
277        let value = T::from_bytes(bytes)?;
278        Ok(Form(value))
279    }
280}
281
282impl Deref for Request {
283    type Target = hyper::Request<HttpBody>;
284    #[inline]
285    fn deref(&self) -> &Self::Target {
286        &self.inner
287    }
288}
289
290impl DerefMut for Request {
291    #[inline]
292    fn deref_mut(&mut self) -> &mut Self::Target {
293        &mut self.inner
294    }
295}