problemreductions_macros/
lib.rs1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::quote;
9use std::collections::HashSet;
10use syn::{parse_macro_input, GenericArgument, ItemImpl, Path, PathArguments, Type};
11
12#[proc_macro_attribute]
25pub fn reduction(attr: TokenStream, item: TokenStream) -> TokenStream {
26 let attrs = parse_macro_input!(attr as ReductionAttrs);
27 let impl_block = parse_macro_input!(item as ItemImpl);
28
29 match generate_reduction_entry(&attrs, &impl_block) {
30 Ok(tokens) => tokens.into(),
31 Err(e) => e.to_compile_error().into(),
32 }
33}
34
35struct ReductionAttrs {
37 overhead: Option<TokenStream2>,
38}
39
40impl syn::parse::Parse for ReductionAttrs {
41 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
42 let mut attrs = ReductionAttrs { overhead: None };
43
44 while !input.is_empty() {
45 let ident: syn::Ident = input.parse()?;
46 input.parse::<syn::Token![=]>()?;
47
48 match ident.to_string().as_str() {
49 "overhead" => {
50 let content;
51 syn::braced!(content in input);
52 attrs.overhead = Some(content.parse()?);
53 }
54 _ => {
55 return Err(syn::Error::new(
56 ident.span(),
57 format!("unknown attribute: {}", ident),
58 ));
59 }
60 }
61
62 if input.peek(syn::Token![,]) {
63 input.parse::<syn::Token![,]>()?;
64 }
65 }
66
67 Ok(attrs)
68 }
69}
70
71fn extract_type_name(ty: &Type) -> Option<String> {
73 match ty {
74 Type::Path(type_path) => {
75 let segment = type_path.path.segments.last()?;
76 Some(segment.ident.to_string())
77 }
78 _ => None,
79 }
80}
81
82fn collect_type_generic_names(generics: &syn::Generics) -> HashSet<String> {
85 generics
86 .params
87 .iter()
88 .filter_map(|p| {
89 if let syn::GenericParam::Type(t) = p {
90 Some(t.ident.to_string())
91 } else {
92 None
93 }
94 })
95 .collect()
96}
97
98fn type_uses_type_generics(ty: &Type, type_generics: &HashSet<String>) -> bool {
100 match ty {
101 Type::Path(type_path) => {
102 if let Some(segment) = type_path.path.segments.last() {
103 if let PathArguments::AngleBracketed(args) = &segment.arguments {
104 for arg in args.args.iter() {
105 if let GenericArgument::Type(Type::Path(inner)) = arg {
106 if let Some(ident) = inner.path.get_ident() {
107 if type_generics.contains(&ident.to_string()) {
108 return true;
109 }
110 }
111 }
112 }
113 }
114 }
115 false
116 }
117 _ => false,
118 }
119}
120
121fn make_variant_fn_body(ty: &Type, type_generics: &HashSet<String>) -> syn::Result<TokenStream2> {
126 if type_uses_type_generics(ty, type_generics) {
127 let used: Vec<_> = type_generics.iter().cloned().collect();
128 return Err(syn::Error::new_spanned(
129 ty,
130 format!(
131 "#[reduction] does not support type generics (found: {}). \
132 Make the ReduceTo impl concrete by specifying explicit types.",
133 used.join(", ")
134 ),
135 ));
136 }
137 Ok(quote! { <#ty as crate::traits::Problem>::variant() })
138}
139
140fn generate_reduction_entry(
142 attrs: &ReductionAttrs,
143 impl_block: &ItemImpl,
144) -> syn::Result<TokenStream2> {
145 let trait_path = impl_block
147 .trait_
148 .as_ref()
149 .map(|(_, path, _)| path)
150 .ok_or_else(|| syn::Error::new_spanned(impl_block, "Expected impl ReduceTo<T> for S"))?;
151
152 let target_type = extract_target_from_trait(trait_path)?;
154
155 let source_type = &impl_block.self_ty;
157
158 let source_name = extract_type_name(source_type)
160 .ok_or_else(|| syn::Error::new_spanned(source_type, "Cannot extract source type name"))?;
161 let target_name = extract_type_name(&target_type)
162 .ok_or_else(|| syn::Error::new_spanned(&target_type, "Cannot extract target type name"))?;
163
164 let type_generics = collect_type_generic_names(&impl_block.generics);
166
167 let source_variant_body = make_variant_fn_body(source_type, &type_generics)?;
169 let target_variant_body = make_variant_fn_body(&target_type, &type_generics)?;
170
171 let overhead = attrs.overhead.clone().unwrap_or_else(|| {
173 quote! {
174 crate::rules::registry::ReductionOverhead::default()
175 }
176 });
177
178 let output = quote! {
180 #impl_block
181
182 inventory::submit! {
183 crate::rules::registry::ReductionEntry {
184 source_name: #source_name,
185 target_name: #target_name,
186 source_variant_fn: || { #source_variant_body },
187 target_variant_fn: || { #target_variant_body },
188 overhead_fn: || { #overhead },
189 module_path: module_path!(),
190 source_size_names_fn: || { <#source_type as crate::traits::Problem>::problem_size_names() },
191 target_size_names_fn: || { <#target_type as crate::traits::Problem>::problem_size_names() },
192 reduce_fn: |src: &dyn std::any::Any| -> Box<dyn crate::rules::traits::DynReductionResult> {
193 let src = src.downcast_ref::<#source_type>().unwrap_or_else(|| {
194 panic!(
195 "DynReductionResult: source type mismatch: expected `{}`, got `{}`",
196 std::any::type_name::<#source_type>(),
197 std::any::type_name_of_val(src),
198 )
199 });
200 Box::new(<#source_type as crate::rules::ReduceTo<#target_type>>::reduce_to(src))
201 },
202 }
203 }
204 };
205
206 Ok(output)
207}
208
209fn extract_target_from_trait(path: &Path) -> syn::Result<Type> {
211 let segment = path
212 .segments
213 .last()
214 .ok_or_else(|| syn::Error::new_spanned(path, "Empty trait path"))?;
215
216 if segment.ident != "ReduceTo" {
217 return Err(syn::Error::new_spanned(segment, "Expected ReduceTo trait"));
218 }
219
220 if let PathArguments::AngleBracketed(args) = &segment.arguments {
221 if let Some(GenericArgument::Type(ty)) = args.args.first() {
222 return Ok(ty.clone());
223 }
224 }
225
226 Err(syn::Error::new_spanned(
227 segment,
228 "Expected ReduceTo<Target> with type parameter",
229 ))
230}