use std::collections::TryReserveError;
use std::error::Error;
use std::fmt::{self, Display};
pub trait Like<const HAS_ESCAPE: bool> {
type Err;
fn like(&self, pattern: &Self) -> Result<bool, Self::Err>;
#[inline]
fn not_like(&self, pattern: &Self) -> Result<bool, Self::Err> {
self.like(pattern).map(|m| !m)
}
}
pub trait ILike<const HAS_ESCAPE: bool> {
type Err;
fn ilike(&self, pattern: &Self) -> Result<bool, Self::Err>;
#[inline]
fn not_ilike(&self, pattern: &Self) -> Result<bool, Self::Err> {
self.ilike(pattern).map(|m| !m)
}
}
pub trait Escape {
type Err;
type Output;
fn escape(&self, esc: &Self) -> Result<Self::Output, Self::Err>;
}
trait Traverser {
fn len(&self) -> usize;
fn advance_byte(&mut self);
#[inline]
fn advance_char(&mut self) {
self.advance_byte()
}
fn raw_byte_at(&self, index: usize) -> u8;
#[inline]
fn next_raw_byte(&self) -> u8 {
self.raw_byte_at(0)
}
#[inline]
fn byte_at(&self, index: usize) -> u8 {
self.raw_byte_at(index)
}
#[inline]
fn next_byte(&self) -> u8 {
self.byte_at(0)
}
#[inline]
fn next_raw_char(&self) -> char {
self.next_raw_byte() as char
}
}
#[derive(PartialEq, Copy, Clone)]
enum Matched {
True,
False,
Abort,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct InvalidPatternError;
impl Display for InvalidPatternError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"missing or illegal character following the escape character"
)
}
}
impl Error for InvalidPatternError {}
fn like<T, const HAS_ESCAPE: bool>(
input: &mut T,
pattern: &mut T,
) -> Result<Matched, InvalidPatternError>
where
T: Traverser + Clone,
{
if pattern.len() == 1 && pattern.next_raw_byte() == b'%' {
return Ok(Matched::True);
}
while input.len() > 0 && pattern.len() > 0 {
if HAS_ESCAPE && pattern.next_raw_byte() == b'\\' {
pattern.advance_byte();
if pattern.len() == 0 {
return Err(InvalidPatternError);
}
if input.next_byte() != pattern.next_byte() {
return Ok(Matched::False);
}
} else if pattern.next_raw_byte() == b'%' {
pattern.advance_byte();
while pattern.len() > 0 {
let pattern_raw_char = pattern.next_raw_byte();
if pattern_raw_char == b'%' {
pattern.advance_byte();
} else if pattern_raw_char == b'_' {
if input.len() == 0 {
return Ok(Matched::Abort);
}
input.advance_char();
pattern.advance_byte();
} else {
break; }
}
if pattern.len() == 0 {
return Ok(Matched::True);
}
let first_pat = if HAS_ESCAPE && pattern.next_raw_byte() == b'\\' {
if pattern.len() < 2 {
return Err(InvalidPatternError);
}
pattern.byte_at(1)
} else {
pattern.next_byte()
};
while input.len() > 0 {
if input.next_byte() == first_pat {
let mut i = input.clone();
let mut p = pattern.clone();
let matched = like::<T, HAS_ESCAPE>(&mut i, &mut p)?;
if matched != Matched::False {
return Ok(matched); }
}
input.advance_char();
}
return Ok(Matched::Abort);
} else if pattern.next_raw_byte() == b'_' {
input.advance_char();
pattern.advance_byte();
continue;
} else if pattern.next_byte() != input.next_byte() {
return Ok(Matched::False);
}
input.advance_byte();
pattern.advance_byte();
}
if input.len() > 0 {
return Ok(Matched::False); }
while pattern.len() > 0 && pattern.next_raw_byte() == b'%' {
pattern.advance_byte();
}
if pattern.len() == 0 {
return Ok(Matched::True);
}
Ok(Matched::Abort)
}
#[derive(Clone)]
struct Bytes<'a> {
bytes: &'a [u8],
}
impl<'a> Bytes<'a> {
#[inline]
const fn from_str(s: &'a str) -> Self {
Self {
bytes: s.as_bytes(),
}
}
#[inline]
const fn from_bytes(bytes: &'a [u8]) -> Self {
Self { bytes }
}
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes = &self.bytes[1..];
}
#[inline]
fn advance_char(&mut self) {
self.advance_byte();
while !self.bytes.is_empty() && (self.raw_byte_at(0) & 0xC0) == 0x80 {
self.advance_byte();
}
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes[index]
}
#[inline]
fn next_raw_char(&self) -> char {
let str = unsafe { std::str::from_utf8_unchecked(self.bytes) };
str.chars().next().unwrap()
}
}
#[derive(Clone)]
struct StrTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> StrTraverser<'a> {
#[inline]
const fn new(s: &'a str) -> Self {
Self {
bytes: Bytes::from_str(s),
}
}
}
impl<'a> Traverser for StrTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn advance_char(&mut self) {
self.bytes.advance_char()
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
#[inline]
fn next_raw_char(&self) -> char {
self.bytes.next_raw_char()
}
}
#[derive(Clone)]
struct BytesTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> BytesTraverser<'a> {
#[inline]
const fn new(bytes: &'a [u8]) -> Self {
Self {
bytes: Bytes::from_bytes(bytes),
}
}
}
impl<'a> Traverser for BytesTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
}
#[derive(Clone)]
struct CiStrTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> CiStrTraverser<'a> {
#[inline]
const fn new(s: &'a str) -> Self {
Self {
bytes: Bytes::from_str(s),
}
}
}
impl<'a> Traverser for CiStrTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn advance_char(&mut self) {
self.bytes.advance_char()
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
#[inline]
fn byte_at(&self, index: usize) -> u8 {
self.raw_byte_at(index).to_ascii_lowercase()
}
#[inline]
fn next_raw_char(&self) -> char {
self.bytes.next_raw_char()
}
}
#[derive(Clone)]
struct CiBytesTraverser<'a> {
bytes: Bytes<'a>,
}
impl<'a> CiBytesTraverser<'a> {
#[inline]
const fn new(bytes: &'a [u8]) -> Self {
Self {
bytes: Bytes::from_bytes(bytes),
}
}
}
impl<'a> Traverser for CiBytesTraverser<'a> {
#[inline]
fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
fn advance_byte(&mut self) {
self.bytes.advance_byte();
}
#[inline]
fn raw_byte_at(&self, index: usize) -> u8 {
self.bytes.raw_byte_at(index)
}
#[inline]
fn byte_at(&self, index: usize) -> u8 {
self.raw_byte_at(index).to_ascii_lowercase()
}
}
impl<const HAS_ESCAPE: bool> Like<HAS_ESCAPE> for str {
type Err = InvalidPatternError;
#[inline]
fn like(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = StrTraverser::new(self);
let mut pattern = StrTraverser::new(pattern);
let result = like::<_, HAS_ESCAPE>(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
impl<const HAS_ESCAPE: bool> Like<HAS_ESCAPE> for [u8] {
type Err = InvalidPatternError;
#[inline]
fn like(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = BytesTraverser::new(self);
let mut pattern = BytesTraverser::new(pattern);
let result = like::<_, HAS_ESCAPE>(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
impl<const HAS_ESCAPE: bool> ILike<HAS_ESCAPE> for str {
type Err = InvalidPatternError;
#[inline]
fn ilike(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = CiStrTraverser::new(self);
let mut pattern = CiStrTraverser::new(pattern);
let result = like::<_, HAS_ESCAPE>(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
impl<const HAS_ESCAPE: bool> ILike<HAS_ESCAPE> for [u8] {
type Err = InvalidPatternError;
#[inline]
fn ilike(&self, pattern: &Self) -> Result<bool, Self::Err> {
let mut input = CiBytesTraverser::new(self);
let mut pattern = CiBytesTraverser::new(pattern);
let result = like::<_, HAS_ESCAPE>(&mut input, &mut pattern)?;
Ok(matches!(result, Matched::True))
}
}
trait Owned: Sized {
fn try_new(size: usize) -> Result<Self, TryReserveError>;
fn append(&mut self, ch: char);
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InvalidEscapeError {
MultiChars,
InvalidEscape,
TryReserveError(TryReserveError),
}
impl From<TryReserveError> for InvalidEscapeError {
#[inline]
fn from(e: TryReserveError) -> Self {
InvalidEscapeError::TryReserveError(e)
}
}
impl Display for InvalidEscapeError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
InvalidEscapeError::MultiChars => {
write!(f, "escape must be one character")
}
InvalidEscapeError::InvalidEscape => {
write!(
f,
"missing or illegal character following the escape character"
)
}
InvalidEscapeError::TryReserveError(e) => {
write!(f, "{}", e)
}
}
}
}
impl Error for InvalidEscapeError {}
fn escape<T, R>(pat: &mut T, esc: &mut T) -> Result<R, InvalidEscapeError>
where
T: Traverser,
R: Owned,
{
let mut result = R::try_new(pat.len() * 2)?;
if esc.len() == 0 {
while pat.len() > 0 {
if pat.next_raw_byte() == b'\\' {
result.append('\\');
}
result.append(pat.next_raw_char());
pat.advance_char();
}
} else {
let e = esc.next_raw_char();
esc.advance_char();
if esc.len() != 0 {
return Err(InvalidEscapeError::MultiChars);
}
let mut afterescape = false;
while pat.len() > 0 {
if pat.next_raw_char() == e && !afterescape {
result.append('\\');
pat.advance_char();
if pat.len() == 0 {
return Err(InvalidEscapeError::InvalidEscape);
} else {
let next_pat = pat.next_raw_char();
if next_pat != '%' && next_pat != '_' && next_pat != e {
return Err(InvalidEscapeError::InvalidEscape);
}
}
afterescape = true;
} else if pat.next_raw_byte() == b'\\' {
result.append('\\');
if !afterescape {
result.append('\\');
}
pat.advance_char();
afterescape = false;
} else {
result.append(pat.next_raw_char());
pat.advance_char();
afterescape = false;
}
}
}
Ok(result)
}
impl Owned for String {
#[inline]
fn try_new(size: usize) -> Result<Self, TryReserveError> {
let mut s = String::new();
s.try_reserve(size)?;
Ok(s)
}
#[inline]
fn append(&mut self, ch: char) {
self.push(ch);
}
}
impl Owned for Vec<u8> {
#[inline]
fn try_new(size: usize) -> Result<Self, TryReserveError> {
let mut s = Vec::new();
s.try_reserve(size)?;
Ok(s)
}
#[inline]
fn append(&mut self, ch: char) {
self.push(ch as u8);
}
}
impl Escape for str {
type Err = InvalidEscapeError;
type Output = String;
#[inline]
fn escape(&self, esc: &Self) -> Result<Self::Output, Self::Err> {
let mut p = StrTraverser::new(self);
let mut e = StrTraverser::new(esc);
escape(&mut p, &mut e)
}
}
impl Escape for [u8] {
type Err = InvalidEscapeError;
type Output = Vec<u8>;
#[inline]
fn escape(&self, esc: &Self) -> Result<Self::Output, Self::Err> {
let mut p = BytesTraverser::new(self);
let mut e = BytesTraverser::new(esc);
escape(&mut p, &mut e)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt::Debug;
fn like_test<T, const HAS_ESCAPE: bool>(input: &T, pattern: &T, result: bool)
where
T: Like<HAS_ESCAPE> + ?Sized,
T::Err: Debug,
{
assert_eq!(input.like(pattern).unwrap(), result);
assert_eq!(input.not_like(pattern).unwrap(), !result);
}
fn ilike_test<T, const HAS_ESCAPE: bool>(input: &T, pattern: &T, result: bool)
where
T: ILike<HAS_ESCAPE> + ?Sized,
T::Err: Debug,
{
assert_eq!(input.ilike(pattern).unwrap(), result);
assert_eq!(input.not_ilike(pattern).unwrap(), !result);
}
fn pattern_error_test<T: Like<true> + ILike<true> + ?Sized>(input: &T, error_pattern: &T) {
assert!(input.like(error_pattern).is_err());
assert!(input.not_like(error_pattern).is_err());
assert!(input.ilike(error_pattern).is_err());
assert!(input.not_ilike(error_pattern).is_err());
}
#[test]
fn test_pattern_error() {
pattern_error_test("H", "\\");
pattern_error_test("Hello,世界!", "Hello,世界\\");
pattern_error_test(&b"H"[..], b"\\");
pattern_error_test(&b"Hello!"[..], b"Hello\\");
}
#[test]
fn test_percent_sign() {
let str = "Hello,世界!";
like_test::<_, false>(str, "%", true);
like_test::<_, false>(str, "%%%%%%%%%", true);
like_test::<_, false>(str, "%%%%%%%%%%", true);
ilike_test::<_, false>(str, "%", true);
ilike_test::<_, false>(str, "%%%%%%%%%", true);
ilike_test::<_, false>(str, "%%%%%%%%%%", true);
let bytes = &b"Hello!"[..];
like_test::<_, false>(bytes, b"%", true);
like_test::<_, false>(bytes, b"%%%%%%", true);
like_test::<_, false>(bytes, b"%%%%%%%", true);
ilike_test::<_, false>(bytes, b"%", true);
ilike_test::<_, false>(bytes, b"%%%%%%", true);
ilike_test::<_, false>(bytes, b"%%%%%%%", true);
let str2: &str = "601618";
let bytes2: &[u8] = b"601618";
like_test::<_, false>(str2, "%618", true);
ilike_test::<_, false>(str2, "%618", true);
like_test::<_, false>(bytes2, b"%618", true);
ilike_test::<_, false>(bytes2, b"%618", true);
let str3 = "中文测试";
like_test::<_, false>(str3, "%测试", true);
like_test::<_, false>(str3, "中文测试%%", true);
ilike_test::<_, false>(str3, "%测试", true);
like_test::<_, false>("中文测试%\\", "中文测试%\\", true);
ilike_test::<_, false>("中文测试%\\", "中文测试%\\", true);
}
#[test]
fn test_underscore() {
let str: &str = "Hello,世界!";
like_test::<_, false>(str, "_", false);
like_test::<_, false>(str, "________", false);
like_test::<_, false>(str, "_________", true);
like_test::<_, false>(str, "__________", false);
ilike_test::<_, false>(str, "_", false);
ilike_test::<_, false>(str, "________", false);
ilike_test::<_, false>(str, "_________", true);
ilike_test::<_, false>(str, "__________", false);
let bytes: &[u8] = b"Hello!";
like_test::<_, false>(bytes, b"_", false);
like_test::<_, false>(bytes, b"_____", false);
like_test::<_, false>(bytes, b"______", true);
like_test::<_, false>(bytes, b"_______", false);
ilike_test::<_, false>(bytes, b"_", false);
ilike_test::<_, false>(bytes, b"_____", false);
ilike_test::<_, false>(bytes, b"______", true);
ilike_test::<_, false>(bytes, b"_______", false);
}
#[test]
fn test_pattern_without_sign() {
let str_lower: &str = "hello,世界!";
let str_upper: &str = "Hello,世界!";
like_test::<_, false>(str_upper, "Hello", false);
like_test::<_, false>(str_upper, "Hello,!", false);
like_test::<_, false>(str_upper, "Hello,世界!", true);
like_test::<_, false>(str_upper, "hello,世界!", false);
ilike_test::<_, false>(str_upper, "hello,世界!", true);
ilike_test::<_, false>(str_lower, "Hello,世界!", true);
let bytes_lower: &[u8] = b"hello!";
let bytes_upper: &[u8] = b"Hello!";
like_test::<_, false>(bytes_upper, b"Hello", false);
like_test::<_, false>(bytes_upper, b"Hello!", true);
like_test::<_, false>(bytes_upper, b"hello!", false);
ilike_test::<_, false>(bytes_upper, b"hello!", true);
ilike_test::<_, false>(bytes_lower, b"Hello!", true);
}
#[test]
fn test_mixed_pattern() {
let str: &str = "Abc";
like_test::<_, false>(str, "A%_", true);
like_test::<_, false>(str, "A_%", true);
like_test::<_, false>(str, "%b_", true);
like_test::<_, false>(str, "%_c", true);
like_test::<_, false>(str, "_b%", true);
like_test::<_, false>(str, "_%c", true);
ilike_test::<_, false>(str, "a%_", true);
ilike_test::<_, false>(str, "a_%", true);
ilike_test::<_, false>(str, "%B_", true);
ilike_test::<_, false>(str, "%_C", true);
ilike_test::<_, false>(str, "_B%", true);
ilike_test::<_, false>(str, "_%C", true);
let bytes: &[u8] = b"Abc";
like_test::<_, false>(bytes, b"A%_", true);
like_test::<_, false>(bytes, b"A_%", true);
like_test::<_, false>(bytes, b"%b_", true);
like_test::<_, false>(bytes, b"%_c", true);
like_test::<_, false>(bytes, b"_b%", true);
like_test::<_, false>(bytes, b"_%c", true);
ilike_test::<_, false>(bytes, b"a%_", true);
ilike_test::<_, false>(bytes, b"a_%", true);
ilike_test::<_, false>(bytes, b"%B_", true);
ilike_test::<_, false>(bytes, b"%_C", true);
ilike_test::<_, false>(bytes, b"_B%", true);
ilike_test::<_, false>(bytes, b"_%C", true);
}
#[test]
fn test_like_escape() {
let escape = ",";
let pattern = "Hello,,世界!".escape(escape).unwrap();
like_test::<_, true>("Hello,世界!", pattern.as_str(), true);
ilike_test::<_, true>("hello,世界!", pattern.as_str(), true);
let escape = "?";
let pattern = "Hello\\世界!".escape(escape).unwrap();
like_test::<_, true>("Hello\\世界!", pattern.as_str(), true);
ilike_test::<_, true>("hello\\世界!", pattern.as_str(), true);
let pattern = "Hello,世界!\\".escape(escape).unwrap();
like_test::<_, true>("Hello,世界!\\", pattern.as_str(), true);
ilike_test::<_, true>("hello,世界!\\", pattern.as_str(), true);
let escape = "\\";
let pattern = "Hello,世界!".escape(escape).unwrap();
like_test::<_, true>("Hello,世界!", pattern.as_str(), true);
ilike_test::<_, true>("hello,世界!", pattern.as_str(), true);
}
fn escape_test<T: Escape + ?Sized, R: Into<T::Output>>(input: &T, escape: &T, result: R)
where
T::Output: Debug + PartialEq,
T::Err: Debug,
{
assert_eq!(input.escape(escape).unwrap(), result.into());
}
fn escape_error_test<T: Escape + ?Sized>(input: &T, error_escape: &T) {
assert!(input.escape(error_escape).is_err());
}
#[test]
fn test_escape_error() {
escape_error_test("Hello,世界!", ",!");
escape_error_test("Hello,世界!", "!");
escape_error_test("Hello,世界\\", "\\");
escape_error_test("Hello,世界\\1", "\\");
escape_error_test("Hello,世界", "界");
escape_error_test(&b"Hello,World!"[..], b",!");
}
#[test]
fn test_escape() {
let str: &str = "Hello,世界!";
escape_test(str, "", str);
escape_test(str, "\\", str);
escape_test(str, "?", "Hello,世界!");
escape_test("HHello,世界!", "H", "\\Hello,世界!");
escape_test("Hello,,世界!", ",", "Hello\\,世界!");
escape_test("Hello,世界!!", "!", "Hello,世界\\!");
escape_test("Hello,世世界!", "世", "Hello,\\世界!");
escape_test("Hello,,,,世界!", ",", "Hello\\,\\,世界!");
escape_test("Hello!!世界!!", "!", "Hello\\!世界\\!");
escape_test("Hello\\!世界!", "", "Hello\\\\!世界!");
escape_test("Hello$$%$$_世界!", "$", "Hello\\$%\\$_世界!");
let bytes: &[u8] = b"Hello,World!";
escape_test(bytes, b"", bytes);
escape_test(bytes, b"\\", bytes);
escape_test(bytes, b"?", &b"Hello,World!"[..]);
escape_test(&b"HHello,World!"[..], b"H", &b"\\Hello,World!"[..]);
escape_test(&b"Hello,,World!"[..], b",", &b"Hello\\,World!"[..]);
escape_test(&b"Hello,World!!"[..], b"!", &b"Hello,World\\!"[..]);
escape_test(&b"Hello,,,,World!"[..], b",", &b"Hello\\,\\,World!"[..]);
escape_test(&b"Hello!!World!!"[..], b"!", &b"Hello\\!World\\!"[..]);
escape_test(&b"Hello\\World!"[..], b"", &b"Hello\\\\World!"[..]);
escape_test(&b"Hello$$%$$_World!"[..], b"$", &b"Hello\\$%\\$_World!"[..]);
}
}