use stellar_xdr::curr::{ScSpecEntry, ScSpecFunctionV0, ScSpecTypeDef};
use crate::ir::{Expr, Statement};
pub(super) fn substitute_param_pass_through(
stmts: Vec<Statement>,
spec: &ScSpecFunctionV0,
all_entries: &[ScSpecEntry],
) -> Vec<Statement> {
let param_types: Vec<(String, String)> = spec.inputs.iter().filter_map(|input| {
if let ScSpecTypeDef::Udt(udt) = &input.type_ {
Some((input.name.to_utf8_string_lossy(), udt.name.to_utf8_string_lossy()))
} else {
None
}
}).collect();
if param_types.is_empty() {
return stmts;
}
let struct_names: std::collections::HashSet<String> = all_entries.iter().filter_map(|e| {
if let ScSpecEntry::UdtStructV0(s) = e {
Some(s.name.to_utf8_string_lossy())
} else {
None
}
}).collect();
stmts.into_iter().map(|stmt| {
substitute_pass_through_stmt(stmt, ¶m_types, &struct_names)
}).collect()
}
fn substitute_pass_through_stmt(
stmt: Statement,
param_types: &[(String, String)],
struct_names: &std::collections::HashSet<String>,
) -> Statement {
match stmt {
Statement::Let { name, mutable, value } => Statement::Let {
name,
mutable,
value: substitute_pass_through_expr(value, param_types, struct_names),
},
Statement::Expr(e) => Statement::Expr(
substitute_pass_through_expr(e, param_types, struct_names),
),
Statement::If { condition, then_body, else_body } => Statement::If {
condition,
then_body: then_body.into_iter()
.map(|s| substitute_pass_through_stmt(s, param_types, struct_names))
.collect(),
else_body: else_body.into_iter()
.map(|s| substitute_pass_through_stmt(s, param_types, struct_names))
.collect(),
},
other => other,
}
}
fn substitute_pass_through_expr(
expr: Expr,
param_types: &[(String, String)],
struct_names: &std::collections::HashSet<String>,
) -> Expr {
match expr {
Expr::StructLiteral { name, fields } => {
let new_fields: Vec<(String, Expr)> = fields.into_iter().map(|(field_name, field_val)| {
if let Expr::Var(ref var_name) = field_val {
let base = var_name.trim_end_matches(|c: char| c == '_' || c.is_ascii_digit());
for (param_name, param_type) in param_types {
if (base == param_name || &field_name == param_name)
&& struct_names.contains(param_type)
{
return (field_name, Expr::Var(param_name.clone()));
}
}
}
for (param_name, param_type) in param_types {
if &field_name == param_name && struct_names.contains(param_type) {
return (field_name, Expr::Var(param_name.clone()));
}
}
(field_name, substitute_pass_through_expr(field_val, param_types, struct_names))
}).collect();
Expr::StructLiteral { name, fields: new_fields }
}
other => other,
}
}