bolt_web/
request.rs

1use bytes::Bytes;
2use futures_util::TryStreamExt;
3use http_body_util::{BodyExt, BodyStream};
4use hyper::header::HeaderName;
5use hyper::{Request, Uri, Version, body::Incoming, header::HeaderValue};
6use mime::Mime;
7use multer::Multipart;
8use serde::de::DeserializeOwned;
9use std::collections::HashMap;
10use tokio::io::AsyncWriteExt;
11use url::form_urlencoded;
12use uuid::Uuid;
13
14use crate::types::{BoltError, FormData, FormFile};
15
16pub struct RequestBody {
17    inner: Option<Request<Incoming>>,
18    params: HashMap<String, String>,
19    form_data_result: Option<Result<FormData, Box<dyn std::error::Error + Send + Sync>>>,
20    temp_paths: Vec<String>,
21    pub log: bool,
22}
23
24#[allow(dead_code)]
25impl RequestBody {
26    pub fn new(req: Request<Incoming>) -> Self {
27        Self {
28            inner: Some(req),
29            params: HashMap::new(),
30            form_data_result: None,
31            temp_paths: Vec::new(),
32            log: false,
33        }
34    }
35
36    pub fn params(&self) -> &HashMap<String, String> {
37        &self.params
38    }
39
40    pub fn param(&self, key: &str) -> String {
41        self.params.get(key).cloned().unwrap_or_default()
42    }
43
44    pub(crate) fn set_params(&mut self, params: HashMap<String, String>) {
45        self.params = params;
46    }
47
48    pub fn method(&self) -> &hyper::Method {
49        self.inner
50            .as_ref()
51            .expect("Cannot access method, request body was consumed.")
52            .method()
53    }
54
55    pub fn path(&self) -> &str {
56        self.inner
57            .as_ref()
58            .expect("Cannot access path, request body was consumed.")
59            .uri()
60            .path()
61    }
62
63    pub fn headers(&self) -> &hyper::HeaderMap {
64        self.inner
65            .as_ref()
66            .expect("Cannot access headers, request body was consumed.")
67            .headers()
68    }
69
70    pub fn set_headers(&mut self, key: &str, value: &str) {
71        let key = HeaderName::from_bytes(key.as_bytes()).expect("Invalid header name");
72        let value = HeaderValue::from_str(value).expect("Invalid header value");
73
74        self.inner
75            .as_mut()
76            .expect("Cannot set headers, request body was consumed.")
77            .headers_mut()
78            .insert(key, value);
79    }
80
81    pub fn get_headers(&mut self, key: &str) -> Option<&HeaderValue> {
82        self.inner
83            .as_ref()
84            .expect("Cannot access headers, request body was consumed.")
85            .headers()
86            .get(key)
87    }
88
89    pub fn uri(&self) -> &Uri {
90        self.inner
91            .as_ref()
92            .expect("Cannot access uri, request body was consumed.")
93            .uri()
94    }
95
96    pub fn version(&self) -> Version {
97        self.inner
98            .as_ref()
99            .expect("Cannot access version, request body was consumed.")
100            .version()
101    }
102
103    pub fn query(&self) -> HashMap<String, String> {
104        self.inner
105            .as_ref()
106            .expect("Cannot access uri, request body was consumed.")
107            .uri()
108            .query()
109            .map(|q| {
110                form_urlencoded::parse(q.as_bytes())
111                    .into_owned()
112                    .collect::<HashMap<String, String>>()
113            })
114            .unwrap_or_default()
115    }
116
117    pub fn query_param(&self, key: &str) -> Option<String> {
118        let query_params = self.query();
119        query_params.get(key).cloned()
120    }
121
122    pub async fn bytes(&mut self) -> Result<Bytes, hyper::Error> {
123        let req = self
124            .inner
125            .take()
126            .expect("Request body has already been consumed.");
127
128        let (_, body) = req.into_parts();
129        let collected = body.collect().await?;
130        Ok(collected.to_bytes())
131    }
132
133    pub async fn text(&mut self) -> Result<String, Box<dyn std::error::Error>> {
134        let bytes = self.bytes().await?;
135        let text = String::from_utf8(bytes.to_vec())?;
136        Ok(text)
137    }
138
139    pub async fn json<T: DeserializeOwned>(&mut self) -> Result<T, Box<dyn std::error::Error>> {
140        let bytes = self.bytes().await?;
141        let value = serde_json::from_slice(&bytes)?;
142        Ok(value)
143    }
144
145    pub fn get_cookie(&self, name: &str) -> Option<String> {
146        self.inner
147            .as_ref()
148            .expect("Request body has already been consumed")
149            .headers()
150            .get(hyper::header::COOKIE)?
151            .to_str()
152            .ok()
153            .and_then(|cookie_header| {
154                cookie_header.split(';').map(|s| s.trim()).find_map(|pair| {
155                    let mut parts = pair.splitn(2, '=');
156                    let key = parts.next()?;
157                    let value = parts.next()?;
158                    if key == name {
159                        Some(value.to_string())
160                    } else {
161                        None
162                    }
163                })
164            })
165    }
166
167    pub async fn form_data(&mut self) -> Result<FormData, BoltError> {
168        if let Some(Ok(fd)) = &self.form_data_result {
169            return Ok(fd.clone());
170        }
171        if let Some(Err(e)) = &self.form_data_result {
172            return Err(Box::new(std::io::Error::new(
173                std::io::ErrorKind::Other,
174                e.to_string(),
175            )));
176        }
177
178        let header_opt = {
179            let req_ref = self
180                .inner
181                .as_ref()
182                .expect("Request body was consumed before form_data call.");
183            req_ref.headers().get(hyper::header::CONTENT_TYPE).cloned()
184        };
185
186        let content_type = match header_opt {
187            Some(header_value) => header_value.to_str()?.parse::<Mime>()?,
188            None => {
189                let err: BoltError = "Missing Content-Type header".into();
190                self.form_data_result = Some(Err(err));
191                return Err("Missing Content-Type header".into());
192            }
193        };
194
195        if content_type.type_() != mime::MULTIPART || content_type.subtype() != mime::FORM_DATA {
196            let err: BoltError = "Content-Type is not multipart/form-data".into();
197            self.form_data_result = Some(Err(err));
198            return Err("Content-Type is not multipart/form-data".into());
199        }
200
201        let boundary = content_type
202            .get_param(mime::BOUNDARY)
203            .ok_or("Missing boundary parameter in Content-Type")?
204            .to_string();
205
206        let (_, body) = self
207            .inner
208            .take()
209            .expect("Request already consumed")
210            .into_parts();
211
212        let stream =
213            BodyStream::new(body).try_filter_map(|frame| async move { Ok(frame.into_data().ok()) });
214
215        let mut multipart = Multipart::new(stream, boundary);
216
217        let mut form_data = FormData {
218            files: Vec::new(),
219            fields: HashMap::new(),
220        };
221
222        while let Ok(Some(mut field)) = multipart.next_field().await {
223            let name = field.name().unwrap_or_default().to_string();
224
225            if let Some(file_name) = field.file_name() {
226                let filename = file_name.to_string();
227                let unique_id = Uuid::new_v4();
228                let temp_path =
229                    std::env::temp_dir().join(format!("bolt_upload_{}_{}", unique_id, filename));
230
231                let mut dest = tokio::fs::File::create(&temp_path).await?;
232
233                while let Some(chunk) = field.chunk().await? {
234                    dest.write_all(&chunk).await?;
235                }
236
237                self.temp_paths.push(temp_path.display().to_string());
238
239                form_data.files.push(FormFile {
240                    field_name: name,
241                    file_name: filename,
242                    content_type: field
243                        .content_type()
244                        .map(|m| m.essence_str().to_string())
245                        .unwrap_or_default(),
246                    temp_path: temp_path.display().to_string(),
247                });
248            } else {
249                form_data.fields.insert(name, field.text().await?);
250            }
251        }
252
253        self.form_data_result = Some(Ok(form_data.clone()));
254        Ok(form_data)
255    }
256
257    pub async fn files(&mut self) -> Result<Vec<FormFile>, BoltError> {
258        let form_data = self.form_data().await?;
259        Ok(form_data.files)
260    }
261
262    pub async fn file(&mut self, name: &str) -> Result<Option<FormFile>, BoltError> {
263        let files = self.files().await?;
264
265        let file = files.iter().find(|f| f.field_name == name);
266        Ok(file.cloned())
267    }
268
269    pub async fn cleanup(&mut self) {
270        for path in self.temp_paths.drain(..) {
271            let _ = tokio::fs::remove_file(&path).await;
272        }
273    }
274}
275
276impl Drop for RequestBody {
277    fn drop(&mut self) {
278        if self.temp_paths.is_empty() {
279            return;
280        }
281
282        let paths = std::mem::take(&mut self.temp_paths);
283
284        if tokio::runtime::Handle::try_current().is_ok() {
285            tokio::spawn(async move {
286                for path in paths {
287                    let _ = tokio::fs::remove_file(&path).await;
288                }
289            });
290        } else {
291            for path in paths {
292                let _ = std::fs::remove_file(&path);
293            }
294        }
295    }
296}