use std::str::FromStr;
use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
use regex_automata::{
MatchKind,
meta::{self, Regex},
util,
};
use super::error::RuleError;
const ESCAPE: &AsciiSet = &CONTROLS
.add(b'~')
.add(b' ') .add(b'\'')
.add(b'"')
.add(b'`')
.add(b'#') .add(b'<')
.add(b'>')
.add(b'?') .add(b'^')
.add(b'{')
.add(b'}')
.add(b'/') .add(b':')
.add(b';')
.add(b'=')
.add(b'@')
.add(b'[')
.add(b']')
.add(b'$') .add(b'&')
.add(b'+')
.add(b',');
#[derive(Clone, Debug)]
pub struct Rule {
pattern: Regex,
rewrite: String,
flags: Vec<RuleFlag>,
}
impl Rule {
#[inline]
pub fn try_rewrite(&self, uri: &str) -> Option<String> {
let mut caps = self.pattern.create_captures();
self.pattern.captures(uri, &mut caps);
if !caps.is_match() {
return None;
}
let noescape = self
.flags
.iter()
.any(|f| matches!(f, RuleFlag::Mod(RuleMod::NoEscape)));
let mut dst = String::new();
util::interpolate::string(
&self.rewrite,
|index, dst| {
let string = match caps.get_group(index) {
None => return,
Some(span) => &uri[span],
};
if noescape {
return dst.push_str(string);
}
let s = utf8_percent_encode(string, ESCAPE).to_string();
dst.push_str(&s);
},
|name| caps.group_info().to_index(caps.pattern()?, name),
&mut dst,
);
Some(dst)
}
#[inline]
pub(crate) fn shift(&self) -> Option<&RuleShift> {
self.flags.iter().find_map(|f| match f {
RuleFlag::Shift(shift) => Some(shift),
_ => None,
})
}
#[inline]
pub(crate) fn resolve(&self) -> Option<&RuleResolve> {
self.flags.iter().find_map(|f| match f {
RuleFlag::Resolve(resolve) => Some(resolve),
_ => None,
})
}
}
impl FromStr for Rule {
type Err = RuleError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut items = s.split_whitespace().filter(|s| !s.is_empty());
let pattern = items.next().ok_or(RuleError::MissingPattern)?;
let rewrite = items.next().ok_or(RuleError::MissingRewrite)?.to_string();
let flags = match items.next() {
Some(flags) => RuleFlagList::from_str(flags)?.0,
None => Vec::new(),
};
if let Some(next) = items.next() {
return Err(RuleError::InvalidSuffix(next.to_owned()));
}
let insense = flags
.iter()
.any(|f| matches!(f, RuleFlag::Mod(RuleMod::NoCase)));
let regex = Regex::builder()
.configure(
meta::Config::new()
.nfa_size_limit(Some(10 * (1 << 20)))
.hybrid_cache_capacity(2 * (1 << 20))
.match_kind(MatchKind::LeftmostFirst)
.utf8_empty(true),
)
.syntax(util::syntax::Config::new().case_insensitive(insense))
.build(pattern)
.map_err(|err| RuleError::InvalidRegex(err.to_string()))?;
Ok(Self {
pattern: regex,
rewrite,
flags,
})
}
}
struct RuleFlagList(Vec<RuleFlag>);
impl FromStr for RuleFlagList {
type Err = RuleError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if !s.starts_with('[') || !s.ends_with(']') {
return Err(RuleError::FlagsMissingBrackets(s.to_owned()));
}
let flags = s[1..s.len() - 1]
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(RuleFlag::from_str)
.collect::<Result<Vec<RuleFlag>, _>>()?;
if flags.is_empty() {
return Err(RuleError::FlagsEmpty);
}
let num_meta = flags
.iter()
.filter(|f| matches!(f, RuleFlag::Shift(_)))
.count();
let num_response = flags
.iter()
.filter(|f| matches!(f, RuleFlag::Resolve(_)))
.count();
if (num_meta + num_response) > 1 {
return Err(RuleError::FlagsMutuallyExclusive);
}
Ok(Self(flags))
}
}
#[inline]
fn parse_int(s: &str, default: u16) -> Result<u16, RuleError> {
match s.is_empty() {
true => Ok(default),
false => Ok(u16::from_str(s)?),
}
}
#[inline]
fn parse_status(s: &str, default: u16) -> Result<u16, RuleError> {
let status = parse_int(s, default)?;
match !(100..600).contains(&status) {
true => Err(RuleError::InvalidFlagStatus(s.to_owned())),
false => Ok(status),
}
}
#[derive(Clone, Debug)]
pub enum RuleShift {
End,
Last,
Next,
Skip(u16),
}
#[derive(Clone, Debug)]
pub enum RuleMod {
NoCase,
NoEscape,
}
#[derive(Clone, Debug)]
pub enum RuleResolve {
Redirect(u16),
Status(u16),
}
#[derive(Clone, Debug)]
pub enum RuleFlag {
Shift(RuleShift),
Mod(RuleMod),
Resolve(RuleResolve),
}
impl FromStr for RuleFlag {
type Err = RuleError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (p, s) = match s.split_once('=') {
Some((prefix, suffix)) => (prefix, suffix),
None => (s, ""),
};
match p.to_lowercase().as_str() {
"e" | "end" => Ok(Self::Shift(RuleShift::End)),
"l" | "last" => Ok(Self::Shift(RuleShift::Last)),
"n" | "next" => Ok(Self::Shift(RuleShift::Next)),
"s" | "skip" => Ok(Self::Shift(RuleShift::Skip(parse_int(s, 1)?))),
"i" | "insensitive" | "nc" | "nocase" => Ok(Self::Mod(RuleMod::NoCase)),
"ne" | "noescape" => Ok(Self::Mod(RuleMod::NoEscape)),
"r" | "redirect" => Ok(Self::Resolve(RuleResolve::Redirect(parse_status(s, 302)?))),
"f" | "forbidden" => Ok(Self::Resolve(RuleResolve::Status(403))),
"g" | "gone" => Ok(Self::Resolve(RuleResolve::Status(410))),
"" => Ok(Self::Resolve(RuleResolve::Status(parse_status(s, 403)?))),
_ => Err(RuleError::InvalidFlag(s.to_owned())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile() {
let rule = Rule::from_str(" ^/replace/[A-Z]+/$ - [I,F]").unwrap();
assert_eq!(rule.rewrite, "-".to_owned());
assert_eq!(rule.flags.len(), 2);
assert!(matches!(
rule.flags.get(0),
Some(RuleFlag::Mod(RuleMod::NoCase))
));
assert!(matches!(
rule.flags.get(1),
Some(RuleFlag::Resolve(RuleResolve::Status(403)))
));
}
#[test]
fn test_simple_replace() {
let rule = Rule::from_str(r" ^/file/(.*)$ /new/$1 [NE]").unwrap();
assert_eq!(rule.try_rewrite("/no/match"), None);
assert_eq!(
rule.try_rewrite("/file/match"),
Some("/new/match".to_owned())
);
assert_eq!(
rule.try_rewrite("/file/multiple/match"),
Some("/new/multiple/match".to_owned())
);
}
#[test]
fn test_multi_replace() {
let rule = Rule::from_str(r" ^/file/(\w+)/break/(\w+)$ /new/$2/$1 ").unwrap();
assert_eq!(rule.try_rewrite("/file/partial/"), None);
assert_eq!(rule.try_rewrite("/file/partial/break/"), None);
assert_eq!(rule.try_rewrite("/file/partial/break/test "), None);
assert_eq!(
rule.try_rewrite("/file/one/break/two"),
Some("/new/two/one".to_owned())
);
}
#[test]
fn test_named_replace() {
let rule = Rule::from_str(r" ^/file/(?P<name>\w+)$ /$name ").unwrap();
assert_eq!(rule.try_rewrite("/file/"), None);
assert_eq!(
rule.try_rewrite("/file/named_file"),
Some("/named_file".to_owned())
);
}
}