use erg_common::config::ErgConfig;
use erg_common::log;
use erg_common::traits::{Locational, Stream};
use erg_common::Str;
use erg_parser::token::TokenKind;
use crate::error::{EffectError, EffectErrors};
use crate::hir::{Array, Def, Dict, Expr, Params, Set, Signature, Tuple, HIR};
use crate::ty::{HasType, Visibility};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum BlockKind {
Func,
ConstFunc,
ConstInstant, Proc,
Instant,
Module,
}
use BlockKind::*;
#[derive(Debug)]
pub struct SideEffectChecker {
cfg: ErgConfig,
path_stack: Vec<Visibility>,
block_stack: Vec<BlockKind>,
errs: EffectErrors,
}
impl SideEffectChecker {
pub fn new(cfg: ErgConfig) -> Self {
Self {
cfg,
path_stack: vec![],
block_stack: vec![],
errs: EffectErrors::empty(),
}
}
fn full_path(&self) -> String {
self.path_stack.iter().fold(String::new(), |acc, vis| {
if vis.is_public() {
acc + "." + &vis.def_namespace[..]
} else {
acc + "::" + &vis.def_namespace[..]
}
})
}
fn in_context_effects_allowed(&self) -> bool {
if self.block_stack.len() == 1 {
return true;
}
match (
self.block_stack.get(self.block_stack.len() - 2).unwrap(),
self.block_stack.last().unwrap(),
) {
(_, Func | ConstInstant) => false,
(_, Proc) => true,
(Proc | Module | Instant, Instant) => true,
_ => false,
}
}
pub fn check(mut self, hir: HIR) -> Result<HIR, (HIR, EffectErrors)> {
self.path_stack.push(Visibility::private(hir.name.clone()));
self.block_stack.push(Module);
log!(info "the side-effects checking process has started.{RESET}");
for expr in hir.module.iter() {
match expr {
Expr::Def(def) => {
self.check_def(def);
}
Expr::ClassDef(class_def) => {
if let Some(req_sup) = &class_def.require_or_sup {
self.check_expr(req_sup);
}
for def in class_def.methods.iter() {
self.check_expr(def);
}
}
Expr::PatchDef(patch_def) => {
self.check_expr(patch_def.base.as_ref());
for def in patch_def.methods.iter() {
self.check_expr(def);
}
}
Expr::Call(call) => {
for parg in call.args.pos_args.iter() {
self.check_expr(&parg.expr);
}
for kwarg in call.args.kw_args.iter() {
self.check_expr(&kwarg.expr);
}
}
Expr::BinOp(bin) => {
self.check_expr(&bin.lhs);
self.check_expr(&bin.rhs);
}
Expr::UnaryOp(unary) => {
self.check_expr(&unary.expr);
}
Expr::Accessor(_) | Expr::Literal(_) => {}
Expr::Array(array) => match array {
Array::Normal(arr) => {
for elem in arr.elems.pos_args.iter() {
self.check_expr(&elem.expr);
}
}
Array::WithLength(arr) => {
self.check_expr(&arr.elem);
self.check_expr(&arr.len);
}
Array::Comprehension(arr) => {
self.check_expr(&arr.elem);
self.check_expr(&arr.guard);
}
},
Expr::Tuple(tuple) => match tuple {
Tuple::Normal(tuple) => {
for elem in tuple.elems.pos_args.iter() {
self.check_expr(&elem.expr);
}
}
},
Expr::Record(rec) => {
self.path_stack
.push(Visibility::private(Str::ever("<record>")));
self.block_stack.push(Instant);
for attr in rec.attrs.iter() {
self.check_def(attr);
}
self.path_stack.pop();
self.block_stack.pop();
}
Expr::Set(set) => match set {
Set::Normal(set) => {
for elem in set.elems.pos_args.iter() {
self.check_expr(&elem.expr);
}
}
Set::WithLength(set) => {
self.check_expr(&set.elem);
self.check_expr(&set.len);
}
},
Expr::Dict(dict) => match dict {
Dict::Normal(dict) => {
for kv in dict.kvs.iter() {
self.check_expr(&kv.key);
self.check_expr(&kv.value);
}
}
other => todo!("{other}"),
},
Expr::TypeAsc(tasc) => {
self.check_expr(&tasc.expr);
}
Expr::Lambda(lambda) => {
let is_proc = lambda.is_procedural();
if is_proc {
self.path_stack
.push(Visibility::private(Str::ever("<lambda!>")));
self.block_stack.push(Proc);
} else {
self.path_stack
.push(Visibility::private(Str::ever("<lambda>")));
self.block_stack.push(Func);
}
lambda.body.iter().for_each(|chunk| self.check_expr(chunk));
self.path_stack.pop();
self.block_stack.pop();
}
Expr::ReDef(_)
| Expr::Code(_)
| Expr::Compound(_)
| Expr::Import(_)
| Expr::Dummy(_) => {}
}
}
log!(info "the side-effects checking process has completed, found errors: {}{RESET}", self.errs.len());
if self.errs.is_empty() {
Ok(hir)
} else {
Err((hir, self.errs))
}
}
fn check_params(&mut self, params: &Params) {
for nd_param in params.non_defaults.iter() {
if nd_param.vi.t.is_procedure() && !nd_param.inspect().unwrap().ends_with('!') {
self.errs.push(EffectError::proc_assign_error(
self.cfg.input.clone(),
line!() as usize,
nd_param.raw.pat.loc(),
self.full_path(),
));
}
}
if let Some(var_arg) = params.var_params.as_deref() {
if var_arg.vi.t.is_procedure() && !var_arg.inspect().unwrap().ends_with('!') {
self.errs.push(EffectError::proc_assign_error(
self.cfg.input.clone(),
line!() as usize,
var_arg.raw.pat.loc(),
self.full_path(),
));
}
}
for d_param in params.defaults.iter() {
if d_param.sig.vi.t.is_procedure() && !d_param.inspect().unwrap().ends_with('!') {
self.errs.push(EffectError::proc_assign_error(
self.cfg.input.clone(),
line!() as usize,
d_param.sig.raw.pat.loc(),
self.full_path(),
));
}
self.check_expr(&d_param.default_val);
}
}
fn check_def(&mut self, def: &Def) {
let name_and_vis = Visibility::new(def.sig.vis().clone(), def.sig.inspect().clone());
self.path_stack.push(name_and_vis);
let is_procedural = def.sig.is_procedural();
let is_subr = def.sig.is_subr();
let is_const = def.sig.is_const();
match (is_procedural, is_subr, is_const) {
(true, true, true) => {
panic!("user-defined constant procedures are not allowed");
}
(true, true, false) => {
self.block_stack.push(Proc);
}
(_, false, false) => {
self.block_stack.push(Instant);
}
(false, true, true) => {
self.block_stack.push(ConstFunc);
}
(false, true, false) => {
self.block_stack.push(Func);
}
(_, false, true) => {
self.block_stack.push(ConstInstant);
}
}
if let Signature::Subr(sig) = &def.sig {
self.check_params(&sig.params);
}
let last_idx = def.body.block.len().saturating_sub(1);
for (i, chunk) in def.body.block.iter().enumerate() {
self.check_expr(chunk);
if i == last_idx
&& self.block_stack.last().unwrap() == &Instant
&& !def.sig.is_procedural()
&& chunk.t().is_procedure()
{
self.errs.push(EffectError::proc_assign_error(
self.cfg.input.clone(),
line!() as usize,
def.sig.loc(),
self.full_path(),
));
}
}
self.path_stack.pop();
self.block_stack.pop();
}
fn check_expr(&mut self, expr: &Expr) {
match expr {
Expr::Literal(_) => {}
Expr::Def(def) => {
self.check_def(def);
}
Expr::ClassDef(class_def) => {
if let Some(req_sup) = &class_def.require_or_sup {
self.check_expr(req_sup);
}
for def in class_def.methods.iter() {
self.check_expr(def);
}
}
Expr::PatchDef(patch_def) => {
self.check_expr(patch_def.base.as_ref());
for def in patch_def.methods.iter() {
self.check_expr(def);
}
}
Expr::Array(array) => match array {
Array::Normal(arr) => {
for elem in arr.elems.pos_args.iter() {
self.check_expr(&elem.expr);
}
}
Array::WithLength(arr) => {
self.check_expr(&arr.elem);
self.check_expr(&arr.len);
}
Array::Comprehension(arr) => {
self.check_expr(&arr.elem);
self.check_expr(&arr.guard);
}
},
Expr::Tuple(tuple) => match tuple {
Tuple::Normal(tup) => {
for arg in tup.elems.pos_args.iter() {
self.check_expr(&arg.expr);
}
}
},
Expr::Record(record) => {
self.path_stack
.push(Visibility::private(Str::ever("<record>")));
self.block_stack.push(Instant);
for attr in record.attrs.iter() {
self.check_def(attr);
}
self.path_stack.pop();
self.block_stack.pop();
}
Expr::Set(set) => match set {
Set::Normal(set) => {
for elem in set.elems.pos_args.iter() {
self.check_expr(&elem.expr);
}
}
Set::WithLength(set) => {
self.check_expr(&set.elem);
self.check_expr(&set.len);
}
},
Expr::Dict(dict) => match dict {
Dict::Normal(dict) => {
for kv in dict.kvs.iter() {
self.check_expr(&kv.key);
self.check_expr(&kv.value);
}
}
other => todo!("{other}"),
},
Expr::Call(call) => {
if (call.obj.t().is_procedure()
|| call
.attr_name
.as_ref()
.map(|name| name.is_procedural())
.unwrap_or(false))
&& !self.in_context_effects_allowed()
{
self.errs.push(EffectError::has_effect(
self.cfg.input.clone(),
line!() as usize,
expr,
self.full_path(),
));
}
call.args
.pos_args
.iter()
.for_each(|parg| self.check_expr(&parg.expr));
call.args
.kw_args
.iter()
.for_each(|kwarg| self.check_expr(&kwarg.expr));
}
Expr::UnaryOp(unary) => {
self.check_expr(&unary.expr);
}
Expr::BinOp(bin) => {
self.check_expr(&bin.lhs);
self.check_expr(&bin.rhs);
if (bin.op.kind == TokenKind::IsOp || bin.op.kind == TokenKind::IsNotOp)
&& !self.in_context_effects_allowed()
{
self.errs.push(EffectError::has_effect(
self.cfg.input.clone(),
line!() as usize,
expr,
self.full_path(),
));
}
}
Expr::Lambda(lambda) => {
let is_proc = lambda.is_procedural();
if is_proc {
self.path_stack
.push(Visibility::private(Str::ever("<lambda!>")));
self.block_stack.push(Proc);
} else {
self.path_stack
.push(Visibility::private(Str::ever("<lambda>")));
self.block_stack.push(Func);
}
self.check_params(&lambda.params);
lambda.body.iter().for_each(|chunk| self.check_expr(chunk));
self.path_stack.pop();
self.block_stack.pop();
}
Expr::TypeAsc(type_asc) => {
self.check_expr(&type_asc.expr);
}
Expr::Accessor(acc) => {
if !self.in_context_effects_allowed() && acc.ref_t().is_mut_type() {
self.errs.push(EffectError::touch_mut_error(
self.cfg.input.clone(),
line!() as usize,
expr,
self.full_path(),
));
}
}
Expr::ReDef(_)
| Expr::Code(_)
| Expr::Compound(_)
| Expr::Import(_)
| Expr::Dummy(_) => {}
}
}
pub(crate) fn is_impure(expr: &Expr) -> bool {
match expr {
Expr::Call(call) => {
call.ref_t().is_procedure()
|| call
.args
.pos_args
.iter()
.any(|parg| Self::is_impure(&parg.expr))
|| call
.args
.var_args
.iter()
.any(|varg| Self::is_impure(&varg.expr))
|| call
.args
.kw_args
.iter()
.any(|kwarg| Self::is_impure(&kwarg.expr))
}
Expr::BinOp(bin) => Self::is_impure(&bin.lhs) || Self::is_impure(&bin.rhs),
Expr::UnaryOp(unary) => Self::is_impure(&unary.expr),
Expr::Array(arr) => match arr {
Array::Normal(arr) => arr
.elems
.pos_args
.iter()
.any(|elem| Self::is_impure(&elem.expr)),
Array::WithLength(arr) => Self::is_impure(&arr.elem) || Self::is_impure(&arr.len),
_ => todo!(),
},
Expr::Tuple(tup) => match tup {
Tuple::Normal(tup) => tup
.elems
.pos_args
.iter()
.any(|elem| Self::is_impure(&elem.expr)),
},
Expr::Set(set) => match set {
Set::Normal(set) => set
.elems
.pos_args
.iter()
.any(|elem| Self::is_impure(&elem.expr)),
Set::WithLength(set) => Self::is_impure(&set.elem) || Self::is_impure(&set.len),
},
Expr::Dict(dict) => match dict {
Dict::Normal(dict) => dict
.kvs
.iter()
.any(|kv| Self::is_impure(&kv.key) || Self::is_impure(&kv.value)),
_ => todo!(),
},
Expr::Lambda(lambda) => {
lambda.op.is_procedural() || lambda.body.iter().any(Self::is_impure)
}
Expr::Def(def) => def.sig.is_procedural() || def.body.block.iter().any(Self::is_impure),
Expr::Code(block) | Expr::Compound(block) => block.iter().any(Self::is_impure),
_ => false,
}
}
pub(crate) fn is_pure(expr: &Expr) -> bool {
!Self::is_impure(expr)
}
}