use std::error::Error;
use std::fmt::{self, Display};
pub trait Like {
type Err;
fn like(&self, pattern: &Self) -> Result<bool, Self::Err>;
#[inline]
fn not_like(&self, pattern: &Self) -> Result<bool, Self::Err> {
self.like(pattern).map(|m| !m)
}
}
pub trait ILike {
type Err;
fn ilike(&self, pattern: &Self) -> Result<bool, Self::Err>;
#[inline]
fn not_ilike(&self, pattern: &Self) -> Result<bool, Self::Err> {
self.ilike(pattern).map(|m| !m)
}
}
pub trait Escape {
type Err;
type Output;
fn escape(&self, esc: &Self) -> Result<Self::Output, Self::Err>;
}
trait Traverser {
fn len(&self) -> usize;
fn advance_byte(&mut self);
#[inline]
fn advance_char(&mut self) {
self.advance_byte()
}
fn raw_byte_at(&self, index: usize) -> u8;
#[inline]
fn next_raw_byte(&self) -> u8 {
self.raw_byte_at(0)
}
#[inline]
fn byte_at(&self, index: usize) -> u8 {
self.raw_byte_at(index)
}
#[inline]
fn next_byte(&self) -> u8 {
self.byte_at(0)
}
#[inline]
fn next_raw_char(&self) -> char {
self.next_raw_byte() as char
}
}
#[derive(PartialEq, Copy, Clone)]
enum Matched {
True,
False,
Abort,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct InvalidPatternError;
impl Display for InvalidPatternError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "invalid pattern")
}
}
impl Error for InvalidPatternError {}
fn like<T: Traverser>(input: &mut T, pattern: &mut T) -> Result<Matched, InvalidPatternError> {
if pattern.len() == 1 && pattern.next_raw_byte() == b'%' {
return Ok(Matched::True);
}
while input.len() > 0 && pattern.len() > 0 {
if pattern.next_raw_byte() == b'\\' {
pattern.advance_byte();
if pattern.len() == 0 {
return Err(InvalidPatternError);
}
if input.next_byte() != pattern.next_byte() {
return Ok(Matched::False);
}
} else if pattern.next_raw_byte() == b'%' {
pattern.advance_byte();
while pattern.len() > 0 {
let pattern_raw_char = pattern.next_raw_byte();
if pattern_raw_char == b'%' {
pattern.advance_byte();
} else if pattern_raw_char == b'_' {
if input.len() == 0 {
return Ok(Matched::Abort);
}
input.advance_char();
pattern.advance_byte();
} else {
break;
}
}
if pattern.len() == 0 {
return Ok(Matched::True);
}
let first_pat = if pattern.next_raw_byte() == b'\\' {
if pattern.len() < 2 {
return Err(InvalidPatternError);
}
pattern.byte_at(1)
} else {
pattern.next_byte()
};
while input.len() > 0 {
if input.next_byte() == first_pat {
let matched = like(input, pattern)?;
if matched != Matched::False {
return Ok(matched);
}
}
input.advance_char();
}
return Ok(Matched::Abort);
} else if pattern.next_raw_byte() == b'_' {
input.advance_char();
pattern.advance_byte();
continue;
} else if pattern.next_byte() != input.next_byte() {
return Ok(Matched::False);
}
input.advance_byte();
pattern.advance_byte();
}
if input.len() > 0 {
return Ok(Matched::False);
}
while pattern.len() > 0 && pattern.next_raw_byte() == b'%' {
pattern.advance_byte();
}
if pattern.len() == 0 {
return Ok(Matched::True);
}
Ok(Matched::Abort)
}
struct Bytes<'a> {
bytes: &'a [u8],
}
impl<'a> Bytes<'a> {
#[inline]
const fn from_str(s: &'a str) -> Self {
Self {
bytes: s.as_bytes(),
}
}
#[inline]
const fn from_bytes(bytes: &'a [u8]) -> Self {
Self { bytes }
}
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes = &self.bytes[1..];
}
#[inline]
fn advance_char(&mut self) {
self.advance_byte();
while !self.bytes.is_empty() && (self.raw_byte_at(0) & 0xC0) == 0x80 {
self.advance_byte();
}
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes[index]
}
#[inline]
fn next_raw_char(&self) -> char {
let str = unsafe { std::str::from_utf8_unchecked(self.bytes) };
str.chars().next().unwrap()
}
#[inline]
fn to_str(&self) -> String {
let str = unsafe { std::str::from_utf8_unchecked(self.bytes) };
str.to_string()
}
#[inline]
fn to_vec(&self) -> Vec<u8> {
self.bytes.into()
}
}
struct StrTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> StrTraverser<'a> {
#[inline]
const fn new(s: &'a str) -> Self {
Self {
bytes: Bytes::from_str(s),
}
}
}
impl<'a> Traverser for StrTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn advance_char(&mut self) {
self.bytes.advance_char()
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
#[inline]
fn next_raw_char(&self) -> char {
self.bytes.next_raw_char()
}
}
struct BytesTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> BytesTraverser<'a> {
#[inline]
const fn new(bytes: &'a [u8]) -> Self {
Self {
bytes: Bytes::from_bytes(bytes),
}
}
}
impl<'a> Traverser for BytesTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
}
struct CiStrTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> CiStrTraverser<'a> {
#[inline]
const fn new(s: &'a str) -> Self {
Self {
bytes: Bytes::from_str(s),
}
}
}
impl<'a> Traverser for CiStrTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn advance_char(&mut self) {
self.bytes.advance_char()
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
#[inline]
fn byte_at(&self, index: usize) -> u8 {
self.raw_byte_at(index).to_ascii_lowercase()
}
#[inline]
fn next_raw_char(&self) -> char {
self.bytes.next_raw_char()
}
}
struct CiBytesTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> CiBytesTraverser<'a> {
#[inline]
const fn new(bytes: &'a [u8]) -> Self {
Self {
bytes: Bytes::from_bytes(bytes),
}
}
}
impl<'a> Traverser for CiBytesTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
#[inline]
fn byte_at(&self, index: usize) -> u8 {
self.raw_byte_at(index).to_ascii_lowercase()
}
}
impl Like for str {
type Err = InvalidPatternError;
#[inline]
fn like(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = StrTraverser::new(self);
let mut pattern = StrTraverser::new(pattern);
let result = like(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
impl Like for [u8] {
type Err = InvalidPatternError;
#[inline]
fn like(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = BytesTraverser::new(self);
let mut pattern = BytesTraverser::new(pattern);
let result = like(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
impl ILike for str {
type Err = InvalidPatternError;
#[inline]
fn ilike(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = CiStrTraverser::new(self);
let mut pattern = CiStrTraverser::new(pattern);
let result = like(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
impl ILike for [u8] {
type Err = InvalidPatternError;
#[inline]
fn ilike(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = CiBytesTraverser::new(self);
let mut pattern = CiBytesTraverser::new(pattern);
let result = like(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
trait Owned {
fn new(size: usize) -> Self;
fn append(&mut self, ch: char);
}
trait ToOwned {
type Owned: Owned;
fn to_owned(&self) -> Self::Owned;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct InvalidEscapeError;
impl Display for InvalidEscapeError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "invalid escape")
}
}
impl Error for InvalidEscapeError {}
fn escape<T, R>(pat: &mut T, esc: &mut T) -> Result<R, InvalidEscapeError>
where
T: Traverser + ToOwned<Owned = R>,
R: Owned,
{
let mut result = R::new(pat.len() * 2);
if esc.len() == 0 {
while pat.len() > 0 {
if pat.next_raw_byte() == b'\\' {
result.append('\\');
}
result.append(pat.next_raw_char());
pat.advance_char();
}
} else {
let e = esc.next_raw_char();
esc.advance_char();
if esc.len() != 0 {
return Err(InvalidEscapeError);
}
if e == '\\' {
return Ok(pat.to_owned());
}
let mut afterescape = false;
while pat.len() > 0 {
if pat.next_raw_char() == e && !afterescape {
result.append('\\');
pat.advance_char();
afterescape = true;
} else if pat.next_raw_byte() == b'\\' {
result.append('\\');
if !afterescape {
result.append('\\');
}
pat.advance_char();
afterescape = false;
} else {
result.append(pat.next_raw_char());
pat.advance_char();
afterescape = false;
}
}
}
Ok(result)
}
impl Owned for String {
fn new(size: usize) -> String {
String::with_capacity(size)
}
fn append(&mut self, ch: char) {
self.push(ch)
}
}
impl Owned for Vec<u8> {
fn new(size: usize) -> Vec<u8> {
Vec::with_capacity(size)
}
fn append(&mut self, ch: char) {
self.push(ch as u8)
}
}
impl<'a> ToOwned for StrTraverser<'a> {
type Owned = String;
fn to_owned(&self) -> Self::Owned {
self.bytes.to_str()
}
}
impl<'a> ToOwned for BytesTraverser<'a> {
type Owned = Vec<u8>;
fn to_owned(&self) -> Self::Owned {
self.bytes.to_vec()
}
}
impl Escape for str {
type Err = InvalidEscapeError;
type Output = String;
#[inline]
fn escape(&self, esc: &Self) -> Result<Self::Output, Self::Err> {
let mut p = StrTraverser::new(self);
let mut e = StrTraverser::new(esc);
escape(&mut p, &mut e)
}
}
impl Escape for [u8] {
type Err = InvalidEscapeError;
type Output = Vec<u8>;
#[inline]
fn escape(&self, esc: &Self) -> Result<Self::Output, Self::Err> {
let mut p = BytesTraverser::new(self);
let mut e = BytesTraverser::new(esc);
escape(&mut p, &mut e)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt::Debug;
fn like_test<T: Like + ?Sized>(input: &T, pattern: &T, result: bool)
where
T::Err: Debug,
{
assert_eq!(input.like(pattern).unwrap(), result);
assert_eq!(input.not_like(pattern).unwrap(), !result);
}
fn ilike_test<T: ILike + ?Sized>(input: &T, pattern: &T, result: bool)
where
T::Err: Debug,
{
assert_eq!(input.ilike(pattern).unwrap(), result);
assert_eq!(input.not_ilike(pattern).unwrap(), !result);
}
fn pattern_error_test<T: Like + ILike + ?Sized>(input: &T, error_pattern: &T) {
assert!(input.like(error_pattern).is_err());
assert!(input.not_like(error_pattern).is_err());
assert!(input.ilike(error_pattern).is_err());
assert!(input.not_ilike(error_pattern).is_err());
}
#[test]
fn test_pattern_error() {
pattern_error_test("H", "\\");
pattern_error_test("Hello,世界!", "Hello,世界\\");
pattern_error_test(&b"H"[..], b"\\");
pattern_error_test(&b"Hello!"[..], b"Hello\\");
}
#[test]
fn test_percent_sign() {
let str = "Hello,世界!";
like_test(str, "%", true);
like_test(str, "%%%%%%%%%", true);
like_test(str, "%%%%%%%%%%", true);
ilike_test(str, "%", true);
ilike_test(str, "%%%%%%%%%", true);
ilike_test(str, "%%%%%%%%%%", true);
let bytes = &b"Hello!"[..];
like_test(bytes, b"%", true);
like_test(bytes, b"%%%%%%", true);
like_test(bytes, b"%%%%%%%", true);
ilike_test(bytes, b"%", true);
ilike_test(bytes, b"%%%%%%", true);
ilike_test(bytes, b"%%%%%%%", true);
}
#[test]
fn test_underscore() {
let str: &str = "Hello,世界!";
like_test(str, "_", false);
like_test(str, "________", false);
like_test(str, "_________", true);
like_test(str, "__________", false);
ilike_test(str, "_", false);
ilike_test(str, "________", false);
ilike_test(str, "_________", true);
ilike_test(str, "__________", false);
let bytes: &[u8] = b"Hello!";
like_test(bytes, b"_", false);
like_test(bytes, b"_____", false);
like_test(bytes, b"______", true);
like_test(bytes, b"_______", false);
ilike_test(bytes, b"_", false);
ilike_test(bytes, b"_____", false);
ilike_test(bytes, b"______", true);
ilike_test(bytes, b"_______", false);
}
#[test]
fn test_pattern_without_sign() {
let str_lower: &str = "hello,世界!";
let str_upper: &str = "Hello,世界!";
like_test(str_upper, "Hello", false);
like_test(str_upper, "Hello,!", false);
like_test(str_upper, "Hello,世界!", true);
like_test(str_upper, "hello,世界!", false);
ilike_test(str_upper, "hello,世界!", true);
ilike_test(str_lower, "Hello,世界!", true);
let bytes_lower: &[u8] = b"hello!";
let bytes_upper: &[u8] = b"Hello!";
like_test(bytes_upper, b"Hello", false);
like_test(bytes_upper, b"Hello!", true);
like_test(bytes_upper, b"hello!", false);
ilike_test(bytes_upper, b"hello!", true);
ilike_test(bytes_lower, b"Hello!", true);
}
#[test]
fn test_mixed_pattern() {
let str: &str = "Abc";
like_test(str, "A%_", true);
like_test(str, "A_%", true);
like_test(str, "%b_", true);
like_test(str, "%_c", true);
like_test(str, "_b%", true);
like_test(str, "_%c", true);
ilike_test(str, "a%_", true);
ilike_test(str, "a_%", true);
ilike_test(str, "%B_", true);
ilike_test(str, "%_C", true);
ilike_test(str, "_B%", true);
ilike_test(str, "_%C", true);
let bytes: &[u8] = b"Abc";
like_test(bytes, b"A%_", true);
like_test(bytes, b"A_%", true);
like_test(bytes, b"%b_", true);
like_test(bytes, b"%_c", true);
like_test(bytes, b"_b%", true);
like_test(bytes, b"_%c", true);
ilike_test(bytes, b"a%_", true);
ilike_test(bytes, b"a_%", true);
ilike_test(bytes, b"%B_", true);
ilike_test(bytes, b"%_C", true);
ilike_test(bytes, b"_B%", true);
ilike_test(bytes, b"_%C", true);
}
fn escape_test<T: Escape + ?Sized, R: Into<T::Output>>(input: &T, escape: &T, result: R)
where
T::Output: Debug + PartialEq,
T::Err: Debug,
{
assert_eq!(input.escape(escape).unwrap(), result.into());
}
fn escape_error_test<T: Escape + ?Sized>(input: &T, error_escape: &T) {
assert!(input.escape(error_escape).is_err());
}
#[test]
fn test_escape_error() {
escape_error_test("Hello,世界!", ",!");
escape_error_test(&b"Hello,World!"[..], b",!");
}
#[test]
fn test_escape() {
let str: &str = "Hello,世界!";
escape_test(str, "", str);
escape_test(str, "\\", str);
escape_test(str, "?", "Hello,世界!");
escape_test(str, "H", "\\ello,世界!");
escape_test(str, ",", "Hello\\世界!");
escape_test(str, "!", "Hello,世界\\");
escape_test(str, "世", "Hello,\\界!");
escape_test("Hello,,世界!", ",", "Hello\\,世界!");
escape_test("Hello!世界!", "!", "Hello\\世界\\");
escape_test("Hello\\!世界!", "", "Hello\\\\!世界!");
escape_test("Hello$%$_世界!", "$", "Hello\\%\\_世界!");
let bytes: &[u8] = b"Hello,World!";
escape_test(bytes, b"", bytes);
escape_test(bytes, b"\\", bytes);
escape_test(bytes, b"?", &b"Hello,World!"[..]);
escape_test(bytes, b"H", &b"\\ello,World!"[..]);
escape_test(bytes, b",", &b"Hello\\World!"[..]);
escape_test(bytes, b"!", &b"Hello,World\\"[..]);
escape_test(&b"Hello,,World!"[..], b",", &b"Hello\\,World!"[..]);
escape_test(&b"Hello!World!"[..], b"!", &b"Hello\\World\\"[..]);
escape_test(&b"Hello\\World!"[..], b"", &b"Hello\\\\World!"[..]);
escape_test(&b"Hello$%$_World!"[..], b"$", &b"Hello\\%\\_World!"[..]);
}
}