eve_rs/
request.rs

1use std::{
2	collections::HashMap,
3	fmt::{Debug, Formatter, Result as FmtResult},
4	net::{IpAddr, SocketAddr},
5	str::{self, Utf8Error},
6};
7
8use hyper::{body, Body, Request as HyperRequestInternal, Uri, Version};
9
10use crate::{cookie::Cookie, HttpMethod};
11
12pub type HyperRequest = HyperRequestInternal<Body>;
13
14pub struct Request {
15	pub(crate) socket_addr: SocketAddr,
16	pub(crate) body: Vec<u8>,
17	pub(crate) method: HttpMethod,
18	pub(crate) uri: Uri,
19	pub(crate) version: (u8, u8),
20	pub(crate) headers: HashMap<String, Vec<String>>,
21	pub(crate) query: HashMap<String, String>,
22	pub(crate) params: HashMap<String, String>,
23	pub(crate) cookies: Vec<Cookie>,
24	pub(crate) hyper_request: HyperRequest,
25}
26
27impl Request {
28	pub async fn from_hyper(
29		socket_addr: SocketAddr,
30		req: HyperRequest,
31	) -> Self {
32		let (parts, hyper_body) = req.into_parts();
33		let mut headers = HashMap::<String, Vec<String>>::new();
34		parts.headers.iter().for_each(|(key, value)| {
35			let key = key.to_string();
36			let value = value.to_str();
37
38			if value.is_err() {
39				return;
40			}
41			let value = value.unwrap().to_string();
42
43			if let Some(values) = headers.get_mut(&key) {
44				values.push(value);
45			} else {
46				headers.insert(key.to_string(), vec![value]);
47			}
48		});
49		let body = body::to_bytes(hyper_body).await.unwrap().to_vec();
50		Request {
51			socket_addr,
52			body: body.clone(),
53			method: HttpMethod::from(parts.method.clone()),
54			uri: parts.uri.clone(),
55			version: match parts.version {
56				Version::HTTP_09 => (0, 9),
57				Version::HTTP_10 => (1, 0),
58				Version::HTTP_11 => (1, 1),
59				Version::HTTP_2 => (2, 0),
60				Version::HTTP_3 => (3, 0),
61				_ => (0, 0),
62			},
63			headers: headers.clone(),
64			query: if let Some(query) = parts.uri.query() {
65				serde_urlencoded::from_str(query)
66					.unwrap_or_else(|_| HashMap::new())
67			} else {
68				HashMap::new()
69			},
70			params: HashMap::new(),
71			cookies: vec![],
72			hyper_request: HyperRequest::from_parts(parts, Body::from(body)),
73		}
74	}
75
76	pub fn get_body_bytes(&self) -> &[u8] {
77		&self.body
78	}
79
80	pub fn get_body(&self) -> Result<String, Utf8Error> {
81		Ok(str::from_utf8(&self.body)?.to_string())
82	}
83
84	pub fn get_method(&self) -> &HttpMethod {
85		&self.method
86	}
87
88	pub fn get_length(&self) -> u128 {
89		if let Some(length) = self.headers.get("Content-Length") {
90			if let Some(value) = length.get(0) {
91				if let Ok(value) = value.parse::<u128>() {
92					return value;
93				}
94			}
95		}
96		self.body.len() as u128
97	}
98
99	pub fn get_path(&self) -> String {
100		self.uri.path().to_string()
101	}
102
103	pub fn get_full_url(&self) -> String {
104		self.uri.to_string()
105	}
106
107	pub fn get_origin(&self) -> Option<String> {
108		Some(format!(
109			"{}://{}",
110			self.uri.scheme_str()?,
111			self.uri.authority()?
112		))
113	}
114
115	pub fn get_query_string(&self) -> String {
116		self.uri.query().unwrap_or("").to_string()
117	}
118
119	pub fn get_host(&self) -> String {
120		self.uri
121			.host()
122			.map(String::from)
123			.unwrap_or_else(|| self.get_header("host").unwrap_or_default())
124	}
125
126	pub fn get_host_and_port(&self) -> String {
127		format!(
128			"{}{}",
129			self.uri.host().unwrap(),
130			if let Some(port) = self.uri.port_u16() {
131				format!(":{}", port)
132			} else {
133				String::new()
134			}
135		)
136	}
137
138	pub fn get_content_type(&self) -> String {
139		let c_type = self
140			.get_header("Content-Type")
141			.unwrap_or_else(|| "text/plain".to_string());
142		c_type.split(';').next().unwrap_or("").to_string()
143	}
144
145	pub fn get_charset(&self) -> Option<String> {
146		let header = self.get_header("Content-Type")?;
147		let charset_index = header.find("charset=")?;
148		let data = &header[charset_index..];
149		Some(
150			data[(charset_index + 8)..data.find(';').unwrap_or(data.len())]
151				.to_string(),
152		)
153	}
154
155	pub fn get_protocol(&self) -> String {
156		// TODO support X-Forwarded-Proto
157		self.uri.scheme_str().unwrap_or("http").to_string()
158	}
159
160	pub fn is_secure(&self) -> bool {
161		self.get_protocol() == "https"
162	}
163
164	pub fn get_ip(&self) -> IpAddr {
165		self.socket_addr.ip()
166	}
167
168	pub fn is(&self, mimes: &[&str]) -> bool {
169		let given = self.get_content_type();
170		mimes.iter().any(|mime| mime == &given)
171	}
172
173	// TODO content negotiation
174	// See: https://koajs.com/#request content negotiation
175
176	pub fn get_version(&self) -> String {
177		format!("{}.{}", self.version.0, self.version.1)
178	}
179
180	pub fn get_version_major(&self) -> u8 {
181		self.version.0
182	}
183
184	pub fn get_version_minor(&self) -> u8 {
185		self.version.1
186	}
187
188	pub fn get_header(&self, field: &str) -> Option<String> {
189		self.headers.iter().find_map(|(key, value)| {
190			if key.to_lowercase() == field.to_lowercase() {
191				Some(value.join("\n"))
192			} else {
193				None
194			}
195		})
196	}
197	pub fn get_headers(&self) -> &HashMap<String, Vec<String>> {
198		&self.headers
199	}
200	pub fn set_header(&mut self, field: &str, value: &str) {
201		self.headers
202			.insert(field.to_string(), vec![value.to_string()]);
203	}
204	pub fn append_header(&mut self, key: String, value: String) {
205		if let Some(headers) = self.headers.get_mut(&key) {
206			headers.push(value);
207		} else {
208			self.headers.insert(key, vec![value]);
209		}
210	}
211	pub fn remove_header(&mut self, field: &str) {
212		self.headers.remove(field);
213	}
214
215	pub fn get_query(&self) -> &HashMap<String, String> {
216		&self.query
217	}
218
219	pub fn get_params(&self) -> &HashMap<String, String> {
220		&self.params
221	}
222
223	pub fn get_cookies(&self) -> &Vec<Cookie> {
224		&self.cookies
225	}
226
227	pub fn get_cookie(&self, name: &str) -> Option<&Cookie> {
228		self.cookies.iter().find(|cookie| cookie.key == name)
229	}
230
231	pub fn get_hyper_request(&self) -> &HyperRequest {
232		&self.hyper_request
233	}
234
235	pub fn get_hyper_request_mut(&mut self) -> &mut HyperRequest {
236		&mut self.hyper_request
237	}
238}
239
240#[cfg(debug_assertions)]
241impl Debug for Request {
242	fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
243		f.debug_struct("Request")
244			.field("socket_addr", &self.socket_addr)
245			.field("body", &self.body)
246			.field("method", &self.method)
247			.field("uri", &self.uri)
248			.field("version", &self.version)
249			.field("headers", &self.headers)
250			.field("query", &self.query)
251			.field("params", &self.params)
252			.field("cookies", &self.cookies)
253			.finish()
254	}
255}
256
257#[cfg(not(debug_assertions))]
258impl Debug for Request {
259	fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
260		write!(f, "[Request {} {}]", self.method, self.get_path())
261	}
262}