use crate::ast::*;
use crate::types::Type;
use std::collections::{HashMap, HashSet};
#[derive(Clone)]
pub struct EmitCtx {
pub used_after: HashSet<String>,
pub local_types: HashMap<String, Type>,
pub rc_wrapped: HashSet<String>,
pub borrowed_params: HashSet<String>,
}
impl EmitCtx {
pub fn empty() -> Self {
EmitCtx {
used_after: HashSet::new(),
local_types: HashMap::new(),
rc_wrapped: HashSet::new(),
borrowed_params: HashSet::new(),
}
}
pub fn for_fn(param_types: HashMap<String, Type>) -> Self {
let borrowed_params = param_types
.iter()
.filter(|(_, ty)| should_borrow_param(ty))
.map(|(name, _)| name.clone())
.collect();
EmitCtx {
used_after: HashSet::new(),
local_types: param_types,
rc_wrapped: HashSet::new(),
borrowed_params,
}
}
pub fn for_fn_no_borrow(param_types: HashMap<String, Type>) -> Self {
EmitCtx {
used_after: HashSet::new(),
local_types: param_types,
rc_wrapped: HashSet::new(),
borrowed_params: HashSet::new(),
}
}
pub fn can_move(&self, name: &str) -> bool {
!self.used_after.contains(name)
}
pub fn is_copy(&self, name: &str) -> bool {
self.local_types.get(name).is_some_and(is_copy_type)
}
pub fn skip_clone(&self, name: &str) -> bool {
if self.rc_wrapped.contains(name) {
return false;
}
if self.borrowed_params.contains(name) {
return false;
}
self.is_copy(name) || self.can_move(name)
}
pub fn is_rc_wrapped(&self, name: &str) -> bool {
self.rc_wrapped.contains(name)
}
pub fn is_borrowed_param(&self, name: &str) -> bool {
self.borrowed_params.contains(name)
}
pub fn with_used_after(&self, extra: &HashSet<String>) -> Self {
let mut ua = self.used_after.clone();
ua.extend(extra.iter().cloned());
EmitCtx {
used_after: ua,
local_types: self.local_types.clone(),
rc_wrapped: self.rc_wrapped.clone(),
borrowed_params: self.borrowed_params.clone(),
}
}
pub fn with_rc_wrapped(&self, rc: HashSet<String>) -> Self {
EmitCtx {
used_after: self.used_after.clone(),
local_types: self.local_types.clone(),
rc_wrapped: rc,
borrowed_params: self.borrowed_params.clone(),
}
}
}
pub fn is_copy_type(ty: &Type) -> bool {
matches!(ty, Type::Int | Type::Float | Type::Bool | Type::Unit)
}
pub fn should_borrow_param(ty: &Type) -> bool {
matches!(
ty,
Type::Map(_, _)
| Type::List(_)
| Type::Vector(_)
| Type::Result(_, _)
| Type::Option(_)
| Type::Tuple(_)
| Type::Named(_)
)
}
pub fn collect_vars(expr: &Expr) -> HashSet<String> {
let mut vars = HashSet::new();
collect_vars_inner(expr, &mut vars);
vars
}
fn collect_vars_inner(expr: &Expr, vars: &mut HashSet<String>) {
match expr {
Expr::Ident(name) => {
vars.insert(name.clone());
}
Expr::Resolved(_) => {}
Expr::Literal(_) => {}
Expr::Attr(obj, _) => collect_vars_inner(&obj.node, vars),
Expr::FnCall(fn_expr, args) => {
collect_vars_inner(&fn_expr.node, vars);
for a in args {
collect_vars_inner(&a.node, vars);
}
}
Expr::BinOp(_, left, right) => {
collect_vars_inner(&left.node, vars);
collect_vars_inner(&right.node, vars);
}
Expr::Match { subject, arms, .. } => {
collect_vars_inner(&subject.node, vars);
for arm in arms {
let mut arm_vars = HashSet::new();
collect_vars_inner(&arm.body.node, &mut arm_vars);
let bindings = pattern_bindings(&arm.pattern);
for v in arm_vars {
if !bindings.contains(&v) {
vars.insert(v);
}
}
}
}
Expr::Constructor(_, Some(inner)) => collect_vars_inner(&inner.node, vars),
Expr::Constructor(_, None) => {}
Expr::ErrorProp(inner) => collect_vars_inner(&inner.node, vars),
Expr::InterpolatedStr(parts) => {
for part in parts {
if let StrPart::Parsed(expr) = part {
collect_vars_inner(&expr.node, vars);
}
}
}
Expr::List(elements) => {
for e in elements {
collect_vars_inner(&e.node, vars);
}
}
Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
for e in items {
collect_vars_inner(&e.node, vars);
}
}
Expr::MapLiteral(entries) => {
for (k, v) in entries {
collect_vars_inner(&k.node, vars);
collect_vars_inner(&v.node, vars);
}
}
Expr::RecordCreate { fields, .. } => {
for (_, expr) in fields {
collect_vars_inner(&expr.node, vars);
}
}
Expr::RecordUpdate { base, updates, .. } => {
collect_vars_inner(&base.node, vars);
for (_, expr) in updates {
collect_vars_inner(&expr.node, vars);
}
}
Expr::TailCall(boxed) => {
let (_, args) = boxed.as_ref();
for a in args {
collect_vars_inner(&a.node, vars);
}
}
}
}
pub fn collect_vars_stmt(stmt: &Stmt) -> HashSet<String> {
match stmt {
Stmt::Binding(_, _, expr) => collect_vars(&expr.node),
Stmt::Expr(expr) => collect_vars(&expr.node),
}
}
pub fn pattern_bindings(pat: &Pattern) -> HashSet<String> {
let mut bindings = HashSet::new();
match pat {
Pattern::Ident(name) => {
if name != "_" {
bindings.insert(name.clone());
}
}
Pattern::Cons(head, tail) => {
if head != "_" {
bindings.insert(head.clone());
}
if tail != "_" {
bindings.insert(tail.clone());
}
}
Pattern::Constructor(_, fields) => {
for f in fields {
if f != "_" {
bindings.insert(f.clone());
}
}
}
Pattern::Tuple(pats) => {
for p in pats {
bindings.extend(pattern_bindings(p));
}
}
Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
}
bindings
}
pub fn compute_block_used_after(
stmts: &[Stmt],
parent_used_after: &HashSet<String>,
local_types: &HashMap<String, Type>,
) -> Vec<EmitCtx> {
compute_block_used_after_full(
stmts,
parent_used_after,
local_types,
&HashSet::new(),
&HashSet::new(),
)
}
pub fn compute_block_used_after_with_rc(
stmts: &[Stmt],
parent_used_after: &HashSet<String>,
local_types: &HashMap<String, Type>,
rc_wrapped: &HashSet<String>,
) -> Vec<EmitCtx> {
compute_block_used_after_full(
stmts,
parent_used_after,
local_types,
rc_wrapped,
&HashSet::new(),
)
}
pub fn compute_block_used_after_full(
stmts: &[Stmt],
parent_used_after: &HashSet<String>,
local_types: &HashMap<String, Type>,
rc_wrapped: &HashSet<String>,
borrowed_params: &HashSet<String>,
) -> Vec<EmitCtx> {
let n = stmts.len();
let mut result = vec![EmitCtx::empty(); n];
let mut suffix_vars = parent_used_after.clone();
for i in (0..n).rev() {
result[i] = EmitCtx {
used_after: suffix_vars.clone(),
local_types: local_types.clone(),
rc_wrapped: rc_wrapped.clone(),
borrowed_params: borrowed_params.clone(),
};
let stmt_vars = collect_vars_stmt(&stmts[i]);
suffix_vars.extend(stmt_vars);
if let Stmt::Binding(name, _, _) = &stmts[i] {
suffix_vars.remove(name);
}
}
result
}
#[cfg(test)]
pub fn compute_args_used_after_with_rc(
args: &[Spanned<Expr>],
parent_used_after: &HashSet<String>,
local_types: &HashMap<String, Type>,
rc_wrapped: &HashSet<String>,
) -> Vec<EmitCtx> {
let bare_args: Vec<&Expr> = args.iter().map(|a| &a.node).collect();
compute_args_used_after_full_refs(
&bare_args,
parent_used_after,
local_types,
rc_wrapped,
&HashSet::new(),
)
}
pub fn compute_args_used_after_full(
args: &[Expr],
parent_used_after: &HashSet<String>,
local_types: &HashMap<String, Type>,
rc_wrapped: &HashSet<String>,
borrowed_params: &HashSet<String>,
) -> Vec<EmitCtx> {
let n = args.len();
let mut result = vec![EmitCtx::empty(); n];
let mut suffix_vars = parent_used_after.clone();
for i in (0..n).rev() {
result[i] = EmitCtx {
used_after: suffix_vars.clone(),
local_types: local_types.clone(),
rc_wrapped: rc_wrapped.clone(),
borrowed_params: borrowed_params.clone(),
};
let arg_vars = collect_vars(&args[i]);
suffix_vars.extend(arg_vars);
}
result
}
#[allow(dead_code)]
pub fn compute_args_used_after_full_refs(
args: &[&Expr],
parent_used_after: &HashSet<String>,
local_types: &HashMap<String, Type>,
rc_wrapped: &HashSet<String>,
borrowed_params: &HashSet<String>,
) -> Vec<EmitCtx> {
let n = args.len();
let mut result = vec![EmitCtx::empty(); n];
let mut suffix_vars = parent_used_after.clone();
for i in (0..n).rev() {
result[i] = EmitCtx {
used_after: suffix_vars.clone(),
local_types: local_types.clone(),
rc_wrapped: rc_wrapped.clone(),
borrowed_params: borrowed_params.clone(),
};
let arg_vars = collect_vars(args[i]);
suffix_vars.extend(arg_vars);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_copy_type() {
assert!(is_copy_type(&Type::Int));
assert!(is_copy_type(&Type::Float));
assert!(is_copy_type(&Type::Bool));
assert!(is_copy_type(&Type::Unit));
assert!(!is_copy_type(&Type::Str));
assert!(!is_copy_type(&Type::List(Box::new(Type::Int))));
assert!(!is_copy_type(&Type::Named("Foo".to_string())));
}
#[test]
fn test_collect_vars_simple() {
let expr = Expr::Ident("x".to_string());
let vars = collect_vars(&expr);
assert_eq!(vars, HashSet::from(["x".to_string()]));
}
#[test]
fn test_collect_vars_binop() {
let expr = Expr::BinOp(
BinOp::Add,
Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
Box::new(Spanned::bare(Expr::Ident("y".to_string()))),
);
let vars = collect_vars(&expr);
assert_eq!(vars, HashSet::from(["x".to_string(), "y".to_string()]));
}
#[test]
fn test_collect_vars_fn_call() {
let expr = Expr::FnCall(
Box::new(Spanned::bare(Expr::Ident("f".to_string()))),
vec![
Spanned::bare(Expr::Ident("a".to_string())),
Spanned::bare(Expr::Ident("b".to_string())),
],
);
let vars = collect_vars(&expr);
assert_eq!(
vars,
HashSet::from(["f".to_string(), "a".to_string(), "b".to_string()])
);
}
#[test]
fn test_collect_vars_match_excludes_pattern_bindings() {
let expr = Expr::Match {
subject: Box::new(Spanned::bare(Expr::Ident("val".to_string()))),
arms: vec![MatchArm {
pattern: Pattern::Ident("x".to_string()),
body: Box::new(Spanned::bare(Expr::BinOp(
BinOp::Add,
Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
Box::new(Spanned::bare(Expr::Ident("y".to_string()))),
))),
}],
};
let vars = collect_vars(&expr);
assert!(vars.contains("val"));
assert!(vars.contains("y"));
assert!(!vars.contains("x"));
}
#[test]
fn test_skip_clone_copy_type() {
let mut lt = HashMap::new();
lt.insert("n".to_string(), Type::Int);
lt.insert("s".to_string(), Type::Str);
let ectx = EmitCtx {
used_after: HashSet::from(["n".to_string(), "s".to_string()]),
local_types: lt,
rc_wrapped: HashSet::new(),
borrowed_params: HashSet::new(),
};
assert!(ectx.skip_clone("n"));
assert!(!ectx.skip_clone("s"));
}
#[test]
fn test_skip_clone_last_use() {
let mut lt = HashMap::new();
lt.insert("s".to_string(), Type::Str);
let ectx = EmitCtx {
used_after: HashSet::new(), local_types: lt,
rc_wrapped: HashSet::new(),
borrowed_params: HashSet::new(),
};
assert!(ectx.skip_clone("s"));
}
#[test]
fn test_compute_block_used_after() {
let stmts = vec![
Stmt::Binding(
"a".to_string(),
None,
Spanned::bare(Expr::Literal(Literal::Int(1))),
),
Stmt::Expr(Spanned::bare(Expr::BinOp(
BinOp::Add,
Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
))),
];
let parent = HashSet::new();
let lt = HashMap::new();
let ctxs = compute_block_used_after(&stmts, &parent, <);
assert!(ctxs[0].used_after.contains("a"));
assert!(ctxs[0].used_after.contains("b"));
assert!(ctxs[1].used_after.is_empty());
}
#[test]
fn test_compute_args_used_after() {
let args = vec![
Spanned::bare(Expr::Ident("x".to_string())),
Spanned::bare(Expr::Ident("y".to_string())),
Spanned::bare(Expr::Ident("x".to_string())),
];
let parent = HashSet::new();
let lt = HashMap::new();
let ctxs = compute_args_used_after_with_rc(&args, &parent, <, &HashSet::new());
assert!(ctxs[0].used_after.contains("x"));
assert!(ctxs[0].used_after.contains("y"));
assert!(ctxs[1].used_after.contains("x"));
assert!(!ctxs[1].used_after.contains("y"));
assert!(ctxs[2].used_after.is_empty());
}
}