use convert_case::{Case, Casing};
use crate::builtin::derive_common::{
CompareFieldOptions, is_numeric_type, is_primitive_type, standalone_fn_name, type_has_derive,
};
use crate::macros::{ts_macro_derive, ts_template};
use crate::swc_ecma_ast::{Expr, Ident};
use crate::ts_syn::abi::ir::type_registry::{ResolvedTypeRef, TypeRegistry};
use crate::ts_syn::{
Data, DeriveInput, MacroforgeError, TsStream, parse_ts_expr, parse_ts_macro_input, ts_ident,
};
struct OrdField {
name: String,
ts_type: String,
}
fn generate_field_compare_for_interface(
field: &OrdField,
self_var: &str,
other_var: &str,
resolved: Option<&ResolvedTypeRef>,
registry: Option<&TypeRegistry>,
) -> String {
let field_name = &field.name;
let ts_type = &field.ts_type;
if let (Some(resolved), Some(registry)) = (resolved, registry)
&& !resolved.is_collection
&& resolved.registry_key.is_some()
&& type_has_derive(registry, &resolved.base_type_name, "Ord")
{
let fn_name = standalone_fn_name(&resolved.base_type_name, "Compare");
return format!("{fn_name}({self_var}.{field_name}, {other_var}.{field_name})");
}
if is_numeric_type(ts_type) {
format!(
"({self_var}.{field_name} < {other_var}.{field_name} ? -1 : \
{self_var}.{field_name} > {other_var}.{field_name} ? 1 : 0)"
)
} else if ts_type == "string" {
format!(
"((cmp => cmp < 0 ? -1 : cmp > 0 ? 1 : 0)({self_var}.{field_name}.localeCompare({other_var}.{field_name})))"
)
} else if ts_type == "boolean" {
format!(
"({self_var}.{field_name} === {other_var}.{field_name} ? 0 : \
{self_var}.{field_name} ? 1 : -1)"
)
} else if is_primitive_type(ts_type) {
"0".to_string()
} else if ts_type.ends_with("[]") || ts_type.starts_with("Array<") {
format!(
"(() => {{ \
const a = {self_var}.{field_name} ?? []; \
const b = {other_var}.{field_name} ?? []; \
const minLen = Math.min(a.length, b.length); \
for (let i = 0; i < minLen; i++) {{ \
const cmp = typeof (a[i] as any)?.compareTo === 'function' \
? (a[i] as any).compareTo(b[i]) ?? 0 \
: (a[i] < b[i] ? -1 : a[i] > b[i] ? 1 : 0); \
if (cmp !== 0) return cmp; \
}} \
return a.length < b.length ? -1 : a.length > b.length ? 1 : 0; \
}})()"
)
} else if ts_type == "Date" {
format!(
"(() => {{ \
const ta = {self_var}.{field_name}?.getTime() ?? 0; \
const tb = {other_var}.{field_name}?.getTime() ?? 0; \
return ta < tb ? -1 : ta > tb ? 1 : 0; \
}})()"
)
} else {
format!(
"(typeof ({self_var}.{field_name} as any)?.compareTo === 'function' \
? ({self_var}.{field_name} as any).compareTo({other_var}.{field_name}) ?? 0 \
: 0)"
)
}
}
#[ts_macro_derive(
Ord,
description = "Generates a compareTo() method for total ordering (returns -1, 0, or 1, never null)",
attributes(ord)
)]
pub fn derive_ord_macro(mut input: TsStream) -> Result<TsStream, MacroforgeError> {
let input = parse_ts_macro_input!(input as DeriveInput);
let resolved_fields = input.context.resolved_fields.as_ref();
let type_registry = input.context.type_registry.as_ref();
match &input.data {
Data::Class(class) => {
let class_name = input.name();
let class_ident = ts_ident!(class_name);
let ord_fields: Vec<OrdField> = class
.fields()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "ord");
if opts.skip {
return None;
}
Some(OrdField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let fn_name_ident = ts_ident!("{}Compare", class_name.to_case(Case::Camel));
let fn_name_expr: Expr = fn_name_ident.clone().into();
let standalone = if !ord_fields.is_empty() {
let compare_steps: Vec<(Ident, Expr)> = ord_fields
.iter()
.enumerate()
.map(|(i, f)| {
let cmp_ident = ts_ident!(format!("cmp{}", i));
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
let expr_src = generate_field_compare_for_interface(
f,
"a",
"b",
resolved,
type_registry,
);
let expr = parse_ts_expr(&expr_src).map_err(|err| {
MacroforgeError::new(
input.decorator_span(),
format!(
"@derive(Ord): invalid comparison expression for '{}': {err:?}",
f.name
),
)
})?;
Ok((cmp_ident, *expr))
})
.collect::<Result<_, MacroforgeError>>()?;
let _ = &compare_steps;
ts_template! {
export function @{fn_name_ident}(a: @{class_ident}, b: @{class_ident}): number {
if (a === b) return 0;
{#for (cmp_ident, cmp_expr) in &compare_steps}
const @{cmp_ident.clone()} = @{cmp_expr.clone()};
if (@{cmp_ident.clone()} !== 0) return @{cmp_ident.clone()};
{/for}
return 0;
}
}
} else {
ts_template! {
export function @{fn_name_ident}(a: @{class_ident}, b: @{class_ident}): number {
if (a === b) return 0;
return 0;
}
}
};
let class_body = ts_template!(Within {
static compareTo(a: @{class_ident}, b: @{class_ident}): number {
return @{fn_name_expr}(a, b);
}
});
Ok(standalone.merge(class_body))
}
Data::Enum(_) => {
let enum_name = input.name();
let fn_name_ident = ts_ident!("{}Compare", enum_name.to_case(Case::Camel));
Ok(ts_template! {
export function @{fn_name_ident}(a: @{ts_ident!(enum_name)}, b: @{ts_ident!(enum_name)}): number {
if (typeof a === "number" && typeof b === "number") {
return a < b ? -1 : a > b ? 1 : 0;
}
if (typeof a === "string" && typeof b === "string") {
const cmp = a.localeCompare(b);
return cmp < 0 ? -1 : cmp > 0 ? 1 : 0;
}
return 0;
}
})
}
Data::Interface(interface) => {
let interface_name = input.name();
let interface_ident = ts_ident!(interface_name);
let ord_fields: Vec<OrdField> = interface
.fields()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "ord");
if opts.skip {
return None;
}
Some(OrdField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let fn_name_ident = ts_ident!("{}Compare", interface_name.to_case(Case::Camel));
if !ord_fields.is_empty() {
let compare_steps: Vec<(Ident, Expr)> = ord_fields
.iter()
.enumerate()
.map(|(i, f)| {
let cmp_ident = ts_ident!(format!("cmp{}", i));
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
let expr_src = generate_field_compare_for_interface(
f,
"a",
"b",
resolved,
type_registry,
);
let expr = parse_ts_expr(&expr_src).map_err(|err| {
MacroforgeError::new(
input.decorator_span(),
format!(
"@derive(Ord): invalid comparison expression for '{}': {err:?}",
f.name
),
)
})?;
Ok((cmp_ident, *expr))
})
.collect::<Result<_, MacroforgeError>>()?;
let _ = &compare_steps;
Ok(ts_template! {
export function @{fn_name_ident}(a: @{interface_ident}, b: @{interface_ident}): number {
if (a === b) return 0;
{#for (cmp_ident, cmp_expr) in &compare_steps}
const @{cmp_ident.clone()} = @{cmp_expr.clone()};
if (@{cmp_ident.clone()} !== 0) return @{cmp_ident.clone()};
{/for}
return 0;
}
})
} else {
Ok(ts_template! {
export function @{fn_name_ident}(a: @{interface_ident}, b: @{interface_ident}): number {
if (a === b) return 0;
return 0;
}
})
}
}
Data::TypeAlias(type_alias) => {
let type_name = input.name();
let type_ident = ts_ident!(type_name);
if type_alias.is_object() {
let ord_fields: Vec<OrdField> = type_alias
.as_object()
.unwrap()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "ord");
if opts.skip {
return None;
}
Some(OrdField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let fn_name_ident = ts_ident!("{}Compare", type_name.to_case(Case::Camel));
if !ord_fields.is_empty() {
let compare_steps: Vec<(Ident, Expr)> = ord_fields
.iter()
.enumerate()
.map(|(i, f)| {
let cmp_ident = ts_ident!(format!("cmp{}", i));
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
let expr_src = generate_field_compare_for_interface(
f, "a", "b", resolved, type_registry,
);
let expr = parse_ts_expr(&expr_src).map_err(|err| {
MacroforgeError::new(
input.decorator_span(),
format!(
"@derive(Ord): invalid comparison expression for '{}': {err:?}",
f.name
),
)
})?;
Ok((cmp_ident, *expr))
})
.collect::<Result<_, MacroforgeError>>()?;
let _ = &compare_steps;
Ok(ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): number {
if (a === b) return 0;
{#for (cmp_ident, cmp_expr) in &compare_steps}
const @{cmp_ident.clone()} = @{cmp_expr.clone()};
if (@{cmp_ident.clone()} !== 0) return @{cmp_ident.clone()};
{/for}
return 0;
}
})
} else {
Ok(ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): number {
if (a === b) return 0;
return 0;
}
})
}
} else {
let fn_name_ident = ts_ident!("{}Compare", type_name.to_case(Case::Camel));
Ok(ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): number {
if (a === b) return 0;
if (typeof a === "number" && typeof b === "number") {
return a < b ? -1 : a > b ? 1 : 0;
}
if (typeof a === "string" && typeof b === "string") {
const cmp = a.localeCompare(b);
return cmp < 0 ? -1 : cmp > 0 ? 1 : 0;
}
return 0;
}
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ord_macro_output() {
let class_name = "User";
let class_ident = ts_ident!(class_name);
let ord_fields: Vec<OrdField> = vec![OrdField {
name: "id".to_string(),
ts_type: "number".to_string(),
}];
let compare_steps: Vec<(Ident, Expr)> = ord_fields
.iter()
.enumerate()
.map(|(i, f)| {
let cmp_ident = ts_ident!(format!("cmp{}", i));
let expr_src = generate_field_compare_for_interface(f, "a", "b", None, None);
let expr = parse_ts_expr(&expr_src).expect("compare expr should parse");
(cmp_ident, *expr)
})
.collect();
let _ = &compare_steps;
let output = ts_template!(Within {
compareTo(other: @{class_ident}): number {
if (a === b) return 0;
{#for (cmp_ident, cmp_expr) in &compare_steps}
const @{cmp_ident.clone()} = @{cmp_expr.clone()};
if (@{cmp_ident.clone()} !== 0) return @{cmp_ident.clone()};
{/for}
return 0;
}
});
let source = output.source();
let body_content = source
.strip_prefix("/* @macroforge:body */")
.unwrap_or(source);
let wrapped = format!("class __Temp {{ {} }}", body_content);
assert!(
macroforge_ts_syn::parse_ts_stmt(&wrapped).is_ok(),
"Generated Ord macro output should parse as class members"
);
assert!(
source.contains("compareTo"),
"Should contain compareTo method"
);
assert!(
!source.contains("number | null"),
"Should not contain nullable return type"
);
}
#[test]
fn test_field_compare_number() {
let field = OrdField {
name: "id".to_string(),
ts_type: "number".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", None, None);
assert!(result.contains("a.id < b.id"));
assert!(result.contains("a.id > b.id"));
assert!(!result.contains("null")); }
#[test]
fn test_field_compare_string() {
let field = OrdField {
name: "name".to_string(),
ts_type: "string".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", None, None);
assert!(result.contains("localeCompare"));
assert!(result.contains("-1"));
assert!(result.contains("1"));
}
#[test]
fn test_field_compare_object_no_null() {
let field = OrdField {
name: "user".to_string(),
ts_type: "User".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", None, None);
assert!(result.contains("compareTo"));
assert!(result.contains("?? 0"));
}
}