use std::borrow::Borrow;
use std::cell::RefCell;
use std::collections::BTreeSet;
use std::iter;
use std::ops::RangeInclusive;
use std::rc::Rc;
use bstr::{BString, ByteSlice};
use itertools::Itertools;
use yara_x_parser::Span;
use yara_x_parser::ast;
use yara_x_parser::ast::WithSpan;
use crate::compiler::context::VarStack;
use crate::compiler::errors::{
ArbitraryRegexpPrefix, AssignmentMismatch, DuplicateModifier,
DuplicatePattern, EmptyPatternSet, EntrypointUnsupported,
InvalidBase64Alphabet, InvalidModifier, InvalidModifierCombination,
InvalidPattern, InvalidRange, InvalidRegexp, MismatchingTypes,
MixedGreediness, NumberOutOfRange, SyntaxError, TooManyPatterns,
UnexpectedNegativeNumber, WrongArguments, WrongType,
};
use crate::compiler::ir::hex2hir::hex_pattern_hir_from_ast;
use crate::compiler::ir::{
Error, Expr, ExprId, Iterable, LiteralPattern, MatchAnchor, Pattern,
PatternFlags, PatternInRule, Quantifier, Range, RegexpPattern,
};
use crate::compiler::report::{Level, ReportBuilder};
use crate::compiler::{
CompileContext, CompileError, FilesizeBounds, ForVars, PatternIdx,
TextPatternAsHex, warnings,
};
use crate::errors::CustomError;
use crate::errors::{MethodNotAllowedInWith, PotentiallySlowLoop};
use crate::re;
use crate::symbols::{Symbol, SymbolLookup, SymbolTable};
use crate::types::Value::Const;
use crate::types::{
IntegerConstraint, Map, Regexp, StringConstraint, Type, TypeValue,
};
use crate::warnings::UnsatisfiableExpression;
const MAX_PATTERNS_PER_RULE: usize = 100_000;
const MAX_LOOP_ITERATIONS: i64 = 1_000_000;
pub(in crate::compiler) fn patterns_from_ast<'src>(
ctx: &mut CompileContext<'_, 'src>,
rule: &ast::Rule<'src>,
) -> Result<(), CompileError> {
for pattern_ast in rule.patterns.as_ref().into_iter().flatten() {
let pattern = pattern_from_ast(ctx, pattern_ast)?;
if pattern.identifier().name != "$"
&& let Some(existing) = ctx
.current_rule_patterns
.iter()
.find(|p| p.identifier.name == pattern.identifier.name)
{
return Err(DuplicatePattern::build(
ctx.report_builder,
pattern.identifier().name.to_string(),
ctx.report_builder
.span_to_code_loc(pattern.identifier().span()),
ctx.report_builder
.span_to_code_loc(existing.identifier.span()),
));
}
if ctx.current_rule_patterns.len() == MAX_PATTERNS_PER_RULE {
return Err(TooManyPatterns::build(
ctx.report_builder,
MAX_PATTERNS_PER_RULE,
ctx.report_builder.span_to_code_loc(rule.identifier.span()),
));
}
ctx.current_rule_patterns.push(pattern);
}
Ok(())
}
fn pattern_from_ast<'src>(
ctx: &mut CompileContext,
pattern: &ast::Pattern<'src>,
) -> Result<PatternInRule<'src>, CompileError> {
let mut modifiers = BTreeSet::new();
for modifier in pattern.modifiers().iter() {
if !modifiers.insert(modifier.as_text()) {
return Err(DuplicateModifier::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(modifier.span()),
));
}
}
match pattern {
ast::Pattern::Text(pat) => Ok(text_pattern_from_ast(ctx, pat)?),
ast::Pattern::Hex(pat) => Ok(hex_pattern_from_ast(ctx, pat)?),
ast::Pattern::Regexp(pat) => Ok(regexp_pattern_from_ast(ctx, pat)?),
}
}
pub(in crate::compiler) fn text_pattern_from_ast<'src>(
ctx: &mut CompileContext,
pattern: &ast::TextPattern<'src>,
) -> Result<PatternInRule<'src>, CompileError> {
let ascii = pattern.modifiers.ascii();
let xor = pattern.modifiers.xor();
let nocase = pattern.modifiers.nocase();
let fullword = pattern.modifiers.fullword();
let base64 = pattern.modifiers.base64();
let base64wide = pattern.modifiers.base64wide();
let wide = pattern.modifiers.wide();
let private = pattern.modifiers.private();
let invalid_combinations = [
("xor", xor, "nocase", nocase),
("base64", base64, "nocase", nocase),
("base64wide", base64wide, "nocase", nocase),
("base64", base64, "fullword", fullword),
("base64wide", base64wide, "fullword", fullword),
("base64", base64, "xor", xor),
("base64wide", base64wide, "xor", xor),
];
for (name1, modifier1, name2, modifier2) in invalid_combinations {
if let (Some(modifier1), Some(modifier2)) = (modifier1, modifier2) {
return Err(InvalidModifierCombination::build(
ctx.report_builder,
name1.to_string(),
name2.to_string(),
ctx.report_builder.span_to_code_loc(modifier1.span()),
ctx.report_builder.span_to_code_loc(modifier2.span()),
Some("these two modifiers can't be used together".to_string()),
));
};
}
let mut flags = PatternFlags::empty();
if ascii.is_some() || wide.is_none() {
flags.insert(PatternFlags::Ascii);
}
if wide.is_some() {
flags.insert(PatternFlags::Wide);
}
if nocase.is_some() {
flags.insert(PatternFlags::Nocase);
}
if fullword.is_some() {
flags.insert(PatternFlags::Fullword);
}
if private.is_some() {
flags.insert(PatternFlags::Private);
}
let xor_range = match xor {
Some(modifier @ ast::PatternModifier::Xor { start, end, .. }) => {
if *end < *start {
return Err(InvalidRange::build(
ctx.report_builder,
format!(
"lower bound ({start}) is greater than upper bound ({end})"
),
ctx.report_builder.span_to_code_loc(modifier.span()),
));
}
flags.insert(PatternFlags::Xor);
Some(*start..=*end)
}
_ => None,
};
let validate_alphabet = |alphabet: &Option<ast::LiteralString>| {
if alphabet.is_none() {
return Ok(None);
}
let alphabet = alphabet.as_ref().unwrap();
let alphabet_str = alphabet.as_str().unwrap();
match base64::alphabet::Alphabet::new(alphabet_str) {
Ok(_) => Ok(Some(String::from(alphabet_str))),
Err(err) => Err(InvalidBase64Alphabet::build(
ctx.report_builder,
err.to_string().to_lowercase(),
ctx.report_builder.span_to_code_loc(alphabet.span()),
)),
}
};
let base64_alphabet = match base64 {
Some(ast::PatternModifier::Base64 { alphabet, .. }) => {
flags.insert(PatternFlags::Base64);
validate_alphabet(alphabet)?
}
_ => None,
};
let base64wide_alphabet = match base64wide {
Some(ast::PatternModifier::Base64Wide { alphabet, .. }) => {
flags.insert(PatternFlags::Base64Wide);
validate_alphabet(alphabet)?
}
_ => None,
};
let (min_len, note) = if base64.is_some() {
(
3,
Some(
"`base64` requires that pattern is at least 3 bytes long"
.to_string(),
),
)
} else if base64wide.is_some() {
(
3,
Some(
"`base64wide` requires that pattern is at least 3 bytes long"
.to_string(),
),
)
} else {
(1, None)
};
let text: BString = pattern.text.value.as_ref().into();
if text.len() < min_len {
return Err(InvalidPattern::build(
ctx.report_builder,
pattern.identifier.name.to_string(),
"this pattern is too short".to_string(),
ctx.report_builder.span_to_code_loc(pattern.text.span()),
note,
));
}
Ok(PatternInRule {
identifier: pattern.identifier.clone(),
in_use: false,
span: pattern.span(),
pattern: Pattern::Text(LiteralPattern {
flags,
text,
xor_range,
base64_alphabet,
base64wide_alphabet,
anchored_at: None,
filesize_bounds: FilesizeBounds::default(),
}),
})
}
pub(in crate::compiler) fn hex_pattern_from_ast<'src>(
ctx: &mut CompileContext,
pattern: &ast::HexPattern<'src>,
) -> Result<PatternInRule<'src>, CompileError> {
for modifier in pattern.modifiers.iter() {
match modifier {
ast::PatternModifier::Private { .. } => {}
_ => {
return Err(InvalidModifier::build(
ctx.report_builder,
"this modifier can't be applied to a hex pattern"
.to_string(),
ctx.report_builder.span_to_code_loc(modifier.span()),
));
}
}
}
let hir = re::hir::Hir::from(hex_pattern_hir_from_ast(ctx, pattern)?);
if let Some(literal) =
hir.as_literal_bytes().and_then(|lit| lit.to_str().ok())
&& literal.chars().all(|c| {
(' '..='~').contains(&c) || c == '\t' || c == '\n' || c == '\r'
})
{
let code_loc = ctx.report_builder.span_to_code_loc(pattern.span());
let mut warning =
TextPatternAsHex::build(ctx.report_builder, code_loc.clone());
warning.report_mut().patch(code_loc, escape(literal));
ctx.warnings.add(|| warning);
}
Ok(PatternInRule {
identifier: pattern.identifier.clone(),
in_use: false,
span: pattern.span(),
pattern: Pattern::Hex(RegexpPattern {
hir,
flags: PatternFlags::Ascii,
anchored_at: None,
filesize_bounds: FilesizeBounds::default(),
}),
})
}
fn escape(s: &str) -> String {
let mut escaped = String::with_capacity(s.len());
escaped.push('"');
for c in s.chars() {
match c {
'\r' => escaped.push_str("\\r"),
'\n' => escaped.push_str("\\n"),
'\t' => escaped.push_str("\\t"),
'\\' => escaped.push_str("\\\\"),
'"' => escaped.push_str("\\\""),
_ => escaped.push(c),
}
}
escaped.push('"');
escaped
}
pub(in crate::compiler) fn regexp_pattern_from_ast<'src>(
ctx: &mut CompileContext,
pattern: &ast::RegexpPattern<'src>,
) -> Result<PatternInRule<'src>, CompileError> {
for modifier in pattern.modifiers.iter() {
match modifier {
ast::PatternModifier::Base64 { .. }
| ast::PatternModifier::Base64Wide { .. }
| ast::PatternModifier::Xor { .. } => {
return Err(InvalidModifier::build(
ctx.report_builder,
"this modifier can't be applied to a regexp".to_string(),
ctx.report_builder.span_to_code_loc(modifier.span()),
));
}
_ => {}
}
}
let mut flags = PatternFlags::empty();
if pattern.modifiers.ascii().is_some()
|| pattern.modifiers.wide().is_none()
{
flags |= PatternFlags::Ascii;
}
if pattern.modifiers.wide().is_some() {
flags |= PatternFlags::Wide;
}
if pattern.modifiers.fullword().is_some() {
flags |= PatternFlags::Fullword;
}
if pattern.modifiers.nocase().is_some() || pattern.regexp.case_insensitive
{
flags |= PatternFlags::Nocase;
}
if pattern.modifiers.nocase().is_some() && pattern.regexp.case_insensitive
{
let i_pos = pattern.regexp.literal.rfind('i').unwrap();
ctx.warnings.add(|| {
warnings::RedundantCaseModifier::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(
pattern.modifiers.nocase().unwrap().span(),
),
ctx.report_builder.span_to_code_loc(
pattern.regexp.span().subspan(i_pos, i_pos + 1),
),
)
});
}
let hir = re::parser::Parser::new()
.force_case_insensitive(flags.contains(PatternFlags::Nocase))
.allow_mixed_greediness(false)
.relaxed_re_syntax(ctx.relaxed_re_syntax)
.parse(&pattern.regexp)
.map_err(|err| {
re_error_to_compile_error(ctx.report_builder, &pattern.regexp, err)
})?;
Ok(PatternInRule {
identifier: pattern.identifier.clone(),
in_use: false,
span: pattern.span(),
pattern: Pattern::Regexp(RegexpPattern {
flags,
hir,
anchored_at: None,
filesize_bounds: FilesizeBounds::default(),
}),
})
}
fn expr_from_ast(
ctx: &mut CompileContext,
expr: &ast::Expr,
) -> Result<ExprId, CompileError> {
let expr = match expr {
ast::Expr::Entrypoint { span } => {
let code_loc = ctx.report_builder.span_to_code_loc(span.clone());
let mut err = EntrypointUnsupported::build(
ctx.report_builder,
code_loc.clone(),
);
err.report_mut()
.new_section(Level::HELP, "use `pe.entry_point`, elf.entry_point` or `macho.entry_point`")
.patch(code_loc, "pe.entry_point");
return Err(err);
}
ast::Expr::Filesize { .. } => ctx.ir.filesize(),
ast::Expr::True { .. } => {
ctx.ir.constant(TypeValue::const_bool_from(true))
}
ast::Expr::False { .. } => {
ctx.ir.constant(TypeValue::const_bool_from(false))
}
ast::Expr::LiteralInteger(lit) => {
ctx.ir.constant(TypeValue::const_integer_from(lit.value))
}
ast::Expr::LiteralFloat(lit) => {
ctx.ir.constant(TypeValue::const_float_from(lit.value))
}
ast::Expr::LiteralString(lit) => {
ctx.ir.constant(TypeValue::const_string_from(lit.value.as_bytes()))
}
ast::Expr::Regexp(regexp) => {
re::parser::Parser::new()
.relaxed_re_syntax(ctx.relaxed_re_syntax)
.parse(regexp.as_ref())
.map_err(|err| {
re_error_to_compile_error(ctx.report_builder, regexp, err)
})?;
ctx.ir
.constant(TypeValue::Regexp(Some(Regexp::new(regexp.literal))))
}
ast::Expr::Defined(expr) => defined_expr_from_ast(ctx, expr)?,
ast::Expr::Not(expr) => not_expr_from_ast(ctx, expr)?,
ast::Expr::And(operands) => and_expr_from_ast(ctx, operands)?,
ast::Expr::Or(operands) => or_expr_from_ast(ctx, operands)?,
ast::Expr::Minus(expr) => minus_expr_from_ast(ctx, expr)?,
ast::Expr::Add(expr) => add_expr_from_ast(ctx, expr)?,
ast::Expr::Sub(expr) => sub_expr_from_ast(ctx, expr)?,
ast::Expr::Mul(expr) => mul_expr_from_ast(ctx, expr)?,
ast::Expr::Div(expr) => div_expr_from_ast(ctx, expr)?,
ast::Expr::Mod(expr) => mod_expr_from_ast(ctx, expr)?,
ast::Expr::Shl(expr) => shl_expr_from_ast(ctx, expr)?,
ast::Expr::Shr(expr) => shr_expr_from_ast(ctx, expr)?,
ast::Expr::BitwiseNot(expr) => bitwise_not_expr_from_ast(ctx, expr)?,
ast::Expr::BitwiseAnd(expr) => bitwise_and_expr_from_ast(ctx, expr)?,
ast::Expr::BitwiseOr(expr) => bitwise_or_expr_from_ast(ctx, expr)?,
ast::Expr::BitwiseXor(expr) => bitwise_xor_expr_from_ast(ctx, expr)?,
ast::Expr::Eq(expr) => {
let eq_expr = eq_expr_from_ast(ctx, expr)?;
let (lhs, rhs) = match ctx.ir.get(eq_expr) {
Expr::Eq { lhs, rhs } => (*lhs, *rhs),
_ => unreachable!(),
};
let span = expr.span();
let lhs_span = expr.lhs.span();
let rhs_span = expr.rhs.span();
let lhs_expr = ctx.ir.get(lhs);
let rhs_expr = ctx.ir.get(rhs);
let replacement =
match (lhs_expr.type_value(), rhs_expr.type_value()) {
(
TypeValue::Bool { .. },
TypeValue::Integer { value: Const(0), .. },
) => Some((
ctx.ir.not(lhs),
format!(
"not {}",
ctx.report_builder.get_snippet(lhs_span)
),
)),
(
TypeValue::Integer { value: Const(0), .. },
TypeValue::Bool { .. },
) => Some((
ctx.ir.not(rhs),
format!(
"not {}",
ctx.report_builder.get_snippet(rhs_span)
),
)),
(
TypeValue::Bool { .. },
TypeValue::Integer { value: Const(1), .. },
) => Some((lhs, ctx.report_builder.get_snippet(lhs_span))),
(
TypeValue::Integer { value: Const(1), .. },
TypeValue::Bool { .. },
) => Some((rhs, ctx.report_builder.get_snippet(rhs_span))),
_ => None,
};
if let Some((replacement_expr, replacement)) = replacement {
let code_loc = ctx.report_builder.span_to_code_loc(span);
let mut warning = warnings::BooleanIntegerComparison::build(
ctx.report_builder,
code_loc.clone(),
);
warning.report_mut().patch(code_loc, replacement);
ctx.warnings.add(|| warning);
replacement_expr
} else {
eq_expr
}
}
ast::Expr::Ne(expr) => ne_expr_from_ast(ctx, expr)?,
ast::Expr::Gt(expr) => gt_expr_from_ast(ctx, expr)?,
ast::Expr::Ge(expr) => ge_expr_from_ast(ctx, expr)?,
ast::Expr::Lt(expr) => lt_expr_from_ast(ctx, expr)?,
ast::Expr::Le(expr) => le_expr_from_ast(ctx, expr)?,
ast::Expr::Contains(expr) => contains_expr_from_ast(ctx, expr)?,
ast::Expr::IContains(expr) => icontains_expr_from_ast(ctx, expr)?,
ast::Expr::StartsWith(expr) => startswith_expr_from_ast(ctx, expr)?,
ast::Expr::IStartsWith(expr) => istartswith_expr_from_ast(ctx, expr)?,
ast::Expr::EndsWith(expr) => endswith_expr_from_ast(ctx, expr)?,
ast::Expr::IEndsWith(expr) => iendswith_expr_from_ast(ctx, expr)?,
ast::Expr::IEquals(expr) => iequals_expr_from_ast(ctx, expr)?,
ast::Expr::Matches(expr) => matches_expr_from_ast(ctx, expr)?,
ast::Expr::Of(of) => of_expr_from_ast(ctx, of)?,
ast::Expr::ForOf(for_of) => for_of_expr_from_ast(ctx, for_of)?,
ast::Expr::ForIn(for_in) => for_in_expr_from_ast(ctx, for_in)?,
ast::Expr::With(with) => with_expr_from_ast(ctx, with)?,
ast::Expr::FuncCall(func_call) => func_call_from_ast(ctx, func_call)?,
ast::Expr::FieldAccess(expr) => {
let mut operands = Vec::with_capacity(expr.operands.len());
for operand in expr.operands.iter().dropping_back(1) {
let expr = expr_from_ast(ctx, operand)?;
check_type(ctx, expr, operand.span(), &[Type::Struct])?;
operands.push(expr);
ctx.one_shot_symbol_table =
ctx.ir.get(expr).type_value().symbol_table();
}
operands.push(expr_from_ast(ctx, expr.operands.last().unwrap())?);
ctx.ir.field_access(operands)
}
ast::Expr::Ident(ident) => {
let symbol = ctx.lookup(ident)?;
if let Symbol::Field { acl: Some(ref acl), .. } = symbol {
for entry in acl {
let accepted = entry.accept_if.is_empty()
|| entry
.accept_if
.iter()
.any(|accepted| ctx.features.contains(accepted));
let rejected = entry
.reject_if
.iter()
.any(|rejected| ctx.features.contains(rejected));
if !accepted || rejected {
return Err(CustomError::build(
ctx.report_builder,
entry.error_title.clone(),
entry.error_label.clone(),
ctx.report_builder.span_to_code_loc(ident.span()),
));
}
}
}
match symbol {
Symbol::Field {
deprecation_notice: Some(ref notice), ..
} => {
let code_loc =
ctx.report_builder.span_to_code_loc(ident.span());
let mut warning = warnings::DeprecatedField::build(
ctx.report_builder,
ident.name.to_string(),
code_loc.clone(),
notice.text.clone(),
);
if let Some(replacement) = ¬ice.replacement {
warning
.report_mut()
.new_section(
Level::HELP,
notice.help.clone().unwrap_or(
"apply the following changes".to_owned(),
),
)
.patch(code_loc, replacement);
}
ctx.warnings.add(|| warning);
}
Symbol::Rule { is_global: true, .. } => {
ctx.warnings.add(|| {
warnings::GlobalRuleMisuse::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(ident.span()),
Some("referencing a global rule in a condition is redundant, and may result in an unsatisfiable condition".to_string()),
)
});
}
_ => {}
}
ctx.ir.ident(symbol)
}
ast::Expr::PatternMatch(p) => {
let anchor = anchor_from_ast(ctx, &p.anchor)?;
match p.identifier.name {
"$" => {
if ctx.for_of_depth == 0 {
return Err(SyntaxError::build(
ctx.report_builder,
"this `$` is outside of the body of a `for .. of` statement".to_string(),
ctx.report_builder.span_to_code_loc(p.identifier.span()),
));
}
ctx.ir.pattern_match_var(
ctx.symbol_table.lookup("$").unwrap(),
anchor,
)
}
_ => {
let at = match anchor {
MatchAnchor::At(expr) => {
let value = ctx.ir.get(expr).type_value();
if value.is_const() {
value.try_as_integer()
} else {
None
}
}
_ => None,
};
let (pattern_idx, pattern) =
ctx.get_pattern_mut(&p.identifier)?;
pattern.mark_as_used();
if let Some(offset) = at {
pattern.anchor_at(offset as usize);
} else {
pattern.make_non_anchorable();
}
ctx.ir.pattern_match(pattern_idx, anchor)
}
}
}
ast::Expr::PatternCount(p) => {
if p.identifier.name == "#" && ctx.for_of_depth == 0 {
return Err(SyntaxError::build(
ctx.report_builder,
"this `#` is outside of the body of a `for .. of` statement".to_string(),
ctx.report_builder.span_to_code_loc(p.identifier.span()),
));
}
match (p.identifier.name, &p.range) {
("#", Some(range)) => {
let range = range_from_ast(ctx, range)?;
ctx.ir.pattern_count_var(
ctx.symbol_table.lookup("$").unwrap(),
Some(range),
)
}
("#", None) => ctx.ir.pattern_count_var(
ctx.symbol_table.lookup("$").unwrap(),
None,
),
(_, Some(range)) => {
let range = range_from_ast(ctx, range)?;
let (pattern_idx, pattern) =
ctx.get_pattern_mut(&p.identifier)?;
pattern.make_non_anchorable().mark_as_used();
ctx.ir.pattern_count(pattern_idx, Some(range))
}
(_, None) => {
let (pattern_idx, pattern) =
ctx.get_pattern_mut(&p.identifier)?;
pattern.make_non_anchorable().mark_as_used();
ctx.ir.pattern_count(pattern_idx, None)
}
}
}
ast::Expr::PatternOffset(p) => {
if p.identifier.name == "@" && ctx.for_of_depth == 0 {
return Err(SyntaxError::build(
ctx.report_builder,
"this `@` is outside of the body of a `for .. of` statement".to_string(),
ctx.report_builder.span_to_code_loc(p.identifier.span()),
));
}
match (p.identifier.name, &p.index) {
("@", Some(index)) => {
let range =
integer_in_range_from_ast(ctx, index, 1..=i64::MAX)?;
ctx.ir.pattern_offset_var(
ctx.symbol_table.lookup("$").unwrap(),
Some(range),
)
}
("@", None) => ctx.ir.pattern_offset_var(
ctx.symbol_table.lookup("$").unwrap(),
None,
),
(_, Some(index)) => {
let range =
integer_in_range_from_ast(ctx, index, 1..=i64::MAX)?;
let (pattern_idx, pattern) =
ctx.get_pattern_mut(&p.identifier)?;
pattern.make_non_anchorable().mark_as_used();
ctx.ir.pattern_offset(pattern_idx, Some(range))
}
(_, None) => {
let (pattern_idx, pattern) =
ctx.get_pattern_mut(&p.identifier)?;
pattern.make_non_anchorable().mark_as_used();
ctx.ir.pattern_offset(pattern_idx, None)
}
}
}
ast::Expr::PatternLength(p) => {
if p.identifier.name == "!" && ctx.for_of_depth == 0 {
return Err(SyntaxError::build(
ctx.report_builder,
"this `!` is outside of the body of a `for .. of` statement".to_string(),
ctx.report_builder.span_to_code_loc(p.identifier.span()),
));
}
match (p.identifier.name, &p.index) {
("!", Some(index)) => {
let index =
integer_in_range_from_ast(ctx, index, 1..=i64::MAX)?;
ctx.ir.pattern_length_var(
ctx.symbol_table.lookup("$").unwrap(),
Some(index),
)
}
("!", None) => ctx.ir.pattern_length_var(
ctx.symbol_table.lookup("$").unwrap(),
None,
),
(_, Some(index)) => {
let index =
integer_in_range_from_ast(ctx, index, 1..=i64::MAX)?;
let (pattern_idx, pattern) =
ctx.get_pattern_mut(&p.identifier)?;
pattern.make_non_anchorable().mark_as_used();
ctx.ir.pattern_length(pattern_idx, Some(index))
}
(_, None) => {
let (pattern_idx, pattern) =
ctx.get_pattern_mut(&p.identifier)?;
pattern.make_non_anchorable().mark_as_used();
ctx.ir.pattern_length(pattern_idx, None)
}
}
}
ast::Expr::Lookup(expr) => {
let primary = expr_from_ast(ctx, &expr.primary)?;
match ctx.ir.get(primary).type_value() {
TypeValue::Array(array) => {
let index =
non_negative_integer_from_ast(ctx, &expr.index)?;
ctx.ir.lookup(array.deputy(), primary, index)
}
TypeValue::Map(map) => {
let (key_ty, deputy_value) = match map.borrow() {
Map::IntegerKeys { deputy: Some(value), .. } => {
(Type::Integer, value)
}
Map::StringKeys { deputy: Some(value), .. } => {
(Type::String, value)
}
_ => unreachable!(),
};
let index = expr_from_ast(ctx, &expr.index)?;
check_type(ctx, index, expr.index.span(), &[key_ty])?;
ctx.ir.lookup(deputy_value.clone(), primary, index)
}
type_value => {
return Err(WrongType::build(
ctx.report_builder,
format!("`{}` or `{}`", Type::Array, Type::Map),
format!("`{}`", type_value.ty()),
ctx.report_builder
.span_to_code_loc(expr.primary.span()),
None,
));
}
}
}
};
Ok(expr)
}
pub(in crate::compiler) fn rule_condition_from_ast(
ctx: &mut CompileContext,
rule: &ast::Rule,
) -> Result<ExprId, CompileError> {
ctx.ir.clear();
let condition = bool_expr_from_ast(ctx, &rule.condition)?;
if let Some(value) =
ctx.ir.get(condition).type_value().cast_to_bool().try_as_bool()
{
ctx.warnings.add(|| {
warnings::InvariantBooleanExpression::build(
ctx.report_builder,
value,
ctx.report_builder.span_to_code_loc(rule.condition.span()),
Some(format!(
"rule `{}` is always `{}`",
rule.identifier.name, value
)),
)
});
}
ctx.ir.root = Some(condition);
Ok(condition)
}
fn bool_expr_from_ast(
ctx: &mut CompileContext,
ast: &ast::Expr,
) -> Result<ExprId, CompileError> {
let expr = expr_from_ast(ctx, ast)?;
match ctx.ir.get(expr).type_value() {
TypeValue::Func(func) => {
let help = func
.signatures()
.iter()
.find(|f| f.args.is_empty() || f.result.ty() == Type::Bool)
.map(|_| {
let style = ctx.report_builder.green_style();
format!(
"you probably meant {style}{}(){style:#}",
ctx.report_builder.get_snippet(ast.span())
)
});
return Err(WrongType::build(
ctx.report_builder,
"`bool`".to_string(),
"a function".to_string(),
ctx.report_builder.span_to_code_loc(ast.span()),
help,
));
}
TypeValue::Map(_) => {
return Err(WrongType::build(
ctx.report_builder,
"`bool`".to_string(),
"a map".to_string(),
ctx.report_builder.span_to_code_loc(ast.span()),
None,
));
}
TypeValue::Struct(_) => {
return Err(WrongType::build(
ctx.report_builder,
"`bool`".to_string(),
"a struct".to_string(),
ctx.report_builder.span_to_code_loc(ast.span()),
None,
));
}
TypeValue::Array(_) => {
return Err(WrongType::build(
ctx.report_builder,
"`bool`".to_string(),
"an array".to_string(),
ctx.report_builder.span_to_code_loc(ast.span()),
None,
));
}
TypeValue::Regexp(_) => {
return Err(WrongType::build(
ctx.report_builder,
"`bool`".to_string(),
"a regexp".to_string(),
ctx.report_builder.span_to_code_loc(ast.span()),
None,
));
}
type_value => {
warn_if_not_bool(ctx, type_value.ty(), ast.span());
}
}
Ok(expr)
}
enum OfItems {
BoolExprTuple(Vec<ExprId>),
PatternSet(Vec<PatternIdx>),
}
impl OfItems {
fn len(&self) -> usize {
match self {
OfItems::BoolExprTuple(tuple) => tuple.len(),
OfItems::PatternSet(pattern_set) => pattern_set.len(),
}
}
}
fn of_expr_from_ast(
ctx: &mut CompileContext,
of: &ast::Of,
) -> Result<ExprId, CompileError> {
let quantifier = quantifier_from_ast(ctx, &of.quantifier)?;
let mut stack_frame = ctx.vars.new_frame(VarStack::OF_FRAME_SIZE);
let for_vars = ForVars {
n: stack_frame.new_var(Type::Integer),
i: stack_frame.new_var(Type::Integer),
max_count: stack_frame.new_var(Type::Integer),
count: stack_frame.new_var(Type::Integer),
};
let (items, next_item_var) = match &of.items {
ast::OfItems::BoolExprTuple(tuple) => {
let next_item_var = stack_frame.new_var(Type::Bool);
(
OfItems::BoolExprTuple(
tuple
.iter()
.map(|e| {
let expr = bool_expr_from_ast(ctx, e)?;
Ok(expr)
})
.collect::<Result<Vec<ExprId>, CompileError>>()?,
),
next_item_var,
)
}
ast::OfItems::PatternSet(patterns) => {
let next_item_var = stack_frame.new_var(Type::Integer);
(
OfItems::PatternSet(pattern_set_from_ast(ctx, patterns)?),
next_item_var,
)
}
};
if let Quantifier::Expr(expr) = &quantifier
&& let Some(value) = ctx.ir.get(*expr).try_as_const_integer()
{
if value == 0 {
let mut warning = warnings::AmbiguousExpression::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(of.span()),
);
warning
.report_mut()
.new_section(
Level::HELP,
"consider using `none` instead of `0`",
)
.patch(
ctx.report_builder.span_to_code_loc(of.quantifier.span()),
"none",
);
ctx.warnings.add(|| warning)
}
if value > items.len() as i64 {
ctx.warnings.add(|| warnings::InvariantBooleanExpression::build(
ctx.report_builder,
false,
ctx.report_builder.span_to_code_loc(of.span()),
Some(format!(
"the expression requires {} matching patterns out of {}",
value, items.len()
)),
));
}
}
if matches!(of.anchor, Some(ast::MatchAnchor::At(_))) {
let raise_warning = match &quantifier {
Quantifier::All => items.len() > 1,
Quantifier::Expr(expr) => match ctx.ir.get(*expr).type_value() {
TypeValue::Integer { value: Const(value), .. } => value >= 2,
_ => false,
},
Quantifier::Percentage(expr) => {
match ctx.ir.get(*expr).type_value() {
TypeValue::Integer {
value: Const(percentage), ..
} => items.len() as f64 * percentage as f64 / 100.0 >= 2.0,
_ => false,
}
}
Quantifier::None | Quantifier::Any => false,
};
if raise_warning {
ctx.warnings.add(|| {
warnings::PotentiallyUnsatisfiableExpression::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(of.quantifier.span()),
ctx.report_builder
.span_to_code_loc(of.anchor.as_ref().unwrap().span()),
)
});
}
}
let anchor = anchor_from_ast(ctx, &of.anchor)?;
ctx.vars.unwind(&stack_frame);
let expr = match items {
OfItems::BoolExprTuple(exprs) => ctx.ir.of_expr_tuple(
quantifier,
for_vars,
next_item_var,
exprs,
anchor,
),
OfItems::PatternSet(pattern_set) => ctx.ir.of_pattern_set(
quantifier,
for_vars,
next_item_var,
pattern_set,
anchor,
),
};
Ok(expr)
}
fn for_of_expr_from_ast(
ctx: &mut CompileContext,
for_of: &ast::ForOf,
) -> Result<ExprId, CompileError> {
let quantifier = quantifier_from_ast(ctx, &for_of.quantifier)?;
let pattern_set = pattern_set_from_ast(ctx, &for_of.pattern_set)?;
let mut stack_frame = ctx.vars.new_frame(VarStack::FOR_OF_FRAME_SIZE);
let for_vars = ForVars {
n: stack_frame.new_var(Type::Integer),
i: stack_frame.new_var(Type::Integer),
max_count: stack_frame.new_var(Type::Integer),
count: stack_frame.new_var(Type::Integer),
};
let next_pattern_id = stack_frame.new_var(Type::Integer);
let mut loop_vars = SymbolTable::new();
loop_vars.insert(
"$",
Symbol::Var {
var: next_pattern_id,
type_value: TypeValue::unknown_integer(),
},
);
ctx.symbol_table.push(Rc::new(RefCell::new(loop_vars)));
ctx.for_of_depth += 1;
let body = bool_expr_from_ast(ctx, &for_of.body)?;
ctx.for_of_depth -= 1;
ctx.symbol_table.pop();
ctx.vars.unwind(&stack_frame);
if let Quantifier::Expr(expr) = &quantifier
&& let Some(value) = ctx.ir.get(*expr).try_as_const_integer()
{
if value > pattern_set.len() as i64 {
ctx.warnings.add(|| warnings::InvariantBooleanExpression::build(
ctx.report_builder,
false,
ctx.report_builder.span_to_code_loc(for_of.span()),
Some(format!(
"the expression requires {} matching patterns out of {}",
value, pattern_set.len()
)),
));
}
}
Ok(ctx.ir.for_of(quantifier, next_pattern_id, for_vars, pattern_set, body))
}
fn is_potentially_large_range(ctx: &CompileContext, range: &Range) -> bool {
if !ctx.ir.get(range.lower_bound).type_value().is_const() {
return false;
}
ctx.ir
.dfs_find(
range.upper_bound,
|node| matches!(node, Expr::Filesize | Expr::PatternCount { .. }),
|node| {
if let Expr::FuncCall(func) = node {
func.signature
.mangled_name
.as_str()
.eq("math.min@a:i,b:i@i")
} else {
false
}
},
)
.is_some()
}
fn for_in_expr_from_ast(
ctx: &mut CompileContext,
for_in: &ast::ForIn,
) -> Result<ExprId, CompileError> {
let quantifier = quantifier_from_ast(ctx, &for_in.quantifier)?;
let iterable = iterable_from_ast(ctx, &for_in.iterable)?;
let parent_multiplier = ctx.loop_iteration_multiplier;
if let Some(loop_iterations) = iterable.num_iterations(ctx.ir) {
let combined_iterations =
parent_multiplier.saturating_mul(loop_iterations);
if combined_iterations > MAX_LOOP_ITERATIONS {
ctx.warnings.add(|| {
warnings::TooManyIterations::build(
ctx.report_builder,
combined_iterations,
ctx.report_builder.span_to_code_loc(for_in.span()),
)
});
}
ctx.loop_iteration_multiplier = combined_iterations;
}
let (expected_vars, iterable_ty) = match &iterable {
Iterable::Range(range) => {
if is_potentially_large_range(ctx, range) {
if ctx.error_on_slow_loop {
return Err(PotentiallySlowLoop::build(
ctx.report_builder,
ctx.report_builder
.span_to_code_loc(for_in.iterable.span()),
));
} else {
ctx.warnings.add(|| {
warnings::PotentiallySlowLoop::build(
ctx.report_builder,
ctx.report_builder
.span_to_code_loc(for_in.iterable.span()),
)
})
}
}
(vec![TypeValue::unknown_integer()], Type::Unknown)
}
Iterable::ExprTuple(expressions) => {
(
vec![
expressions
.first()
.map(|node_idx| ctx.ir.get(*node_idx).type_value())
.unwrap()
.clone_without_value(),
],
Type::Unknown,
)
}
Iterable::Expr(expr) => match ctx.ir.get(*expr).type_value() {
TypeValue::Array(array) => (vec![array.deputy()], Type::Array),
TypeValue::Map(map) => match map.as_ref() {
Map::IntegerKeys { .. } => (
vec![TypeValue::unknown_integer(), map.deputy()],
Type::Map,
),
Map::StringKeys { .. } => (
vec![TypeValue::unknown_string(), map.deputy()],
Type::Map,
),
},
_ => unreachable!(),
},
};
let loop_vars = &for_in.variables;
if loop_vars.len() != expected_vars.len() {
let span = loop_vars.first().unwrap().span();
let span = span.combine(&loop_vars.last().unwrap().span());
return Err(AssignmentMismatch::build(
ctx.report_builder,
loop_vars.len() as u8,
expected_vars.len() as u8,
ctx.report_builder.span_to_code_loc(for_in.iterable.span()),
ctx.report_builder.span_to_code_loc(span),
));
}
let mut stack_frame = ctx.vars.new_frame(VarStack::FOR_IN_FRAME_SIZE);
let iterable_var = stack_frame.new_var(iterable_ty);
let for_vars = ForVars {
n: stack_frame.new_var(Type::Integer),
i: stack_frame.new_var(Type::Integer),
max_count: stack_frame.new_var(Type::Integer),
count: stack_frame.new_var(Type::Integer),
};
let mut symbols = SymbolTable::new();
let mut variables = Vec::new();
for (loop_var, type_value) in iter::zip(loop_vars, expected_vars) {
let var = stack_frame.new_var(type_value.ty());
variables.push(var);
symbols.insert(loop_var.name, Symbol::Var { var, type_value });
}
ctx.symbol_table.push(Rc::new(RefCell::new(symbols)));
let body = bool_expr_from_ast(ctx, &for_in.body)?;
ctx.symbol_table.pop();
ctx.vars.unwind(&stack_frame);
ctx.loop_iteration_multiplier = parent_multiplier;
Ok(ctx.ir.for_in(
quantifier,
variables,
for_vars,
iterable_var,
iterable,
body,
))
}
fn with_expr_from_ast(
ctx: &mut CompileContext,
with: &ast::With,
) -> Result<ExprId, CompileError> {
let mut stack_frame = ctx.vars.new_frame(with.declarations.len() as i32);
let mut declarations = Vec::new();
let symbols = ctx.symbol_table.push_new();
for item in with.declarations.iter() {
let expr = expr_from_ast(ctx, &item.expression)?;
let type_value = ctx.ir.get(expr).type_value();
if let TypeValue::Func(func) = &type_value {
if func.is_method() {
return Err(MethodNotAllowedInWith::build(
ctx.report_builder,
ctx.report_builder
.span_to_code_loc(item.expression.span()),
));
}
symbols
.borrow_mut()
.insert(item.identifier.name, Symbol::Func(func.clone()));
} else {
let var = stack_frame.new_var(type_value.ty());
declarations.push((var, expr));
symbols
.borrow_mut()
.insert(item.identifier.name, Symbol::Var { var, type_value });
}
}
let body = bool_expr_from_ast(ctx, &with.body)?;
let with_vars = ctx.symbol_table.pop().unwrap().take();
for item in with.declarations.iter() {
if !with_vars.used(item.identifier.name) {
ctx.warnings.add(|| {
warnings::UnusedIdentifier::build(
ctx.report_builder,
ctx.report_builder
.span_to_code_loc(item.identifier.span()),
)
})
}
}
ctx.vars.unwind(&stack_frame);
Ok(ctx.ir.with(declarations, body))
}
fn iterable_from_ast(
ctx: &mut CompileContext,
iter: &ast::Iterable,
) -> Result<Iterable, CompileError> {
match iter {
ast::Iterable::Range(range) => {
Ok(Iterable::Range(range_from_ast(ctx, range)?))
}
ast::Iterable::Expr(expr) => {
let span = expr.span();
let expr = expr_from_ast(ctx, expr)?;
check_type(ctx, expr, span, &[Type::Array, Type::Map])?;
Ok(Iterable::Expr(expr))
}
ast::Iterable::ExprTuple(expr_tuple) => {
let mut e = Vec::with_capacity(expr_tuple.len());
let mut prev: Option<(Type, Span)> = None;
for expr in expr_tuple {
let span = expr.span();
let expr = expr_from_ast(ctx, expr)?;
check_type(
ctx,
expr,
span.clone(),
&[Type::Integer, Type::Float, Type::String, Type::Bool],
)?;
let ty = ctx.ir.get(expr).ty();
if let Some((prev_ty, prev_span)) = prev
&& prev_ty != ty
{
return Err(MismatchingTypes::build(
ctx.report_builder,
prev_ty.to_string(),
ty.to_string(),
ctx.report_builder.span_to_code_loc(prev_span),
ctx.report_builder.span_to_code_loc(span),
));
}
prev = Some((ty, span));
e.push(expr);
}
Ok(Iterable::ExprTuple(e))
}
}
}
fn anchor_from_ast(
ctx: &mut CompileContext,
anchor: &Option<ast::MatchAnchor>,
) -> Result<MatchAnchor, CompileError> {
match anchor {
Some(ast::MatchAnchor::At(at_)) => {
Ok(MatchAnchor::At(non_negative_integer_from_ast(ctx, &at_.expr)?))
}
Some(ast::MatchAnchor::In(in_)) => {
Ok(MatchAnchor::In(range_from_ast(ctx, &in_.range)?))
}
None => Ok(MatchAnchor::None),
}
}
fn range_from_ast(
ctx: &mut CompileContext,
range: &ast::Range,
) -> Result<Range, CompileError> {
let lower_bound = non_negative_integer_from_ast(ctx, &range.lower_bound)?;
let upper_bound = non_negative_integer_from_ast(ctx, &range.upper_bound)?;
if let (
TypeValue::Integer { value: Const(lower_bound), .. },
TypeValue::Integer { value: Const(upper_bound), .. },
) = (
ctx.ir.get(lower_bound).type_value(),
ctx.ir.get(upper_bound).type_value(),
) && lower_bound > upper_bound
{
return Err(InvalidRange::build(
ctx.report_builder,
format!(
"lower bound ({lower_bound}) is greater than upper bound ({upper_bound})"
),
ctx.report_builder.span_to_code_loc(range.span()),
));
}
Ok(Range { lower_bound, upper_bound })
}
fn non_negative_integer_from_ast(
ctx: &mut CompileContext,
expr: &ast::Expr,
) -> Result<ExprId, CompileError> {
let span = expr.span();
let expr = expr_from_ast(ctx, expr)?;
check_type(ctx, expr, span.clone(), &[Type::Integer])?;
let type_value = ctx.ir.get(expr).type_value();
if let TypeValue::Integer { value: Const(value), .. } = type_value
&& value < 0
{
return Err(UnexpectedNegativeNumber::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(span),
));
}
Ok(expr)
}
fn integer_in_range_from_ast(
ctx: &mut CompileContext,
expr: &ast::Expr,
range: RangeInclusive<i64>,
) -> Result<ExprId, CompileError> {
let span = expr.span();
let expr = expr_from_ast(ctx, expr)?;
check_type(ctx, expr, span.clone(), &[Type::Integer])?;
let type_value = ctx.ir.get(expr).type_value();
if let TypeValue::Integer { value: Const(value), .. } = type_value
&& !range.contains(&value)
{
return Err(NumberOutOfRange::build(
ctx.report_builder,
*range.start(),
*range.end(),
ctx.report_builder.span_to_code_loc(span),
));
}
Ok(expr)
}
fn quantifier_from_ast(
ctx: &mut CompileContext,
quantifier: &ast::Quantifier,
) -> Result<Quantifier, CompileError> {
match quantifier {
ast::Quantifier::None { .. } => Ok(Quantifier::None),
ast::Quantifier::All { .. } => Ok(Quantifier::All),
ast::Quantifier::Any { .. } => Ok(Quantifier::Any),
ast::Quantifier::Percentage(expr) => {
Ok(Quantifier::Percentage(integer_in_range_from_ast(
ctx,
expr,
0..=100,
)?))
}
ast::Quantifier::Expr(expr) => {
Ok(Quantifier::Expr(non_negative_integer_from_ast(ctx, expr)?))
}
}
}
fn pattern_set_from_ast(
ctx: &mut CompileContext,
pattern_set: &ast::PatternSet,
) -> Result<Vec<PatternIdx>, CompileError> {
match pattern_set {
ast::PatternSet::Them { span } => {
if ctx.current_rule_patterns.is_empty() {
return Err(EmptyPatternSet::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(span.clone()),
Some("this rule doesn't define any patterns".to_string()),
));
}
for pattern in ctx.current_rule_patterns.iter_mut() {
pattern.make_non_anchorable().mark_as_used();
}
let pattern_indexes: Vec<PatternIdx> =
(0..ctx.current_rule_patterns.len())
.map(|i| i.into())
.collect();
Ok(pattern_indexes)
}
ast::PatternSet::Set(set) => {
for item in set {
if !ctx
.current_rule_patterns
.iter()
.any(|pattern| item.matches(pattern.identifier()))
{
return Err(EmptyPatternSet::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(item.span()),
Some(if item.wildcard {
format!(
"`{}*` doesn't match any pattern identifier",
item.identifier,
)
} else {
format!(
"`{}` doesn't match any pattern identifier",
item.identifier,
)
}),
));
}
}
let mut pattern_indexes: Vec<PatternIdx> = Vec::new();
for (i, pattern) in
ctx.current_rule_patterns.iter_mut().enumerate()
{
if set.iter().any(|p| p.matches(pattern.identifier())) {
pattern_indexes.push(i.into());
pattern.make_non_anchorable().mark_as_used();
}
}
Ok(pattern_indexes)
}
}
}
fn func_call_from_ast(
ctx: &mut CompileContext,
func_call: &ast::FuncCall,
) -> Result<ExprId, CompileError> {
let mut object = if let Some(obj) = &func_call.object {
let expr = expr_from_ast(ctx, obj)?;
ctx.one_shot_symbol_table =
ctx.ir.get(expr).type_value().symbol_table();
Some(expr)
} else {
None
};
let symbol = ctx.lookup(&func_call.identifier)?;
let func = match symbol {
Symbol::Func(func) => func,
Symbol::Field { type_value: TypeValue::Func(func), .. } => func,
Symbol::Var { type_value: TypeValue::Func(func), .. } => func,
_ => {
return Err(WrongType::build(
ctx.report_builder,
"`function`".to_string(),
format!("`{}`", symbol.ty()),
ctx.report_builder
.span_to_code_loc(func_call.identifier.span()),
None,
));
}
};
let args = func_call
.args
.iter()
.map(|arg| expr_from_ast(ctx, arg))
.collect::<Result<Vec<ExprId>, CompileError>>()?;
let arg_types: Vec<Type> =
args.iter().map(|arg| ctx.ir.get(*arg).ty()).collect();
let mut expected_args = Vec::new();
let mut matching_signature = None;
for signature in func.signatures().iter() {
let expected_arg_types: Vec<Type> = if signature.method_of().is_some()
{
signature.args.iter().skip(1).map(|(_, arg)| arg.ty()).collect()
} else {
signature.args.iter().map(|(_, arg)| arg.ty()).collect()
};
if arg_types == expected_arg_types {
matching_signature = Some(signature);
break;
}
expected_args.push(expected_arg_types);
}
if matching_signature.is_none() {
return Err(WrongArguments::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(func_call.args_span()),
Some(format!(
"accepted argument combinations:\n\n{}",
expected_args
.iter()
.map(|v| {
format!(
"({})",
v.iter()
.map(|i| i.to_string())
.collect::<Vec<String>>()
.join(", ")
)
})
.collect::<Vec<String>>()
.join("\n")
)),
));
}
let matching_signature = matching_signature.unwrap();
if matching_signature.method_of().is_none() {
object = None
}
Ok(ctx.ir.func_call(object, args, matching_signature.clone()))
}
fn matches_expr_from_ast(
ctx: &mut CompileContext,
expr: &ast::BinaryExpr,
) -> Result<ExprId, CompileError> {
let lhs_span = expr.lhs.span();
let rhs_span = expr.rhs.span();
let lhs = expr_from_ast(ctx, &expr.lhs)?;
let rhs = expr_from_ast(ctx, &expr.rhs)?;
check_type(ctx, lhs, lhs_span, &[Type::String])?;
check_type(ctx, rhs, rhs_span, &[Type::Regexp])?;
Ok(ctx.ir.matches(lhs, rhs))
}
fn check_type(
ctx: &CompileContext,
expr: ExprId,
span: Span,
accepted_types: &[Type],
) -> Result<(), CompileError> {
let ty = ctx.ir.get(expr).ty();
if accepted_types.contains(&ty) {
Ok(())
} else {
Err(WrongType::build(
ctx.report_builder,
CompileError::join_with_or(accepted_types, true),
format!("`{ty}`"),
ctx.report_builder.span_to_code_loc(span),
None,
))
}
}
fn check_operands(
ctx: &CompileContext,
lhs: ExprId,
rhs: ExprId,
lhs_span: Span,
rhs_span: Span,
accepted_types: &[Type],
compatible_types: &[(Type, Type)],
) -> Result<(), CompileError> {
let lhs_ty = ctx.ir.get(lhs).ty();
let rhs_ty = ctx.ir.get(rhs).ty();
assert!(!matches!(lhs_ty, Type::Unknown));
assert!(!matches!(rhs_ty, Type::Unknown));
check_type(ctx, lhs, lhs_span.clone(), accepted_types)?;
check_type(ctx, rhs, rhs_span.clone(), accepted_types)?;
let types_are_compatible = {
(lhs_ty == rhs_ty)
|| compatible_types.contains(&(lhs_ty, rhs_ty))
|| compatible_types.contains(&(rhs_ty, lhs_ty))
};
if !types_are_compatible {
return Err(MismatchingTypes::build(
ctx.report_builder,
lhs_ty.to_string(),
rhs_ty.to_string(),
ctx.report_builder.span_to_code_loc(lhs_span),
ctx.report_builder.span_to_code_loc(rhs_span),
));
}
Ok(())
}
fn re_error_to_compile_error(
report_builder: &ReportBuilder,
regexp: &ast::Regexp,
err: re::parser::Error,
) -> CompileError {
match err {
re::parser::Error::WrongSyntax { msg, span, note } => {
InvalidRegexp::build(
report_builder,
msg,
report_builder.span_to_code_loc(
regexp
.span()
.subspan(span.start.offset, span.end.offset)
.offset(1),
),
note,
)
}
re::parser::Error::ArbitraryPrefix { span } => {
ArbitraryRegexpPrefix::build(
report_builder,
report_builder.span_to_code_loc(
regexp
.span()
.subspan(span.start.offset, span.end.offset)
.offset(1),
),
)
}
re::parser::Error::UnsupportedInUnicode { span } => {
InvalidRegexp::build(
report_builder,
err.to_string(),
report_builder.span_to_code_loc(
regexp
.span()
.subspan(span.start.offset, span.end.offset)
.offset(1),
),
None,
)
}
re::parser::Error::MixedGreediness {
is_greedy_1,
is_greedy_2,
span_1,
span_2,
} => MixedGreediness::build(
report_builder,
if is_greedy_1 { "greedy" } else { "non-greedy" }.to_string(),
if is_greedy_2 { "greedy" } else { "non-greedy" }.to_string(),
report_builder.span_to_code_loc(
regexp
.span()
.subspan(span_1.start.offset, span_1.end.offset)
.offset(1),
),
report_builder.span_to_code_loc(
regexp
.span()
.subspan(span_2.start.offset, span_2.end.offset)
.offset(1),
),
),
}
}
pub(in crate::compiler) fn warn_if_not_bool(
ctx: &mut CompileContext,
ty: Type,
span: Span,
) {
if !matches!(ty, Type::Bool) {
ctx.warnings.add(|| {
let note = match ty {
Type::Integer => Some(
"non-zero integers are considered `true`, while zero is `false`"
.to_string(),
),
Type::Float => Some(
"non-zero floats are considered `true`, while zero is `false`"
.to_string(),
),
Type::String => Some(
r#"non-empty strings are considered `true`, while the empty string ("") is `false`"#
.to_string(),
),
_ => None,
};
warnings::NonBooleanAsBoolean::build(
ctx.report_builder,
ty.to_string(),
ctx.report_builder.span_to_code_loc(span),
note,
)
});
}
}
macro_rules! gen_unary_op {
($name:ident, $variant:ident, $( $accepted_types:path )|+, $check_fn:expr) => {
fn $name(
ctx: &mut CompileContext,
expr: &ast::UnaryExpr,
) -> Result<ExprId, CompileError> {
let operand = expr_from_ast(ctx, &expr.operand)?;
check_type(
ctx,
operand,
expr.operand.span(),
&[$( $accepted_types ),+],
)?;
let check_fn:
Option<fn(&mut CompileContext, ExprId, Span) -> Result<(), CompileError>>
= $check_fn;
if let Some(check_fn) = check_fn {
check_fn(ctx, operand, expr.operand.span())?;
}
Ok(ctx.ir.$variant(operand))
}
};
}
macro_rules! gen_binary_op {
($name:ident, $variant:ident, $( $accepted_types:path )|+, $compatible_types:expr, $check_fn:expr) => {
fn $name(
ctx: &mut CompileContext,
expr: &ast::BinaryExpr,
) -> Result<ExprId, CompileError> {
let lhs_span = expr.lhs.span();
let rhs_span = expr.rhs.span();
let lhs = expr_from_ast(ctx, &expr.lhs)?;
let rhs = expr_from_ast(ctx, &expr.rhs)?;
check_operands(
ctx,
lhs,
rhs,
lhs_span.clone(),
rhs_span.clone(),
&[$( $accepted_types ),+],
$compatible_types,
)?;
let check_fn:
Option<fn(&mut CompileContext, ExprId, ExprId, Span, Span) -> Result<(), CompileError>>
= $check_fn;
if let Some(check_fn) = check_fn {
check_fn(ctx, lhs, rhs, lhs_span, rhs_span)?;
}
Ok(ctx.ir.$variant(lhs, rhs))
}
};
}
macro_rules! gen_string_op {
($name:ident, $variant:ident) => {
fn $name(
ctx: &mut CompileContext,
expr: &ast::BinaryExpr,
) -> Result<ExprId, CompileError> {
let lhs_span = expr.lhs.span();
let rhs_span = expr.rhs.span();
let lhs = expr_from_ast(ctx, &expr.lhs)?;
let rhs = expr_from_ast(ctx, &expr.rhs)?;
check_operands(
ctx,
lhs,
rhs,
lhs_span.clone(),
rhs_span.clone(),
&[Type::String],
&[],
)?;
Ok(ctx.ir.$variant(lhs, rhs))
}
};
}
macro_rules! gen_n_ary_operation {
($name:ident, $variant:ident, $( $accepted_types:path )|+, $compatible_types:expr, $check_fn:expr) => {
fn $name(
ctx: &mut CompileContext,
expr: &ast::NAryExpr,
) -> Result<ExprId, CompileError> {
let span = expr.span();
let accepted_types = &[$( $accepted_types ),+];
let compatible_types = $compatible_types;
let operands_hir: Vec<ExprId> = expr
.operands()
.map(|expr| expr_from_ast(ctx, expr))
.collect::<Result<Vec<ExprId>, CompileError>>()?;
let check_fn:
Option<fn(&mut CompileContext, ExprId, Span) -> Result<(), CompileError>>
= $check_fn;
for (hir, ast) in iter::zip(operands_hir.iter(), expr.operands()) {
check_type(ctx, *hir, ast.span(), accepted_types)?;
if let Some(check_fn) = check_fn {
check_fn(ctx, *hir, ast.span())?;
}
}
for ((lhs_hir, rhs_ast), (rhs_hir, lhs_ast)) in
iter::zip(operands_hir.iter(), expr.operands()).tuple_windows()
{
let lhs_ty = ctx.ir.get(*lhs_hir).ty();
let rhs_ty = ctx.ir.get(*rhs_hir).ty();
let types_are_compatible = {
(lhs_ty == rhs_ty) ||
compatible_types.contains(&(lhs_ty, rhs_ty))
|| compatible_types.contains(&(rhs_ty, lhs_ty))
};
if !types_are_compatible {
return Err(MismatchingTypes::build(
ctx.report_builder,
lhs_ty.to_string(),
rhs_ty.to_string(),
ctx.report_builder.span_to_code_loc(expr.first().span().combine(&lhs_ast.span())),
ctx.report_builder.span_to_code_loc(rhs_ast.span()),
));
}
}
ctx.ir.$variant(operands_hir).map_err(|err| {
match err {
Error::NumberOutOfRange => {
NumberOutOfRange::build(
ctx.report_builder,
i64::MIN,
i64::MAX,
ctx.report_builder.span_to_code_loc(span),
)
}
}
})
}
};
}
gen_unary_op!(
defined_expr_from_ast,
defined,
Type::Bool | Type::Integer | Type::Float | Type::String,
None
);
gen_unary_op!(
not_expr_from_ast,
not,
Type::Bool | Type::Integer | Type::Float | Type::String,
Some(|ctx, operand, span| {
let ty = ctx.ir.get(operand).ty();
warn_if_not_bool(ctx, ty, span);
Ok(())
})
);
gen_n_ary_operation!(
and_expr_from_ast,
and,
Type::Bool | Type::Integer | Type::Float | Type::String,
&[
(Type::Integer, Type::Bool),
(Type::Integer, Type::Float),
(Type::Integer, Type::String),
(Type::String, Type::Bool),
(Type::String, Type::Float),
(Type::Float, Type::Bool)
],
Some(|ctx, operand, span| {
let ty = ctx.ir.get(operand).ty();
warn_if_not_bool(ctx, ty, span);
Ok(())
})
);
gen_n_ary_operation!(
or_expr_from_ast,
or,
Type::Bool | Type::Integer | Type::Float | Type::String,
&[
(Type::Integer, Type::Bool),
(Type::Integer, Type::Float),
(Type::Integer, Type::String),
(Type::String, Type::Bool),
(Type::String, Type::Float),
(Type::Float, Type::Bool)
],
Some(|ctx, operand, span| {
let ty = ctx.ir.get(operand).ty();
warn_if_not_bool(ctx, ty, span);
Ok(())
})
);
gen_unary_op!(minus_expr_from_ast, minus, Type::Integer | Type::Float, None);
gen_n_ary_operation!(
add_expr_from_ast,
add,
Type::Integer | Type::Float,
&[(Type::Integer, Type::Float)],
None
);
gen_n_ary_operation!(
sub_expr_from_ast,
sub,
Type::Integer | Type::Float,
&[(Type::Integer, Type::Float)],
None
);
gen_n_ary_operation!(
mul_expr_from_ast,
mul,
Type::Integer | Type::Float,
&[(Type::Integer, Type::Float)],
None
);
gen_n_ary_operation!(
div_expr_from_ast,
div,
Type::Integer | Type::Float,
&[(Type::Integer, Type::Float)],
None
);
gen_n_ary_operation!(mod_expr_from_ast, modulus, Type::Integer, &[], None);
gen_unary_op!(bitwise_not_expr_from_ast, bitwise_not, Type::Integer, None);
gen_binary_op!(shl_expr_from_ast, shl, Type::Integer, &[], Some(shx_check));
gen_binary_op!(shr_expr_from_ast, shr, Type::Integer, &[], Some(shx_check));
gen_binary_op!(bitwise_or_expr_from_ast, bitwise_or, Type::Integer, &[], None);
gen_binary_op!(
bitwise_and_expr_from_ast,
bitwise_and,
Type::Integer,
&[],
None
);
gen_binary_op!(
bitwise_xor_expr_from_ast,
bitwise_xor,
Type::Integer,
&[],
None
);
gen_binary_op!(
eq_expr_from_ast,
eq,
Type::Bool | Type::Integer | Type::Float | Type::String,
&[(Type::Integer, Type::Float), (Type::Integer, Type::Bool)],
Some(eq_check)
);
gen_binary_op!(
ne_expr_from_ast,
ne,
Type::Integer | Type::Float | Type::String,
&[(Type::Integer, Type::Float)],
None
);
gen_binary_op!(
gt_expr_from_ast,
gt,
Type::Integer | Type::Float | Type::String,
&[(Type::Integer, Type::Float)],
None
);
gen_binary_op!(
ge_expr_from_ast,
ge,
Type::Integer | Type::Float | Type::String,
&[(Type::Integer, Type::Float)],
None
);
gen_binary_op!(
lt_expr_from_ast,
lt,
Type::Integer | Type::Float | Type::String,
&[(Type::Integer, Type::Float)],
None
);
gen_binary_op!(
le_expr_from_ast,
le,
Type::Integer | Type::Float | Type::String,
&[(Type::Integer, Type::Float)],
None
);
gen_string_op!(contains_expr_from_ast, contains);
gen_string_op!(icontains_expr_from_ast, icontains);
gen_string_op!(startswith_expr_from_ast, starts_with);
gen_string_op!(istartswith_expr_from_ast, istarts_with);
gen_string_op!(endswith_expr_from_ast, ends_with);
gen_string_op!(iendswith_expr_from_ast, iends_with);
gen_string_op!(iequals_expr_from_ast, iequals);
fn eq_check(
ctx: &mut CompileContext,
lhs: ExprId,
rhs: ExprId,
lhs_span: Span,
rhs_span: Span,
) -> Result<(), CompileError> {
let lhs = ctx.ir.get(lhs).type_value();
let rhs = ctx.ir.get(rhs).type_value();
let check_string_constraints =
|ctx: &mut CompileContext,
const_string: Rc<BString>,
const_string_span: Span,
constraints: Vec<StringConstraint>,
constrained_string_span: Span| {
for constraint in constraints {
match constraint {
StringConstraint::Uppercase
if const_string.chars().any(|c| c.is_lowercase()) =>
{
let mut warning = UnsatisfiableExpression::build(
ctx.report_builder,
"this is an uppercase string".to_string(),
"this contains lowercase characters".to_string(),
ctx.report_builder.span_to_code_loc(
constrained_string_span.clone()
),
ctx.report_builder.span_to_code_loc(
const_string_span.clone()
),
Some(
"an uppercase string can't be equal to a string containing lowercase characters"
.to_string()),
);
warning.report_mut().patch(
ctx.report_builder
.span_to_code_loc(const_string_span.clone()),
format!(
"\"{}\"",
const_string.to_string().to_uppercase()
),
);
ctx.warnings.add(|| warning);
return;
}
StringConstraint::Lowercase
if const_string.chars().any(|c| c.is_uppercase()) =>
{
let mut warning = UnsatisfiableExpression::build(
ctx.report_builder,
"this is a lowercase string".to_string(),
"this contains uppercase characters".to_string(),
ctx.report_builder.span_to_code_loc(
constrained_string_span.clone()
),
ctx.report_builder.span_to_code_loc(
const_string_span.clone()
),
Some(
"a lowercase string can't be equal to a string containing uppercase characters"
.to_string()),
);
warning.report_mut().patch(
ctx.report_builder
.span_to_code_loc(const_string_span.clone()),
format!(
"\"{}\"",
const_string.to_string().to_lowercase()
),
);
ctx.warnings.add(|| warning);
return;
}
StringConstraint::ExactLength(n)
if const_string.len() != n =>
{
ctx.warnings.add(|| {
UnsatisfiableExpression::build(
ctx.report_builder,
format!("the length of this string is {n}"),
format!(
"the length of this string is {}",
const_string.len()
),
ctx.report_builder.span_to_code_loc(
constrained_string_span.clone(),
),
ctx.report_builder.span_to_code_loc(
const_string_span.clone(),
),
None,
)
});
return;
}
_ => {}
}
}
};
let check_integer_constraints =
|ctx: &mut CompileContext,
const_integer: i64,
const_integer_span: Span,
constraints: Vec<IntegerConstraint>,
constrained_integer_span: Span| {
for constraint in constraints {
match constraint {
IntegerConstraint::Range(min, max)
if !(min..=max).contains(&const_integer) =>
{
ctx.warnings.add(|| {
UnsatisfiableExpression::build(
ctx.report_builder,
format!(
"this expression is an integer in the range [{min},{max}]",
),
format!(
"this integer is outside the range [{min},{max}]",
),
ctx.report_builder.span_to_code_loc(
constrained_integer_span.clone(),
),
ctx.report_builder.span_to_code_loc(
const_integer_span.clone(),
),
None,
)
});
}
_ => {}
}
}
};
match (lhs, rhs, lhs_span, rhs_span) {
(
TypeValue::String { value: Const(const_string), .. },
TypeValue::String { constraints: Some(constraints), .. },
const_string_span,
constrained_string_span,
)
| (
TypeValue::String { constraints: Some(constraints), .. },
TypeValue::String { value: Const(const_string), .. },
constrained_string_span,
const_string_span,
) => check_string_constraints(
ctx,
const_string,
const_string_span,
constraints,
constrained_string_span,
),
(
TypeValue::Integer { value: Const(const_integer), .. },
TypeValue::Integer { constraints: Some(constraints), .. },
const_integer_span,
constrained_integer_span,
)
| (
TypeValue::Integer { constraints: Some(constraints), .. },
TypeValue::Integer { value: Const(const_integer), .. },
constrained_integer_span,
const_integer_span,
) => check_integer_constraints(
ctx,
const_integer,
const_integer_span,
constraints,
constrained_integer_span,
),
_ => {}
};
Ok(())
}
fn shx_check(
ctx: &mut CompileContext,
_lhs: ExprId,
rhs: ExprId,
_lhs_span: Span,
rhs_span: Span,
) -> Result<(), CompileError> {
if let TypeValue::Integer { value: Const(value), .. } =
ctx.ir.get(rhs).type_value()
&& value < 0
{
return Err(UnexpectedNegativeNumber::build(
ctx.report_builder,
ctx.report_builder.span_to_code_loc(rhs_span),
));
}
Ok(())
}