use convert_case::{Case, Casing};
use crate::builtin::derive_common::{
CompareFieldOptions, is_numeric_type, is_primitive_type, standalone_fn_name, type_has_derive,
};
use crate::builtin::return_types::{is_none_check, partial_ord_return_type, unwrap_option_or_null};
use crate::macros::{ts_macro_derive, ts_template};
use crate::swc_ecma_ast::Expr;
use crate::ts_syn::abi::ir::type_registry::{ResolvedTypeRef, TypeRegistry};
use crate::ts_syn::ts_ident;
use crate::ts_syn::{Data, DeriveInput, MacroforgeError, TsStream, parse_ts_macro_input};
struct OrdField {
name: String,
ts_type: String,
}
fn generate_field_compare_for_interface(
field: &OrdField,
self_var: &str,
other_var: &str,
allow_null: bool,
resolved: Option<&ResolvedTypeRef>,
registry: Option<&TypeRegistry>,
) -> String {
let field_name = &field.name;
let ts_type = &field.ts_type;
let null_return = if allow_null { "null" } else { "0" };
if let (Some(resolved), Some(registry)) = (resolved, registry)
&& !resolved.is_collection
&& resolved.registry_key.is_some()
&& type_has_derive(registry, &resolved.base_type_name, "PartialOrd")
{
let fn_name = standalone_fn_name(&resolved.base_type_name, "PartialCompare");
return format!("{fn_name}({self_var}.{field_name}, {other_var}.{field_name})");
}
if is_numeric_type(ts_type) {
format!(
"({self_var}.{field_name} < {other_var}.{field_name} ? -1 : \
{self_var}.{field_name} > {other_var}.{field_name} ? 1 : 0)"
)
} else if ts_type == "string" {
format!("{self_var}.{field_name}.localeCompare({other_var}.{field_name})")
} else if ts_type == "boolean" {
format!(
"({self_var}.{field_name} === {other_var}.{field_name} ? 0 : \
{self_var}.{field_name} ? 1 : -1)"
)
} else if is_primitive_type(ts_type) {
format!("({self_var}.{field_name} === {other_var}.{field_name} ? 0 : {null_return})")
} else if ts_type.ends_with("[]") || ts_type.starts_with("Array<") {
let unwrap_opt = unwrap_option_or_null("optResult");
format!(
"(() => {{ \
const a = {self_var}.{field_name}; \
const b = {other_var}.{field_name}; \
if (!Array.isArray(a) || !Array.isArray(b)) return {null_return}; \
const minLen = Math.min(a.length, b.length); \
for (let i = 0; i < minLen; i++) {{ \
let cmp: number | null; \
if (typeof (a[i] as any)?.compareTo === 'function') {{ \
const optResult = (a[i] as any).compareTo(b[i]); \
cmp = {unwrap_opt}; \
}} else {{ \
cmp = a[i] < b[i] ? -1 : a[i] > b[i] ? 1 : 0; \
}} \
if (cmp === null) return {null_return}; \
if (cmp !== 0) return cmp; \
}} \
return a.length < b.length ? -1 : a.length > b.length ? 1 : 0; \
}})()"
)
} else if ts_type == "Date" {
format!(
"(() => {{ \
const a = {self_var}.{field_name}; \
const b = {other_var}.{field_name}; \
if (!(a instanceof Date) || !(b instanceof Date)) return {null_return}; \
const ta = a.getTime(); \
const tb = b.getTime(); \
return ta < tb ? -1 : ta > tb ? 1 : 0; \
}})()"
)
} else {
let unwrap_opt = unwrap_option_or_null("optResult");
let is_none = is_none_check("optResult");
format!(
"(() => {{ \
if (typeof ({self_var}.{field_name} as any)?.compareTo === 'function') {{ \
const optResult = ({self_var}.{field_name} as any).compareTo({other_var}.{field_name}); \
return {is_none} ? {null_return} : {unwrap_opt}; \
}} \
return {self_var}.{field_name} === {other_var}.{field_name} ? 0 : {null_return}; \
}})()"
)
}
}
#[ts_macro_derive(
PartialOrd,
description = "Generates a compareTo() method for partial ordering (returns number | null: -1, 0, 1, or null)",
attributes(ord)
)]
pub fn derive_partial_ord_macro(mut input: TsStream) -> Result<TsStream, MacroforgeError> {
let input = parse_ts_macro_input!(input as DeriveInput);
let resolved_fields = input.context.resolved_fields.as_ref();
let type_registry = input.context.type_registry.as_ref();
match &input.data {
Data::Class(class) => {
let class_name = input.name();
let class_ident = ts_ident!(class_name);
let ord_fields: Vec<OrdField> = class
.fields()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "ord");
if opts.skip {
return None;
}
Some(OrdField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let has_fields = !ord_fields.is_empty();
let fn_name_ident = ts_ident!("{}PartialCompare", class_name.to_case(Case::Camel));
let fn_name_expr: Expr = fn_name_ident.clone().into();
let return_type = partial_ord_return_type();
let return_type_ident = ts_ident!(return_type);
let standalone = if has_fields {
let mut compare_body = String::new();
for (i, f) in ord_fields.iter().enumerate() {
let cmp_var = format!("cmp{}", i);
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
let expr_src = generate_field_compare_for_interface(
f,
"a",
"b",
true,
resolved,
type_registry,
);
compare_body.push_str(&format!(
"const {} = {};\nif ({} === null) return null;\nif ({} !== 0) return {};\n",
cmp_var, expr_src, cmp_var, cmp_var, cmp_var
));
}
ts_template! {
export function @{fn_name_ident}(a: @{class_ident}, b: @{class_ident}): @{return_type_ident} {
if (a === b) return 0;
{$typescript TsStream::from_string(compare_body)}
return 0;
}
}
} else {
ts_template! {
export function @{fn_name_ident}(a: @{class_ident}, b: @{class_ident}): @{return_type_ident} {
if (a === b) return 0;
return 0;
}
}
};
let class_body = ts_template!(Within {
static compareTo(a: @{class_ident}, b: @{class_ident}): @{return_type_ident} {
return @{fn_name_expr}(a, b);
}
});
Ok(standalone.merge(class_body))
}
Data::Enum(_) => {
let enum_name = input.name();
let fn_name_ident = ts_ident!("{}PartialCompare", enum_name.to_case(Case::Camel));
let return_type = partial_ord_return_type();
let return_type_ident = ts_ident!(return_type);
let result = ts_template! {
export function @{fn_name_ident}(a: @{ts_ident!(enum_name)}, b: @{ts_ident!(enum_name)}): @{return_type_ident} {
if (typeof a === "number" && typeof b === "number") {
return a < b ? -1 : a > b ? 1 : 0;
}
if (typeof a === "string" && typeof b === "string") {
return a.localeCompare(b);
}
return a === b ? 0 : null;
}
};
Ok(result)
}
Data::Interface(interface) => {
let interface_name = input.name();
let interface_ident = ts_ident!(interface_name);
let ord_fields: Vec<OrdField> = interface
.fields()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "ord");
if opts.skip {
return None;
}
Some(OrdField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let has_fields = !ord_fields.is_empty();
let return_type = partial_ord_return_type();
let return_type_ident = ts_ident!(return_type);
let fn_name_ident = ts_ident!("{}PartialCompare", interface_name.to_case(Case::Camel));
let result = if has_fields {
let mut compare_body = String::new();
for (i, f) in ord_fields.iter().enumerate() {
let cmp_var = format!("cmp{}", i);
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
let expr_src = generate_field_compare_for_interface(
f,
"a",
"b",
true,
resolved,
type_registry,
);
compare_body.push_str(&format!(
"const {} = {};\nif ({} === null) return null;\nif ({} !== 0) return {};\n",
cmp_var, expr_src, cmp_var, cmp_var, cmp_var
));
}
ts_template! {
export function @{fn_name_ident}(a: @{interface_ident}, b: @{interface_ident}): @{return_type_ident} {
if (a === b) return 0;
{$typescript TsStream::from_string(compare_body)}
return 0;
}
}
} else {
ts_template! {
export function @{fn_name_ident}(a: @{interface_ident}, b: @{interface_ident}): @{return_type_ident} {
if (a === b) return 0;
return 0;
}
}
};
Ok(result)
}
Data::TypeAlias(type_alias) => {
let type_name = input.name();
let type_ident = ts_ident!(type_name);
let return_type = partial_ord_return_type();
let return_type_ident = ts_ident!(return_type);
if type_alias.is_object() {
let ord_fields: Vec<OrdField> = type_alias
.as_object()
.unwrap()
.iter()
.filter_map(|field| {
let opts = CompareFieldOptions::from_decorators(&field.decorators, "ord");
if opts.skip {
return None;
}
Some(OrdField {
name: field.name.clone(),
ts_type: field.ts_type.clone(),
})
})
.collect();
let has_fields = !ord_fields.is_empty();
let fn_name_ident = ts_ident!("{}PartialCompare", type_name.to_case(Case::Camel));
let result = if has_fields {
let mut compare_body = String::new();
for (i, f) in ord_fields.iter().enumerate() {
let cmp_var = format!("cmp{}", i);
let resolved = resolved_fields.and_then(|rf| rf.get(&f.name));
let expr_src = generate_field_compare_for_interface(
f,
"a",
"b",
true,
resolved,
type_registry,
);
compare_body.push_str(&format!(
"const {} = {};\nif ({} === null) return null;\nif ({} !== 0) return {};\n",
cmp_var, expr_src, cmp_var, cmp_var, cmp_var
));
}
ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): @{return_type_ident} {
if (a === b) return 0;
{$typescript TsStream::from_string(compare_body)}
return 0;
}
}
} else {
ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): @{return_type_ident} {
if (a === b) return 0;
return 0;
}
}
};
Ok(result)
} else {
let fn_name_ident = ts_ident!("{}PartialCompare", type_name.to_case(Case::Camel));
let result = ts_template! {
export function @{fn_name_ident}(a: @{type_ident}, b: @{type_ident}): @{return_type_ident} {
if (a === b) return 0;
if (typeof a === "number" && typeof b === "number") {
return a < b ? -1 : a > b ? 1 : 0;
}
if (typeof a === "string" && typeof b === "string") {
return a.localeCompare(b);
}
return null;
}
};
Ok(result)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partial_ord_macro_output_vanilla() {
let ord_fields: Vec<OrdField> = vec![OrdField {
name: "id".to_string(),
ts_type: "number".to_string(),
}];
let mut compare_body_str = String::new();
for (i, f) in ord_fields.iter().enumerate() {
let cmp_var = format!("cmp{}", i);
let expr_src = generate_field_compare_for_interface(f, "a", "b", true, None, None);
compare_body_str.push_str(&format!(
"const {} = {};\nif ({} === null) return null;\nif ({} !== 0) return {};\n",
cmp_var, expr_src, cmp_var, cmp_var, cmp_var
));
}
let return_type = partial_ord_return_type();
let _return_type_ident = ts_ident!(return_type);
let output = ts_template!(Within {
compareTo(other: unknown): @{_return_type_ident} {
if (a === b) return 0;
{$typescript TsStream::from_string(compare_body_str)}
return 0;
}
});
let source = output.source();
let body_content = source
.strip_prefix("/* @macroforge:body */")
.unwrap_or(source);
let wrapped = format!("class __Temp {{ {} }}", body_content);
assert!(
macroforge_ts_syn::parse_ts_stmt(&wrapped).is_ok(),
"Generated PartialOrd macro output should parse as class members"
);
assert!(
source.contains("compareTo"),
"Should contain compareTo method"
);
assert!(
source.contains("number | null"),
"Should have number | null return type"
);
}
#[test]
fn test_field_compare_number() {
let field = OrdField {
name: "id".to_string(),
ts_type: "number".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", true, None, None);
assert!(result.contains("a.id < b.id"));
assert!(result.contains("a.id > b.id"));
}
#[test]
fn test_field_compare_string() {
let field = OrdField {
name: "name".to_string(),
ts_type: "string".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", true, None, None);
assert!(result.contains("localeCompare"));
}
#[test]
fn test_field_compare_boolean() {
let field = OrdField {
name: "active".to_string(),
ts_type: "boolean".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", true, None, None);
assert!(result.contains("-1"));
assert!(result.contains("1"));
}
#[test]
fn test_field_compare_date() {
let field = OrdField {
name: "createdAt".to_string(),
ts_type: "Date".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", true, None, None);
assert!(result.contains("getTime"));
}
#[test]
fn test_field_compare_object_vanilla() {
let field = OrdField {
name: "user".to_string(),
ts_type: "User".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", true, None, None);
assert!(result.contains("compareTo"));
assert!(result.contains("=== null"));
}
#[test]
fn test_field_compare_array_vanilla() {
let field = OrdField {
name: "items".to_string(),
ts_type: "Item[]".to_string(),
};
let result = generate_field_compare_for_interface(&field, "a", "b", true, None, None);
assert!(result.contains("cmp = optResult"));
}
}