use crate::core::ir::{EnumDef, ParamDef, TypeRef};
use ahash::{AHashMap, AHashSet};
pub(in crate::backends::rustler::gen_bindings) fn json_encode_param_indices(
params: &[ParamDef],
opaque_types: &AHashSet<String>,
) -> AHashSet<usize> {
params
.iter()
.enumerate()
.filter_map(|(idx, param)| match ¶m.ty {
TypeRef::Vec(inner) => match inner.as_ref() {
TypeRef::Named(inner_name) if !opaque_types.contains(inner_name) => Some(idx),
_ => None,
},
_ => None,
})
.collect()
}
pub(in crate::backends::rustler::gen_bindings) fn tagged_enum_param_map(
params: &[ParamDef],
enum_lookup: &AHashMap<String, &EnumDef>,
) -> AHashMap<usize, TaggedEnumParam> {
params
.iter()
.enumerate()
.filter_map(|(idx, param)| {
let (inner_name, is_vec) = match ¶m.ty {
TypeRef::Vec(inner) => match inner.as_ref() {
TypeRef::Named(n) => (n.as_str(), true),
_ => return None,
},
TypeRef::Named(n) => (n.as_str(), false),
_ => return None,
};
let enum_def = enum_lookup.get(inner_name)?;
if enum_def.serde_tag.is_some() {
Some((
idx,
TaggedEnumParam {
enum_name: enum_def.name.clone(),
is_vec,
},
))
} else {
None
}
})
.collect()
}
#[derive(Debug, Clone)]
pub(in crate::backends::rustler::gen_bindings) struct TaggedEnumParam {
pub enum_name: String,
pub is_vec: bool,
}
pub(in crate::backends::rustler::gen_bindings) fn nif_arg(
index: usize,
param: &str,
json_encode_params: &AHashSet<usize>,
tagged_enum_params: &AHashMap<usize, TaggedEnumParam>,
) -> String {
if let Some(te) = tagged_enum_params.get(&index) {
let helper = encoder_fn_name(&te.enum_name);
if te.is_vec {
format!("Jason.encode!(Enum.map({param}, &{helper}/1))")
} else {
format!("Jason.encode!({helper}({param}))")
}
} else if json_encode_params.contains(&index) {
format!("Jason.encode!({param})")
} else {
param.to_string()
}
}
pub(in crate::backends::rustler::gen_bindings) fn keyword_nif_arg(
index: usize,
param: &str,
json_encode_params: &AHashSet<usize>,
tagged_enum_params: &AHashMap<usize, TaggedEnumParam>,
) -> String {
if let Some(te) = tagged_enum_params.get(&index) {
let helper = encoder_fn_name(&te.enum_name);
let mapped = if te.is_vec {
format!("Jason.encode!(Enum.map(v, &{helper}/1))")
} else {
format!("Jason.encode!({helper}(v))")
};
format!("case Keyword.get(opts, :{param}) do nil -> nil; v -> {mapped} end")
} else if json_encode_params.contains(&index) {
format!("case Keyword.get(opts, :{param}) do nil -> nil; v -> Jason.encode!(v) end")
} else {
format!("Keyword.get(opts, :{param})")
}
}
pub(in crate::backends::rustler::gen_bindings) fn encoder_fn_name(enum_name: &str) -> String {
format!("encode_{}", crate::codegen::naming::pascal_to_snake(enum_name))
}
pub(in crate::backends::rustler::gen_bindings) fn emit_tagged_enum_encoder(enum_def: &EnumDef) -> String {
use crate::codegen::naming::{pascal_to_snake, wire_field_name, wire_variant_value};
let Some(tag) = enum_def.serde_tag.as_deref() else {
return String::new();
};
if enum_def.serde_untagged {
return String::new();
}
let fn_name = encoder_fn_name(&enum_def.name);
let rename_all = enum_def.serde_rename_all.as_deref();
let mut out = String::with_capacity(1024);
let mut first_clause = true;
for variant in &enum_def.variants {
if variant.binding_excluded {
continue;
}
let atom = pascal_to_snake(&variant.name);
let wire = wire_variant_value(&variant.name, variant.serde_rename.as_deref(), rename_all);
let wire_escaped = wire.replace('\\', "\\\\").replace('"', "\\\"");
if variant.fields.is_empty() {
if !first_clause {
out.push_str("\n");
}
out.push_str(&format!(
" defp {fn_name}(:{atom}), do: %{{\"{tag}\" => \"{wire_escaped}\"}}\n"
));
out.push_str("\n");
out.push_str(&format!(
" defp {fn_name}({{:{atom}, _}}), do: %{{\"{tag}\" => \"{wire_escaped}\"}}\n"
));
first_clause = false;
continue;
}
if !first_clause {
out.push_str("\n");
}
out.push_str(&format!(" defp {fn_name}({{:{atom}, %{{}} = data}}) do\n"));
out.push_str(" data\n");
out.push_str(" |> Enum.reduce(%{}, fn {k, v}, acc ->\n");
out.push_str(" key =\n");
out.push_str(" case k do\n");
for field in &variant.fields {
if field.binding_excluded {
continue;
}
let wire_field = wire_field_name(&field.name, field.serde_rename.as_deref(), None);
if wire_field != field.name {
let wire_field_escaped = wire_field.replace('\\', "\\\\").replace('"', "\\\"");
out.push_str(&format!(" :{} -> \"{}\"\n", field.name, wire_field_escaped));
}
}
out.push_str(" k when is_atom(k) -> Atom.to_string(k)\n");
out.push_str(" k when is_binary(k) -> k\n");
out.push_str(" end\n\n");
out.push_str(" Map.put(acc, key, v)\n");
out.push_str(" end)\n");
out.push_str(&format!(" |> Map.put(\"{tag}\", \"{wire_escaped}\")\n"));
out.push_str(" end\n");
first_clause = false;
}
if !first_clause {
out.push_str("\n");
}
out.push_str(&format!(" defp {fn_name}(%{{}} = m), do: m\n"));
out.push_str("\n");
out.push_str(&format!(
" defp {fn_name}(other),\n do: raise(ArgumentError, \"expected {} (atom, {{atom, map}}, or map), got: \" <> inspect(other))\n\n",
enum_def.name
));
out
}