use std::collections::HashMap;
use super::types::HttpMethod;
const MAX_URL_LENGTH: usize = 8192;
const ALLOWED_SCHEMES: &[&str] = &["http://", "https://"];
#[derive(Clone, Debug, Default)]
pub struct HttpRequest {
pub method: HttpMethod,
pub(crate) url: String,
pub headers: HashMap<String, String>,
pub body: String,
pub params: HashMap<String, String>,
}
impl HttpRequest {
pub fn new(url: impl Into<String>) -> Option<Self> {
let url = url.into();
if url.len() > MAX_URL_LENGTH {
return None;
}
let has_valid_scheme = ALLOWED_SCHEMES.iter().any(|scheme| {
url.as_bytes()
.get(0..scheme.len())
.map(|prefix| prefix == scheme.as_bytes())
.unwrap_or(false)
});
if !has_valid_scheme {
return None;
}
Some(Self {
url,
..Default::default()
})
}
pub fn url(&self) -> &str {
&self.url
}
pub fn method(mut self, method: HttpMethod) -> Self {
self.method = method;
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn body(mut self, body: impl Into<String>) -> Self {
self.body = body.into();
self
}
pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.params.insert(key.into(), value.into());
self
}
pub fn full_url(&self) -> String {
if self.params.is_empty() {
self.url.clone()
} else {
let params: Vec<String> = self
.params
.iter()
.map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
.collect();
format!("{}?{}", self.url, params.join("&"))
}
}
}
fn percent_encode(s: &str) -> String {
let mut encoded = String::with_capacity(s.len());
for byte in s.as_bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
encoded.push(*byte as char);
}
b' ' => encoded.push('+'),
_ => {
encoded.push('%');
encoded.push(hex_char(*byte >> 4));
encoded.push(hex_char(*byte & 0x0F));
}
}
}
encoded
}
fn hex_char(nibble: u8) -> char {
match nibble {
0..=9 => (b'0' + nibble) as char,
10..=15 => (b'A' + nibble - 10) as char,
_ => unreachable!(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_percent_encode() {
assert_eq!(percent_encode("hello"), "hello");
assert_eq!(percent_encode("hello world"), "hello+world");
assert_eq!(percent_encode("a&b=c"), "a%26b%3Dc");
assert_eq!(percent_encode("<script>"), "%3Cscript%3E");
assert_eq!(percent_encode("a/b"), "a%2Fb");
}
#[test]
fn test_new_valid_urls() {
assert!(HttpRequest::new("http://example.com").is_some());
assert!(HttpRequest::new("https://example.com").is_some());
assert!(HttpRequest::new("https://example.com/path?query=value").is_some());
}
#[test]
fn test_new_rejects_invalid_schemes() {
assert!(HttpRequest::new("file:///etc/passwd").is_none());
assert!(HttpRequest::new("ftp://example.com").is_none());
assert!(HttpRequest::new("javascript:alert(1)").is_none());
assert!(HttpRequest::new("data:text/html,<script>alert(1)</script>").is_none());
assert!(HttpRequest::new("//example.com").is_none()); }
#[test]
fn test_new_rejects_long_urls() {
let long_url = format!("https://example.com/?{}", "a".repeat(MAX_URL_LENGTH));
assert!(HttpRequest::new(long_url).is_none());
}
#[test]
fn test_full_url_encoding() {
let req = HttpRequest::new("https://example.com").unwrap();
let req = req.param("q", "hello world").param("filter", "a&b=c");
let full = req.full_url();
assert!(full.contains("q=hello+world"));
assert!(full.contains("filter=a%26b%3Dc"));
}
#[test]
fn test_full_url_no_params() {
let req = HttpRequest::new("https://example.com/path").unwrap();
assert_eq!(req.full_url(), "https://example.com/path");
}
}