use std::collections::Bound;
use std::fmt::{Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::mem;
use std::mem::discriminant;
use std::ops::RangeInclusive;
use std::ops::{Add, Index};
use std::rc::Rc;
use bitflags::bitflags;
use bstr::BString;
use rustc_hash::FxHasher;
use serde::{Deserialize, Serialize};
use yara_x_parser::Span;
use yara_x_parser::ast::Ident;
use crate::compiler::context::{Var, VarStack};
use crate::compiler::ir::dfs::{
DFSIter, DFSWithScopeIter, Event, EventContext, dfs_common,
};
use crate::compiler::FilesizeBounds;
use crate::re;
use crate::symbols::Symbol;
use crate::types::Value::Const;
use crate::types::{FuncSignature, Type, TypeValue};
pub(in crate::compiler) use ast2ir::patterns_from_ast;
pub(in crate::compiler) use ast2ir::rule_condition_from_ast;
mod ast2ir;
mod dfs;
mod hex2hir;
#[cfg(test)]
mod tests;
bitflags! {
#[derive(Debug, Clone, Copy, Hash, Serialize, Deserialize, PartialEq, Eq)]
pub struct PatternFlags: u16 {
const Ascii = 0x0001;
const Wide = 0x0002;
const Nocase = 0x0004;
const Base64 = 0x0008;
const Base64Wide = 0x0010;
const Xor = 0x0020;
const Fullword = 0x0040;
const Private = 0x0080;
const NonAnchorable = 0x0100;
}
}
pub(crate) struct PatternInRule<'src> {
identifier: Ident<'src>,
pattern: Pattern,
span: Span,
in_use: bool,
}
impl<'src> PatternInRule<'src> {
#[inline]
pub fn identifier(&self) -> &Ident<'src> {
&self.identifier
}
#[inline]
pub fn into_pattern(self) -> Pattern {
self.pattern
}
#[inline]
pub fn pattern(&self) -> &Pattern {
&self.pattern
}
#[inline]
pub fn pattern_mut(&mut self) -> &mut Pattern {
&mut self.pattern
}
#[inline]
pub fn span(&self) -> &Span {
&self.span
}
#[inline]
pub fn anchored_at(&self) -> Option<usize> {
self.pattern.anchored_at()
}
#[inline]
pub fn in_use(&self) -> bool {
self.in_use
}
pub fn anchor_at(&mut self, offset: usize) -> &mut Self {
self.pattern.anchor_at(offset);
self
}
pub fn make_non_anchorable(&mut self) -> &mut Self {
self.pattern.make_non_anchorable();
self
}
pub fn mark_as_used(&mut self) -> &mut Self {
self.in_use = true;
self
}
}
#[derive(Clone, Eq, Hash, PartialEq)]
pub(crate) enum Pattern {
Text(LiteralPattern),
Regexp(RegexpPattern),
Hex(RegexpPattern),
}
impl Pattern {
#[inline]
pub fn flags(&self) -> &PatternFlags {
match self {
Pattern::Text(literal) => &literal.flags,
Pattern::Regexp(regexp) => ®exp.flags,
Pattern::Hex(regexp) => ®exp.flags,
}
}
#[inline]
pub fn flags_mut(&mut self) -> &mut PatternFlags {
match self {
Pattern::Text(literal) => &mut literal.flags,
Pattern::Regexp(regexp) => &mut regexp.flags,
Pattern::Hex(regexp) => &mut regexp.flags,
}
}
#[inline]
pub fn anchored_at(&self) -> Option<usize> {
match self {
Pattern::Text(literal) => literal.anchored_at,
Pattern::Regexp(regexp) => regexp.anchored_at,
Pattern::Hex(regexp) => regexp.anchored_at,
}
}
pub fn anchor_at(&mut self, offset: usize) {
let is_anchorable =
!self.flags().contains(PatternFlags::NonAnchorable);
let anchored_at = match self {
Pattern::Text(literal) => &mut literal.anchored_at,
Pattern::Regexp(regexp) => &mut regexp.anchored_at,
Pattern::Hex(regexp) => &mut regexp.anchored_at,
};
match anchored_at {
Some(o) if *o != offset => {
*anchored_at = None;
self.flags_mut().insert(PatternFlags::NonAnchorable);
}
None => {
if is_anchorable {
*anchored_at = Some(offset);
}
}
_ => {}
}
}
pub fn make_non_anchorable(&mut self) {
match self {
Pattern::Text(literal) => literal.anchored_at = None,
Pattern::Regexp(regexp) => regexp.anchored_at = None,
Pattern::Hex(regexp) => regexp.anchored_at = None,
};
self.flags_mut().insert(PatternFlags::NonAnchorable);
}
pub fn set_filesize_bounds(&mut self, bounds: &FilesizeBounds) {
match self {
Pattern::Text(literal) => {
literal.filesize_bounds = bounds.clone();
}
Pattern::Regexp(regexp) | Pattern::Hex(regexp) => {
regexp.filesize_bounds = bounds.clone();
}
}
}
}
#[derive(Clone, Eq, Hash, PartialEq)]
pub(crate) struct LiteralPattern {
pub flags: PatternFlags,
pub text: BString,
pub anchored_at: Option<usize>,
pub xor_range: Option<RangeInclusive<u8>>,
pub base64_alphabet: Option<String>,
pub base64wide_alphabet: Option<String>,
pub filesize_bounds: FilesizeBounds,
}
#[derive(Clone, Eq, Hash, PartialEq)]
pub(crate) struct RegexpPattern {
pub flags: PatternFlags,
pub hir: re::hir::Hir,
pub anchored_at: Option<usize>,
pub filesize_bounds: FilesizeBounds,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct PatternIdx(usize);
impl PatternIdx {
#[inline]
pub fn as_usize(&self) -> usize {
self.0
}
}
impl From<usize> for PatternIdx {
#[inline]
fn from(value: usize) -> Self {
Self(value)
}
}
impl From<&PatternIdx> for i64 {
#[inline]
fn from(value: &PatternIdx) -> Self {
value.0 as i64
}
}
#[derive(Clone, Copy, PartialEq, Eq, Ord, Hash, PartialOrd)]
pub(crate) struct ExprId(u32);
impl ExprId {
pub const fn none() -> Self {
ExprId(u32::MAX)
}
}
impl Debug for ExprId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.0 == u32::MAX {
write!(f, "None")
} else {
write!(f, "{}", self.0)
}
}
}
impl From<usize> for ExprId {
#[inline]
fn from(value: usize) -> Self {
Self(value as u32)
}
}
#[derive(Debug)]
pub(crate) enum Error {
NumberOutOfRange,
}
pub(crate) struct IR {
constant_folding: bool,
root: Option<ExprId>,
nodes: Vec<Expr>,
parents: Vec<ExprId>,
}
impl IR {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
parents: Vec::new(),
root: None,
constant_folding: false,
}
}
pub fn constant_folding(&mut self, yes: bool) -> &mut Self {
self.constant_folding = yes;
self
}
pub fn clear(&mut self) {
self.nodes.clear();
self.parents.clear();
}
#[inline]
pub fn get(&self, expr_id: ExprId) -> &Expr {
self.nodes.get(expr_id.0 as usize).unwrap()
}
#[inline]
pub fn get_mut(&mut self, expr_id: ExprId) -> &mut Expr {
self.nodes.get_mut(expr_id.0 as usize).unwrap()
}
pub fn replace(&mut self, expr_id: ExprId, expr: Expr) -> Expr {
mem::replace(&mut self.nodes[expr_id.0 as usize], expr)
}
#[inline]
pub fn get_parent(&self, expr_id: ExprId) -> Option<ExprId> {
let parent = self.parents[expr_id.0 as usize];
if parent == ExprId::none() {
return None;
}
Some(parent)
}
#[inline]
pub fn set_parent(&mut self, expr_id: ExprId, parent_id: ExprId) {
self.parents[expr_id.0 as usize] = parent_id;
}
pub fn push(&mut self, expr: Expr) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents.push(ExprId::none());
self.nodes.push(expr);
for child in self.children(expr_id).collect::<Vec<ExprId>>() {
self.parents[child.0 as usize] = expr_id;
}
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn shift_vars(
&mut self,
expr_id: ExprId,
from_index: i32,
shift_amount: i32,
) {
self.dfs_mut(expr_id, |evt| match evt {
Event::Enter((_, expr, _)) => {
expr.shift_vars(from_index, shift_amount)
}
Event::Leave(_) => {}
});
}
pub fn dfs_iter(&self, start: ExprId) -> DFSIter<'_> {
DFSIter::new(start, self)
}
pub fn dfs_with_scope(&self, start: ExprId) -> DFSWithScopeIter<'_> {
DFSWithScopeIter::new(start, self)
}
pub fn dfs_mut<F>(&mut self, start: ExprId, mut f: F)
where
F: FnMut(Event<(ExprId, &mut Expr, EventContext)>),
{
let mut stack = vec![Event::Enter((start, EventContext::None))];
while let Some(evt) = stack.pop() {
if let Event::Enter(expr) = evt {
stack.push(Event::Leave(expr));
}
f(match &evt {
Event::Enter((expr_id, ctx)) => {
Event::Enter((*expr_id, self.get_mut(*expr_id), *ctx))
}
Event::Leave((expr_id, ctx)) => {
Event::Leave((*expr_id, self.get_mut(*expr_id), *ctx))
}
});
if let Event::Enter((expr, _)) = evt {
dfs_common(&self.nodes[expr.0 as usize], &mut stack);
}
}
}
pub fn dfs_find<P, C>(
&self,
start: ExprId,
predicate: P,
prune_if: C,
) -> Option<&Expr>
where
P: Fn(&Expr) -> bool,
C: Fn(&Expr) -> bool,
{
let mut dfs = self.dfs_iter(start);
while let Some(evt) = dfs.next() {
if let Event::Enter((_, expr, _)) = evt {
if predicate(expr) {
return Some(expr);
}
if prune_if(expr) {
dfs.prune();
}
}
}
None
}
pub fn ancestors(&self, expr: ExprId) -> Ancestors<'_> {
Ancestors { ir: self, current: expr }
}
pub fn children(&self, expr: ExprId) -> Children<'_> {
let mut dfs = self.dfs_iter(expr);
dfs.next();
Children { dfs }
}
pub fn compute_expr_hashes<F>(&self, start: ExprId, mut f: F)
where
F: FnMut(ExprId, u64),
{
let mut hashers = Vec::new();
let ignore = |expr: &Expr| {
matches!(expr, Expr::Const(_) | Expr::Filesize | Expr::Symbol(_))
};
for evt in self.dfs_iter(start) {
match evt {
Event::Enter((_, expr, _)) => {
if !ignore(expr) {
hashers.push(FxHasher::default());
}
for h in hashers.iter_mut() {
expr.hash(h);
}
}
Event::Leave((expr_id, expr, _)) => {
if !ignore(expr) {
let hasher = hashers.pop().unwrap();
f(expr_id, hasher.finish());
}
}
}
}
}
pub fn find_hoisting_candidates(&self) -> Vec<(ExprId, ExprId)> {
let mut depends_on = vec![None; self.nodes.len()];
let mut result = Vec::new();
let mut dfs = self.dfs_with_scope(self.root.unwrap());
while let Some(event) = dfs.next() {
let current_expr_id = match event {
Event::Enter((expr_id, _)) => expr_id,
Event::Leave(_) => continue,
};
let symbol = match self.get(current_expr_id) {
Expr::Symbol(symbol) => symbol,
_ => continue,
};
let var = match symbol.as_ref() {
Symbol::Var { var, .. } => var.index(),
_ => continue,
};
let stmt_declaring_var =
match dfs.scopes().find(|expr_id| match self.get(*expr_id) {
Expr::With(with) => {
with.declarations.iter().any(|(v, _)| v.index() == var)
}
Expr::ForIn(for_in) => {
for_in.variables.iter().any(|v| v.index() == var)
}
_ => false,
}) {
Some(stmt) => stmt,
None => continue,
};
for ancestor in self
.ancestors(current_expr_id)
.take_while(|ancestor| ancestor.ne(&stmt_declaring_var))
{
match &mut depends_on[ancestor.0 as usize] {
Some((_, v)) if *v >= var => {}
entry => *entry = Some((stmt_declaring_var, var)),
}
}
}
let mut dfs = self.dfs_with_scope(self.root.unwrap());
while let Some(event) = dfs.next() {
let (current_expr_id, ctx) = match event {
Event::Enter(_) => continue,
Event::Leave((expr_id, ctx)) => (expr_id, ctx),
};
match self.get(current_expr_id) {
Expr::Const(_) | Expr::Filesize | Expr::Symbol(_) => {}
_ if !matches!(ctx, EventContext::FieldAccess) => {
let current_depends_on =
depends_on[current_expr_id.0 as usize];
let parent_depends_on = self
.get_parent(current_expr_id)
.and_then(|e| depends_on[e.0 as usize]);
match (current_depends_on, parent_depends_on) {
(Some((c, _)), Some((p, _))) if c == p => continue,
_ => {}
}
match current_depends_on {
None => {
if let Some(outermost) = dfs.for_scopes().next() {
result.push((current_expr_id, outermost));
}
}
Some((defining_expr, _)) => {
match self.get(defining_expr) {
Expr::ForIn(_) => {}
_ => continue,
}
let mut scopes = dfs.for_scopes();
for expr_id in scopes.by_ref() {
if expr_id == defining_expr {
break;
}
}
if let Some(inner_loop) = scopes.next() {
result.push((current_expr_id, inner_loop));
}
if result.len() > 100 {
return result;
}
}
}
}
_ => {}
}
}
result
}
pub fn hoisting(&mut self) -> ExprId {
for (expr_id, loop_expr_id) in self.find_hoisting_candidates() {
let loop_parent = self.get_parent(loop_expr_id);
let var_index = self
.ancestors(loop_expr_id)
.map(|expr_id| self.get(expr_id).stack_frame_size())
.sum::<i32>();
self.shift_vars(loop_expr_id, var_index, 1);
let type_value = self.get(expr_id).type_value();
let var = Var::new(0, type_value.ty(), var_index);
let replaced = self.replace(
expr_id,
Expr::Symbol(Box::new(Symbol::Var {
var,
type_value: type_value.clone(),
})),
);
let var_init_stmt = self.push(replaced);
let with_stmt =
self.with(vec![(var, var_init_stmt)], loop_expr_id);
if let Some(loop_parent) = loop_parent {
self.set_parent(with_stmt, loop_parent);
self.get_mut(loop_parent)
.replace_child(loop_expr_id, with_stmt);
} else {
self.root = Some(with_stmt);
}
}
self.root.unwrap()
}
pub fn filesize_bounds(&self) -> FilesizeBounds {
let mut result = FilesizeBounds::default();
let mut dfs = self.dfs_iter(self.root.unwrap());
while let Some(evt) = dfs.next() {
let expr = match evt {
Event::Enter((_, expr, _)) => expr,
_ => continue,
};
match expr {
Expr::Gt { lhs, rhs } => {
match (self.get(*lhs), self.get(*rhs)) {
(Expr::Const(c), Expr::Filesize) => {
result.min_end(Bound::Excluded(c.as_integer()));
}
(Expr::Filesize, Expr::Const(c)) => {
result.max_start(Bound::Excluded(c.as_integer()));
}
_ => {}
}
}
Expr::Ge { lhs, rhs } => {
match (self.get(*lhs), self.get(*rhs)) {
(Expr::Const(c), Expr::Filesize) => {
result.min_end(Bound::Included(c.as_integer()));
}
(Expr::Filesize, Expr::Const(c)) => {
result.max_start(Bound::Included(c.as_integer()));
}
_ => {}
}
}
Expr::Lt { lhs, rhs } => {
match (self.get(*lhs), self.get(*rhs)) {
(Expr::Const(c), Expr::Filesize) => {
result.max_start(Bound::Excluded(c.as_integer()));
}
(Expr::Filesize, Expr::Const(c)) => {
result.min_end(Bound::Excluded(c.as_integer()));
}
_ => {}
}
}
Expr::Le { lhs, rhs } => {
match (self.get(*lhs), self.get(*rhs)) {
(Expr::Const(c), Expr::Filesize) => {
result.max_start(Bound::Included(c.as_integer()));
}
(Expr::Filesize, Expr::Const(c)) => {
result.min_end(Bound::Included(c.as_integer()));
}
_ => {}
}
}
_ => {}
}
if !matches!(expr, Expr::And { .. }) {
dfs.prune();
}
}
result
}
}
impl IR {
pub fn filesize(&mut self) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents.push(ExprId::none());
self.nodes.push(Expr::Filesize);
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn constant(&mut self, type_value: TypeValue) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents.push(ExprId::none());
self.nodes.push(Expr::Const(type_value));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn ident(&mut self, symbol: Symbol) -> ExprId {
if self.constant_folding {
let type_value = symbol.type_value();
if type_value.is_const() {
return self.constant(type_value.clone());
}
}
let expr_id = ExprId::from(self.nodes.len());
self.parents.push(ExprId::none());
self.nodes.push(Expr::Symbol(Box::new(symbol)));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn lookup(
&mut self,
type_value: TypeValue,
primary: ExprId,
index: ExprId,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[primary.0 as usize] = expr_id;
self.parents[index.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Lookup(Box::new(Lookup {
type_value,
primary,
index,
})));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn not(&mut self, operand: ExprId) -> ExprId {
if self.constant_folding
&& let Some(v) = self.get(operand).try_as_const_bool()
{
return self.constant(TypeValue::const_bool_from(!v));
}
let expr_id = ExprId::from(self.nodes.len());
self.parents[operand.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Not { operand });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn and(&mut self, mut operands: Vec<ExprId>) -> Result<ExprId, Error> {
if self.constant_folding {
operands.retain(|op| {
let type_value = self.get(*op).type_value().cast_to_bool();
!type_value.is_const() || !type_value.as_bool()
});
if operands.is_empty() {
return Ok(self.constant(TypeValue::const_bool_from(true)));
}
if operands.iter().any(|op| self.get(*op).type_value().is_const())
{
return Ok(self.constant(TypeValue::const_bool_from(false)));
}
}
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::And { operands });
debug_assert_eq!(self.parents.len(), self.nodes.len());
Ok(expr_id)
}
pub fn or(&mut self, mut operands: Vec<ExprId>) -> Result<ExprId, Error> {
if self.constant_folding {
operands.retain(|op| {
let type_value = self.get(*op).type_value().cast_to_bool();
!type_value.is_const() || type_value.as_bool()
});
if operands.is_empty() {
return Ok(self.constant(TypeValue::const_bool_from(false)));
}
if operands.iter().any(|op| self.get(*op).type_value().is_const())
{
return Ok(self.constant(TypeValue::const_bool_from(true)));
}
}
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::Or { operands });
debug_assert_eq!(self.parents.len(), self.nodes.len());
Ok(expr_id)
}
pub fn minus(&mut self, operand: ExprId) -> ExprId {
if self.constant_folding {
match self.get(operand).type_value() {
TypeValue::Integer { value: Const(v), .. } => {
return self.constant(TypeValue::const_integer_from(-v));
}
TypeValue::Float { value: Const(v), .. } => {
return self.constant(TypeValue::const_float_from(-v));
}
_ => {}
}
}
let expr_id = ExprId::from(self.nodes.len());
self.parents[operand.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Minus {
operand,
is_float: matches!(self.get(operand).ty(), Type::Float),
});
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn defined(&mut self, operand: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[operand.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Defined { operand });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn bitwise_not(&mut self, operand: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[operand.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::BitwiseNot { operand });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn bitwise_and(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::BitwiseAnd { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn bitwise_or(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::BitwiseOr { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn bitwise_xor(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::BitwiseXor { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn shl(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Shl { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn shr(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Shr { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn add(&mut self, operands: Vec<ExprId>) -> Result<ExprId, Error> {
let is_float = operands
.iter()
.any(|op| matches!(self.get(*op).ty(), Type::Float));
if self.constant_folding
&& let Some(value) = self.fold_arithmetic(
operands.as_slice(),
is_float,
|acc, x| acc + x,
)?
{
return Ok(self.constant(value));
}
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::Add { operands, is_float });
debug_assert_eq!(self.parents.len(), self.nodes.len());
Ok(expr_id)
}
pub fn sub(&mut self, operands: Vec<ExprId>) -> Result<ExprId, Error> {
let is_float = operands
.iter()
.any(|op| matches!(self.get(*op).ty(), Type::Float));
if self.constant_folding
&& let Some(value) = self.fold_arithmetic(
operands.as_slice(),
is_float,
|acc, x| acc - x,
)?
{
return Ok(self.constant(value));
}
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::Sub { operands, is_float });
debug_assert_eq!(self.parents.len(), self.nodes.len());
Ok(expr_id)
}
pub fn mul(&mut self, operands: Vec<ExprId>) -> Result<ExprId, Error> {
let is_float = operands
.iter()
.any(|op| matches!(self.get(*op).ty(), Type::Float));
if self.constant_folding
&& let Some(value) = self.fold_arithmetic(
operands.as_slice(),
is_float,
|acc, x| acc * x,
)?
{
return Ok(self.constant(value));
}
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::Mul { operands, is_float });
debug_assert_eq!(self.parents.len(), self.nodes.len());
Ok(expr_id)
}
pub fn div(&mut self, operands: Vec<ExprId>) -> Result<ExprId, Error> {
let is_float = operands
.iter()
.any(|op| matches!(self.get(*op).ty(), Type::Float));
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::Div { operands, is_float });
debug_assert_eq!(self.parents.len(), self.nodes.len());
Ok(expr_id)
}
pub fn modulus(&mut self, operands: Vec<ExprId>) -> Result<ExprId, Error> {
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::Mod { operands });
debug_assert_eq!(self.parents.len(), self.nodes.len());
Ok(expr_id)
}
pub fn field_access(&mut self, operands: Vec<ExprId>) -> ExprId {
let type_value = self.get(*operands.last().unwrap()).type_value();
if self.constant_folding && type_value.is_const() {
return self.constant(type_value.clone());
}
let expr_id = ExprId::from(self.nodes.len());
for operand in operands.iter() {
self.parents[operand.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::FieldAccess(Box::new(FieldAccess {
operands,
type_value,
})));
expr_id
}
pub fn eq(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Eq { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn ne(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Ne { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn ge(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Ge { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn gt(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Gt { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn le(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Le { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn lt(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Lt { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn contains(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Contains { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn icontains(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::IContains { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn starts_with(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::StartsWith { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn istarts_with(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::IStartsWith { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn ends_with(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::EndsWith { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn iends_with(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::IEndsWith { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn iequals(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::IEquals { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn matches(&mut self, lhs: ExprId, rhs: ExprId) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
self.parents[lhs.0 as usize] = expr_id;
self.parents[rhs.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::Matches { lhs, rhs });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_match(
&mut self,
pattern: PatternIdx,
anchor: MatchAnchor,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
match &anchor {
MatchAnchor::None => {}
MatchAnchor::At(expr) => {
self.parents[expr.0 as usize] = expr_id;
}
MatchAnchor::In(range) => {
self.parents[range.lower_bound.0 as usize] = expr_id;
self.parents[range.upper_bound.0 as usize] = expr_id;
}
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::PatternMatch { pattern, anchor });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_match_var(
&mut self,
symbol: Symbol,
anchor: MatchAnchor,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
match &anchor {
MatchAnchor::None => {}
MatchAnchor::At(expr) => {
self.parents[expr.0 as usize] = expr_id;
}
MatchAnchor::In(range) => {
self.parents[range.lower_bound.0 as usize] = expr_id;
self.parents[range.upper_bound.0 as usize] = expr_id;
}
}
self.parents.push(ExprId::none());
self.nodes
.push(Expr::PatternMatchVar { symbol: Box::new(symbol), anchor });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_length(
&mut self,
pattern: PatternIdx,
index: Option<ExprId>,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
if let Some(index) = &index {
self.parents[index.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::PatternLength { pattern, index });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_length_var(
&mut self,
symbol: Symbol,
index: Option<ExprId>,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
if let Some(index) = &index {
self.parents[index.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes
.push(Expr::PatternLengthVar { symbol: Box::new(symbol), index });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_offset(
&mut self,
pattern: PatternIdx,
index: Option<ExprId>,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
if let Some(index) = &index {
self.parents[index.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::PatternOffset { pattern, index });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_offset_var(
&mut self,
symbol: Symbol,
index: Option<ExprId>,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
if let Some(index) = &index {
self.parents[index.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes
.push(Expr::PatternOffsetVar { symbol: Box::new(symbol), index });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_count(
&mut self,
pattern: PatternIdx,
range: Option<Range>,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
if let Some(range) = &range {
self.parents[range.lower_bound.0 as usize] = expr_id;
self.parents[range.upper_bound.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::PatternCount { pattern, range });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn pattern_count_var(
&mut self,
symbol: Symbol,
range: Option<Range>,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
if let Some(range) = &range {
self.parents[range.lower_bound.0 as usize] = expr_id;
self.parents[range.upper_bound.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes
.push(Expr::PatternCountVar { symbol: Box::new(symbol), range });
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn func_call(
&mut self,
object: Option<ExprId>,
args: Vec<ExprId>,
signature: Rc<FuncSignature>,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
for arg in args.iter() {
self.parents[arg.0 as usize] = expr_id
}
if let Some(obj) = &object {
self.parents[obj.0 as usize] = expr_id;
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::FuncCall(Box::new(FuncCall {
object,
args,
signature,
})));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn of_expr_tuple(
&mut self,
quantifier: Quantifier,
for_vars: ForVars,
next_expr_var: Var,
items: Vec<ExprId>,
anchor: MatchAnchor,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
match quantifier {
Quantifier::Percentage(expr) | Quantifier::Expr(expr) => {
self.parents[expr.0 as usize] = expr_id
}
_ => {}
}
for item in items.iter() {
self.parents[item.0 as usize] = expr_id;
}
match &anchor {
MatchAnchor::None => {}
MatchAnchor::At(expr) => {
self.parents[expr.0 as usize] = expr_id;
}
MatchAnchor::In(range) => {
self.parents[range.lower_bound.0 as usize] = expr_id;
self.parents[range.upper_bound.0 as usize] = expr_id;
}
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::OfExprTuple(Box::new(OfExprTuple {
quantifier,
items,
anchor,
for_vars,
next_expr_var,
})));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn of_pattern_set(
&mut self,
quantifier: Quantifier,
for_vars: ForVars,
next_pattern_var: Var,
items: Vec<PatternIdx>,
anchor: MatchAnchor,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
match quantifier {
Quantifier::Percentage(expr) | Quantifier::Expr(expr) => {
self.parents[expr.0 as usize] = expr_id
}
_ => {}
}
match &anchor {
MatchAnchor::None => {}
MatchAnchor::At(expr) => {
self.parents[expr.0 as usize] = expr_id;
}
MatchAnchor::In(range) => {
self.parents[range.lower_bound.0 as usize] = expr_id;
self.parents[range.upper_bound.0 as usize] = expr_id;
}
}
self.parents.push(ExprId::none());
self.nodes.push(Expr::OfPatternSet(Box::new(OfPatternSet {
quantifier,
items,
anchor,
for_vars,
next_pattern_var,
})));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn for_of(
&mut self,
quantifier: Quantifier,
variable: Var,
for_vars: ForVars,
pattern_set: Vec<PatternIdx>,
body: ExprId,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
match quantifier {
Quantifier::Percentage(expr) | Quantifier::Expr(expr) => {
self.parents[expr.0 as usize] = expr_id
}
_ => {}
}
self.parents[body.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::ForOf(Box::new(ForOf {
quantifier,
variable,
pattern_set,
body,
for_vars,
})));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn for_in(
&mut self,
quantifier: Quantifier,
variables: Vec<Var>,
for_vars: ForVars,
iterable_var: Var,
iterable: Iterable,
body: ExprId,
) -> ExprId {
let expr_id = ExprId::from(self.nodes.len());
match quantifier {
Quantifier::Percentage(expr) | Quantifier::Expr(expr) => {
self.parents[expr.0 as usize] = expr_id
}
_ => {}
}
match &iterable {
Iterable::Range(range) => {
self.parents[range.lower_bound.0 as usize] = expr_id;
self.parents[range.upper_bound.0 as usize] = expr_id;
}
Iterable::ExprTuple(exprs) => {
for expr in exprs.iter() {
self.parents[expr.0 as usize] = expr_id;
}
}
Iterable::Expr(expr) => {
self.parents[expr.0 as usize] = expr_id;
}
}
self.parents[body.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::ForIn(Box::new(ForIn {
quantifier,
variables,
for_vars,
iterable_var,
iterable,
body,
})));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
pub fn with(
&mut self,
declarations: Vec<(Var, ExprId)>,
body: ExprId,
) -> ExprId {
let type_value = self.get(body).type_value();
let expr_id = ExprId::from(self.nodes.len());
for (_, expr) in declarations.iter() {
self.parents[expr.0 as usize] = expr_id;
}
self.parents[body.0 as usize] = expr_id;
self.parents.push(ExprId::none());
self.nodes.push(Expr::With(Box::new(With {
type_value,
declarations,
body,
})));
debug_assert_eq!(self.parents.len(), self.nodes.len());
expr_id
}
}
impl IR {
fn fold_arithmetic<F>(
&mut self,
operands: &[ExprId],
is_float: bool,
f: F,
) -> Result<Option<TypeValue>, Error>
where
F: FnMut(f64, f64) -> f64,
{
debug_assert!(!operands.is_empty());
if !operands.iter().all(|op| self.get(*op).type_value().is_const()) {
return Ok(None);
}
let folded = operands
.iter()
.map(|op| match self.get(*op).type_value() {
TypeValue::Integer { value: Const(v), .. } => v as f64,
TypeValue::Float { value: Const(v) } => v,
_ => unreachable!(),
})
.reduce(f) .unwrap();
if is_float {
Ok(Some(TypeValue::const_float_from(folded)))
} else if folded >= i64::MIN as f64 && folded <= i64::MAX as f64 {
Ok(Some(TypeValue::const_integer_from(folded as i64)))
} else {
Err(Error::NumberOutOfRange)
}
}
}
impl Debug for IR {
#[rustfmt::skip]
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut level = 1;
let anchor_str = |anchor: &MatchAnchor| match anchor {
MatchAnchor::None => "",
MatchAnchor::At(_) => " AT",
MatchAnchor::In(_) => " IN",
};
let range_str = |range: &Option<_>| {
if range.is_some() { " IN" } else { "" }
};
let index_str = |index: &Option<_>| {
if index.is_some() { " INDEX" } else { "" }
};
let mut expr_hashes = vec![0; self.nodes.len()];
self.compute_expr_hashes(self.root.unwrap(), |expr_id, hash| {
expr_hashes[expr_id.0 as usize] = hash;
});
for event in self.dfs_iter(self.root.unwrap()) {
match event {
Event::Leave(_) => level -= 1,
Event::Enter((expr_id, expr,_)) => {
for _ in 0..level {
write!(f, " ")?;
}
level += 1;
write!(f, "{expr_id:?}: ")?;
let expr_hash = expr_hashes[expr_id.0 as usize];
match expr {
Expr::Const(c) => write!(f, "CONST {c}")?,
Expr::Filesize => write!(f, "FILESIZE")?,
Expr::Not { .. } => write!(f, "NOT -- hash: {expr_hash:#08x}")?,
Expr::And { .. } => write!(f, "AND -- hash: {expr_hash:#08x}")?,
Expr::Or { .. } => write!(f, "OR -- hash: {expr_hash:#08x}")?,
Expr::Minus { .. } => write!(f, "MINUS -- hash: {expr_hash:#08x}")?,
Expr::Add { .. } => write!(f, "ADD -- hash: {expr_hash:#08x}")?,
Expr::Sub { .. } => write!(f, "SUB -- hash: {expr_hash:#08x}")?,
Expr::Mul { .. } => write!(f, "MUL -- hash: {expr_hash:#08x}")?,
Expr::Div { .. } => write!(f, "DIV -- hash: {expr_hash:#08x}")?,
Expr::Mod { .. } => write!(f, "MOD -- hash: {expr_hash:#08x}")?,
Expr::Shl { .. } => write!(f, "SHL -- hash: {expr_hash:#08x}")?,
Expr::Shr { .. } => write!(f, "SHR -- hash: {expr_hash:#08x}")?,
Expr::Eq { .. } => write!(f, "EQ -- hash: {expr_hash:#08x}")?,
Expr::Ne { .. } => write!(f, "NE -- hash: {expr_hash:#08x}")?,
Expr::Lt { .. } => write!(f, "LT -- hash: {expr_hash:#08x}")?,
Expr::Gt { .. } => write!(f, "GT -- hash: {expr_hash:#08x}")?,
Expr::Le { .. } => write!(f, "LE -- hash: {expr_hash:#08x}")?,
Expr::Ge { .. } => write!(f, "GE -- hash: {expr_hash:#08x}")?,
Expr::BitwiseNot { .. } => write!(f, "BITWISE_NOT -- hash: {expr_hash:#08x}")?,
Expr::BitwiseAnd { .. } => write!(f, "BITWISE_AND -- hash: {expr_hash:#08x}")?,
Expr::BitwiseOr { .. } => write!(f, "BITWISE_OR -- hash: {expr_hash:#08x}")?,
Expr::BitwiseXor { .. } => write!(f, "BITWISE_XOR -- hash: {expr_hash:#08x}")?,
Expr::Contains { .. } => write!(f, "CONTAINS -- hash: {expr_hash:#08x}")?,
Expr::IContains { .. } => write!(f, "ICONTAINS -- hash: {expr_hash:#08x}")?,
Expr::StartsWith { .. } => write!(f, "STARTS_WITH -- hash: {expr_hash:#08x}")?,
Expr::IStartsWith { .. } => write!(f, "ISTARTS_WITH -- hash: {expr_hash:#08x}")?,
Expr::EndsWith { .. } => write!(f, "ENDS_WITH -- hash: {expr_hash:#08x}")?,
Expr::IEndsWith { .. } => write!(f, "IENDS_WITH -- hash: {expr_hash:#08x}")?,
Expr::IEquals { .. } => write!(f, "IEQUALS -- hash: {expr_hash:#08x}")?,
Expr::Matches { .. } => write!(f, "MATCHES -- hash: {expr_hash:#08x}")?,
Expr::Defined { .. } => write!(f, "DEFINED -- hash: {expr_hash:#08x}")?,
Expr::FieldAccess { .. } => write!(f, "FIELD_ACCESS -- hash: {expr_hash:#08x}")?,
Expr::With { .. } => write!(f, "WITH -- hash: {expr_hash:#08x}")?,
Expr::Symbol(symbol) => write!(f, "SYMBOL {symbol:?}")?,
Expr::OfExprTuple(_) => write!(f, "OF -- hash: {expr_hash:#08x}")?,
Expr::OfPatternSet(_) => write!(f, "OF -- hash: {expr_hash:#08x}")?,
Expr::ForOf(_) => write!(f, "FOR_OF -- hash: {expr_hash:#08x}")?,
Expr::ForIn(_) => write!(f, "FOR_IN -- hash: {expr_hash:#08x}")?,
Expr::Lookup(_) => write!(f, "LOOKUP -- hash: {expr_hash:#08x}")?,
Expr::FuncCall(func_call) => write!(f,
"FN_CALL {} -- hash: {:#08x}",
func_call.mangled_name(),
expr_hash
)?,
Expr::PatternMatch { pattern, anchor } => write!(
f,
"PATTERN_MATCH {:?}{} -- hash: {:#08x}",
pattern,
anchor_str(anchor),
expr_hash
)?,
Expr::PatternMatchVar { symbol, anchor } => write!(
f,
"PATTERN_MATCH {:?}{} -- hash: {:#08x}",
symbol,
anchor_str(anchor),
expr_hash
)?,
Expr::PatternCount { pattern, range } => write!(
f,
"PATTERN_COUNT {:?}{} -- hash: {:#08x}",
pattern,
range_str(range),
expr_hash
)?,
Expr::PatternCountVar { symbol, range } => write!(
f,
"PATTERN_COUNT {:?}{} -- hash: {:#08x}",
symbol,
range_str(range),
expr_hash
)?,
Expr::PatternOffset { pattern, index } => write!(
f,
"PATTERN_OFFSET {:?}{} -- hash: {:#08x}",
pattern,
index_str(index),
expr_hash
)?,
Expr::PatternOffsetVar { symbol, index } => write!(
f,
"PATTERN_OFFSET {:?}{} -- hash: {:#08x}",
symbol,
index_str(index),
expr_hash
)?,
Expr::PatternLength { pattern, index } => write!(
f,
"PATTERN_LENGTH {:?}{} -- hash: {:#08x}",
pattern,
index_str(index),
expr_hash
)?,
Expr::PatternLengthVar { symbol, index } => write!(
f,
"PATTERN_LENGTH {:?}{} -- hash: {:#08x}",
symbol,
index_str(index),
expr_hash
)?,
}
writeln!(f, " -- parent: {:?} ", self.parents[expr_id.0 as usize])?;
}
}
}
Ok(())
}
}
pub(crate) struct Ancestors<'a> {
ir: &'a IR,
current: ExprId,
}
impl Iterator for Ancestors<'_> {
type Item = ExprId;
fn next(&mut self) -> Option<Self::Item> {
if self.current == ExprId::none() {
return None;
}
self.current = self.ir.parents[self.current.0 as usize];
if self.current == ExprId::none() {
return None;
}
Some(self.current)
}
}
pub(crate) struct Children<'a> {
dfs: DFSIter<'a>,
}
impl Iterator for Children<'_> {
type Item = ExprId;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.dfs.next()? {
Event::Enter((expr_id, _, _)) => {
self.dfs.prune();
return Some(expr_id);
}
Event::Leave(_) => {}
}
}
}
}
pub(crate) enum Expr {
Const(TypeValue),
Filesize,
Not { operand: ExprId },
And { operands: Vec<ExprId> },
Or { operands: Vec<ExprId> },
Minus { is_float: bool, operand: ExprId },
Add { is_float: bool, operands: Vec<ExprId> },
Sub { is_float: bool, operands: Vec<ExprId> },
Mul { is_float: bool, operands: Vec<ExprId> },
Div { is_float: bool, operands: Vec<ExprId> },
Mod { operands: Vec<ExprId> },
BitwiseNot { operand: ExprId },
BitwiseAnd { rhs: ExprId, lhs: ExprId },
Shl { rhs: ExprId, lhs: ExprId },
Shr { rhs: ExprId, lhs: ExprId },
BitwiseOr { rhs: ExprId, lhs: ExprId },
BitwiseXor { rhs: ExprId, lhs: ExprId },
Eq { rhs: ExprId, lhs: ExprId },
Ne { rhs: ExprId, lhs: ExprId },
Lt { rhs: ExprId, lhs: ExprId },
Gt { rhs: ExprId, lhs: ExprId },
Le { rhs: ExprId, lhs: ExprId },
Ge { rhs: ExprId, lhs: ExprId },
Contains { rhs: ExprId, lhs: ExprId },
IContains { rhs: ExprId, lhs: ExprId },
StartsWith { rhs: ExprId, lhs: ExprId },
IStartsWith { rhs: ExprId, lhs: ExprId },
EndsWith { rhs: ExprId, lhs: ExprId },
IEndsWith { rhs: ExprId, lhs: ExprId },
IEquals { rhs: ExprId, lhs: ExprId },
Matches { rhs: ExprId, lhs: ExprId },
Defined { operand: ExprId },
PatternMatch { pattern: PatternIdx, anchor: MatchAnchor },
PatternMatchVar { symbol: Box<Symbol>, anchor: MatchAnchor },
PatternCount { pattern: PatternIdx, range: Option<Range> },
PatternCountVar { symbol: Box<Symbol>, range: Option<Range> },
PatternOffset { pattern: PatternIdx, index: Option<ExprId> },
PatternOffsetVar { symbol: Box<Symbol>, index: Option<ExprId> },
PatternLength { pattern: PatternIdx, index: Option<ExprId> },
PatternLengthVar { symbol: Box<Symbol>, index: Option<ExprId> },
Symbol(Box<Symbol>),
With(Box<With>),
FieldAccess(Box<FieldAccess>),
FuncCall(Box<FuncCall>),
OfExprTuple(Box<OfExprTuple>),
OfPatternSet(Box<OfPatternSet>),
ForOf(Box<ForOf>),
ForIn(Box<ForIn>),
Lookup(Box<Lookup>),
}
pub(crate) struct Lookup {
pub type_value: TypeValue,
pub primary: ExprId,
pub index: ExprId,
}
pub(crate) struct FieldAccess {
pub type_value: TypeValue,
pub operands: Vec<ExprId>,
}
pub(crate) struct FuncCall {
pub object: Option<ExprId>,
pub signature: Rc<FuncSignature>,
pub args: Vec<ExprId>,
}
impl FuncCall {
pub fn signature(&self) -> &FuncSignature {
self.signature.as_ref()
}
pub fn mangled_name(&self) -> &str {
self.signature().mangled_name.as_str()
}
}
pub(crate) struct OfExprTuple {
pub quantifier: Quantifier,
pub items: Vec<ExprId>,
pub for_vars: ForVars,
pub next_expr_var: Var,
pub anchor: MatchAnchor,
}
pub(crate) struct OfPatternSet {
pub quantifier: Quantifier,
pub items: Vec<PatternIdx>,
pub for_vars: ForVars,
pub next_pattern_var: Var,
pub anchor: MatchAnchor,
}
pub(crate) struct ForOf {
pub quantifier: Quantifier,
pub variable: Var,
pub for_vars: ForVars,
pub pattern_set: Vec<PatternIdx>,
pub body: ExprId,
}
pub(crate) struct ForIn {
pub quantifier: Quantifier,
pub variables: Vec<Var>,
pub for_vars: ForVars,
pub iterable_var: Var,
pub iterable: Iterable,
pub body: ExprId,
}
pub(crate) enum Quantifier {
None,
All,
Any,
Percentage(ExprId),
Expr(ExprId),
}
#[derive(PartialEq, Eq)]
pub(crate) struct ForVars {
pub n: Var,
pub i: Var,
pub max_count: Var,
pub count: Var,
}
impl ForVars {
pub fn shift(&mut self, after: i32, amount: i32) {
self.n.shift(after, amount);
self.i.shift(after, amount);
self.max_count.shift(after, amount);
self.count.shift(after, amount);
}
}
pub(crate) struct With {
pub type_value: TypeValue,
pub declarations: Vec<(Var, ExprId)>,
pub body: ExprId,
}
pub(crate) enum MatchAnchor {
None,
At(ExprId),
In(Range),
}
pub(crate) struct Range {
pub lower_bound: ExprId,
pub upper_bound: ExprId,
}
pub(crate) enum Iterable {
Range(Range),
ExprTuple(Vec<ExprId>),
Expr(ExprId),
}
impl Iterable {
pub(crate) fn num_iterations(&self, ir: &IR) -> Option<i64> {
match self {
Iterable::Range(range) => {
let lower_bound = ir.get(range.lower_bound);
let upper_bound = ir.get(range.upper_bound);
if let (Some(lower_val), Some(upper_val)) = (
lower_bound.try_as_const_integer(),
upper_bound.try_as_const_integer(),
) {
upper_val.add(1).checked_sub(lower_val)
} else {
None
}
}
Iterable::ExprTuple(exprs) => Some(exprs.len() as i64),
Iterable::Expr(_) => None,
}
}
}
impl Index<ExprId> for [Expr] {
type Output = Expr;
fn index(&self, index: ExprId) -> &Self::Output {
self.get(index.0 as usize).unwrap()
}
}
impl Hash for Expr {
fn hash<H: Hasher>(&self, state: &mut H) {
discriminant(self).hash(state);
match self {
Expr::Const(type_value) => type_value.hash(state),
Expr::Symbol(symbol) => symbol.hash(state),
Expr::PatternMatch { pattern, anchor } => {
pattern.hash(state);
discriminant(anchor).hash(state);
}
Expr::PatternMatchVar { symbol, anchor } => {
symbol.hash(state);
discriminant(anchor).hash(state);
}
Expr::PatternCount { pattern, range } => {
pattern.hash(state);
discriminant(range).hash(state);
}
Expr::PatternCountVar { symbol, range } => {
symbol.hash(state);
discriminant(range).hash(state);
}
Expr::PatternOffset { pattern, index } => {
pattern.hash(state);
discriminant(index).hash(state);
}
Expr::PatternOffsetVar { symbol, index } => {
symbol.hash(state);
discriminant(index).hash(state);
}
Expr::PatternLength { pattern, index } => {
pattern.hash(state);
discriminant(index).hash(state);
}
Expr::PatternLengthVar { symbol, index } => {
symbol.hash(state);
discriminant(index).hash(state);
}
Expr::FuncCall(func_call) => {
func_call.signature.hash(state);
}
Expr::OfExprTuple(of_expr_tuple) => {
discriminant(&of_expr_tuple.quantifier).hash(state);
discriminant(&of_expr_tuple.anchor).hash(state);
}
Expr::OfPatternSet(of_pattern_set) => {
discriminant(&of_pattern_set.quantifier).hash(state);
discriminant(&of_pattern_set.anchor).hash(state);
for item in of_pattern_set.items.iter() {
item.hash(state);
}
}
Expr::ForOf(for_of) => {
discriminant(&for_of.quantifier).hash(state);
for item in for_of.pattern_set.iter() {
item.hash(state);
}
}
Expr::ForIn(for_in) => {
discriminant(&for_in.quantifier).hash(state);
discriminant(&for_in.iterable).hash(state);
}
_ => {}
}
}
}
impl Expr {
pub fn stack_frame_size(&self) -> i32 {
match self {
Expr::With(with) => with.declarations.len() as i32,
Expr::ForOf(_) => VarStack::FOR_OF_FRAME_SIZE,
Expr::ForIn(_) => VarStack::FOR_IN_FRAME_SIZE,
Expr::OfExprTuple(_) => VarStack::OF_FRAME_SIZE,
Expr::OfPatternSet(_) => VarStack::OF_FRAME_SIZE,
_ => 0,
}
}
pub fn shift_vars(&mut self, from_index: i32, shift_amount: i32) {
match self {
Expr::Symbol(symbol)
| Expr::PatternMatchVar { symbol, .. }
| Expr::PatternCountVar { symbol, .. }
| Expr::PatternOffsetVar { symbol, .. }
| Expr::PatternLengthVar { symbol, .. } => {
if let Symbol::Var { var, .. } = symbol.as_mut() {
var.shift(from_index, shift_amount)
}
}
Expr::With(with) => {
for (v, _) in with.declarations.iter_mut() {
v.shift(from_index, shift_amount)
}
}
Expr::OfExprTuple(of) => {
of.next_expr_var.shift(from_index, shift_amount);
of.for_vars.shift(from_index, shift_amount);
}
Expr::OfPatternSet(of) => {
of.next_pattern_var.shift(from_index, shift_amount);
of.for_vars.shift(from_index, shift_amount);
}
Expr::ForOf(for_of) => {
for_of.for_vars.shift(from_index, shift_amount);
}
Expr::ForIn(for_in) => {
for_in.iterable_var.shift(from_index, shift_amount);
for v in for_in.variables.iter_mut() {
v.shift(from_index, shift_amount)
}
for_in.for_vars.shift(from_index, shift_amount);
}
_ => {}
}
}
pub fn replace_child(&mut self, child: ExprId, replacement: ExprId) {
let replace_in_slice = |exprs: &mut [ExprId]| {
for expr in exprs {
if *expr == child {
*expr = replacement;
}
}
};
let replace_in_quantifier =
|quantifier: &mut Quantifier| match quantifier {
Quantifier::None | Quantifier::All | Quantifier::Any => {}
Quantifier::Percentage(expr) | Quantifier::Expr(expr) => {
if *expr == child {
*expr = replacement;
}
}
};
let replace_in_range = |range: &mut Range| {
if range.lower_bound == child {
range.lower_bound = replacement;
}
if range.upper_bound == child {
range.upper_bound = replacement;
}
};
let replace_in_anchor = |anchor: &mut MatchAnchor| match anchor {
MatchAnchor::None => {}
MatchAnchor::At(expr) => {
if *expr == child {
*expr = replacement;
}
}
MatchAnchor::In(range) => replace_in_range(range),
};
match self {
Expr::Const(_) => {}
Expr::Filesize => {}
Expr::Symbol(_) => {}
Expr::Not { operand }
| Expr::Minus { operand, .. }
| Expr::Defined { operand }
| Expr::BitwiseNot { operand } => {
if *operand == child {
*operand = replacement;
}
}
Expr::And { operands }
| Expr::Or { operands }
| Expr::Add { operands, .. }
| Expr::Sub { operands, .. }
| Expr::Mul { operands, .. }
| Expr::Div { operands, .. }
| Expr::Mod { operands, .. } => {
replace_in_slice(operands.as_mut_slice());
}
Expr::BitwiseAnd { lhs, rhs }
| Expr::Shl { lhs, rhs }
| Expr::Shr { lhs, rhs }
| Expr::BitwiseOr { lhs, rhs }
| Expr::BitwiseXor { lhs, rhs }
| Expr::Eq { lhs, rhs }
| Expr::Ne { lhs, rhs }
| Expr::Lt { lhs, rhs }
| Expr::Gt { lhs, rhs }
| Expr::Le { lhs, rhs }
| Expr::Ge { lhs, rhs }
| Expr::Contains { lhs, rhs }
| Expr::IContains { lhs, rhs }
| Expr::StartsWith { lhs, rhs }
| Expr::IStartsWith { lhs, rhs }
| Expr::EndsWith { lhs, rhs }
| Expr::IEndsWith { lhs, rhs }
| Expr::IEquals { lhs, rhs }
| Expr::Matches { lhs, rhs } => {
if *lhs == child {
*lhs = replacement;
}
if *rhs == child {
*rhs = replacement;
}
}
Expr::PatternMatch { anchor, .. }
| Expr::PatternMatchVar { anchor, .. } => {
replace_in_anchor(anchor)
}
Expr::PatternCount { range, .. }
| Expr::PatternCountVar { range, .. } => {
if let Some(range) = range {
replace_in_range(range)
}
}
Expr::PatternOffset { index, .. }
| Expr::PatternOffsetVar { index, .. }
| Expr::PatternLength { index, .. }
| Expr::PatternLengthVar { index, .. } => {
if let Some(index) = index
&& *index == child
{
*index = replacement
}
}
Expr::With(with) => {
for (_, expr) in with.declarations.iter_mut() {
if *expr == child {
*expr = replacement
}
}
if with.body == child {
with.body = replacement
}
}
Expr::FieldAccess(field_access) => {
replace_in_slice(field_access.operands.as_mut_slice());
}
Expr::FuncCall(func_call) => {
if let Some(expr) = &mut func_call.object
&& *expr == child
{
*expr = replacement
}
replace_in_slice(func_call.args.as_mut_slice());
}
Expr::OfExprTuple(of) => {
replace_in_slice(of.items.as_mut_slice());
replace_in_anchor(&mut of.anchor);
}
Expr::OfPatternSet(of) => {
replace_in_anchor(&mut of.anchor);
}
Expr::ForOf(for_of) => {
replace_in_quantifier(&mut for_of.quantifier);
if for_of.body == child {
for_of.body = replacement
}
}
Expr::ForIn(for_in) => {
replace_in_quantifier(&mut for_in.quantifier);
if for_in.body == child {
for_in.body = replacement
}
}
Expr::Lookup(lookup) => {
if lookup.primary == child {
lookup.primary = replacement;
}
if lookup.index == child {
lookup.index = replacement;
}
}
}
}
pub fn ty(&self) -> Type {
match self {
Expr::Const(type_value) => type_value.ty(),
Expr::Defined { .. }
| Expr::Not { .. }
| Expr::And { .. }
| Expr::Or { .. }
| Expr::Eq { .. }
| Expr::Ne { .. }
| Expr::Ge { .. }
| Expr::Gt { .. }
| Expr::Le { .. }
| Expr::Lt { .. }
| Expr::Contains { .. }
| Expr::IContains { .. }
| Expr::StartsWith { .. }
| Expr::IStartsWith { .. }
| Expr::EndsWith { .. }
| Expr::IEndsWith { .. }
| Expr::IEquals { .. }
| Expr::Matches { .. }
| Expr::PatternMatch { .. }
| Expr::PatternMatchVar { .. }
| Expr::OfExprTuple(_)
| Expr::OfPatternSet(_)
| Expr::ForOf(_)
| Expr::ForIn(_) => Type::Bool,
Expr::Minus { is_float, .. } => {
if *is_float {
Type::Float
} else {
Type::Integer
}
}
Expr::Add { is_float, .. }
| Expr::Sub { is_float, .. }
| Expr::Mul { is_float, .. }
| Expr::Div { is_float, .. } => {
if *is_float {
Type::Float
} else {
Type::Integer
}
}
Expr::Filesize
| Expr::PatternCount { .. }
| Expr::PatternCountVar { .. }
| Expr::PatternOffset { .. }
| Expr::PatternOffsetVar { .. }
| Expr::PatternLength { .. }
| Expr::PatternLengthVar { .. }
| Expr::Mod { .. }
| Expr::BitwiseNot { .. }
| Expr::BitwiseAnd { .. }
| Expr::BitwiseOr { .. }
| Expr::BitwiseXor { .. }
| Expr::Shl { .. }
| Expr::Shr { .. } => Type::Integer,
Expr::Symbol(symbol) => symbol.ty(),
Expr::FieldAccess(field_access) => field_access.type_value.ty(),
Expr::FuncCall(func_call) => func_call.signature.result.ty(),
Expr::Lookup(lookup) => lookup.type_value.ty(),
Expr::With(with) => with.type_value.ty(),
}
}
pub fn type_value(&self) -> TypeValue {
match self {
Expr::Const(type_value) => type_value.clone(),
Expr::Defined { .. }
| Expr::Not { .. }
| Expr::And { .. }
| Expr::Or { .. }
| Expr::Eq { .. }
| Expr::Ne { .. }
| Expr::Ge { .. }
| Expr::Gt { .. }
| Expr::Le { .. }
| Expr::Lt { .. }
| Expr::Contains { .. }
| Expr::IContains { .. }
| Expr::StartsWith { .. }
| Expr::IStartsWith { .. }
| Expr::EndsWith { .. }
| Expr::IEndsWith { .. }
| Expr::IEquals { .. }
| Expr::Matches { .. }
| Expr::PatternMatch { .. }
| Expr::PatternMatchVar { .. }
| Expr::OfExprTuple(_)
| Expr::OfPatternSet(_)
| Expr::ForOf(_)
| Expr::ForIn(_) => TypeValue::unknown_bool(),
Expr::Minus { is_float, .. } => {
if *is_float {
TypeValue::unknown_float()
} else {
TypeValue::unknown_integer()
}
}
Expr::Add { is_float, .. }
| Expr::Sub { is_float, .. }
| Expr::Mul { is_float, .. }
| Expr::Div { is_float, .. } => {
if *is_float {
TypeValue::unknown_float()
} else {
TypeValue::unknown_integer()
}
}
Expr::Filesize
| Expr::PatternCount { .. }
| Expr::PatternCountVar { .. }
| Expr::PatternOffset { .. }
| Expr::PatternOffsetVar { .. }
| Expr::PatternLength { .. }
| Expr::PatternLengthVar { .. }
| Expr::Mod { .. }
| Expr::BitwiseNot { .. }
| Expr::BitwiseAnd { .. }
| Expr::BitwiseOr { .. }
| Expr::BitwiseXor { .. }
| Expr::Shl { .. }
| Expr::Shr { .. } => TypeValue::unknown_integer(),
Expr::Symbol(symbol) => symbol.type_value().clone(),
Expr::FieldAccess(field_access) => field_access.type_value.clone(),
Expr::FuncCall(func_call) => func_call.signature.result.clone(),
Expr::Lookup(lookup) => lookup.type_value.clone(),
Expr::With(with) => with.type_value.clone(),
}
}
pub fn try_as_const_bool(&self) -> Option<bool> {
if let TypeValue::Bool { value: Const(v) } = self.type_value() {
Some(v)
} else {
None
}
}
pub fn try_as_const_integer(&self) -> Option<i64> {
if let TypeValue::Integer { value: Const(v), .. } = self.type_value() {
Some(v)
} else {
None
}
}
}