1use std::{
2 collections::HashMap,
3 fmt::{Debug, Formatter, Result as FmtResult},
4 net::{IpAddr, SocketAddr},
5 str::{self, Utf8Error},
6};
7
8use hyper::{body, Body, Request as HyperRequestInternal, Uri, Version};
9
10use crate::{cookie::Cookie, HttpMethod};
11
12pub type HyperRequest = HyperRequestInternal<Body>;
13
14pub struct Request {
15 pub(crate) socket_addr: SocketAddr,
16 pub(crate) body: Vec<u8>,
17 pub(crate) method: HttpMethod,
18 pub(crate) uri: Uri,
19 pub(crate) version: (u8, u8),
20 pub(crate) headers: HashMap<String, Vec<String>>,
21 pub(crate) query: HashMap<String, String>,
22 pub(crate) params: HashMap<String, String>,
23 pub(crate) cookies: Vec<Cookie>,
24 pub(crate) hyper_request: HyperRequest,
25}
26
27impl Request {
28 pub async fn from_hyper(
29 socket_addr: SocketAddr,
30 req: HyperRequest,
31 ) -> Self {
32 let (parts, hyper_body) = req.into_parts();
33 let mut headers = HashMap::<String, Vec<String>>::new();
34 parts.headers.iter().for_each(|(key, value)| {
35 let key = key.to_string();
36 let value = value.to_str();
37
38 if value.is_err() {
39 return;
40 }
41 let value = value.unwrap().to_string();
42
43 if let Some(values) = headers.get_mut(&key) {
44 values.push(value);
45 } else {
46 headers.insert(key.to_string(), vec![value]);
47 }
48 });
49 let body = body::to_bytes(hyper_body).await.unwrap().to_vec();
50 Request {
51 socket_addr,
52 body: body.clone(),
53 method: HttpMethod::from(parts.method.clone()),
54 uri: parts.uri.clone(),
55 version: match parts.version {
56 Version::HTTP_09 => (0, 9),
57 Version::HTTP_10 => (1, 0),
58 Version::HTTP_11 => (1, 1),
59 Version::HTTP_2 => (2, 0),
60 Version::HTTP_3 => (3, 0),
61 _ => (0, 0),
62 },
63 headers: headers.clone(),
64 query: if let Some(query) = parts.uri.query() {
65 serde_urlencoded::from_str(query)
66 .unwrap_or_else(|_| HashMap::new())
67 } else {
68 HashMap::new()
69 },
70 params: HashMap::new(),
71 cookies: vec![],
72 hyper_request: HyperRequest::from_parts(parts, Body::from(body)),
73 }
74 }
75
76 pub fn get_body_bytes(&self) -> &[u8] {
77 &self.body
78 }
79
80 pub fn get_body(&self) -> Result<String, Utf8Error> {
81 Ok(str::from_utf8(&self.body)?.to_string())
82 }
83
84 pub fn get_method(&self) -> &HttpMethod {
85 &self.method
86 }
87
88 pub fn get_length(&self) -> u128 {
89 if let Some(length) = self.headers.get("Content-Length") {
90 if let Some(value) = length.get(0) {
91 if let Ok(value) = value.parse::<u128>() {
92 return value;
93 }
94 }
95 }
96 self.body.len() as u128
97 }
98
99 pub fn get_path(&self) -> String {
100 self.uri.path().to_string()
101 }
102
103 pub fn get_full_url(&self) -> String {
104 self.uri.to_string()
105 }
106
107 pub fn get_origin(&self) -> Option<String> {
108 Some(format!(
109 "{}://{}",
110 self.uri.scheme_str()?,
111 self.uri.authority()?
112 ))
113 }
114
115 pub fn get_query_string(&self) -> String {
116 self.uri.query().unwrap_or("").to_string()
117 }
118
119 pub fn get_host(&self) -> String {
120 self.uri
121 .host()
122 .map(String::from)
123 .unwrap_or_else(|| self.get_header("host").unwrap_or_default())
124 }
125
126 pub fn get_host_and_port(&self) -> String {
127 format!(
128 "{}{}",
129 self.uri.host().unwrap(),
130 if let Some(port) = self.uri.port_u16() {
131 format!(":{}", port)
132 } else {
133 String::new()
134 }
135 )
136 }
137
138 pub fn get_content_type(&self) -> String {
139 let c_type = self
140 .get_header("Content-Type")
141 .unwrap_or_else(|| "text/plain".to_string());
142 c_type.split(';').next().unwrap_or("").to_string()
143 }
144
145 pub fn get_charset(&self) -> Option<String> {
146 let header = self.get_header("Content-Type")?;
147 let charset_index = header.find("charset=")?;
148 let data = &header[charset_index..];
149 Some(
150 data[(charset_index + 8)..data.find(';').unwrap_or(data.len())]
151 .to_string(),
152 )
153 }
154
155 pub fn get_protocol(&self) -> String {
156 self.uri.scheme_str().unwrap_or("http").to_string()
158 }
159
160 pub fn is_secure(&self) -> bool {
161 self.get_protocol() == "https"
162 }
163
164 pub fn get_ip(&self) -> IpAddr {
165 self.socket_addr.ip()
166 }
167
168 pub fn is(&self, mimes: &[&str]) -> bool {
169 let given = self.get_content_type();
170 mimes.iter().any(|mime| mime == &given)
171 }
172
173 pub fn get_version(&self) -> String {
177 format!("{}.{}", self.version.0, self.version.1)
178 }
179
180 pub fn get_version_major(&self) -> u8 {
181 self.version.0
182 }
183
184 pub fn get_version_minor(&self) -> u8 {
185 self.version.1
186 }
187
188 pub fn get_header(&self, field: &str) -> Option<String> {
189 self.headers.iter().find_map(|(key, value)| {
190 if key.to_lowercase() == field.to_lowercase() {
191 Some(value.join("\n"))
192 } else {
193 None
194 }
195 })
196 }
197 pub fn get_headers(&self) -> &HashMap<String, Vec<String>> {
198 &self.headers
199 }
200 pub fn set_header(&mut self, field: &str, value: &str) {
201 self.headers
202 .insert(field.to_string(), vec![value.to_string()]);
203 }
204 pub fn append_header(&mut self, key: String, value: String) {
205 if let Some(headers) = self.headers.get_mut(&key) {
206 headers.push(value);
207 } else {
208 self.headers.insert(key, vec![value]);
209 }
210 }
211 pub fn remove_header(&mut self, field: &str) {
212 self.headers.remove(field);
213 }
214
215 pub fn get_query(&self) -> &HashMap<String, String> {
216 &self.query
217 }
218
219 pub fn get_params(&self) -> &HashMap<String, String> {
220 &self.params
221 }
222
223 pub fn get_cookies(&self) -> &Vec<Cookie> {
224 &self.cookies
225 }
226
227 pub fn get_cookie(&self, name: &str) -> Option<&Cookie> {
228 self.cookies.iter().find(|cookie| cookie.key == name)
229 }
230
231 pub fn get_hyper_request(&self) -> &HyperRequest {
232 &self.hyper_request
233 }
234
235 pub fn get_hyper_request_mut(&mut self) -> &mut HyperRequest {
236 &mut self.hyper_request
237 }
238}
239
240#[cfg(debug_assertions)]
241impl Debug for Request {
242 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
243 f.debug_struct("Request")
244 .field("socket_addr", &self.socket_addr)
245 .field("body", &self.body)
246 .field("method", &self.method)
247 .field("uri", &self.uri)
248 .field("version", &self.version)
249 .field("headers", &self.headers)
250 .field("query", &self.query)
251 .field("params", &self.params)
252 .field("cookies", &self.cookies)
253 .finish()
254 }
255}
256
257#[cfg(not(debug_assertions))]
258impl Debug for Request {
259 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
260 write!(f, "[Request {} {}]", self.method, self.get_path())
261 }
262}