use ahash::AHashMap;
use alef_core::ir::{DefaultValue, FieldDef};
use syn;
pub(crate) fn extract_default_values(item: &syn::ItemImpl, fields: &mut [FieldDef]) {
let default_fn = item.items.iter().find_map(|impl_item| {
if let syn::ImplItem::Fn(method) = impl_item {
if method.sig.ident == "default" {
return Some(method);
}
}
None
});
let Some(default_fn) = default_fn else {
for field in fields.iter_mut() {
field.typed_default = Some(DefaultValue::Empty);
}
return;
};
let defaults = parse_default_body(&default_fn.block);
for field in fields.iter_mut() {
if let Some(default_val) = defaults.get(&field.name) {
field.typed_default = Some(default_val.clone());
} else {
field.typed_default = Some(DefaultValue::Empty);
}
}
}
fn parse_default_body(block: &syn::Block) -> AHashMap<String, DefaultValue> {
let mut defaults = AHashMap::new();
let struct_expr = find_struct_expr(block);
let Some(struct_expr) = struct_expr else {
return defaults;
};
for field in &struct_expr.fields {
let Some(ident) = &field.member_named() else {
continue;
};
let field_name = ident.to_string();
let default_val = expr_to_default_value(&field.expr);
defaults.insert(field_name, default_val);
}
defaults
}
fn find_struct_expr(block: &syn::Block) -> Option<&syn::ExprStruct> {
for stmt in block.stmts.iter().rev() {
match stmt {
syn::Stmt::Expr(expr, _) => {
if let Some(s) = unwrap_to_struct_expr(expr) {
return Some(s);
}
}
syn::Stmt::Local(local) => {
if let Some(init) = &local.init {
if let Some(s) = unwrap_to_struct_expr(&init.expr) {
return Some(s);
}
}
}
_ => {}
}
}
None
}
fn unwrap_to_struct_expr(expr: &syn::Expr) -> Option<&syn::ExprStruct> {
match expr {
syn::Expr::Struct(s) => Some(s),
syn::Expr::Block(b) => find_struct_expr(&b.block),
_ => None,
}
}
trait FieldMemberExt {
fn member_named(&self) -> Option<&syn::Ident>;
}
impl FieldMemberExt for syn::FieldValue {
fn member_named(&self) -> Option<&syn::Ident> {
match &self.member {
syn::Member::Named(ident) => Some(ident),
syn::Member::Unnamed(_) => None,
}
}
}
fn expr_to_default_value(expr: &syn::Expr) -> DefaultValue {
match expr {
syn::Expr::Lit(lit) => match &lit.lit {
syn::Lit::Bool(b) => DefaultValue::BoolLiteral(b.value),
syn::Lit::Int(i) => {
if let Ok(val) = i.base10_parse::<i64>() {
DefaultValue::IntLiteral(val)
} else {
DefaultValue::Empty
}
}
syn::Lit::Float(f) => {
if let Ok(val) = f.base10_parse::<f64>() {
DefaultValue::FloatLiteral(val)
} else {
DefaultValue::Empty
}
}
syn::Lit::Char(c) => DefaultValue::StringLiteral(c.value().to_string()),
syn::Lit::Str(s) => DefaultValue::StringLiteral(s.value()),
_ => DefaultValue::Empty,
},
syn::Expr::Unary(unary) if matches!(unary.op, syn::UnOp::Neg(_)) => match expr_to_default_value(&unary.expr) {
DefaultValue::IntLiteral(v) => DefaultValue::IntLiteral(-v),
DefaultValue::FloatLiteral(v) => DefaultValue::FloatLiteral(-v),
_ => DefaultValue::Empty,
},
syn::Expr::MethodCall(mc) => {
let method_name = mc.method.to_string();
match method_name.as_str() {
"to_string" | "to_owned" | "into" => {
if let syn::Expr::Lit(lit) = &*mc.receiver {
if let syn::Lit::Str(s) = &lit.lit {
return DefaultValue::StringLiteral(s.value());
}
}
DefaultValue::Empty
}
_ => DefaultValue::Empty,
}
}
syn::Expr::Call(call) => {
if let syn::Expr::Path(path) = &*call.func {
let segments: Vec<String> = path.path.segments.iter().map(|s| s.ident.to_string()).collect();
if segments == ["String", "from"] && call.args.len() == 1 {
if let Some(syn::Expr::Lit(lit)) = call.args.first() {
if let syn::Lit::Str(s) = &lit.lit {
return DefaultValue::StringLiteral(s.value());
}
}
return DefaultValue::Empty;
}
if segments == ["String", "new"] && call.args.is_empty() {
return DefaultValue::StringLiteral(String::new());
}
if segments.len() == 2 && segments[1] == "new" && call.args.is_empty() {
let type_name = &segments[0];
if matches!(
type_name.as_str(),
"Vec" | "HashMap" | "HashSet" | "BTreeMap" | "BTreeSet" | "AHashMap" | "AHashSet"
) {
return DefaultValue::Empty;
}
}
if segments == ["Duration", "from_secs"] && call.args.len() == 1 {
if let Some(syn::Expr::Lit(lit)) = call.args.first() {
if let syn::Lit::Int(i) = &lit.lit {
if let Ok(val) = i.base10_parse::<i64>() {
return DefaultValue::IntLiteral(val * 1000);
}
}
}
return DefaultValue::Empty;
}
if segments == ["Duration", "from_millis"] && call.args.len() == 1 {
if let Some(syn::Expr::Lit(lit)) = call.args.first() {
if let syn::Lit::Int(i) = &lit.lit {
if let Ok(val) = i.base10_parse::<i64>() {
return DefaultValue::IntLiteral(val);
}
}
}
return DefaultValue::Empty;
}
if segments.last().is_some_and(|s| s == "default") {
return DefaultValue::Empty;
}
}
DefaultValue::Empty
}
syn::Expr::Path(path) => {
let segments: Vec<String> = path.path.segments.iter().map(|s| s.ident.to_string()).collect();
if segments.len() == 2 {
return DefaultValue::EnumVariant(segments[1].clone());
}
if segments.len() == 1 && segments[0] == "None" {
return DefaultValue::None;
}
DefaultValue::Empty
}
syn::Expr::Macro(mac) => {
let macro_name = mac
.mac
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if matches!(macro_name.as_str(), "vec" | "hashmap" | "hashset") && mac.mac.tokens.is_empty() {
return DefaultValue::Empty;
}
DefaultValue::Empty
}
_ => DefaultValue::Empty,
}
}