use proc_macro2::TokenStream;
use quote::quote;
use crate::ast::*;
pub fn generate_hook_actions(
actions: &[HookAction],
_lookup_by: &Option<FieldPath>,
) -> TokenStream {
let action_code: Vec<TokenStream> = actions
.iter()
.map(|action| {
match action {
HookAction::RegisterPdaMapping {
pda_field,
seed_field,
lookup_name: _,
} => {
let pda_raw = pda_field.segments.last().cloned().unwrap_or_default();
let seed_raw = seed_field.segments.last().cloned().unwrap_or_default();
let pda_camel = crate::event_type_helpers::snake_to_lower_camel(&pda_raw);
let seed_camel = crate::event_type_helpers::snake_to_lower_camel(&seed_raw);
quote! {
let pda_val = ctx.account(#pda_camel).or_else(|| ctx.account(#pda_raw));
let seed_val = ctx.account(#seed_camel).or_else(|| ctx.account(#seed_raw));
if let (Some(pda), Some(seed)) = (pda_val, seed_val) {
ctx.register_pda_reverse_lookup(&pda, &seed);
}
}
}
HookAction::SetField {
target_field,
source,
condition,
} => {
let set_code = generate_set_field_code(target_field, source);
if let Some(cond) = condition {
let cond_code = generate_condition_code(cond);
quote! {
if #cond_code {
#set_code
}
}
} else {
set_code
}
}
HookAction::IncrementField {
target_field,
increment_by,
condition,
} => {
let increment_code = quote! {
ctx.increment(#target_field, #increment_by);
};
if let Some(cond) = condition {
let cond_code = generate_condition_code(cond);
quote! {
if #cond_code {
#increment_code
}
}
} else {
increment_code
}
}
}
})
.collect();
quote! {
#(#action_code)*
}
}
pub fn generate_set_field_code(target_field: &str, source: &MappingSource) -> TokenStream {
match source {
MappingSource::FromSource {
path,
default: _,
transform: _,
} => {
let field_str = path.segments.last().cloned().unwrap_or_default();
quote! {
if let Some(value) = ctx.data::<hyperstack::runtime::serde_json::Value>(#field_str) {
ctx.set(#target_field, value);
}
}
}
MappingSource::Constant(value) => {
let value_str = serde_json::to_string(value).unwrap_or_else(|_| "null".to_string());
quote! {
ctx.set(#target_field, hyperstack::runtime::serde_json::from_str::<hyperstack::runtime::serde_json::Value>(#value_str).unwrap_or(hyperstack::runtime::serde_json::Value::Null));
}
}
MappingSource::FromContext { field } => match field.as_str() {
"slot" => {
quote! { ctx.set(#target_field, hyperstack::runtime::serde_json::json!(ctx.slot().unwrap_or(0))); }
}
"signature" => {
quote! { ctx.set(#target_field, hyperstack::runtime::serde_json::json!(ctx.signature().unwrap_or_default())); }
}
"timestamp" => {
quote! { ctx.set(#target_field, hyperstack::runtime::serde_json::json!(ctx.timestamp())); }
}
_ => quote! {},
},
_ => quote! {},
}
}
pub fn generate_condition_code(condition: &ConditionExpr) -> TokenStream {
if let Some(ref parsed) = condition.parsed {
generate_parsed_condition_code(parsed)
} else {
let _expr = &condition.expression;
quote! { true }
}
}
pub fn generate_parsed_condition_code(condition: &ParsedCondition) -> TokenStream {
match condition {
ParsedCondition::Comparison { field, op, value } => {
let field_str = field.segments.last().cloned().unwrap_or_default();
let value_str = serde_json::to_string(value).unwrap_or_else(|_| "null".to_string());
match op {
ComparisonOp::Equal | ComparisonOp::NotEqual => {
let op_code = match op {
ComparisonOp::Equal => quote! { == },
ComparisonOp::NotEqual => quote! { != },
_ => unreachable!(),
};
quote! {
ctx.data::<hyperstack::runtime::serde_json::Value>(#field_str)
.map(|v| v #op_code hyperstack::runtime::serde_json::from_str::<hyperstack::runtime::serde_json::Value>(#value_str).unwrap_or(hyperstack::runtime::serde_json::Value::Null))
.unwrap_or(false)
}
}
ComparisonOp::GreaterThan
| ComparisonOp::GreaterThanOrEqual
| ComparisonOp::LessThan
| ComparisonOp::LessThanOrEqual => {
let op_code = match op {
ComparisonOp::GreaterThan => quote! { > },
ComparisonOp::GreaterThanOrEqual => quote! { >= },
ComparisonOp::LessThan => quote! { < },
ComparisonOp::LessThanOrEqual => quote! { <= },
_ => unreachable!(),
};
quote! {
{
let field_val: Option<i64> = ctx.data(#field_str);
let compare_val: Option<i64> = hyperstack::runtime::serde_json::from_str(#value_str).ok();
match (field_val, compare_val) {
(Some(f), Some(c)) => f #op_code c,
_ => false
}
}
}
}
}
}
ParsedCondition::Logical { op, conditions } => {
let condition_codes: Vec<TokenStream> = conditions
.iter()
.map(generate_parsed_condition_code)
.collect();
match op {
LogicalOp::And => quote! { #(#condition_codes)&&* },
LogicalOp::Or => quote! { #(#condition_codes)||* },
}
}
}
}
pub fn to_snake_case(s: &str) -> String {
s.split("::")
.last()
.unwrap_or(s)
.chars()
.enumerate()
.flat_map(|(i, c)| {
if c.is_uppercase() && i > 0 {
vec!['_', c.to_lowercase().next().unwrap()]
} else {
vec![c.to_lowercase().next().unwrap()]
}
})
.collect()
}