crate::ix!();
pub fn unify_generics_ast(
wire_generics: &syn::Generics,
op_generics: &syn::Generics,
) -> syn::Generics {
let mut merged = wire_generics.clone();
for param in &op_generics.params {
merged.params.push(param.clone());
}
match (&mut merged.where_clause, &op_generics.where_clause) {
(Some(wire_wc), Some(op_wc)) => {
wire_wc.predicates.extend(op_wc.predicates.clone());
}
(None, Some(op_wc)) => {
merged.where_clause = Some(op_wc.clone());
}
_ => {}
}
merged
}
#[cfg(test)]
mod test_unify_generics_ast {
use super::*; use quote::ToTokens;
use syn::{parse_str, DeriveInput, Generics};
fn parse_generics_str(input: &str) -> Generics {
let code = format!("struct Dummy{};", input);
match parse_str::<DeriveInput>(&code) {
Ok(ast) => ast.generics,
Err(e) => panic!("Failed parsing generics from {}: {}", code, e),
}
}
fn normalize_gens(gens: &Generics) -> String {
gens.into_token_stream().to_string().replace('\n', " ")
}
fn where_clause_string(gens: &Generics) -> String {
match &gens.where_clause {
Some(wc) => wc.into_token_stream().to_string().replace('\n', " "),
None => String::new(),
}
}
fn merge_ast(wire: &Generics, op: &Generics) -> Generics {
let merged = unify_generics_ast(wire, op);
merged
}
#[test]
fn unify_no_generics() {
let wire = parse_generics_str("");
let op = parse_generics_str("");
let merged = merge_ast(&wire, &op);
assert!(merged.params.is_empty(), "Expected no params");
assert!(merged.where_clause.is_none(), "Expected no where-clause");
}
#[test]
fn unify_wire_generics_only() {
let wire = parse_generics_str("<T: Clone + Default>");
let op = parse_generics_str("");
let merged = merge_ast(&wire, &op);
assert_eq!(merged.params.len(), 1, "Should have 1 param (T)");
let merged_str = normalize_gens(&merged);
assert!(
merged_str.contains("T : Clone + Default"),
"Missing T:Clone+Default in merged generics; got: {}",
merged_str
);
assert!(merged.where_clause.is_none());
}
#[test]
fn unify_op_generics_only() {
let wire = parse_generics_str("");
let op = parse_generics_str("<U: Send, const N: usize>");
let merged = merge_ast(&wire, &op);
assert_eq!(merged.params.len(), 2, "Should have U, N");
let merged_str = normalize_gens(&merged);
assert!(
merged_str.contains("U : Send") && merged_str.contains("const N : usize"),
"Missing U:Send or N in merged generics; got: {}",
merged_str
);
assert!(merged.where_clause.is_none());
}
#[test]
fn unify_type_and_const_generics() {
let wire = parse_generics_str("<T, 'a, const M: usize>");
let op = parse_generics_str("<U, const N: usize>");
let merged = merge_ast(&wire, &op);
let merged_str = normalize_gens(&merged);
assert!(merged_str.contains("'a"), "No 'a param found!");
assert!(merged_str.contains("T"), "No T param found!");
assert!(merged_str.contains("M : usize"), "No M param!");
assert!(merged_str.contains("U"), "No U param found!");
assert!(merged_str.contains("N : usize"), "No N param!");
assert!(merged.where_clause.is_none());
}
#[test]
fn unify_where_clauses() {
let wire = parse_generics_str("<T: Clone> where T: Default");
let op = parse_generics_str("<U: Send> where U: Sync");
let merged = merge_ast(&wire, &op);
let merged_str = normalize_gens(&merged);
assert!(merged_str.contains("T : Clone") && merged_str.contains("U : Send"),
"Expected T:Clone, U:Send in generics; got {}", merged_str
);
let wc_str = where_clause_string(&merged);
assert!(wc_str.contains("T : Default") && wc_str.contains("U : Sync"),
"Expected T:Default, U:Sync in where-clause; got {}", wc_str
);
}
#[test]
fn unify_complex_generics() {
let wire_str = r#"
<'a, A, B: 'a + std::fmt::Debug>
where
A: core::iter::Iterator<Item = i32>,
B: core::clone::Clone
"#;
let wire = parse_generics_str(wire_str);
let op_str = r#"
<C: ::core::cmp::Eq + ::core::cmp::Ord, const N: usize>
where
C: ::core::marker::Copy
"#;
let op = parse_generics_str(op_str);
let merged = merge_ast(&wire, &op);
let merged_str = normalize_gens(&merged);
info!("Merged generics => {}", merged_str);
assert!(merged_str.contains("'a"), "No 'a found in merged generics");
assert!(merged_str.contains("A"), "No A found");
assert!(merged_str.contains("B : 'a") && merged_str.contains("Debug"), "B missing 'a or Debug");
assert!(merged_str.contains("C : ::core::cmp :: Eq + ::core::cmp :: Ord") ||
merged_str.contains("C : :: core :: cmp :: Eq + :: core :: cmp :: Ord"),
"C missing Eq+Ord"
);
assert!(merged_str.contains("N : usize"), "No N param found");
let wc_str = where_clause_string(&merged);
info!("Merged where-clause => {}", wc_str);
assert!(wc_str.contains("A : core :: iter :: Iterator") && wc_str.contains("Item = i32"),
"Missing A:Iterator<Item=i32> in where-clause"
);
assert!(wc_str.contains("B : core :: clone :: Clone"),
"Missing B:Clone in where-clause"
);
assert!(wc_str.contains("C : :: core :: marker :: Copy") || wc_str.contains("C : core :: marker :: Copy"),
"Missing C:Copy in where-clause"
);
}
#[test]
fn unify_conflicting_params_not_deduplicated() {
let wire = parse_generics_str("<T>");
let op = parse_generics_str("<T>");
let merged = merge_ast(&wire, &op);
assert_eq!(merged.params.len(), 2, "Expected T, T repeated (no dedup).");
}
#[test]
fn parse_error_handling() {
let input = "<T where T: Clone"; let code = format!("struct Dummy{};", input);
let parse_res = parse_str::<DeriveInput>(&code);
assert!(parse_res.is_err(), "Expected parse error for malformed input");
}
}