use std::hash::{Hash, Hasher};
use std::mem;
use std::ops::{RangeFrom, RangeInclusive};
use bstr::ByteSlice;
use regex_syntax::hir::Class;
use regex_syntax::hir::ClassBytes;
use regex_syntax::hir::ClassBytesRange;
use regex_syntax::hir::ClassUnicode;
use regex_syntax::hir::ClassUnicodeRange;
use regex_syntax::hir::Dot;
use regex_syntax::hir::HirKind;
use regex_syntax::hir::Repetition;
use serde::{Deserialize, Serialize};
use yara_x_parser::ast;
use crate::utils::cast;
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) struct HexByte {
pub value: u8,
pub mask: u8,
}
impl From<ast::HexByte> for HexByte {
fn from(hex_byte: ast::HexByte) -> Self {
Self { value: hex_byte.value, mask: hex_byte.mask }
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub(crate) enum ChainedPatternGap {
Bounded(RangeInclusive<u32>),
Unbounded(RangeFrom<u32>),
}
#[derive(Debug, PartialEq)]
pub(crate) struct ChainedPattern {
pub gap: ChainedPatternGap,
pub hir: Hir,
}
#[derive(Clone, Eq, Debug)]
pub(crate) struct Hir {
pub(super) inner: regex_syntax::hir::Hir,
pub(super) greedy: Option<bool>,
}
impl Hash for Hir {
fn hash<H: Hasher>(&self, state: &mut H) {
regex_syntax::hir::visit(&self.inner, HirHasher { state }).unwrap();
}
}
impl PartialEq for Hir {
fn eq(&self, other: &Self) -> bool {
self.inner.eq(&other.inner)
}
}
impl From<regex_syntax::hir::Hir> for Hir {
fn from(value: regex_syntax::hir::Hir) -> Self {
Self { inner: value, greedy: None }
}
}
impl Hir {
const PATTERN_CHAINING_THRESHOLD: u32 = 200;
const MIN_PATTERN_LENGTH_IN_CHAIN: usize = 2;
pub fn split_at_large_gaps(self) -> (Self, Vec<ChainedPattern>) {
if !matches!(self.kind(), HirKind::Concat(_)) {
return (self, vec![]);
}
let greedy = self.greedy;
let mut gap_min = 0;
let mut gap_max = None;
let mut gap_greedy = false;
let mut chunks = Vec::new();
let mut chain = Vec::new();
for item in cast!(self.into_kind(), HirKind::Concat) {
if let HirKind::Repetition(rep) = item.kind() {
let num_repetitions =
rep.max.unwrap_or(u32::MAX).saturating_sub(rep.min);
if !chunks.is_empty()
&& num_repetitions > Self::PATTERN_CHAINING_THRESHOLD
&& any_byte(rep.sub.as_ref().kind())
{
let hir: Hir = Hir::concat(chunks).set_greedy(greedy);
if hir.minimum_len().unwrap_or(0)
>= Self::MIN_PATTERN_LENGTH_IN_CHAIN
{
chain.push(ChainedPattern {
gap: if let Some(gap_max) = gap_max {
ChainedPatternGap::Bounded(gap_min..=gap_max)
} else {
ChainedPatternGap::Unbounded(gap_min..)
},
hir,
});
gap_min = rep.min;
gap_max = rep.max;
gap_greedy = rep.greedy;
chunks = Vec::new();
} else {
chunks = vec![hir, item.into()];
}
} else {
chunks.push(item.into());
}
} else {
chunks.push(item.into())
}
}
if chunks.is_empty() {
return (chain.remove(0).hir, chain);
}
let hir = Hir::concat(chunks).set_greedy(greedy);
if chain.is_empty()
|| hir.minimum_len().unwrap_or(0)
>= Self::MIN_PATTERN_LENGTH_IN_CHAIN
{
chain.push(ChainedPattern {
gap: if let Some(gap_max) = gap_max {
ChainedPatternGap::Bounded(gap_min..=gap_max)
} else {
ChainedPatternGap::Unbounded(gap_min..)
},
hir,
});
} else {
let mut last = chain.pop().unwrap();
last.hir = Hir::concat(vec![
last.hir,
Hir::any_byte_repetition(gap_min, gap_max, gap_greedy),
hir,
])
.set_greedy(greedy);
chain.push(last);
}
(chain.remove(0).hir, chain)
}
pub fn set_greedy(mut self, greediness: Option<bool>) -> Self {
self.greedy = greediness;
self
}
#[inline]
pub fn is_greedy(&self) -> Option<bool> {
self.greedy
}
#[inline]
pub fn kind(&self) -> &HirKind {
self.inner.kind()
}
#[inline]
pub fn into_kind(self) -> HirKind {
self.inner.into_kind()
}
#[inline]
pub fn into_inner(self) -> regex_syntax::hir::Hir {
self.inner
}
#[inline]
pub fn minimum_len(&self) -> Option<usize> {
self.inner.properties().minimum_len()
}
#[inline]
pub fn is_alternation_literal(&self) -> bool {
if self.inner.properties().is_alternation_literal()
&& !matches!(self.inner.kind(), HirKind::Concat(_))
{
return true;
}
match self.inner.kind() {
HirKind::Capture(cap) => {
cap.sub.properties().is_alternation_literal()
&& !matches!(cap.sub.kind(), HirKind::Concat(_))
}
_ => false,
}
}
pub fn as_literal_bytes(&self) -> Option<&[u8]> {
match self.inner.kind() {
HirKind::Literal(literal) => Some(literal.0.as_bytes()),
_ => None,
}
}
}
impl Hir {
#[cfg(test)]
pub fn literal<B: Into<Box<[u8]>>>(lit: B) -> Hir {
regex_syntax::hir::Hir::literal(lit).into()
}
pub fn concat(subs: Vec<Hir>) -> Hir {
regex_syntax::hir::Hir::concat(
subs.into_iter().map(|s| s.inner).collect(),
)
.into()
}
pub fn any_byte_repetition(
min: u32,
max: Option<u32>,
greedy: bool,
) -> Hir {
regex_syntax::hir::Hir::repetition(Repetition {
min,
max,
greedy,
sub: Box::new(regex_syntax::hir::Hir::dot(Dot::AnyByte)),
})
.into()
}
}
struct HirHasher<'a, H: Hasher> {
state: &'a mut H,
}
impl<H: Hasher> regex_syntax::hir::Visitor for HirHasher<'_, H> {
type Output = ();
type Err = ();
fn finish(self) -> Result<Self::Output, Self::Err> {
Ok(())
}
fn visit_pre(
&mut self,
hir: ®ex_syntax::hir::Hir,
) -> Result<(), Self::Err> {
mem::discriminant(hir.kind()).hash(self.state);
match hir.kind() {
HirKind::Literal(lit) => {
lit.0.hash(self.state);
}
HirKind::Class(class) => {
mem::discriminant(class).hash(self.state);
match class {
Class::Unicode(class) => {
for range in class.ranges() {
range.start().hash(self.state);
range.end().hash(self.state);
}
}
Class::Bytes(class) => {
for range in class.ranges() {
range.start().hash(self.state);
range.end().hash(self.state);
}
}
}
}
HirKind::Repetition(rep) => {
rep.min.hash(self.state);
rep.max.hash(self.state);
rep.greedy.hash(self.state);
}
HirKind::Empty => {}
HirKind::Look(_) => {}
HirKind::Capture(_) => {}
HirKind::Concat(_) => {}
HirKind::Alternation(_) => {}
}
Ok(())
}
}
pub(crate) fn any_byte(hir_kind: &HirKind) -> bool {
match hir_kind {
HirKind::Class(Class::Bytes(class)) => {
if let Some(range) = class.ranges().first() {
range.start() == 0 && range.end() == u8::MAX
} else {
false
}
}
HirKind::Class(Class::Unicode(class)) => {
if let Some(range) = class.ranges().first() {
range.start() == 0 as char && range.end() == char::MAX
} else {
false
}
}
_ => false,
}
}
pub(crate) fn any_byte_except_newline(hir_kind: &HirKind) -> bool {
match hir_kind {
HirKind::Class(Class::Bytes(class)) => {
let all_bytes_except_newline = ClassBytes::new([
ClassBytesRange::new(0x00, 0x09),
ClassBytesRange::new(0x0B, 0xFF),
]);
all_bytes_except_newline.eq(class)
}
HirKind::Class(Class::Unicode(class)) => {
let all_bytes_except_newline = ClassUnicode::new([
ClassUnicodeRange::new(0x00 as char, 0x09 as char),
ClassUnicodeRange::new(0x0B as char, char::MAX),
]);
all_bytes_except_newline.eq(class)
}
_ => false,
}
}
pub(crate) fn class_to_masked_byte(c: &ClassBytes) -> Option<HexByte> {
if c.ranges().is_empty() {
return None;
}
let smallest_byte = c.ranges().first().unwrap().start();
let largest_byte = c.ranges().last().unwrap().end();
let neg_mask = largest_byte ^ smallest_byte;
let mut num_bytes: u32 = 0;
for range in c.ranges().iter() {
for b in range.start()..=range.end() {
if b & smallest_byte != smallest_byte {
return None;
}
}
num_bytes += range.len() as u32;
}
if 1 << neg_mask.count_ones() != num_bytes {
return None;
}
Some(HexByte { value: smallest_byte, mask: !neg_mask })
}
pub(crate) fn class_to_masked_bytes_alternation(
c: &ClassBytes,
) -> Option<Vec<HexByte>> {
if c.ranges().is_empty() {
return None;
}
let mut result = Vec::new();
for range in c.ranges() {
if range.start() & range.end() != range.start() {
return None;
}
let neg_mask = range.start() ^ range.end();
let num_bytes = (range.end() - range.start()) + 1;
if 1 << neg_mask.count_ones() != num_bytes {
return None;
}
result.push(HexByte { value: range.start(), mask: !neg_mask });
}
Some(result)
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::{ChainedPatternGap, Hir};
use crate::re::hir::ChainedPattern;
#[test]
fn split() {
assert_eq!(
Hir::literal([0x01, 0x02, 0x03]).split_at_large_gaps(),
(Hir::literal([0x01, 0x02, 0x03]), vec![])
);
assert_eq!(
Hir::concat(vec![
Hir::literal([0x01, 0x02, 0x03]),
Hir::literal([0x06, 0x07]),
])
.split_at_large_gaps(),
(
Hir::concat(vec![
Hir::literal([0x01, 0x02, 0x03]),
Hir::literal([0x06, 0x07])
]),
vec![]
)
);
assert_eq!(
Hir::concat(vec![
Hir::literal([0x01]),
Hir::any_byte_repetition(
0,
Some(Hir::PATTERN_CHAINING_THRESHOLD),
false
),
Hir::literal([0x02, 0x03]),
])
.split_at_large_gaps(),
(
Hir::concat(vec![
Hir::literal([0x01]),
Hir::any_byte_repetition(
0,
Some(Hir::PATTERN_CHAINING_THRESHOLD),
false
),
Hir::literal([0x02, 0x03]),
]),
vec![]
)
);
assert_eq!(
Hir::concat(vec![
Hir::literal([0x01, 0x02, 0x03]),
Hir::any_byte_repetition(0, None, false),
Hir::literal([0x05]),
Hir::any_byte_repetition(
10,
Some(11 + Hir::PATTERN_CHAINING_THRESHOLD),
false
),
Hir::literal([0x06, 0x07]),
])
.split_at_large_gaps(),
(
Hir::literal([0x01, 0x02, 0x03]),
vec![ChainedPattern {
gap: ChainedPatternGap::Unbounded(0..),
hir: Hir::concat(vec![
Hir::literal([0x05]),
Hir::any_byte_repetition(
10,
Some(11 + Hir::PATTERN_CHAINING_THRESHOLD),
false
),
Hir::literal([0x06, 0x07])
])
}]
)
);
assert_eq!(
Hir::concat(vec![
Hir::literal([0x01, 0x02, 0x03]),
Hir::any_byte_repetition(0, None, false),
Hir::literal([0x05]),
])
.split_at_large_gaps(),
(
Hir::concat(vec![
Hir::literal([0x01, 0x02, 0x03]),
Hir::any_byte_repetition(0, None, false),
Hir::literal([0x05]),
]),
vec![]
)
);
assert_eq!(
Hir::concat(vec![
Hir::literal([0x01, 0x02, 0x03]),
Hir::any_byte_repetition(0, None, true),
Hir::literal([0x04, 0x05]),
])
.split_at_large_gaps(),
(
Hir::literal([0x01, 0x02, 0x03]),
vec![ChainedPattern {
gap: ChainedPatternGap::Unbounded(0..),
hir: Hir::literal([0x04, 0x05])
},]
)
);
assert_eq!(
Hir::concat(vec![
Hir::any_byte_repetition(0, None, true),
Hir::literal([0x04, 0x05]),
])
.split_at_large_gaps(),
(
Hir::concat(vec![
Hir::any_byte_repetition(0, None, true),
Hir::literal([0x04, 0x05]),
]),
vec![]
)
);
}
}