use std::collections::HashMap;
use stellar_xdr::curr::{ScSpecEntry, ScSpecUdtUnionCaseV0};
use walrus::ir::Value;
use crate::ir::{Expr, MethodCall, Statement};
use crate::wasm_analysis::{AnalyzedModule, StackValue, TrackedHostCall};
use super::super::val_decoding::{
strip_val_boilerplate, extract_u32_val, try_decode_symbol_small,
decode_keys_from_linear_memory, resolve_arg,
};
use super::types::{find_struct_by_fields, to_snake_case};
pub(super) fn recognize_vec_new_from_linear_memory(
call: &TrackedHostCall,
param_names: &[String],
crn: &HashMap<usize, String>,
vec_contents: &HashMap<usize, Vec<StackValue>>,
all_entries: &[ScSpecEntry],
memory_strings: &HashMap<usize, String>,
) -> Option<Statement> {
if let Some(elements) = vec_contents.get(&call.call_site_id) {
if let Some(first) = elements.first() {
let first_stripped = strip_val_boilerplate(first);
let variant_name = match &first_stripped {
StackValue::Const(Value::I64(v)) => try_decode_symbol_small(*v),
StackValue::CallResult(call_id) => memory_strings.get(call_id).cloned(),
_ => None,
};
if let Some(vname) = variant_name {
if let Some(stmt) = try_match_enum_variant(
&vname, elements, param_names, crn, all_entries,
) {
return Some(stmt);
}
}
}
let mut elem_exprs: Vec<Expr> = Vec::new();
let mut i = 0;
while i < elements.len() {
if i + 1 < elements.len() {
if let Some(merged) = try_merge_i128_pair(&elements[i], &elements[i + 1]) {
let stripped = strip_val_boilerplate(&merged);
elem_exprs.push(resolve_arg(&stripped, param_names, crn));
i += 2;
continue;
}
}
let stripped = strip_val_boilerplate(&elements[i]);
elem_exprs.push(resolve_arg(&stripped, param_names, crn));
i += 1;
}
let value = Expr::MacroCall {
name: "vec".into(),
args: std::iter::once(Expr::Var("&env".into()))
.chain(elem_exprs)
.collect(),
};
Some(Statement::Let {
name: "args".into(),
mutable: false,
value,
})
} else {
let len = extract_u32_val(call.args.get(1)?);
if len == Some(1) {
let our_id = call.call_site_id;
for check_id in (our_id.saturating_sub(5)..our_id).rev() {
if let Some(sym_name) = memory_strings.get(&check_id) {
if let Some(stmt) = try_match_enum_variant(
sym_name, &[StackValue::Unknown], param_names, crn, all_entries,
) {
return Some(stmt);
}
}
}
}
Some(Statement::Let {
name: "args".into(),
mutable: false,
value: Expr::Raw("/* vec from memory */".into()),
})
}
}
pub(super) fn try_match_enum_variant(
variant_name: &str,
elements: &[StackValue],
param_names: &[String],
crn: &HashMap<usize, String>,
all_entries: &[ScSpecEntry],
) -> Option<Statement> {
for entry in all_entries {
if let ScSpecEntry::UdtUnionV0(union_spec) = entry {
for case in union_spec.cases.iter() {
let case_name = match case {
ScSpecUdtUnionCaseV0::VoidV0(v) => v.name.to_utf8_string_lossy(),
ScSpecUdtUnionCaseV0::TupleV0(t) => t.name.to_utf8_string_lossy(),
};
if case_name == variant_name {
let enum_name = union_spec.name.to_utf8_string_lossy();
let is_void = matches!(case, ScSpecUdtUnionCaseV0::VoidV0(_));
let field_exprs: Vec<Expr> = if is_void {
vec![]
} else {
elements.iter().skip(1)
.map(|el| {
let stripped = strip_val_boilerplate(el);
resolve_arg(&stripped, param_names, crn)
})
.collect()
};
return Some(Statement::Let {
name: to_snake_case(variant_name),
mutable: false,
value: Expr::EnumVariant {
enum_name,
variant_name: variant_name.to_string(),
fields: field_exprs,
},
});
}
}
}
}
None
}
pub(super) fn recognize_map_new_from_linear_memory(
call: &TrackedHostCall,
param_names: &[String],
crn: &HashMap<usize, String>,
map_contents: &HashMap<usize, (Vec<String>, Vec<StackValue>)>,
all_entries: &[ScSpecEntry],
analyzed: &AnalyzedModule,
) -> Option<Statement> {
if let Some((keys, values)) = map_contents.get(&call.call_site_id) {
if let Some(struct_spec) = find_struct_by_fields(keys, all_entries) {
let struct_name = struct_spec.name.to_utf8_string_lossy();
let fields: Vec<(String, Expr)> = keys.iter().zip(values.iter())
.map(|(k, v)| {
let stripped = strip_val_boilerplate(v);
let expr = resolve_arg(&stripped, param_names, crn);
(k.clone(), expr)
})
.collect();
return Some(Statement::Let {
name: to_snake_case(&struct_name),
mutable: false,
value: Expr::StructLiteral {
name: struct_name,
fields,
},
});
}
let fields: Vec<(String, Expr)> = keys.iter().zip(values.iter())
.map(|(k, v)| {
let stripped = strip_val_boilerplate(v);
let expr = resolve_arg(&stripped, param_names, crn);
(k.clone(), expr)
})
.collect();
return Some(Statement::Let {
name: "map_val".into(),
mutable: false,
value: Expr::StructLiteral {
name: "map".into(),
fields,
},
});
}
let keys_ptr = extract_u32_val(call.args.first()?)?;
let len = extract_u32_val(call.args.get(2)?)?;
if let Some(keys) = decode_keys_from_linear_memory(keys_ptr, len, analyzed) {
if let Some(struct_spec) = find_struct_by_fields(&keys, all_entries) {
let struct_name = struct_spec.name.to_utf8_string_lossy();
let fields: Vec<(String, Expr)> = keys.iter()
.map(|k| (k.clone(), Expr::Raw("/* value */".into())))
.collect();
return Some(Statement::Let {
name: to_snake_case(&struct_name),
mutable: false,
value: Expr::StructLiteral {
name: struct_name,
fields,
},
});
}
}
Some(Statement::Let {
name: "map_val".into(),
mutable: false,
value: Expr::Raw("/* map from memory */".into()),
})
}
pub(super) fn recognize_map_unpack(
call: &TrackedHostCall,
param_names: &[String],
crn: &HashMap<usize, String>,
all_entries: &[ScSpecEntry],
analyzed: &AnalyzedModule,
unpack_field_ids: &HashMap<usize, Vec<usize>>,
) -> Option<Statement> {
if unpack_field_ids.contains_key(&call.call_site_id) {
let map_arg = resolve_arg(call.args.first()?, param_names, crn);
let keys_ptr = extract_u32_val(call.args.get(1)?)?;
let len = extract_u32_val(call.args.get(3)?)?;
if let Some(keys) = decode_keys_from_linear_memory(keys_ptr, len, analyzed) {
if let Some(struct_spec) = find_struct_by_fields(&keys, all_entries) {
let struct_name = struct_spec.name.to_utf8_string_lossy();
return Some(Statement::Let {
name: to_snake_case(&struct_name),
mutable: false,
value: map_arg,
});
}
}
return None;
}
let map_arg = resolve_arg(call.args.first()?, param_names, crn);
let keys_ptr = extract_u32_val(call.args.get(1)?)?;
let len = extract_u32_val(call.args.get(3)?)?;
if let Some(keys) = decode_keys_from_linear_memory(keys_ptr, len, analyzed) {
if let Some(struct_spec) = find_struct_by_fields(&keys, all_entries) {
let struct_name = struct_spec.name.to_utf8_string_lossy();
return Some(Statement::Let {
name: to_snake_case(&struct_name),
mutable: false,
value: Expr::Raw(format!("/* unpack {} from {:?} */", struct_name, map_arg)),
});
}
return Some(Statement::Let {
name: "unpacked".into(),
mutable: false,
value: Expr::Raw(format!(
"/* map_unpack keys=[{}] from {:?} */",
keys.join(", "),
map_arg,
)),
});
}
None
}
pub(super) fn recognize_symbol_index(
call: &TrackedHostCall,
param_names: &[String],
crn: &HashMap<usize, String>,
analyzed: &AnalyzedModule,
) -> Option<Statement> {
let sym_expr = resolve_arg(call.args.first()?, param_names, crn);
let strs_ptr = extract_u32_val(call.args.get(1)?)?;
let len = extract_u32_val(call.args.get(2)?)?;
if let Some(variants) = decode_keys_from_linear_memory(strs_ptr, len, analyzed) {
return Some(Statement::Let {
name: "variant_idx".into(),
mutable: false,
value: Expr::Raw(format!(
"/* match {:?} against [{}] */",
sym_expr,
variants.join(", "),
)),
});
}
None
}
pub(super) fn try_merge_i128_pair(lo_raw: &StackValue, hi_raw: &StackValue) -> Option<StackValue> {
use crate::ir::BinOp as B;
let lo_stripped = strip_val_boilerplate(lo_raw);
if let StackValue::BinOp { op: B::Shr, left: hi_base, right: hi_shift } = hi_raw {
if matches!(hi_shift.as_ref(), StackValue::Const(Value::I64(63))) {
let lo_base = match lo_raw {
StackValue::BinOp { op: B::Shr, left, right }
if matches!(right.as_ref(), StackValue::Const(Value::I64(8)))
=> left.as_ref(),
_ => lo_raw,
};
if format!("{:?}", hi_base.as_ref()) == format!("{:?}", lo_base) {
return Some(lo_stripped);
}
}
}
if matches!(lo_raw, StackValue::Const(Value::I64(2))) {
if let StackValue::BinOp { op: B::Shr, right: hi_shift, .. } = hi_raw {
if matches!(hi_shift.as_ref(), StackValue::Const(Value::I64(63))) {
return Some(StackValue::Const(Value::I64(0)));
}
}
}
None
}
pub(super) fn detect_128_roundtrip(hi: &Expr, lo: &Expr) -> Option<Expr> {
if let (Expr::Var(hi_name), Expr::Var(lo_name)) = (hi, lo) {
let base_h = hi_name.strip_suffix("_hi");
let base_l = lo_name.strip_suffix("_lo");
if let (Some(bh), Some(bl)) = (base_h, base_l) {
if bh == bl {
return Some(Expr::Var(bh.to_string()));
}
}
}
None
}
pub(super) fn recognize_u256_binop(
call: &TrackedHostCall,
param_names: &[String],
crn: &HashMap<usize, String>,
op: crate::ir::BinOp,
) -> Option<Statement> {
let lhs = resolve_arg(call.args.first()?, param_names, crn);
let rhs = resolve_arg(call.args.get(1)?, param_names, crn);
Some(Statement::Let {
name: "result".into(),
mutable: false,
value: Expr::BinOp {
left: Box::new(lhs),
op,
right: Box::new(rhs),
},
})
}
pub(super) fn recognize_u256_pow(
call: &TrackedHostCall,
param_names: &[String],
crn: &HashMap<usize, String>,
) -> Option<Statement> {
let base = resolve_arg(call.args.first()?, param_names, crn);
let exp = resolve_arg(call.args.get(1)?, param_names, crn);
Some(Statement::Let {
name: "result".into(),
mutable: false,
value: Expr::MethodChain {
receiver: Box::new(base),
calls: vec![MethodCall { name: "pow".into(), args: vec![exp] }],
},
})
}