#[cfg(test)]
mod tests;
use crate::application::Application;
use crate::header::Header;
use crate::middleware::Middleware;
use crate::request::Request;
use crate::response::Response;
use crate::server::ConnectionInfo;
enum RequestRule {
SetHeader { name: String, value: String },
RemoveHeader(String),
SetUri(String),
StripUriPrefix(String),
AddUriPrefix(String),
}
enum ResponseRule {
SetHeader { name: String, value: String },
RemoveHeader(String),
SetStatus { code: i16, reason: String },
BodyReplace { from: Vec<u8>, to: Vec<u8> },
}
pub struct RewriteLayer {
request_rules: Vec<RequestRule>,
response_rules: Vec<ResponseRule>,
}
impl RewriteLayer {
pub fn new() -> Self {
RewriteLayer { request_rules: Vec::new(), response_rules: Vec::new() }
}
pub fn request_header_set(mut self, name: &str, value: &str) -> Self {
self.request_rules.push(RequestRule::SetHeader {
name: name.to_string(),
value: value.to_string(),
});
self
}
pub fn request_header_remove(mut self, name: &str) -> Self {
self.request_rules.push(RequestRule::RemoveHeader(name.to_string()));
self
}
pub fn request_uri_set(mut self, uri: &str) -> Self {
self.request_rules.push(RequestRule::SetUri(uri.to_string()));
self
}
pub fn request_uri_strip_prefix(mut self, prefix: &str) -> Self {
self.request_rules.push(RequestRule::StripUriPrefix(prefix.to_string()));
self
}
pub fn request_uri_add_prefix(mut self, prefix: &str) -> Self {
self.request_rules.push(RequestRule::AddUriPrefix(prefix.to_string()));
self
}
pub fn response_header_set(mut self, name: &str, value: &str) -> Self {
self.response_rules.push(ResponseRule::SetHeader {
name: name.to_string(),
value: value.to_string(),
});
self
}
pub fn response_header_remove(mut self, name: &str) -> Self {
self.response_rules.push(ResponseRule::RemoveHeader(name.to_string()));
self
}
pub fn response_status(mut self, code: i16, reason: &str) -> Self {
self.response_rules.push(ResponseRule::SetStatus { code, reason: reason.to_string() });
self
}
pub fn response_body_replace(mut self, from: &str, to: &str) -> Self {
self.response_rules.push(ResponseRule::BodyReplace {
from: from.as_bytes().to_vec(),
to: to.as_bytes().to_vec(),
});
self
}
}
impl Middleware for RewriteLayer {
fn handle(
&self,
request: &Request,
connection: &ConnectionInfo,
next: &dyn Application,
) -> Result<Response, String> {
let mut req = request.clone();
for rule in &self.request_rules {
match rule {
RequestRule::SetHeader { name, value } => {
req.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
req.headers.push(Header { name: name.clone(), value: value.clone() });
}
RequestRule::RemoveHeader(name) => {
req.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
}
RequestRule::SetUri(uri) => {
req.request_uri = uri.clone();
}
RequestRule::StripUriPrefix(prefix) => {
if let Some(stripped) = req.request_uri.strip_prefix(prefix.as_str()) {
req.request_uri = if stripped.is_empty() || !stripped.starts_with('/') {
format!("/{}", stripped)
} else {
stripped.to_string()
};
}
}
RequestRule::AddUriPrefix(prefix) => {
req.request_uri = format!("{}{}", prefix, req.request_uri);
}
}
}
let mut response = next.execute(&req, connection)?;
for rule in &self.response_rules {
match rule {
ResponseRule::SetHeader { name, value } => {
response.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
response.headers.push(Header { name: name.clone(), value: value.clone() });
}
ResponseRule::RemoveHeader(name) => {
response.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
}
ResponseRule::SetStatus { code, reason } => {
response.status_code = *code;
response.reason_phrase = reason.clone();
}
ResponseRule::BodyReplace { from, to } => {
for cr in &mut response.content_range_list {
cr.body = replace_bytes(&cr.body, from, to);
}
}
}
}
Ok(response)
}
}
fn replace_bytes(haystack: &[u8], needle: &[u8], replacement: &[u8]) -> Vec<u8> {
if needle.is_empty() {
return haystack.to_vec();
}
let mut result = Vec::with_capacity(haystack.len());
let mut i = 0;
while i + needle.len() <= haystack.len() {
if haystack[i..].starts_with(needle) {
result.extend_from_slice(replacement);
i += needle.len();
} else {
result.push(haystack[i]);
i += 1;
}
}
result.extend_from_slice(&haystack[i..]);
result
}