use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::rc::Rc;
#[cfg(feature = "logging")]
use std::time::Instant;
use std::{env, fmt, fs, io, iter};
use bitflags::bitflags;
use bstr::{BStr, ByteSlice};
use itertools::{Itertools, MinMaxResult, izip};
#[cfg(feature = "logging")]
use log::*;
use regex_syntax::hir;
use rustc_hash::{FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize};
use walrus::FunctionId;
use yara_x_parser::ast;
use yara_x_parser::ast::{AST, Ident, Import, Include, RuleFlags, WithSpan};
use yara_x_parser::cst::CSTStream;
use yara_x_parser::{Parser, Span};
use crate::compiler::base64::base64_patterns;
use crate::compiler::emit::{EmitContext, emit_rule_condition};
use crate::compiler::errors::{
CompileError, ConflictingRuleIdentifier, CustomError, DuplicateRule,
DuplicateTag, EmitWasmError, InvalidRegexp, InvalidUTF8, UnknownModule,
UnusedPattern,
};
use crate::compiler::report::ReportBuilder;
use crate::compiler::{CompileContext, VarStack};
use crate::modules::BUILTIN_MODULES;
use crate::re::hir::{ChainedPattern, ChainedPatternGap};
use crate::string_pool::{BStringPool, StringPool};
use crate::symbols::{StackedSymbolTable, Symbol, SymbolLookup, SymbolTable};
use crate::types::{Func, Struct, TypeValue};
use crate::utils::cast;
use crate::variables::{Variable, VariableError, is_valid_identifier};
use crate::wasm::builder::WasmModuleBuilder;
use crate::wasm::{WasmSymbols, wasm_exports};
use crate::{re, wasm};
pub(crate) use crate::compiler::atoms::*;
pub(crate) use crate::compiler::context::*;
pub(crate) use crate::compiler::ir::*;
use crate::compiler::wsh::WarningSuppressionHook;
use crate::errors::{
CircularIncludes, IncludeError, IncludeNotAllowed, IncludeNotFound,
InvalidWarningCode,
};
use crate::linters::LinterResult;
use crate::models::PatternKind;
#[doc(inline)]
pub use crate::compiler::report::Patch;
#[doc(inline)]
pub use crate::compiler::rules::*;
#[doc(inline)]
pub use crate::compiler::warnings::*;
mod atoms;
mod context;
mod emit;
mod ir;
mod report;
mod rules;
#[cfg(test)]
mod tests;
pub mod base64;
pub mod errors;
pub mod linters;
pub mod warnings;
pub mod wsh;
#[derive(Debug, Clone)]
pub struct SourceCode<'src> {
pub(crate) raw: &'src BStr,
pub(crate) valid: Option<&'src str>,
pub(crate) origin: Option<String>,
}
impl<'src> SourceCode<'src> {
pub fn with_origin<S: Into<String>>(self, origin: S) -> Self {
Self { raw: self.raw, valid: self.valid, origin: Some(origin.into()) }
}
fn as_str(&mut self) -> Result<&'src str, bstr::Utf8Error> {
match self.valid {
Some(s) => Ok(s),
None => {
let src = self.raw.to_str()?;
self.valid = Some(src);
Ok(src)
}
}
}
}
impl<'src> From<&'src str> for SourceCode<'src> {
fn from(src: &'src str) -> Self {
Self { raw: BStr::new(src), valid: Some(src), origin: None }
}
}
impl<'src> From<&'src [u8]> for SourceCode<'src> {
fn from(src: &'src [u8]) -> Self {
Self { raw: BStr::new(src), valid: None, origin: None }
}
}
pub fn compile<'src, S>(src: S) -> Result<Rules, CompileError>
where
S: Into<SourceCode<'src>>,
{
let mut compiler = Compiler::new();
compiler.add_source(src)?;
Ok(compiler.build())
}
struct Namespace {
id: NamespaceId,
ident_id: IdentId,
symbols: Rc<RefCell<SymbolTable>>,
}
pub struct Compiler<'a> {
relaxed_re_syntax: bool,
hoisting: bool,
include_dirs: Option<Vec<PathBuf>>,
error_on_slow_pattern: bool,
error_on_slow_loop: bool,
includes_enabled: bool,
include_stack: Vec<PathBuf>,
report_builder: ReportBuilder,
symbol_table: StackedSymbolTable,
global_symbols: Rc<RefCell<SymbolTable>>,
current_namespace: Namespace,
ident_pool: StringPool<IdentId>,
regexp_pool: StringPool<RegexpId>,
lit_pool: BStringPool<LiteralId>,
ir: IR,
wasm_mod: WasmModuleBuilder,
wasm_symbols: WasmSymbols,
wasm_exports: FxHashMap<String, FunctionId>,
filesize_bounds: FxHashMap<PatternId, FilesizeBounds>,
rules: Vec<RuleInfo>,
next_pattern_id: PatternId,
patterns: FxHashMap<Pattern, PatternId>,
sub_patterns: Vec<(PatternId, SubPattern)>,
anchored_sub_patterns: Vec<SubPatternId>,
atoms: Vec<SubPatternAtom>,
re_code: Vec<u8>,
imported_modules: Vec<IdentId>,
ignored_modules: FxHashSet<String>,
banned_modules: FxHashMap<String, (String, String)>,
ignored_rules: FxHashMap<String, String>,
root_struct: Struct,
warnings: Warnings,
errors: Vec<CompileError>,
features: FxHashSet<String>,
ir_writer: Option<Box<dyn Write>>,
linters: Vec<Box<dyn linters::Linter + 'a>>,
}
impl<'a> Compiler<'a> {
pub fn new() -> Self {
let mut ident_pool = StringPool::new();
let mut symbol_table = StackedSymbolTable::new();
let global_symbols = symbol_table.push_new();
for export in wasm_exports()
.filter(|e| e.public && e.builtin())
{
let func = Rc::new(Func::from(export.mangled_name));
let symbol = Symbol::Func(func);
global_symbols.borrow_mut().insert(export.name, symbol);
}
let default_namespace = Namespace {
id: NamespaceId(0),
ident_id: ident_pool.get_or_intern("default"),
symbols: symbol_table.push_new(),
};
let mut wasm_mod = WasmModuleBuilder::new();
wasm_mod.namespaces_per_func(20);
wasm_mod.rules_per_func(10);
let wasm_symbols = wasm_mod.wasm_symbols();
let wasm_exports = wasm_mod.wasm_exports();
let mut ir = IR::new();
if cfg!(feature = "constant-folding") {
ir.constant_folding(true);
}
Self {
ir,
ident_pool,
global_symbols,
symbol_table,
wasm_mod,
wasm_symbols,
wasm_exports,
relaxed_re_syntax: false,
hoisting: false,
error_on_slow_pattern: false,
error_on_slow_loop: false,
next_pattern_id: PatternId(0),
current_namespace: default_namespace,
features: FxHashSet::default(),
warnings: Warnings::default(),
errors: Vec::new(),
rules: Vec::new(),
sub_patterns: Vec::new(),
anchored_sub_patterns: Vec::new(),
atoms: Vec::new(),
re_code: Vec::new(),
imported_modules: Vec::new(),
ignored_modules: FxHashSet::default(),
banned_modules: FxHashMap::default(),
ignored_rules: FxHashMap::default(),
filesize_bounds: FxHashMap::default(),
root_struct: Struct::new().make_root(),
report_builder: ReportBuilder::new(),
lit_pool: BStringPool::new(),
regexp_pool: StringPool::new(),
patterns: FxHashMap::default(),
ir_writer: None,
linters: Vec::new(),
include_dirs: None,
includes_enabled: true,
include_stack: Vec::new(),
}
}
pub fn add_include_dir<P: AsRef<std::path::Path>>(
&mut self,
dir: P,
) -> &mut Self {
self.include_dirs
.get_or_insert_default()
.push(dir.as_ref().to_path_buf());
self
}
pub fn add_source<'src, S>(
&mut self,
src: S,
) -> Result<&mut Self, CompileError>
where
S: Into<SourceCode<'src>>,
{
let mut src = src.into();
self.report_builder.register_source(&src);
let ast = match src.as_str() {
Ok(src) => {
let cst = Parser::new(src.as_bytes());
let cst =
WarningSuppressionHook::from(cst).hook(|warning, span| {
self.warnings.suppress(warning, span);
});
AST::from(CSTStream::new(src.as_bytes(), cst))
}
Err(err) => {
let span_start = err.valid_up_to();
let span_end = if let Some(error_len) = err.error_len() {
span_start + error_len.next_multiple_of(3)
} else {
span_start
};
let err = InvalidUTF8::build(
&self.report_builder,
self.report_builder.span_to_code_loc(Span(
span_start as u32..span_end as u32,
)),
);
self.errors.push(err.clone());
return Err(err);
}
};
let existing_errors = self.errors.len();
self.c_items(ast.items());
self.warnings.clear_suppressed();
self.errors.extend(
ast.into_errors()
.into_iter()
.map(|err| CompileError::from(&self.report_builder, err)),
);
if self.errors.len() > existing_errors {
return Err(self.errors[existing_errors].clone());
}
Ok(self)
}
pub fn define_global<T: TryInto<Variable>>(
&mut self,
ident: &str,
value: T,
) -> Result<&mut Self, VariableError>
where
VariableError: From<<T as TryInto<Variable>>::Error>,
{
if !is_valid_identifier(ident) {
return Err(VariableError::InvalidIdentifier(ident.to_string()));
}
let var: Variable = value.try_into()?;
let type_value: TypeValue = var.into();
if self.root_struct.add_field(ident, type_value).is_some() {
return Err(VariableError::AlreadyExists(ident.to_string()));
}
self.global_symbols
.borrow_mut()
.insert(ident, self.root_struct.lookup(ident).unwrap());
Ok(self)
}
pub fn new_namespace(&mut self, namespace: &str) -> &mut Self {
let current_namespace = self
.ident_pool
.get(self.current_namespace.ident_id)
.expect("expecting a namespace");
if namespace == current_namespace {
return self;
}
self.symbol_table.pop().expect("expecting a namespace");
self.current_namespace = Namespace {
id: NamespaceId(self.current_namespace.id.0 + 1),
ident_id: self.ident_pool.get_or_intern(namespace),
symbols: self.symbol_table.push_new(),
};
self.ignored_rules.clear();
self.wasm_mod.new_namespace();
self
}
pub fn build(self) -> Rules {
let wasm_mod = self.wasm_mod.build().emit_wasm();
#[cfg(feature = "logging")]
let start = Instant::now();
let compiled_wasm_mod = wasm::runtime::Module::from_binary(
wasm::get_engine(),
wasm_mod.as_slice(),
)
.expect("WASM module is not valid");
#[cfg(feature = "logging")]
info!("WASM module build time: {:?}", Instant::elapsed(&start));
let serialized_globals = bincode::serde::encode_to_vec(
&self.root_struct,
bincode::config::standard().with_variable_int_encoding(),
)
.expect("failed to serialize global variables");
let mut rules = Rules {
serialized_globals,
wasm_mod,
compiled_wasm_mod: Some(compiled_wasm_mod),
relaxed_re_syntax: self.relaxed_re_syntax,
ac: None,
num_patterns: self.next_pattern_id.0 as usize,
ident_pool: self.ident_pool,
regexp_pool: self.regexp_pool,
lit_pool: self.lit_pool,
imported_modules: self.imported_modules,
rules: self.rules,
sub_patterns: self.sub_patterns,
anchored_sub_patterns: self.anchored_sub_patterns,
atoms: self.atoms,
re_code: self.re_code,
warnings: self.warnings.into(),
filesize_bounds: self.filesize_bounds,
};
rules.build_ac_automaton();
rules
}
pub fn add_linter<L: linters::Linter + 'a>(
&mut self,
linter: L,
) -> &mut Self {
self.linters.push(Box::new(linter));
self
}
#[doc(hidden)]
pub fn enable_feature<F: Into<String>>(
&mut self,
feature: F,
) -> &mut Self {
self.features.insert(feature.into());
self
}
pub fn ignore_module<M: Into<String>>(&mut self, module: M) -> &mut Self {
self.ignored_modules.insert(module.into());
self
}
pub fn ban_module<M: Into<String>, T: Into<String>, E: Into<String>>(
&mut self,
module: M,
error_title: T,
error_message: E,
) -> &mut Self {
self.banned_modules
.insert(module.into(), (error_title.into(), error_message.into()));
self
}
pub fn colorize_errors(&mut self, yes: bool) -> &mut Self {
self.report_builder.with_colors(yes);
self
}
pub fn errors_max_width(&mut self, width: usize) -> &mut Self {
self.report_builder.max_width(width);
self
}
pub fn switch_warning(
&mut self,
code: &str,
enabled: bool,
) -> Result<&mut Self, InvalidWarningCode> {
self.warnings.switch_warning(code, enabled)?;
Ok(self)
}
pub fn switch_all_warnings(&mut self, enabled: bool) -> &mut Self {
self.warnings.switch_all_warnings(enabled);
self
}
pub fn relaxed_re_syntax(&mut self, yes: bool) -> &mut Self {
if !self.rules.is_empty() {
panic!("calling relaxed_re_syntax in non-empty compiler")
}
self.relaxed_re_syntax = yes;
self
}
pub fn error_on_slow_pattern(&mut self, yes: bool) -> &mut Self {
self.error_on_slow_pattern = yes;
self
}
pub fn error_on_slow_loop(&mut self, yes: bool) -> &mut Self {
self.error_on_slow_loop = yes;
self
}
pub fn enable_includes(&mut self, yes: bool) -> &mut Self {
self.includes_enabled = yes;
self
}
#[doc(hidden)]
pub fn condition_optimization(&mut self, yes: bool) -> &mut Self {
self.hoisting(yes)
}
pub(crate) fn hoisting(&mut self, yes: bool) -> &mut Self {
self.hoisting = yes;
self
}
#[inline]
pub fn errors(&self) -> &[CompileError] {
self.errors.as_slice()
}
#[inline]
pub fn warnings(&self) -> &[Warning] {
self.warnings.as_slice()
}
pub fn emit_wasm_file<P>(self, path: P) -> Result<(), EmitWasmError>
where
P: AsRef<Path>,
{
let mut wasm_mod = self.wasm_mod.build();
Ok(wasm_mod.emit_wasm_file(path)?)
}
#[doc(hidden)]
pub fn set_ir_writer<W: Write + 'static>(&mut self, w: W) -> &mut Self {
self.ir_writer = Some(Box::new(w));
self
}
}
impl Compiler<'_> {
fn add_sub_pattern<I, F, A>(
&mut self,
pattern_id: PatternId,
sub_pattern: SubPattern,
atoms: I,
f: F,
) -> SubPatternId
where
I: Iterator<Item = A>,
F: Fn(SubPatternId, A) -> SubPatternAtom,
{
let sub_pattern_id = SubPatternId(self.sub_patterns.len() as u32);
if let SubPattern::Literal { anchored_at: Some(_), .. } = sub_pattern {
self.anchored_sub_patterns.push(sub_pattern_id);
} else {
self.atoms.extend(atoms.map(|atom| f(sub_pattern_id, atom)));
}
self.sub_patterns.push((pattern_id, sub_pattern));
sub_pattern_id
}
fn check_for_existing_identifier(
&self,
ident: &Ident,
) -> Result<(), CompileError> {
if let Some(symbol) = self.symbol_table.lookup(ident.name) {
return match symbol {
Symbol::Rule { rule_id, .. } => Err(DuplicateRule::build(
&self.report_builder,
ident.name.to_string(),
self.report_builder.span_to_code_loc(ident.span()),
self.rules
.get(rule_id.0 as usize)
.unwrap()
.ident_ref
.clone(),
)),
_ => Err(ConflictingRuleIdentifier::build(
&self.report_builder,
ident.name.to_string(),
self.report_builder.span_to_code_loc(ident.span()),
)),
};
}
Ok(())
}
fn check_for_duplicate_tags(
&self,
tags: &[Ident],
) -> Result<(), CompileError> {
let mut s = HashSet::new();
for tag in tags {
if !s.insert(tag.name) {
return Err(DuplicateTag::build(
&self.report_builder,
tag.name.to_string(),
self.report_builder.span_to_code_loc(tag.span()),
));
}
}
Ok(())
}
fn intern_literal(&mut self, literal: &[u8], wide: bool) -> LiteralId {
let wide_pattern;
let literal_bytes = if wide {
wide_pattern = make_wide(literal);
wide_pattern.as_bytes()
} else {
literal
};
self.lit_pool.get_or_intern(literal_bytes)
}
fn take_snapshot(&self) -> Snapshot {
Snapshot {
next_pattern_id: self.next_pattern_id,
rules_len: self.rules.len(),
atoms_len: self.atoms.len(),
re_code_len: self.re_code.len(),
sub_patterns_len: self.sub_patterns.len(),
symbol_table_len: self.symbol_table.len(),
}
}
fn restore_snapshot(&mut self, snapshot: Snapshot) {
self.next_pattern_id = snapshot.next_pattern_id;
self.rules.truncate(snapshot.rules_len);
self.sub_patterns.truncate(snapshot.sub_patterns_len);
self.re_code.truncate(snapshot.re_code_len);
self.atoms.truncate(snapshot.atoms_len);
self.symbol_table.truncate(snapshot.symbol_table_len);
self.patterns
.retain(|_, pattern_id| *pattern_id < snapshot.next_pattern_id);
self.filesize_bounds
.retain(|pattern_id, _| *pattern_id < snapshot.next_pattern_id);
}
fn common_byte_repetition(bytes: &[u8]) -> bool {
let mut all_x00 = true;
let mut all_x90 = true;
let mut all_xff = true;
for b in bytes {
match *b {
0x00 => {
all_x90 = false;
all_xff = false;
}
0x90 => {
all_x00 = false;
all_xff = false;
}
0xff => {
all_x00 = false;
all_x90 = false;
}
_ => return false,
}
if !all_x00 && !all_x90 && !all_xff {
return false;
}
}
true
}
fn read_included_file(
&mut self,
include: &Include,
) -> Result<(Vec<u8>, PathBuf), CompileError> {
let read_file =
|path: PathBuf| -> Result<(Vec<u8>, PathBuf), io::Error> {
let mut path = path.canonicalize()?;
let content = fs::read(&path)?;
if let Ok(cwd) =
env::current_dir().and_then(|dir| dir.canonicalize())
&& let Ok(relative_path) = path.strip_prefix(cwd)
{
path = relative_path.to_path_buf();
}
Ok((content, path))
};
if let Some(dir) =
self.include_stack.last().and_then(|path| path.parent())
&& let Ok(result) = read_file(dir.join(include.file_name))
{
return Ok(result);
}
if let Some(include_dirs) = &self.include_dirs {
if let Some(result) = include_dirs
.iter()
.find_map(|dir| read_file(dir.join(include.file_name)).ok())
{
Ok(result)
} else {
Err(IncludeNotFound::build(
&self.report_builder,
include.file_name.to_string(),
self.report_builder.span_to_code_loc(include.span()),
))
}
} else {
read_file(PathBuf::from(include.file_name)).map_err(|err| {
if err.kind() == io::ErrorKind::NotFound {
IncludeNotFound::build(
&self.report_builder,
include.file_name.to_string(),
self.report_builder.span_to_code_loc(include.span()),
)
} else {
IncludeError::build(
&self.report_builder,
self.report_builder.span_to_code_loc(include.span()),
err.to_string(),
)
}
})
}
}
}
impl Compiler<'_> {
fn c_items<'a, I>(&mut self, items: I)
where
I: Iterator<Item = &'a ast::Item<'a>>,
{
let mut already_imported = FxHashMap::default();
for item in items {
match item {
ast::Item::Import(import) => {
if let Some(existing_import) = already_imported.insert(
&import.module_name,
self.report_builder.span_to_code_loc(import.span()),
) {
let duplicated_import = self
.report_builder
.span_to_code_loc(import.span());
let mut warning = warnings::DuplicateImport::build(
&self.report_builder,
import.module_name.to_string(),
duplicated_import.clone(),
existing_import,
);
warning.report_mut().patch(duplicated_import, "");
self.warnings.add(|| warning)
}
if let Err(err) = self.c_import(import) {
self.errors.push(err);
}
}
ast::Item::Include(include) => {
if !self.includes_enabled {
self.errors.push(IncludeNotAllowed::build(
&self.report_builder,
self.report_builder
.span_to_code_loc(include.span()),
));
continue;
}
let (included_src, included_path) =
match self.read_included_file(include) {
Ok(included) => included,
Err(err) => {
self.errors.push(err);
continue;
}
};
if self.include_stack.contains(&included_path) {
self.errors.push(CircularIncludes::build(
&self.report_builder,
self.report_builder
.span_to_code_loc(include.span()),
Some(format!(
"include dependencies:\n{}",
self.include_stack
.iter()
.enumerate()
.map(|(i, path)| format!(
"{:>width$}↳ {}",
"",
path.display(),
width = i * 2
))
.collect::<Vec<_>>()
.join("\n")
)),
));
continue;
}
let source_id =
self.report_builder.get_current_source_id().unwrap();
let source_code =
SourceCode::from(included_src.as_slice()).with_origin(
included_path.to_str().unwrap().replace("\\", "/"),
);
self.include_stack.push(included_path);
let _ = self.add_source(source_code);
self.report_builder.set_current_source_id(source_id);
self.include_stack.pop().unwrap();
}
ast::Item::Rule(rule) => {
if let Err(err) = self.c_rule(rule) {
self.errors.push(err);
}
}
}
}
}
fn c_rule(&mut self, rule: &ast::Rule) -> Result<(), CompileError> {
self.check_for_existing_identifier(&rule.identifier)?;
if let Some(tags) = &rule.tags {
self.check_for_duplicate_tags(tags.as_slice())?;
}
let mut first_linter_err: Option<CompileError> = None;
for linter in self.linters.iter() {
match linter.check(&self.report_builder, rule) {
LinterResult::Ok => {}
LinterResult::Warn(warning) => {
self.warnings.add(|| warning);
}
LinterResult::Warns(warnings) => {
for warning in warnings {
self.warnings.add(|| warning);
}
}
LinterResult::Err(err) => {
if first_linter_err.is_none() {
first_linter_err = Some(err);
} else {
self.errors.push(err);
}
}
}
}
if let Some(err) = first_linter_err {
return Err(err);
}
let snapshot = self.take_snapshot();
let tags: Vec<IdentId> = rule
.tags
.iter()
.flatten()
.map(|t| self.ident_pool.get_or_intern(t.name))
.collect();
let mut convert_meta_value = |value: &ast::MetaValue| match value {
ast::MetaValue::Integer((i, _)) => MetaValue::Integer(*i),
ast::MetaValue::Float((f, _)) => MetaValue::Float(*f),
ast::MetaValue::Bool((b, _)) => MetaValue::Bool(*b),
ast::MetaValue::String((s, _)) => {
MetaValue::String(self.lit_pool.get_or_intern(s))
}
ast::MetaValue::Bytes((s, _)) => {
MetaValue::Bytes(self.lit_pool.get_or_intern(s))
}
};
let metadata = rule
.meta
.iter()
.flatten()
.map(|m| {
(
self.ident_pool.get_or_intern(m.identifier.name),
convert_meta_value(&m.value),
)
})
.collect();
let mut rule_patterns = Vec::new();
let mut ctx = CompileContext {
ir: &mut self.ir,
relaxed_re_syntax: self.relaxed_re_syntax,
error_on_slow_loop: self.error_on_slow_loop,
one_shot_symbol_table: None,
symbol_table: &mut self.symbol_table,
report_builder: &self.report_builder,
current_rule_patterns: &mut rule_patterns,
warnings: &mut self.warnings,
vars: VarStack::new(),
for_of_depth: 0,
features: &self.features,
loop_iteration_multiplier: 1,
};
if let Err(err) = patterns_from_ast(&mut ctx, rule) {
drop(ctx);
self.restore_snapshot(snapshot);
return Err(err);
}
let condition = rule_condition_from_ast(&mut ctx, rule);
drop(ctx);
for pat in rule_patterns.iter() {
if pat.anchored_at().is_none()
&& !pat.pattern().flags().intersects(
PatternFlags::Xor
| PatternFlags::Fullword
| PatternFlags::Base64
| PatternFlags::Base64Wide,
)
{
let literal_bytes = match pat.pattern() {
Pattern::Text(lit) => Some(lit.text.as_bytes()),
Pattern::Regexp(re) => re.hir.as_literal_bytes(),
Pattern::Hex(re) => re.hir.as_literal_bytes(),
};
if let Some(literal_bytes) = literal_bytes
&& Self::common_byte_repetition(literal_bytes)
{
self.warnings.add(|| {
warnings::SlowPattern::build(
&self.report_builder,
self.report_builder
.span_to_code_loc(pat.span().clone()),
None,
)
});
}
}
}
let mut condition = match condition {
Ok(condition) => condition,
Err(CompileError::UnknownIdentifier(unknown))
if self.ignored_rules.contains_key(unknown.identifier())
|| self.ignored_modules.contains(unknown.identifier()) =>
{
self.restore_snapshot(snapshot);
if let Some(module_name) =
self.ignored_rules.get(unknown.identifier())
{
self.warnings.add(|| {
warnings::IgnoredRule::build(
&self.report_builder,
module_name.clone(),
rule.identifier.name.to_string(),
unknown.identifier_location().clone(),
)
});
self.ignored_rules.insert(
rule.identifier.name.to_string(),
module_name.clone(),
);
} else {
self.warnings.add(|| {
warnings::IgnoredModule::build(
&self.report_builder,
unknown.identifier().to_string(),
unknown.identifier_location().clone(),
Some(format!(
"the whole rule `{}` will be ignored",
rule.identifier.name
)),
)
});
self.ignored_rules.insert(
rule.identifier.name.to_string(),
unknown.identifier().to_string(),
);
}
return Ok(());
}
Err(err) => {
self.restore_snapshot(snapshot);
return Err(err);
}
};
if self.hoisting {
condition = self.ir.hoisting();
}
let filesize_bounds = self.ir.filesize_bounds();
if !filesize_bounds.unbounded() {
for pattern in &mut rule_patterns {
pattern.pattern_mut().set_filesize_bounds(&filesize_bounds);
}
}
if let Some(w) = &mut self.ir_writer {
writeln!(w, "RULE {}", rule.identifier.name).unwrap();
writeln!(w, "{:?}", self.ir).unwrap();
if !filesize_bounds.unbounded() {
writeln!(w, "{filesize_bounds:?}\n",).unwrap();
}
}
let mut pattern_ids = Vec::with_capacity(rule_patterns.len());
let mut patterns = Vec::with_capacity(rule_patterns.len());
let mut pending_patterns = HashSet::new();
let mut num_private_patterns = 0;
for pattern in &rule_patterns {
if !pattern.in_use() && !pattern.identifier().starts_with("$_") {
self.restore_snapshot(snapshot);
return Err(UnusedPattern::build(
&self.report_builder,
pattern.identifier().name.to_string(),
self.report_builder
.span_to_code_loc(pattern.identifier().span()),
));
}
if pattern.pattern().flags().contains(PatternFlags::Private) {
num_private_patterns += 1;
}
let pattern_id =
match self.patterns.entry(pattern.pattern().clone()) {
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => {
let pattern_id = self.next_pattern_id;
self.next_pattern_id.incr(1);
pending_patterns.insert(pattern_id);
entry.insert(pattern_id);
pattern_id
}
};
let kind = match pattern.pattern() {
Pattern::Text(_) => PatternKind::Text,
Pattern::Regexp(_) => PatternKind::Regexp,
Pattern::Hex(_) => PatternKind::Hex,
};
patterns.push(PatternInfo {
kind,
pattern_id,
ident_id: self
.ident_pool
.get_or_intern(pattern.identifier().name),
is_private: pattern
.pattern()
.flags()
.contains(PatternFlags::Private),
});
pattern_ids.push(pattern_id);
}
let rule_id = RuleId::from(self.rules.len());
self.rules.push(RuleInfo {
tags,
metadata,
patterns,
num_private_patterns,
is_global: rule.flags.contains(RuleFlags::Global),
is_private: rule.flags.contains(RuleFlags::Private),
namespace_id: self.current_namespace.id,
namespace_ident_id: self.current_namespace.ident_id,
ident_id: self.ident_pool.get_or_intern(rule.identifier.name),
ident_ref: self
.report_builder
.span_to_code_loc(rule.identifier.span()),
});
for (pattern_id, pattern) in
izip!(pattern_ids.iter(), rule_patterns.into_iter())
{
if pending_patterns.contains(pattern_id) {
let pattern_span = pattern.span().clone();
match pattern.into_pattern() {
Pattern::Text(pattern) => {
self.c_literal_pattern(*pattern_id, pattern);
}
Pattern::Regexp(pattern) | Pattern::Hex(pattern) => {
if let Err(err) = self.c_regexp_pattern(
*pattern_id,
pattern,
pattern_span,
) {
self.restore_snapshot(snapshot);
return Err(err);
}
}
};
if !filesize_bounds.unbounded()
&& self
.filesize_bounds
.insert(*pattern_id, filesize_bounds.clone())
.is_some()
{
panic!(
"modifying the file size bounds of an existing pattern"
)
}
pending_patterns.remove(pattern_id);
}
}
let new_symbol = Symbol::Rule {
rule_id,
is_global: rule.flags.contains(RuleFlags::Global),
};
let existing_symbol = self
.current_namespace
.symbols
.as_ref()
.borrow_mut()
.insert(rule.identifier.name, new_symbol);
assert!(existing_symbol.is_none());
let mut ctx = EmitContext {
current_rule: self.rules.last_mut().unwrap(),
lit_pool: &mut self.lit_pool,
regexp_pool: &mut self.regexp_pool,
wasm_symbols: &self.wasm_symbols,
wasm_exports: &self.wasm_exports,
exception_handler_stack: Vec::new(),
lookup_list: Vec::new(),
emit_search_for_pattern_stack: Vec::new(),
};
emit_rule_condition(
&mut ctx,
&self.ir,
rule_id,
condition,
&mut self.wasm_mod,
);
Ok(())
}
fn c_import(&mut self, import: &Import) -> Result<(), CompileError> {
let module_name = import.module_name;
let module = BUILTIN_MODULES.get(module_name);
if module.is_none() {
return if self.ignored_modules.iter().any(|m| m == module_name) {
self.warnings.add(|| {
warnings::IgnoredModule::build(
&self.report_builder,
module_name.to_string(),
self.report_builder.span_to_code_loc(import.span()),
None,
)
});
Ok(())
} else {
Err(UnknownModule::build(
&self.report_builder,
module_name.to_string(),
self.report_builder.span_to_code_loc(import.span()),
))
};
}
let module = module.unwrap();
if !self.root_struct.has_field(module_name) {
self.imported_modules
.push(self.ident_pool.get_or_intern(module_name));
let module_struct = Rc::<Struct>::from(module);
if self
.root_struct
.add_field(module_name, TypeValue::Struct(module_struct))
.is_some()
{
panic!("duplicate module `{module_name}`")
}
}
let mut symbol_table =
self.current_namespace.symbols.as_ref().borrow_mut();
if !symbol_table.contains(module_name) {
symbol_table.insert(
module_name,
self.root_struct.lookup(module_name).unwrap(),
);
}
if let Some((error_title, error_msg)) =
self.banned_modules.get(module_name)
{
return Err(CustomError::build(
&self.report_builder,
error_title.clone(),
error_msg.clone(),
self.report_builder.span_to_code_loc(import.span()),
));
}
Ok(())
}
fn c_literal_pattern(
&mut self,
pattern_id: PatternId,
pattern: LiteralPattern,
) {
let full_word = pattern.flags.contains(PatternFlags::Fullword);
let mut flags = SubPatternFlags::empty();
if full_word {
flags.insert(SubPatternFlags::FullwordLeft);
flags.insert(SubPatternFlags::FullwordRight);
}
let mut main_patterns = Vec::new();
let wide_pattern;
if pattern.flags.contains(PatternFlags::Wide) {
wide_pattern = make_wide(pattern.text.as_bytes());
main_patterns.push((
wide_pattern.as_slice(),
best_atom_in_bytes(wide_pattern.as_slice()),
flags | SubPatternFlags::Wide,
));
}
if pattern.flags.contains(PatternFlags::Ascii) {
main_patterns.push((
pattern.text.as_bytes(),
best_atom_in_bytes(pattern.text.as_bytes()),
flags,
));
}
for (main_pattern, best_atom, flags) in main_patterns {
let pattern_lit_id = self.lit_pool.get_or_intern(main_pattern);
if pattern.flags.contains(PatternFlags::Xor) {
debug_assert!(!pattern.flags.contains(
PatternFlags::Base64
| PatternFlags::Base64Wide
| PatternFlags::Nocase,
));
let xor_range = pattern.xor_range.clone().unwrap();
self.add_sub_pattern(
pattern_id,
SubPattern::Xor { pattern: pattern_lit_id, flags },
best_atom.xor_combinations(xor_range),
SubPatternAtom::from_atom,
);
} else if pattern.flags.contains(PatternFlags::Nocase) {
debug_assert!(!pattern.flags.contains(
PatternFlags::Base64
| PatternFlags::Base64Wide
| PatternFlags::Xor,
));
self.add_sub_pattern(
pattern_id,
SubPattern::Literal {
pattern: pattern_lit_id,
flags: flags | SubPatternFlags::Nocase,
anchored_at: None,
},
best_atom.case_combinations(),
SubPatternAtom::from_atom,
);
}
else if pattern
.flags
.intersects(PatternFlags::Base64 | PatternFlags::Base64Wide)
{
debug_assert!(!pattern.flags.contains(
PatternFlags::Xor
| PatternFlags::Fullword
| PatternFlags::Nocase,
));
if pattern.flags.contains(PatternFlags::Base64) {
for (padding, base64_pattern) in base64_patterns(
main_pattern,
pattern.base64_alphabet.as_deref(),
) {
let sub_pattern = if let Some(alphabet) =
pattern.base64_alphabet.as_deref()
{
SubPattern::CustomBase64 {
pattern: pattern_lit_id,
alphabet: self
.lit_pool
.get_or_intern(alphabet),
padding,
}
} else {
SubPattern::Base64 {
pattern: pattern_lit_id,
padding,
}
};
self.add_sub_pattern(
pattern_id,
sub_pattern,
iter::once({
let mut atom = best_atom_in_bytes(
base64_pattern.as_slice(),
);
atom.make_inexact();
atom
}),
SubPatternAtom::from_atom,
);
}
}
if pattern.flags.contains(PatternFlags::Base64Wide) {
for (padding, base64_pattern) in base64_patterns(
main_pattern,
pattern.base64wide_alphabet.as_deref(),
) {
let sub_pattern = if let Some(alphabet) =
pattern.base64wide_alphabet.as_deref()
{
SubPattern::CustomBase64Wide {
pattern: pattern_lit_id,
alphabet: self
.lit_pool
.get_or_intern(alphabet),
padding,
}
} else {
SubPattern::Base64Wide {
pattern: pattern_lit_id,
padding,
}
};
let wide = make_wide(base64_pattern.as_slice());
self.add_sub_pattern(
pattern_id,
sub_pattern,
iter::once({
let mut atom =
best_atom_in_bytes(wide.as_slice());
atom.make_inexact();
atom
}),
SubPatternAtom::from_atom,
);
}
}
} else {
self.add_sub_pattern(
pattern_id,
SubPattern::Literal {
pattern: pattern_lit_id,
anchored_at: pattern.anchored_at,
flags,
},
iter::once(best_atom),
SubPatternAtom::from_atom,
);
}
}
}
fn c_regexp_pattern(
&mut self,
pattern_id: PatternId,
pattern: RegexpPattern,
span: Span,
) -> Result<(), CompileError> {
let (head, tail) = pattern.hir.split_at_large_gaps();
if !tail.is_empty() {
return self.c_chain(
pattern_id,
&head,
&tail,
pattern.flags,
span,
);
}
if head.is_alternation_literal() {
return self.c_alternation_literal(
pattern_id,
head,
pattern.anchored_at,
pattern.flags,
);
}
let mut flags = SubPatternFlags::empty();
if pattern.flags.contains(PatternFlags::Nocase) {
flags.insert(SubPatternFlags::Nocase);
}
if pattern.flags.contains(PatternFlags::Fullword) {
flags.insert(SubPatternFlags::FullwordLeft);
flags.insert(SubPatternFlags::FullwordRight);
}
if matches!(head.is_greedy(), Some(true)) {
flags.insert(SubPatternFlags::GreedyRegexp);
}
let (atoms, is_fast_regexp) = self.c_regexp(&head, span)?;
if is_fast_regexp {
flags.insert(SubPatternFlags::FastRegexp);
}
if pattern.flags.contains(PatternFlags::Wide) {
self.add_sub_pattern(
pattern_id,
SubPattern::Regexp { flags: flags | SubPatternFlags::Wide },
atoms.iter().cloned().map(|atom| atom.make_wide()),
SubPatternAtom::from_regexp_atom,
);
}
if pattern.flags.contains(PatternFlags::Ascii) {
self.add_sub_pattern(
pattern_id,
SubPattern::Regexp { flags },
atoms.into_iter(),
SubPatternAtom::from_regexp_atom,
);
}
Ok(())
}
fn c_alternation_literal(
&mut self,
pattern_id: PatternId,
hir: re::hir::Hir,
anchored_at: Option<usize>,
flags: PatternFlags,
) -> Result<(), CompileError> {
let ascii = flags.contains(PatternFlags::Ascii);
let wide = flags.contains(PatternFlags::Wide);
let case_insensitive = flags.contains(PatternFlags::Nocase);
let full_word = flags.contains(PatternFlags::Fullword);
let mut flags = SubPatternFlags::empty();
if case_insensitive {
flags.insert(SubPatternFlags::Nocase);
}
if full_word {
flags.insert(SubPatternFlags::FullwordLeft);
flags.insert(SubPatternFlags::FullwordRight);
}
let mut process_literal = |literal: &hir::Literal, wide: bool| {
let pattern_lit_id =
self.intern_literal(literal.0.as_bytes(), wide);
let best_atom = best_atom_in_bytes(
self.lit_pool.get_bytes(pattern_lit_id).unwrap(),
);
let flags =
if wide { flags | SubPatternFlags::Wide } else { flags };
let sub_pattern = SubPattern::Literal {
pattern: pattern_lit_id,
anchored_at,
flags,
};
if case_insensitive {
self.add_sub_pattern(
pattern_id,
sub_pattern,
best_atom.case_combinations(),
SubPatternAtom::from_atom,
);
} else {
self.add_sub_pattern(
pattern_id,
sub_pattern,
iter::once(best_atom),
SubPatternAtom::from_atom,
);
}
};
let inner;
let hir = if let hir::HirKind::Capture(group) = hir.kind() {
group.sub.as_ref()
} else {
inner = hir.into_inner();
&inner
};
match hir.kind() {
hir::HirKind::Literal(literal) => {
if ascii {
process_literal(literal, false);
}
if wide {
process_literal(literal, true);
}
}
hir::HirKind::Alternation(literals) => {
let literals = literals
.iter()
.map(|l| cast!(l.kind(), hir::HirKind::Literal));
for literal in literals {
if ascii {
process_literal(literal, false);
}
if wide {
process_literal(literal, true);
}
}
}
_ => unreachable!(),
}
Ok(())
}
fn c_chain(
&mut self,
pattern_id: PatternId,
leading: &re::hir::Hir,
trailing: &[ChainedPattern],
flags: PatternFlags,
span: Span,
) -> Result<(), CompileError> {
let ascii = flags.contains(PatternFlags::Ascii);
let wide = flags.contains(PatternFlags::Wide);
let case_insensitive = flags.contains(PatternFlags::Nocase);
let full_word = flags.contains(PatternFlags::Fullword);
let mut common_flags = SubPatternFlags::empty();
if case_insensitive {
common_flags.insert(SubPatternFlags::Nocase);
}
if matches!(leading.is_greedy(), Some(true)) {
common_flags.insert(SubPatternFlags::GreedyRegexp);
}
let mut prev_sub_pattern_ascii = SubPatternId(0);
let mut prev_sub_pattern_wide = SubPatternId(0);
if let hir::HirKind::Literal(literal) = leading.kind() {
let mut flags = common_flags;
if full_word {
flags.insert(SubPatternFlags::FullwordLeft);
}
if ascii {
prev_sub_pattern_ascii =
self.c_literal_chain_head(pattern_id, literal, flags);
}
if wide {
prev_sub_pattern_wide = self.c_literal_chain_head(
pattern_id,
literal,
flags | SubPatternFlags::Wide,
);
};
} else {
let mut flags = common_flags;
let (atoms, is_fast_regexp) =
self.c_regexp(leading, span.clone())?;
if is_fast_regexp {
flags.insert(SubPatternFlags::FastRegexp);
}
if full_word {
flags.insert(SubPatternFlags::FullwordLeft);
}
if wide {
prev_sub_pattern_wide = self.add_sub_pattern(
pattern_id,
SubPattern::RegexpChainHead {
flags: flags | SubPatternFlags::Wide,
},
atoms.iter().cloned().map(|atom| atom.make_wide()),
SubPatternAtom::from_regexp_atom,
);
}
if ascii {
prev_sub_pattern_ascii = self.add_sub_pattern(
pattern_id,
SubPattern::RegexpChainHead { flags },
atoms.into_iter(),
SubPatternAtom::from_regexp_atom,
);
}
}
for (i, p) in trailing.iter().enumerate() {
let mut flags = common_flags;
if i == trailing.len() - 1 {
flags.insert(SubPatternFlags::LastInChain);
if full_word {
flags.insert(SubPatternFlags::FullwordRight);
}
}
if let hir::HirKind::Literal(literal) = p.hir.kind() {
if wide {
prev_sub_pattern_wide = self.c_literal_chain_tail(
pattern_id,
literal,
prev_sub_pattern_wide,
p.gap.clone(),
flags | SubPatternFlags::Wide,
);
};
if ascii {
prev_sub_pattern_ascii = self.c_literal_chain_tail(
pattern_id,
literal,
prev_sub_pattern_ascii,
p.gap.clone(),
flags,
);
}
} else {
if matches!(p.hir.is_greedy(), Some(true)) {
flags.insert(SubPatternFlags::GreedyRegexp);
}
let (atoms, is_fast_regexp) =
self.c_regexp(&p.hir, span.clone())?;
if is_fast_regexp {
flags.insert(SubPatternFlags::FastRegexp);
}
if wide {
prev_sub_pattern_wide = self.add_sub_pattern(
pattern_id,
SubPattern::RegexpChainTail {
chained_to: prev_sub_pattern_wide,
gap: p.gap.clone(),
flags: flags | SubPatternFlags::Wide,
},
atoms.iter().cloned().map(|atom| atom.make_wide()),
SubPatternAtom::from_regexp_atom,
)
}
if ascii {
prev_sub_pattern_ascii = self.add_sub_pattern(
pattern_id,
SubPattern::RegexpChainTail {
chained_to: prev_sub_pattern_ascii,
gap: p.gap.clone(),
flags,
},
atoms.into_iter(),
SubPatternAtom::from_regexp_atom,
);
}
}
}
Ok(())
}
fn c_regexp(
&mut self,
hir: &re::hir::Hir,
span: Span,
) -> Result<(Vec<re::RegexpAtom>, bool), CompileError> {
#[cfg(feature = "fast-regexp")]
let (result, is_fast_regexp) = match re::fast::Compiler::new()
.compile(hir, &mut self.re_code)
{
Err(re::Error::FastIncompatible) => (
re::thompson::Compiler::new().compile(hir, &mut self.re_code),
false,
),
result => (result, true),
};
#[cfg(not(feature = "fast-regexp"))]
let (result, is_fast_regexp) = (
re::thompson::Compiler::new().compile(hir, &mut self.re_code),
false,
);
let re_atoms = result.map_err(|err| {
InvalidRegexp::build(
&self.report_builder,
err.to_string(),
self.report_builder.span_to_code_loc(span.clone()),
None,
)
})?;
if matches!(hir.minimum_len(), Some(0)) {
return Err(InvalidRegexp::build(
&self.report_builder,
"this regexp can match empty strings".to_string(),
self.report_builder.span_to_code_loc(span),
None,
));
}
let (slow_pattern, note) =
match re_atoms.iter().map(|re_atom| re_atom.atom.len()).minmax() {
MinMaxResult::NoElements => (true, None),
MinMaxResult::OneElement(0) => (
true,
Some(
"this is an exceptionally extreme case that may severely degrade scanning throughput"
.to_string(),
),
),
MinMaxResult::OneElement(len) if len < 2 => (true, None),
MinMaxResult::MinMax(min, _) if min < 2 => (true, None),
MinMaxResult::MinMax(2, 2) if re_atoms.len() > 2700 => {
(true, None)
}
_ => (false, None),
};
if slow_pattern {
if self.error_on_slow_pattern {
return Err(errors::SlowPattern::build(
&self.report_builder,
self.report_builder.span_to_code_loc(span),
note,
));
} else {
self.warnings.add(|| {
warnings::SlowPattern::build(
&self.report_builder,
self.report_builder.span_to_code_loc(span),
note,
)
});
}
}
Ok((re_atoms, is_fast_regexp))
}
fn c_literal_chain_head(
&mut self,
pattern_id: PatternId,
literal: &hir::Literal,
flags: SubPatternFlags,
) -> SubPatternId {
let pattern_lit_id = self.intern_literal(
literal.0.as_bytes(),
flags.contains(SubPatternFlags::Wide),
);
self.add_sub_pattern(
pattern_id,
SubPattern::LiteralChainHead { pattern: pattern_lit_id, flags },
extract_atoms(
self.lit_pool.get_bytes(pattern_lit_id).unwrap(),
flags,
),
SubPatternAtom::from_atom,
)
}
fn c_literal_chain_tail(
&mut self,
pattern_id: PatternId,
literal: &hir::Literal,
chained_to: SubPatternId,
gap: ChainedPatternGap,
flags: SubPatternFlags,
) -> SubPatternId {
let pattern_lit_id = self.intern_literal(
literal.0.as_bytes(),
flags.contains(SubPatternFlags::Wide),
);
self.add_sub_pattern(
pattern_id,
SubPattern::LiteralChainTail {
pattern: pattern_lit_id,
chained_to,
gap,
flags,
},
extract_atoms(
self.lit_pool.get_bytes(pattern_lit_id).unwrap(),
flags,
),
SubPatternAtom::from_atom,
)
}
}
impl fmt::Debug for Compiler<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Compiler")
}
}
impl Default for Compiler<'_> {
fn default() -> Self {
Self::new()
}
}
#[derive(Eq, PartialEq, Hash, Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub(crate) struct IdentId(u32);
impl From<u32> for IdentId {
fn from(v: u32) -> Self {
Self(v)
}
}
impl From<IdentId> for u32 {
fn from(v: IdentId) -> Self {
v.0
}
}
#[derive(PartialEq, Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub(crate) struct LiteralId(u32);
impl From<i32> for LiteralId {
fn from(v: i32) -> Self {
Self(v as u32)
}
}
impl From<u32> for LiteralId {
fn from(v: u32) -> Self {
Self(v)
}
}
impl From<LiteralId> for u32 {
fn from(v: LiteralId) -> Self {
v.0
}
}
impl From<LiteralId> for i64 {
fn from(v: LiteralId) -> Self {
v.0 as i64
}
}
impl From<LiteralId> for u64 {
fn from(v: LiteralId) -> Self {
v.0 as u64
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub(crate) struct NamespaceId(i32);
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
pub(crate) struct RuleId(i32);
impl RuleId {
#[allow(dead_code)]
pub(crate) fn next(&self) -> Self {
RuleId(self.0 + 1)
}
}
impl From<i32> for RuleId {
#[inline]
fn from(value: i32) -> Self {
Self(value)
}
}
impl From<usize> for RuleId {
#[inline]
fn from(value: usize) -> Self {
Self(value.try_into().unwrap())
}
}
impl From<RuleId> for usize {
#[inline]
fn from(value: RuleId) -> Self {
value.0 as usize
}
}
impl From<RuleId> for i32 {
#[inline]
fn from(value: RuleId) -> Self {
value.0
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(crate) struct RegexpId(i32);
impl From<i32> for RegexpId {
#[inline]
fn from(value: i32) -> Self {
Self(value)
}
}
impl From<u32> for RegexpId {
#[inline]
fn from(value: u32) -> Self {
Self(value.try_into().unwrap())
}
}
impl From<i64> for RegexpId {
#[inline]
fn from(value: i64) -> Self {
Self(value.try_into().unwrap())
}
}
impl From<RegexpId> for usize {
#[inline]
fn from(value: RegexpId) -> Self {
value.0 as usize
}
}
impl From<RegexpId> for i32 {
#[inline]
fn from(value: RegexpId) -> Self {
value.0
}
}
impl From<RegexpId> for u32 {
#[inline]
fn from(value: RegexpId) -> Self {
value.0.try_into().unwrap()
}
}
#[derive(
Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Serialize, Deserialize,
)]
#[serde(transparent)]
#[derive(Ord)]
pub(crate) struct PatternId(i32);
impl PatternId {
#[inline]
fn incr(&mut self, amount: usize) {
self.0 += amount as i32;
}
}
impl From<i32> for PatternId {
#[inline]
fn from(value: i32) -> Self {
Self(value)
}
}
impl From<usize> for PatternId {
#[inline]
fn from(value: usize) -> Self {
Self(value as i32)
}
}
impl From<PatternId> for i32 {
#[inline]
fn from(value: PatternId) -> Self {
value.0
}
}
impl From<PatternId> for i64 {
#[inline]
fn from(value: PatternId) -> Self {
value.0 as i64
}
}
impl From<PatternId> for usize {
#[inline]
fn from(value: PatternId) -> Self {
value.0 as usize
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub(crate) struct SubPatternId(u32);
pub struct Imports<'a> {
iter: std::slice::Iter<'a, IdentId>,
ident_pool: &'a StringPool<IdentId>,
}
impl<'a> Iterator for Imports<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|id| self.ident_pool.get(*id).unwrap())
}
}
bitflags! {
#[derive(Debug, Clone, Copy, Hash, Serialize, Deserialize, PartialEq, Eq)]
pub struct SubPatternFlags: u16 {
const Wide = 0x01;
const Nocase = 0x02;
const LastInChain = 0x04;
const FullwordLeft = 0x08;
const FullwordRight = 0x10;
const GreedyRegexp = 0x20;
const FastRegexp = 0x40;
}
}
#[derive(Serialize, Deserialize)]
pub(crate) enum SubPattern {
Literal {
pattern: LiteralId,
anchored_at: Option<usize>,
flags: SubPatternFlags,
},
LiteralChainHead {
pattern: LiteralId,
flags: SubPatternFlags,
},
LiteralChainTail {
pattern: LiteralId,
chained_to: SubPatternId,
gap: ChainedPatternGap,
flags: SubPatternFlags,
},
Regexp {
flags: SubPatternFlags,
},
RegexpChainHead {
flags: SubPatternFlags,
},
RegexpChainTail {
chained_to: SubPatternId,
gap: ChainedPatternGap,
flags: SubPatternFlags,
},
Xor {
pattern: LiteralId,
flags: SubPatternFlags,
},
Base64 {
pattern: LiteralId,
padding: u8,
},
Base64Wide {
pattern: LiteralId,
padding: u8,
},
CustomBase64 {
pattern: LiteralId,
alphabet: LiteralId,
padding: u8,
},
CustomBase64Wide {
pattern: LiteralId,
alphabet: LiteralId,
padding: u8,
},
}
impl SubPattern {
pub fn chained_to(&self) -> Option<SubPatternId> {
match self {
SubPattern::LiteralChainTail { chained_to, .. }
| SubPattern::RegexpChainTail { chained_to, .. } => {
Some(*chained_to)
}
_ => None,
}
}
}
#[derive(Debug, PartialEq, Eq)]
struct Snapshot {
next_pattern_id: PatternId,
rules_len: usize,
atoms_len: usize,
re_code_len: usize,
sub_patterns_len: usize,
symbol_table_len: usize,
}
pub(crate) struct Warnings {
warnings: Vec<Warning>,
max_warnings: usize,
disabled_warnings: HashSet<String>,
suppressed_warnings: HashMap<String, Vec<Span>>,
}
impl Default for Warnings {
fn default() -> Self {
Self {
warnings: Vec::new(),
max_warnings: 100,
disabled_warnings: HashSet::default(),
suppressed_warnings: HashMap::default(),
}
}
}
impl Warnings {
#[inline]
pub fn add(&mut self, f: impl FnOnce() -> Warning) {
if self.warnings.len() < self.max_warnings {
let warning = f();
let mut warn = !self.disabled_warnings.contains(warning.code());
if warn
&& let Some(spans) =
self.suppressed_warnings.get(warning.code())
{
'l: for disabled_span in spans {
for label in warning.labels() {
if disabled_span.contains(label.span()) {
warn = false;
break 'l;
}
}
}
}
if warn {
self.warnings.push(warning);
}
}
}
pub fn is_valid_code(code: &str) -> bool {
Warning::all_codes().contains(&code)
}
#[inline]
pub fn switch_warning(
&mut self,
code: &str,
enabled: bool,
) -> Result<bool, InvalidWarningCode> {
if !Self::is_valid_code(code) {
return Err(InvalidWarningCode::new(code.to_string()));
}
if enabled {
Ok(!self.disabled_warnings.remove(code))
} else {
Ok(self.disabled_warnings.insert(code.to_string()))
}
}
pub fn switch_all_warnings(&mut self, enabled: bool) {
if enabled {
self.disabled_warnings.clear();
} else {
for c in Warning::all_codes() {
self.disabled_warnings.insert(c.to_string());
}
}
}
pub fn clear_suppressed(&mut self) {
self.suppressed_warnings.clear();
}
pub fn suppress(&mut self, code: &str, span: Span) {
self.suppressed_warnings
.entry(code.to_string())
.or_default()
.push(span);
}
#[inline]
pub fn as_slice(&self) -> &[Warning] {
self.warnings.as_slice()
}
}
impl From<Warnings> for Vec<Warning> {
fn from(value: Warnings) -> Self {
value.warnings
}
}