use std::fmt;
use std::mem;
use std::ops::RangeInclusive;
use bitflags::bitflags;
#[derive(Clone, Debug, Default)]
struct CharacterClass {
negated: bool,
named: Vec<fn(u8) -> bool>,
listed: Vec<u8>,
ranges: Vec<RangeInclusive<u8>>,
}
impl CharacterClass {
pub fn matches(&self, ch: u8) -> bool {
self.matches_do(ch) != self.negated
}
fn matches_do(&self, ch: u8) -> bool {
self.listed.contains(&ch)
|| self.ranges.iter().any(|range| range.contains(&ch))
|| self.named.iter().any(|func| func(ch))
}
}
#[derive(Clone, Debug)]
enum Component {
Literal(Vec<u8>),
QuestionMark,
Star,
StarStar,
Class(CharacterClass),
}
impl Component {
fn ends_with_slash(&self) -> bool {
match self {
Component::Literal(lit) => lit.last().copied() == Some(b'/'),
_ => false,
}
}
fn starts_with_slash(&self) -> bool {
match self {
Component::Literal(lit) => {
lit.first().copied() == Some(b'/')
|| (lit.first().copied() == Some(b'\\') && lit.get(1).copied() == Some(b'/'))
}
_ => false,
}
}
}
bitflags! {
pub struct PatternFlag: u8 {
const IGNORE_CASE = 0x01;
const PATH_NAME = 0x02;
}
}
#[derive(Clone, Debug)]
pub enum ParseError {
EmptyPattern,
NulByteError,
TrailingBackslash,
UnclosedCharacterClass(usize),
MalformedNamedCharacterClass(usize),
}
impl std::error::Error for ParseError {}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ParseError::EmptyPattern => write!(f, "empty pattern"),
ParseError::NulByteError => write!(f, "null-byte in pattern"),
ParseError::TrailingBackslash => write!(f, "trailing backslash in pattern"),
ParseError::UnclosedCharacterClass(begin) => write!(
f,
"unclosed character class in pattern, starting at byte {}",
begin
),
ParseError::MalformedNamedCharacterClass(begin) => write!(
f,
"malformed named character class in pattern, starting at byte {}",
begin
),
}
}
}
enum MatchResult {
Match,
NoMatch,
AbortAll,
AbortToStarStar,
}
#[derive(Clone, Debug)]
pub struct Pattern {
pattern: std::ffi::CString,
components: Vec<Component>,
flags: PatternFlag,
}
impl Pattern {
pub fn pattern(&self) -> &std::ffi::CStr {
&self.pattern
}
pub fn new<T: AsRef<[u8]>>(pattern: T, flags: PatternFlag) -> Result<Self, ParseError> {
Self::new_do(pattern.as_ref(), flags)
}
pub fn path<T: AsRef<[u8]>>(pattern: T) -> Result<Self, ParseError> {
Self::new_do(pattern.as_ref(), PatternFlag::PATH_NAME)
}
fn new_do(pattern: &[u8], flags: PatternFlag) -> Result<Self, ParseError> {
if pattern.is_empty() {
return Err(ParseError::EmptyPattern);
}
let pattern = match pattern.iter().rposition(|&b| b != b'/') {
Some(pos) => &pattern[..=pos],
None => b"/",
};
let c_pattern = std::ffi::CString::new(pattern).map_err(|_| ParseError::NulByteError)?;
let mut components = Vec::<Component>::new();
let mut literal = Vec::<u8>::new();
fn push_literal(
literal: &mut Vec<u8>,
components: &mut Vec<Component>,
flags: PatternFlag,
) {
if !literal.is_empty() {
if flags.intersects(PatternFlag::IGNORE_CASE) {
for b in &mut literal[..] {
*b = b.to_ascii_lowercase();
}
}
components.push(Component::Literal(mem::take(literal)));
}
}
let mut i = 0;
while i != pattern.len() {
match pattern[i] {
0 => return Err(ParseError::NulByteError),
b'\\' => {
i += 1;
let mut ch = *pattern.get(i).ok_or(ParseError::TrailingBackslash)?;
if flags.intersects(PatternFlag::IGNORE_CASE) {
ch = ch.to_ascii_lowercase()
}
literal.push(ch);
}
b'?' => {
push_literal(&mut literal, &mut components, flags);
components.push(Component::QuestionMark);
}
b'*' => {
push_literal(&mut literal, &mut components, flags);
if pattern.get(i + 1).copied() == Some(b'*') {
let beg = i;
i += 1;
while pattern.get(i + 1).copied() == Some(b'*') {
i += 1;
}
if (beg == 0 || pattern[beg - 1] == b'/')
&& ((i + 1) == pattern.len() || pattern[i + 1] == b'/')
{
components.push(Component::StarStar)
} else {
components.push(Component::Star)
}
} else {
components.push(Component::Star)
}
}
b'[' => {
push_literal(&mut literal, &mut components, flags);
let (component, new_i) = Self::parse_char_class(pattern, i, flags)?;
i = new_i;
components.push(component);
}
ch => literal.push(if flags.intersects(PatternFlag::IGNORE_CASE) {
ch.to_ascii_lowercase()
} else {
ch
}),
}
i += 1;
}
push_literal(&mut literal, &mut components, flags);
Ok(Self {
pattern: c_pattern,
components,
flags,
})
}
fn parse_char_class(
pattern: &[u8],
begin_i: usize,
flags: PatternFlag,
) -> Result<(Component, usize), ParseError> {
let mut i = begin_i + 1;
let negated = if pattern.get(i).copied() == Some(b'^') {
i += 1;
true
} else {
false
};
if i == pattern.len() {
return Err(ParseError::UnclosedCharacterClass(begin_i));
}
let mut class = CharacterClass {
negated,
..Default::default()
};
let mut prev = None;
while i != pattern.len() {
let mut new_prev = None;
match pattern[i] {
0 => return Err(ParseError::NulByteError),
b'[' if pattern[(i + 1)..].starts_with(b":alnum:]") => {
i += 8;
class.named.push(|b| b.is_ascii_alphanumeric());
}
b'[' if pattern[(i + 1)..].starts_with(b":alpha:]") => {
i += 8;
class.named.push(|b| b.is_ascii_alphabetic());
}
b'[' if pattern[(i + 1)..].starts_with(b":blank:]") => {
i += 8;
class.named.push(|b| b == b' ' || b == b'\t');
}
b'[' if pattern[(i + 1)..].starts_with(b":cntrl:]") => {
i += 8;
class.named.push(|b| b.is_ascii_control());
}
b'[' if pattern[(i + 1)..].starts_with(b":digit:]") => {
i += 8;
class.named.push(|b| b.is_ascii_digit());
}
b'[' if pattern[(i + 1)..].starts_with(b":graph:]") => {
i += 8;
class.named.push(|b| b.is_ascii_graphic());
}
b'[' if pattern[(i + 1)..].starts_with(b":lower:]") => {
i += 8;
class.named.push(|b| b.is_ascii_lowercase());
}
b'[' if pattern[(i + 1)..].starts_with(b":print:]") => {
i += 8;
class.named.push(|b| b >= 0x20 && b <= 0x7f);
}
b'[' if pattern[(i + 1)..].starts_with(b":punct:]") => {
i += 8;
class.named.push(|b| b.is_ascii_punctuation());
}
b'[' if pattern[(i + 1)..].starts_with(b":space:]") => {
i += 8;
class.named.push(|b| b.is_ascii_whitespace());
}
b'[' if pattern[(i + 1)..].starts_with(b":upper:]") => {
i += 8;
if flags.intersects(PatternFlag::IGNORE_CASE) {
class.named.push(|b| b.is_ascii_lowercase());
} else {
class.named.push(|b| b.is_ascii_uppercase());
}
}
b'[' if pattern[(i + 1)..].starts_with(b":xdigit:]") => {
i += 9;
class.named.push(|b| b.is_ascii_hexdigit());
}
b'[' if pattern.get(i + 1).copied() == Some(b':') => {
return Err(ParseError::MalformedNamedCharacterClass(begin_i));
}
b']' => break,
b'\\' => {
i += 1;
let ch = *pattern.get(i).ok_or(ParseError::TrailingBackslash)?;
class.listed.push(ch);
new_prev = Some(ch);
}
b'-' => match prev {
None => {
new_prev = Some(b'-');
class.listed.push(b'-');
}
Some(beg) => {
class.listed.pop();
i += 1;
let mut end = *pattern
.get(i)
.ok_or(ParseError::UnclosedCharacterClass(begin_i))?;
if end == b'\\' {
i += 1;
end = *pattern
.get(i)
.ok_or(ParseError::UnclosedCharacterClass(begin_i))?;
}
if flags.intersects(PatternFlag::IGNORE_CASE) {
end = end.to_ascii_lowercase();
}
if beg <= end {
class.ranges.push(beg..=end);
} else {
class.ranges.push(end..=beg);
}
}
},
mut ch => {
if flags.intersects(PatternFlag::IGNORE_CASE) {
ch = ch.to_ascii_lowercase();
}
new_prev = Some(ch);
class.listed.push(ch);
}
}
prev = new_prev;
i += 1;
}
Ok((Component::Class(class), i))
}
pub fn matches<T: AsRef<[u8]>>(&self, text: T) -> bool {
match self.do_matches(0, text.as_ref(), false) {
MatchResult::Match => true,
_ => false,
}
}
fn do_matches(
&self,
mut ci: usize,
mut text: &[u8],
mut skip_slash_in_literal: bool,
) -> MatchResult {
let components = &self.components[..];
if self.flags.intersects(PatternFlag::PATH_NAME) {
}
while ci != components.len() {
let skip_slash_in_literal = mem::replace(&mut skip_slash_in_literal, false);
match &components[ci] {
Component::Literal(literal) => {
if text.is_empty() {
return MatchResult::AbortAll;
}
let literal = if skip_slash_in_literal {
&literal[1..]
} else {
literal
};
if !starts_with(text, &literal, self.flags) {
return MatchResult::NoMatch;
}
text = &text[literal.len()..];
}
Component::QuestionMark => {
if text.is_empty() {
return MatchResult::AbortAll;
}
if text[0] == b'/' && self.flags.intersects(PatternFlag::PATH_NAME) {
return MatchResult::NoMatch;
}
text = &text[1..];
}
Component::Class(class) => {
if text.is_empty() {
return MatchResult::AbortAll;
}
let mut ch = text[0];
if self.flags.intersects(PatternFlag::IGNORE_CASE) {
ch = ch.to_ascii_lowercase();
}
if !class.matches(ch) {
return MatchResult::NoMatch;
}
text = &text[1..];
}
Component::Star if self.flags.intersects(PatternFlag::PATH_NAME) => {
if (ci + 1) == components.len() && !text.contains(&b'/') {
return MatchResult::Match;
}
loop {
if text.is_empty() {
return MatchResult::AbortAll;
}
match self.do_matches(ci + 1, text, false) {
MatchResult::NoMatch => {
if text[0] == b'/' {
return MatchResult::AbortToStarStar;
}
}
other => return other,
}
text = &text[1..];
}
}
Component::Star | Component::StarStar => {
if (ci + 1) == components.len() {
return MatchResult::Match;
}
if let Component::StarStar = components[ci] {
if ((ci > 0 && components[ci - 1].ends_with_slash()) || ci == 0)
&& ((ci + 1) == components.len()
|| components[ci + 1].starts_with_slash())
{
#[allow(clippy::single_match)]
match self.do_matches(ci + 1, text, true) {
MatchResult::Match => return MatchResult::Match,
_ => (), }
}
}
loop {
if text.is_empty() {
return MatchResult::AbortAll;
}
match self.do_matches(ci + 1, text, false) {
MatchResult::NoMatch => (),
MatchResult::AbortToStarStar => (), other => return other,
}
text = &text[1..];
}
}
}
ci += 1;
}
if text.is_empty() {
MatchResult::Match
} else {
MatchResult::NoMatch
}
}
}
fn starts_with(text: &[u8], with: &[u8], flags: PatternFlag) -> bool {
if flags.intersects(PatternFlag::IGNORE_CASE) {
starts_with_caseless(text, with)
} else {
text.starts_with(with)
}
}
fn starts_with_caseless(text: &[u8], with: &[u8]) -> bool {
if text.len() < with.len() {
return false;
}
for i in 0..with.len() {
if text[i].to_ascii_lowercase() != with[i].to_ascii_lowercase() {
return false;
}
}
true
}
#[test]
fn test() {
let pattern = Pattern::new("/hey/*/you", PatternFlag::PATH_NAME).unwrap();
assert!(pattern.matches("/hey/asdf/you"));
assert!(!pattern.matches("/hey/asdf/more/you"));
assert!(!pattern.matches("/heyasdf/you"));
assert!(!pattern.matches("/heyasdfyou"));
assert!(!pattern.matches("/hey/asdfyou"));
assert!(!pattern.matches("/hey/you"));
assert!(pattern.matches("/hey//you"));
let pattern = Pattern::new("/hey/*/you", PatternFlag::empty()).unwrap();
assert!(pattern.matches("/hey/asdf/you"));
assert!(pattern.matches("/hey/asdf/more/you")); assert!(!pattern.matches("/heyasdf/you"));
assert!(!pattern.matches("/heyasdfyou"));
assert!(!pattern.matches("/hey/asdfyou"));
assert!(!pattern.matches("/hey/you"));
assert!(pattern.matches("/hey//you"));
let pattern = Pattern::new("/hey/**/you", PatternFlag::PATH_NAME).unwrap();
assert!(pattern.matches("/hey/asdf/you"));
assert!(pattern.matches("/hey/asdf/more/you"));
assert!(!pattern.matches("/heyasdf/you"));
assert!(!pattern.matches("/heyasdfyou"));
assert!(!pattern.matches("/hey/asdfyou"));
assert!(pattern.matches("/hey/you"));
assert!(pattern.matches("/hey//you"));
let pattern = Pattern::new("/he[yx]/**/you", PatternFlag::PATH_NAME).unwrap();
assert!(pattern.matches("/hey/asdf/you"));
assert!(pattern.matches("/hey/asdf/more/you"));
assert!(!pattern.matches("/heyasdf/you"));
assert!(!pattern.matches("/heyasdfyou"));
assert!(!pattern.matches("/hey/asdfyou"));
assert!(pattern.matches("/hey/you"));
assert!(pattern.matches("/hey//you"));
assert!(pattern.matches("/hex/asdf/you"));
assert!(pattern.matches("/hex/asdf/more/you"));
assert!(!pattern.matches("/hexasdf/you"));
assert!(!pattern.matches("/hexasdfyou"));
assert!(!pattern.matches("/hex/asdfyou"));
assert!(pattern.matches("/hex/you"));
assert!(pattern.matches("/hex//you"));
assert!(!pattern.matches("/hez/asdf/you"));
assert!(!pattern.matches("/hez/asdf/more/you"));
assert!(!pattern.matches("/hezasdf/you"));
assert!(!pattern.matches("/hezasdfyou"));
assert!(!pattern.matches("/hez/asdfyou"));
assert!(!pattern.matches("/hez/you"));
assert!(!pattern.matches("/hez//you"));
let pattern = Pattern::new("/he[^yx]/**/you", PatternFlag::PATH_NAME).unwrap();
assert!(!pattern.matches("/hey/asdf/you"));
assert!(!pattern.matches("/hey/asdf/more/you"));
assert!(!pattern.matches("/heyasdf/you"));
assert!(!pattern.matches("/heyasdfyou"));
assert!(!pattern.matches("/hey/asdfyou"));
assert!(!pattern.matches("/hey/you"));
assert!(!pattern.matches("/hey//you"));
assert!(!pattern.matches("/hex/asdf/you"));
assert!(!pattern.matches("/hex/asdf/more/you"));
assert!(!pattern.matches("/hexasdf/you"));
assert!(!pattern.matches("/hexasdfyou"));
assert!(!pattern.matches("/hex/asdfyou"));
assert!(!pattern.matches("/hex/you"));
assert!(!pattern.matches("/hex//you"));
assert!(pattern.matches("/hez/asdf/you"));
assert!(pattern.matches("/hez/asdf/more/you"));
assert!(!pattern.matches("/hezasdf/you"));
assert!(!pattern.matches("/hezasdfyou"));
assert!(!pattern.matches("/hez/asdfyou"));
assert!(pattern.matches("/hez/you"));
assert!(pattern.matches("/hez//you"));
let wrong = b"/hez/";
for i in 0..wrong.len() {
assert!(!pattern.matches(&wrong[..i]));
}
let pattern = Pattern::new("/tes[a-t]", PatternFlag::PATH_NAME).unwrap();
assert!(!pattern.matches("/testoolong"));
assert!(!pattern.matches("/tes"));
assert!(!pattern.matches("/t"));
assert!(!pattern.matches("/"));
assert!(!pattern.matches(""));
assert!(pattern.matches("/tesa"));
assert!(pattern.matches("/test"));
assert!(!pattern.matches("/tesu"));
let pattern_path = Pattern::new("/tes[a-t]/a?a", PatternFlag::PATH_NAME).unwrap();
let pattern_nopath = Pattern::new("/tes[a-t]/a?a", PatternFlag::empty()).unwrap();
assert!(!pattern_path.matches("/tesu"));
assert!(!pattern_nopath.matches("/tesu"));
assert!(!pattern_path.matches("/tesu/aaa"));
assert!(!pattern_nopath.matches("/tesu/aaa"));
assert!(!pattern_path.matches("/tesu/xax"));
assert!(!pattern_nopath.matches("/tesu/xax"));
assert!(!pattern_path.matches("/test/xax"));
assert!(!pattern_nopath.matches("/test/xax"));
assert!(!pattern_path.matches("/test/a"));
assert!(!pattern_nopath.matches("/test/a"));
assert!(!pattern_path.matches("/test/ab"));
assert!(!pattern_nopath.matches("/test/ab"));
assert!(pattern_path.matches("/test/aba"));
assert!(pattern_nopath.matches("/test/aba"));
assert!(pattern_path.matches("/test/aaa"));
assert!(pattern_nopath.matches("/test/aaa"));
assert!(pattern_path.matches("/test/aba"));
assert!(pattern_nopath.matches("/test/aba"));
assert!(!pattern_path.matches("/test/a/a"));
assert!(pattern_nopath.matches("/test/a/a"));
let pattern = Pattern::new("a*b*c", PatternFlag::PATH_NAME).unwrap();
assert!(pattern.matches("axxbxxc"));
assert!(!pattern.matches("axxbxxcxx"));
assert!(pattern.matches("axxbxxbxxc"));
assert!(!pattern.matches("axxbxxbxxcxx"));
assert!(pattern.matches("axxbxxbxxcxxc"));
assert!(!pattern.matches("axxbxxbxxcxxcxx"));
let pattern = Pattern::new("a*b*c*", PatternFlag::PATH_NAME).unwrap();
assert!(pattern.matches("axxbxxc"));
assert!(pattern.matches("axxbxxcxx"));
assert!(pattern.matches("axxbxxbxxc"));
assert!(pattern.matches("axxbxxbxxcxx"));
assert!(pattern.matches("axxbxxbxxcxxc"));
assert!(pattern.matches("axxbxxbxxcxxcxx"));
let pattern = Pattern::new(
"aB[c-fX-Z][[:upper:]][[:lower:]][[:digit:]k]",
PatternFlag::PATH_NAME | PatternFlag::IGNORE_CASE,
)
.unwrap();
assert!(pattern.matches("aBcUl3"));
assert!(pattern.matches("AbCuL9"));
assert!(!pattern.matches("aBgUl3"));
assert!(!pattern.matches("aBgUl3"));
assert!(!pattern.matches("aBcUlx"));
assert!(pattern.matches("abculk"));
assert!(pattern.matches("abxulk"));
assert!(!pattern.matches("abxul"));
let pattern = Pattern::new("a/b**/c", PatternFlag::PATH_NAME).unwrap();
assert!(pattern.matches("a/bxx/c"));
assert!(!pattern.matches("a/bxx/yy/c"));
let pattern = Pattern::new("**/lost+found", PatternFlag::PATH_NAME).unwrap();
assert!(pattern.matches("/foo/lost+found"));
assert!(pattern.matches("foo/lost+found"));
assert!(pattern.matches("/lost+found"));
assert!(pattern.matches("///lost+found"));
assert!(pattern.matches("lost+found"));
assert!(!pattern.matches("lost+found2"));
assert!(!pattern.matches("lost+found/"));
assert!(!pattern.matches("xlost+found"));
assert!(!pattern.matches("xlost+found/"));
}