use convert_case::{Case, Casing};
use crate::builtin::derive_common::{
CompareFieldOptions, collection_element_type, is_primitive_type, standalone_fn_name,
type_has_derive,
};
use crate::macros::{ts_macro_derive, ts_template};
use crate::swc_ecma_ast::Expr;
use crate::ts_syn::abi::ir::type_registry::{ResolvedTypeRef, TypeRegistry};
use crate::ts_syn::{
Data, DeriveInput, MacroforgeError, TsStream, parse_ts_expr, parse_ts_macro_input, ts_ident,
};
pub struct EqField {
pub name: String,
pub ts_type: String,
}
pub fn generate_field_equality_for_interface(
field: &EqField,
self_var: &str,
other_var: &str,
resolved: Option<&ResolvedTypeRef>,
registry: Option<&TypeRegistry>,
) -> String {
let field_name = &field.name;
let ts_type = &field.ts_type;
if let (Some(resolved), Some(registry)) = (resolved, registry) {
if !resolved.is_collection
&& resolved.registry_key.is_some()
&& type_has_derive(registry, &resolved.base_type_name, "PartialEq")
{
let fn_name = standalone_fn_name(&resolved.base_type_name, "Equals");
return format!("{fn_name}({self_var}.{field_name}, {other_var}.{field_name})");
}
if resolved.is_collection
&& let Some(elem) = collection_element_type(resolved)
&& elem.registry_key.is_some()
&& type_has_derive(registry, &elem.base_type_name, "PartialEq")
{
let elem_fn = standalone_fn_name(&elem.base_type_name, "Equals");
let base = resolved.base_type_name.as_str();
match base {
"Map" => {
return format!(
"({self_var}.{field_name} instanceof Map && {other_var}.{field_name} instanceof Map && \
{self_var}.{field_name}.size === {other_var}.{field_name}.size && \
Array.from({self_var}.{field_name}.entries()).every(([k, v]) => \
{other_var}.{field_name}.has(k) && \
{elem_fn}(v, {other_var}.{field_name}.get(k))))"
);
}
"Set" => {
}
_ => {
return format!(
"(Array.isArray({self_var}.{field_name}) && Array.isArray({other_var}.{field_name}) && \
{self_var}.{field_name}.length === {other_var}.{field_name}.length && \
{self_var}.{field_name}.every((v, i) => \
{elem_fn}(v, {other_var}.{field_name}[i])))"
);
}
}
}
}
if is_primitive_type(ts_type) {
format!("{self_var}.{field_name} === {other_var}.{field_name}")
} else if ts_type.ends_with("[]") || ts_type.starts_with("Array<") {
format!(
"(Array.isArray({self_var}.{field_name}) && Array.isArray({other_var}.{field_name}) && \
{self_var}.{field_name}.length === {other_var}.{field_name}.length && \
{self_var}.{field_name}.every((v, i) => \
typeof (v as any)?.equals === 'function' \
? (v as any).equals({other_var}.{field_name}[i]) \
: v === {other_var}.{field_name}[i]))"
)
} else if ts_type == "Date" {
format!(
"({self_var}.{field_name} instanceof Date && {other_var}.{field_name} instanceof Date \
? {self_var}.{field_name}.getTime() === {other_var}.{field_name}.getTime() \
: {self_var}.{field_name} === {other_var}.{field_name})"
)
} else if ts_type.starts_with("Map<") {
format!(
"({self_var}.{field_name} instanceof Map && {other_var}.{field_name} instanceof Map && \
{self_var}.{field_name}.size === {other_var}.{field_name}.size && \
Array.from({self_var}.{field_name}.entries()).every(([k, v]) => \
{other_var}.{field_name}.has(k) && \
(typeof (v as any)?.equals === 'function' \
? (v as any).equals({other_var}.{field_name}.get(k)) \
: v === {other_var}.{field_name}.get(k))))"
)
} else if ts_type.starts_with("Set<") {
format!(
"({self_var}.{field_name} instanceof Set && {other_var}.{field_name} instanceof Set && \
{self_var}.{field_name}.size === {other_var}.{field_name}.size && \
Array.from({self_var}.{field_name}).every(v => {other_var}.{field_name}.has(v)))"
)
} else {
format!(
"(typeof ({self_var}.{field_name} as any)?.equals === 'function' \
? ({self_var}.{field_name} as any).equals({other_var}.{field_name}) \
: {self_var}.{field_name} === {other_var}.{field_name})"
)
}
}
#[ts_macro_derive(
PartialEq,
description = "Generates an equals() method for field-by-field comparison",
attributes(partialEq)
)]
pub fn derive_partial_eq_macro(mut input: TsStream) -> Result<TsStream, MacroforgeError> {
let input = parse_ts_macro_input!(input as DeriveInput);
let resolved_fields = input.context.resolved_fields.as_ref();
let type_registry = input.context.type_registry.as_ref();
match &input.data {
Data::Class(class) => {
let class_name = input.name();
let class_ident = ts_ident!(class_name);
let eq_fields: Vec<EqField> = class
.fields()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "partialEq");
if opts.skip {
return None;
}
Some(EqField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let fn_name_ident = ts_ident!("{}Equals", class_name.to_case(Case::Camel));
let fn_name_expr: Expr = fn_name_ident.clone().into();
let comparison_src = if eq_fields.is_empty() {
"true".to_string()
} else {
eq_fields
.iter()
.map(|f| {
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
generate_field_equality_for_interface(f, "a", "b", resolved, type_registry)
})
.collect::<Vec<_>>()
.join(" && ")
};
let comparison_expr = parse_ts_expr(&comparison_src).map_err(|err| {
MacroforgeError::new(
input.decorator_span(),
format!("@derive(PartialEq): invalid comparison expression: {err:?}"),
)
})?;
let standalone = ts_template! {
export function @{fn_name_ident}(a: @{class_ident}, b: @{class_ident}): boolean {
if (a === b) return true;
return @{comparison_expr};
}
};
let class_body = ts_template!(Within {
static equals(a: @{class_ident}, b: @{class_ident}): boolean {
return @{fn_name_expr}(a, b);
}
});
Ok(standalone.merge(class_body))
}
Data::Enum(_) => {
let enum_name = input.name();
let fn_name_ident = ts_ident!("{}Equals", enum_name.to_case(Case::Camel));
Ok(ts_template! {
export function @{fn_name_ident}(a: @{ts_ident!(enum_name)}, b: @{ts_ident!(enum_name)}): boolean {
return a === b;
}
})
}
Data::Interface(interface) => {
let interface_name = input.name();
let interface_ident = ts_ident!(interface_name);
let eq_fields: Vec<EqField> = interface
.fields()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "partialEq");
if opts.skip {
return None;
}
Some(EqField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let comparison_src = if eq_fields.is_empty() {
"true".to_string()
} else {
eq_fields
.iter()
.map(|f| {
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
generate_field_equality_for_interface(f, "a", "b", resolved, type_registry)
})
.collect::<Vec<_>>()
.join(" && ")
};
let comparison_expr = parse_ts_expr(&comparison_src).map_err(|err| {
MacroforgeError::new(
input.decorator_span(),
format!("@derive(PartialEq): invalid comparison expression: {err:?}"),
)
})?;
let fn_name_ident = ts_ident!("{}Equals", interface_name.to_case(Case::Camel));
Ok(ts_template! {
export function @{fn_name_ident}(a: @{interface_ident}, b: @{interface_ident}): boolean {
if (a === b) return true;
return @{comparison_expr};
}
})
}
Data::TypeAlias(type_alias) => {
let type_name = input.name();
let type_ident = ts_ident!(type_name);
if type_alias.is_object() {
let eq_fields: Vec<EqField> = type_alias
.as_object()
.unwrap()
.iter()
.filter_map(|field| {
let opts =
CompareFieldOptions::from_decorators(&field.decorators, "partialEq");
if opts.skip {
return None;
}
Some(EqField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let comparison_src = if eq_fields.is_empty() {
"true".to_string()
} else {
eq_fields
.iter()
.map(|f| {
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
generate_field_equality_for_interface(
f,
"a",
"b",
resolved,
type_registry,
)
})
.collect::<Vec<_>>()
.join(" && ")
};
let comparison_expr = parse_ts_expr(&comparison_src).map_err(|err| {
MacroforgeError::new(
input.decorator_span(),
format!("@derive(PartialEq): invalid comparison expression: {err:?}"),
)
})?;
let fn_name_ident = ts_ident!("{}Equals", type_name.to_case(Case::Camel));
Ok(ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): boolean {
if (a === b) return true;
return @{comparison_expr};
}
})
} else {
let fn_name_ident = ts_ident!("{}Equals", type_name.to_case(Case::Camel));
Ok(ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): boolean {
if (a === b) return true;
if (typeof a === "object" && typeof b === "object" && a !== null && b !== null) {
return JSON.stringify(a) === JSON.stringify(b);
}
return false;
}
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partial_eq_macro_output() {
let eq_fields: Vec<EqField> = vec![
EqField {
name: "id".to_string(),
ts_type: "number".to_string(),
},
EqField {
name: "name".to_string(),
ts_type: "string".to_string(),
},
];
let comparison = eq_fields
.iter()
.map(|f| generate_field_equality_for_interface(f, "a", "b", None, None))
.collect::<Vec<_>>()
.join(" && ");
let _comparison_expr = parse_ts_expr(&comparison).expect("comparison expr should parse");
let output = ts_template!(Within {
equals(other: unknown): boolean {
if (a === b) return true;
return @{_comparison_expr};
}
});
let source = output.source();
let body_content = source
.strip_prefix("/* @macroforge:body */")
.unwrap_or(source);
let wrapped = format!("class __Temp {{ {} }}", body_content);
assert!(
macroforge_ts_syn::parse_ts_stmt(&wrapped).is_ok(),
"Generated PartialEq macro output should parse as class members"
);
assert!(source.contains("equals"), "Should contain equals method");
}
#[test]
fn test_field_equality_primitive() {
let field = EqField {
name: "id".to_string(),
ts_type: "number".to_string(),
};
let result = generate_field_equality_for_interface(&field, "a", "b", None, None);
assert!(result.contains("a.id === b.id"));
}
#[test]
fn test_field_equality_object() {
let field = EqField {
name: "user".to_string(),
ts_type: "User".to_string(),
};
let result = generate_field_equality_for_interface(&field, "a", "b", None, None);
assert!(result.contains("equals"));
}
#[test]
fn test_field_equality_array() {
let field = EqField {
name: "items".to_string(),
ts_type: "string[]".to_string(),
};
let result = generate_field_equality_for_interface(&field, "a", "b", None, None);
assert!(result.contains("Array.isArray"));
assert!(result.contains("every"));
}
#[test]
fn test_field_equality_date() {
let field = EqField {
name: "createdAt".to_string(),
ts_type: "Date".to_string(),
};
let result = generate_field_equality_for_interface(&field, "a", "b", None, None);
assert!(result.contains("getTime"));
}
}