use std::collections::HashMap;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use crate::ir::{Expr, FunctionIR, Literal, MethodCall, Statement};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
struct AccessorKey {
tier: String,
enum_name: String,
variant_name: String,
}
pub struct HelperFunction {
pub name: String,
pub body_stmt: Statement,
pub is_writer: bool,
}
pub fn extract_helpers(functions: &mut [FunctionIR]) -> Vec<HelperFunction> {
let mut reader_counts: HashMap<AccessorKey, usize> = HashMap::new();
let mut writer_counts: HashMap<AccessorKey, usize> = HashMap::new();
for func in functions.iter() {
count_in_stmts(&func.body, &mut reader_counts, &mut writer_counts);
}
let reader_helpers: HashMap<AccessorKey, String> = reader_counts.into_iter()
.filter(|(_, count)| *count >= 2)
.map(|(key, _)| {
let name = format!("get_{}", to_snake(&key.variant_name));
(key, name)
})
.collect();
let writer_helpers: HashMap<AccessorKey, String> = writer_counts.into_iter()
.filter(|(_, count)| *count >= 3)
.map(|(key, _)| {
let name = format!("put_{}", to_snake(&key.variant_name));
(key, name)
})
.collect();
if reader_helpers.is_empty() && writer_helpers.is_empty() {
return vec![];
}
let mut helpers = Vec::new();
for (key, name) in &reader_helpers {
helpers.push(build_reader_helper(key, name));
}
for (key, name) in &writer_helpers {
helpers.push(build_writer_helper(key, name));
}
for func in functions.iter_mut() {
func.body = rewrite_stmts(
std::mem::take(&mut func.body),
&reader_helpers,
&writer_helpers,
);
}
helpers
}
fn count_in_stmts(
stmts: &[Statement],
readers: &mut HashMap<AccessorKey, usize>,
writers: &mut HashMap<AccessorKey, usize>,
) {
for stmt in stmts {
if let Some(key) = try_match_reader(stmt) {
*readers.entry(key).or_insert(0) += 1;
}
if let Some(key) = try_match_writer(stmt) {
*writers.entry(key).or_insert(0) += 1;
}
match stmt {
Statement::If { then_body, else_body, .. } => {
count_in_stmts(then_body, readers, writers);
count_in_stmts(else_body, readers, writers);
}
Statement::While { body, .. }
| Statement::Loop { body }
| Statement::ForEach { body, .. }
| Statement::ForRange { body, .. } => {
count_in_stmts(body, readers, writers);
}
_ => {}
}
}
}
fn try_match_reader(stmt: &Statement) -> Option<AccessorKey> {
if let Statement::Let { value: Expr::MethodChain { receiver, calls }, .. } = stmt {
if !matches!(receiver.as_ref(), Expr::Var(n) if n == "env") {
return None;
}
if calls.len() < 4 { return None; }
if calls[0].name != "storage" { return None; }
let tier = &calls[1].name;
if calls[2].name != "get" { return None; }
if calls[3].name != "unwrap_or" { return None; }
let key_expr = calls[2].args.first()?;
if let Some((enum_name, variant_name)) = extract_enum_key(key_expr) {
return Some(AccessorKey {
tier: tier.clone(),
enum_name,
variant_name,
});
}
}
None
}
fn try_match_writer(stmt: &Statement) -> Option<AccessorKey> {
if let Statement::Expr(Expr::MethodChain { receiver, calls }) = stmt {
if !matches!(receiver.as_ref(), Expr::Var(n) if n == "env") {
return None;
}
if calls.len() < 3 { return None; }
if calls[0].name != "storage" { return None; }
let tier = &calls[1].name;
if calls[2].name != "set" { return None; }
if calls[2].args.len() < 2 { return None; }
let key_expr = &calls[2].args[0];
if let Some((enum_name, variant_name)) = extract_enum_key(key_expr) {
return Some(AccessorKey {
tier: tier.clone(),
enum_name,
variant_name,
});
}
}
None
}
fn extract_enum_key(expr: &Expr) -> Option<(String, String)> {
match expr {
Expr::Ref(inner) => extract_enum_key(inner),
Expr::EnumVariant { enum_name, variant_name, fields } if fields.is_empty() => {
Some((enum_name.clone(), variant_name.clone()))
}
_ => None,
}
}
fn build_reader_helper(key: &AccessorKey, name: &str) -> HelperFunction {
let body_stmt = Statement::Return(Some(Expr::MethodChain {
receiver: Box::new(Expr::Var("e".into())),
calls: vec![
MethodCall { name: "storage".into(), args: vec![] },
MethodCall { name: key.tier.clone(), args: vec![] },
MethodCall { name: "get".into(), args: vec![Expr::Ref(Box::new(
Expr::EnumVariant {
enum_name: key.enum_name.clone(),
variant_name: key.variant_name.clone(),
fields: vec![],
}
))] },
MethodCall { name: "unwrap_or".into(), args: vec![Expr::Literal(Literal::I64(0))] },
],
}));
HelperFunction { name: name.into(), body_stmt, is_writer: false }
}
fn build_writer_helper(key: &AccessorKey, name: &str) -> HelperFunction {
let body_stmt = Statement::Expr(Expr::MethodChain {
receiver: Box::new(Expr::Var("e".into())),
calls: vec![
MethodCall { name: "storage".into(), args: vec![] },
MethodCall { name: key.tier.clone(), args: vec![] },
MethodCall { name: "set".into(), args: vec![
Expr::Ref(Box::new(Expr::EnumVariant {
enum_name: key.enum_name.clone(),
variant_name: key.variant_name.clone(),
fields: vec![],
})),
Expr::Ref(Box::new(Expr::Var("amount".into()))),
] },
],
});
HelperFunction { name: name.into(), body_stmt, is_writer: true }
}
fn rewrite_stmts(
stmts: Vec<Statement>,
readers: &HashMap<AccessorKey, String>,
writers: &HashMap<AccessorKey, String>,
) -> Vec<Statement> {
stmts.into_iter().map(|stmt| {
if let Some(key) = try_match_reader(&stmt) {
if let Some(helper_name) = readers.get(&key) {
if let Statement::Let { name, mutable, .. } = stmt {
return Statement::Let {
name,
mutable,
value: Expr::HostCall {
module: String::new(),
name: helper_name.clone(),
args: vec![Expr::Var("&env".into())],
},
};
}
}
}
if let Some(key) = try_match_writer(&stmt) {
if let Some(helper_name) = writers.get(&key) {
if let Statement::Expr(Expr::MethodChain { calls, .. }) = &stmt {
let val_arg = calls.iter()
.find(|c| c.name == "set")
.and_then(|c| c.args.get(1))
.cloned()
.unwrap_or(Expr::Literal(Literal::I64(0)));
return Statement::Expr(Expr::HostCall {
module: String::new(),
name: helper_name.clone(),
args: vec![Expr::Var("&env".into()), val_arg],
});
}
}
}
match stmt {
Statement::If { condition, then_body, else_body } => Statement::If {
condition,
then_body: rewrite_stmts(then_body, readers, writers),
else_body: rewrite_stmts(else_body, readers, writers),
},
Statement::While { condition, body } => Statement::While {
condition,
body: rewrite_stmts(body, readers, writers),
},
Statement::Loop { body } => Statement::Loop {
body: rewrite_stmts(body, readers, writers),
},
other => other,
}
}).collect()
}
pub fn gen_helper_tokens(helper: &HelperFunction) -> TokenStream {
let fn_name = format_ident!("{}", helper.name);
let body = super::emit::gen_statement(&helper.body_stmt);
if helper.is_writer {
quote! {
fn #fn_name(e: &Env, amount: i128) {
#body
}
}
} else {
quote! {
fn #fn_name(e: &Env) -> i128 {
#body
}
}
}
}
fn to_snake(s: &str) -> String {
let mut result = String::new();
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() && i > 0 {
result.push('_');
}
result.push(ch.to_ascii_lowercase());
}
result
}