use crate::ir::{BinOp, Expr, Statement};
pub fn reconstruct_struct_mutation(stmts: Vec<Statement>) -> Vec<Statement> {
let mut result = Vec::with_capacity(stmts.len());
let mut i = 0;
while i < stmts.len() {
if let Some((src_var, struct_idx, struct_var, fields)) =
find_struct_mutation_pair(&stmts, i)
{
let has_mutation = fields.iter().any(|(_, kind)| !matches!(kind, FieldKind::Identity));
if !has_mutation {
result.push(stmts[i].clone());
i += 1;
continue;
}
if let Statement::Let { value, .. } = &stmts[i] {
result.push(Statement::Let {
name: "state".into(),
mutable: true,
value: value.clone(),
});
}
for j in (i + 1)..struct_idx {
result.push(rename_in_stmt(stmts[j].clone(), &src_var, "state"));
}
for (field_name, kind) in &fields {
match kind {
FieldKind::Identity => {
}
FieldKind::AddAssign(addend) => {
result.push(Statement::Expr(Expr::BinOp {
op: BinOp::AddAssign,
left: Box::new(Expr::Var(format!("state.{}", field_name))),
right: Box::new(addend.clone()),
}));
}
FieldKind::Assign(value) => {
result.push(Statement::Assign {
target: Expr::Var(format!("state.{}", field_name)),
value: value.clone(),
});
}
}
}
for j in (struct_idx + 1)..stmts.len() {
let mut s = rename_in_stmt(stmts[j].clone(), &struct_var, "state");
s = rewrite_field_exprs(s, &src_var, &fields);
result.push(s);
}
return result;
}
result.push(stmts[i].clone());
i += 1;
}
result
}
#[derive(Debug)]
enum FieldKind {
Identity,
AddAssign(Expr),
Assign(Expr),
}
fn find_struct_mutation_pair(
stmts: &[Statement],
get_idx: usize,
) -> Option<(String, usize, String, Vec<(String, FieldKind)>)> {
let src_var = match &stmts[get_idx] {
Statement::Let { name, value: Expr::MethodChain { calls, .. }, .. } => {
let has_get = calls.iter().any(|c| c.name == "get");
let has_unwrap = calls.iter().any(|c| c.name == "unwrap_or");
if has_get && has_unwrap { name.clone() } else { return None; }
}
_ => return None,
};
for j in (get_idx + 1)..stmts.len().min(get_idx + 4) {
if let Statement::Let { name: struct_var, value: Expr::StructLiteral { fields, .. }, .. } = &stmts[j] {
let mut field_analysis = Vec::new();
let mut uses_src = false;
for (field_name, field_value) in fields {
let kind = classify_field(field_value, &src_var, field_name);
if matches!(kind, FieldKind::Identity | FieldKind::AddAssign(_)) {
uses_src = true;
}
field_analysis.push((field_name.clone(), kind));
}
if uses_src && !field_analysis.is_empty() {
let struct_is_stored = stmts[j + 1..].iter().any(|s| {
refs_var(s, struct_var)
});
if struct_is_stored {
return Some((src_var, j, struct_var.clone(), field_analysis));
}
}
}
}
None
}
fn classify_field(value: &Expr, src_var: &str, field_name: &str) -> FieldKind {
let expected_ref = format!("{}.{}", src_var, field_name);
match value {
Expr::Var(name) if *name == expected_ref => FieldKind::Identity,
Expr::BinOp { op: BinOp::Add, left, right } => {
if matches!(left.as_ref(), Expr::Var(n) if *n == expected_ref) {
FieldKind::AddAssign(*right.clone())
} else if matches!(right.as_ref(), Expr::Var(n) if *n == expected_ref) {
FieldKind::AddAssign(*left.clone())
} else {
FieldKind::Assign(value.clone())
}
}
_ => FieldKind::Assign(value.clone()),
}
}
fn refs_var(stmt: &Statement, var_name: &str) -> bool {
match stmt {
Statement::Let { value, .. } => expr_refs_var(value, var_name),
Statement::Expr(e) => expr_refs_var(e, var_name),
Statement::Return(Some(e)) => expr_refs_var(e, var_name),
Statement::Assign { target, value } => expr_refs_var(target, var_name) || expr_refs_var(value, var_name),
Statement::If { condition, then_body, else_body } => {
expr_refs_var(condition, var_name)
|| then_body.iter().any(|s| refs_var(s, var_name))
|| else_body.iter().any(|s| refs_var(s, var_name))
}
_ => false,
}
}
fn expr_refs_var(expr: &Expr, var_name: &str) -> bool {
match expr {
Expr::Var(n) => n == var_name || n.starts_with(&format!("{}.", var_name)),
Expr::BinOp { left, right, .. } => expr_refs_var(left, var_name) || expr_refs_var(right, var_name),
Expr::UnOp { operand, .. } => expr_refs_var(operand, var_name),
Expr::Ref(inner) => expr_refs_var(inner, var_name),
Expr::MethodChain { receiver, calls } => {
expr_refs_var(receiver, var_name) || calls.iter().any(|c| c.args.iter().any(|a| expr_refs_var(a, var_name)))
}
Expr::HostCall { args, .. } | Expr::MacroCall { args, .. } => args.iter().any(|a| expr_refs_var(a, var_name)),
Expr::StructLiteral { fields, .. } => fields.iter().any(|(_, v)| expr_refs_var(v, var_name)),
_ => false,
}
}
fn rename_in_stmt(stmt: Statement, old: &str, new: &str) -> Statement {
match stmt {
Statement::Let { name, mutable, value } => Statement::Let {
name: if name == old { new.into() } else { name },
mutable,
value: rename_in_expr(value, old, new),
},
Statement::Expr(e) => Statement::Expr(rename_in_expr(e, old, new)),
Statement::Return(Some(e)) => Statement::Return(Some(rename_in_expr(e, old, new))),
Statement::Assign { target, value } => Statement::Assign {
target: rename_in_expr(target, old, new),
value: rename_in_expr(value, old, new),
},
Statement::If { condition, then_body, else_body } => Statement::If {
condition: rename_in_expr(condition, old, new),
then_body: then_body.into_iter().map(|s| rename_in_stmt(s, old, new)).collect(),
else_body: else_body.into_iter().map(|s| rename_in_stmt(s, old, new)).collect(),
},
other => other,
}
}
fn rename_in_expr(expr: Expr, old: &str, new: &str) -> Expr {
match expr {
Expr::Var(ref name) => {
if name == old {
Expr::Var(new.into())
} else if name.starts_with(&format!("{}.", old)) {
Expr::Var(name.replacen(old, new, 1))
} else {
expr
}
}
Expr::BinOp { op, left, right } => Expr::BinOp {
op,
left: Box::new(rename_in_expr(*left, old, new)),
right: Box::new(rename_in_expr(*right, old, new)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op,
operand: Box::new(rename_in_expr(*operand, old, new)),
},
Expr::Ref(inner) => Expr::Ref(Box::new(rename_in_expr(*inner, old, new))),
Expr::MethodChain { receiver, calls } => Expr::MethodChain {
receiver: Box::new(rename_in_expr(*receiver, old, new)),
calls: calls.into_iter().map(|c| crate::ir::MethodCall {
name: c.name,
args: c.args.into_iter().map(|a| rename_in_expr(a, old, new)).collect(),
}).collect(),
},
Expr::HostCall { module, name, args } => Expr::HostCall {
module, name,
args: args.into_iter().map(|a| rename_in_expr(a, old, new)).collect(),
},
Expr::StructLiteral { name, fields } => Expr::StructLiteral {
name,
fields: fields.into_iter().map(|(f, v)| (f, rename_in_expr(v, old, new))).collect(),
},
other => other,
}
}
fn rewrite_field_exprs(
stmt: Statement,
src_var: &str,
fields: &[(String, FieldKind)],
) -> Statement {
match stmt {
Statement::Return(Some(e)) => Statement::Return(Some(rewrite_field_expr(e, src_var, fields))),
Statement::Expr(e) => Statement::Expr(rewrite_field_expr(e, src_var, fields)),
Statement::Let { name, mutable, value } => Statement::Let {
name, mutable,
value: rewrite_field_expr(value, src_var, fields),
},
other => other,
}
}
fn rewrite_field_expr(expr: Expr, src_var: &str, fields: &[(String, FieldKind)]) -> Expr {
match &expr {
Expr::BinOp { op: BinOp::Add, left, .. } => {
if let Expr::Var(name) = left.as_ref() {
if name.starts_with(&format!("{}.", src_var)) {
let field_name = &name[src_var.len() + 1..];
for (fname, kind) in fields {
if fname == field_name && matches!(kind, FieldKind::AddAssign(_)) {
return Expr::Var(format!("state.{}", field_name));
}
}
}
}
}
Expr::Var(name) if name.starts_with(&format!("{}.", src_var)) => {
return Expr::Var(name.replacen(src_var, "state", 1));
}
_ => {}
}
match expr {
Expr::BinOp { op, left, right } => Expr::BinOp {
op,
left: Box::new(rewrite_field_expr(*left, src_var, fields)),
right: Box::new(rewrite_field_expr(*right, src_var, fields)),
},
Expr::Ref(inner) => Expr::Ref(Box::new(rewrite_field_expr(*inner, src_var, fields))),
Expr::MethodChain { receiver, calls } => Expr::MethodChain {
receiver: Box::new(rewrite_field_expr(*receiver, src_var, fields)),
calls: calls.into_iter().map(|c| crate::ir::MethodCall {
name: c.name,
args: c.args.into_iter().map(|a| rewrite_field_expr(a, src_var, fields)).collect(),
}).collect(),
},
other => other,
}
}