Skip to main content

problemreductions_macros/
lib.rs

1//! Procedural macros for problemreductions.
2//!
3//! This crate provides the `#[reduction]` attribute macro that automatically
4//! generates `ReductionEntry` registrations from `ReduceTo` impl blocks.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::quote;
9use std::collections::HashSet;
10use syn::{parse_macro_input, GenericArgument, ItemImpl, Path, PathArguments, Type};
11
12/// Attribute macro for automatic reduction registration.
13///
14/// Parses a `ReduceTo` impl block and generates the corresponding `inventory::submit!`
15/// call. Variant fields are derived from `Problem::variant()`.
16///
17/// **Type generics are not supported** — all `ReduceTo` impls must use concrete types.
18/// If you need a reduction for a generic problem, write separate impls for each concrete
19/// type combination.
20///
21/// # Attributes
22///
23/// - `overhead = { expr }` — overhead specification (required for non-trivial reductions)
24#[proc_macro_attribute]
25pub fn reduction(attr: TokenStream, item: TokenStream) -> TokenStream {
26    let attrs = parse_macro_input!(attr as ReductionAttrs);
27    let impl_block = parse_macro_input!(item as ItemImpl);
28
29    match generate_reduction_entry(&attrs, &impl_block) {
30        Ok(tokens) => tokens.into(),
31        Err(e) => e.to_compile_error().into(),
32    }
33}
34
35/// Parsed attributes from #[reduction(...)]
36struct ReductionAttrs {
37    overhead: Option<TokenStream2>,
38}
39
40impl syn::parse::Parse for ReductionAttrs {
41    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
42        let mut attrs = ReductionAttrs { overhead: None };
43
44        while !input.is_empty() {
45            let ident: syn::Ident = input.parse()?;
46            input.parse::<syn::Token![=]>()?;
47
48            match ident.to_string().as_str() {
49                "overhead" => {
50                    let content;
51                    syn::braced!(content in input);
52                    attrs.overhead = Some(content.parse()?);
53                }
54                _ => {
55                    return Err(syn::Error::new(
56                        ident.span(),
57                        format!("unknown attribute: {}", ident),
58                    ));
59                }
60            }
61
62            if input.peek(syn::Token![,]) {
63                input.parse::<syn::Token![,]>()?;
64            }
65        }
66
67        Ok(attrs)
68    }
69}
70
71/// Extract the base type name from a Type (e.g., "IndependentSet" from "IndependentSet<i32>")
72fn extract_type_name(ty: &Type) -> Option<String> {
73    match ty {
74        Type::Path(type_path) => {
75            let segment = type_path.path.segments.last()?;
76            Some(segment.ident.to_string())
77        }
78        _ => None,
79    }
80}
81
82/// Collect type generic parameter names from impl generics.
83/// e.g., `impl<G: Graph, W: NumericSize>` → {"G", "W"}
84fn collect_type_generic_names(generics: &syn::Generics) -> HashSet<String> {
85    generics
86        .params
87        .iter()
88        .filter_map(|p| {
89            if let syn::GenericParam::Type(t) = p {
90                Some(t.ident.to_string())
91            } else {
92                None
93            }
94        })
95        .collect()
96}
97
98/// Check if a type uses any of the given type generic parameters.
99fn type_uses_type_generics(ty: &Type, type_generics: &HashSet<String>) -> bool {
100    match ty {
101        Type::Path(type_path) => {
102            if let Some(segment) = type_path.path.segments.last() {
103                if let PathArguments::AngleBracketed(args) = &segment.arguments {
104                    for arg in args.args.iter() {
105                        if let GenericArgument::Type(Type::Path(inner)) = arg {
106                            if let Some(ident) = inner.path.get_ident() {
107                                if type_generics.contains(&ident.to_string()) {
108                                    return true;
109                                }
110                            }
111                        }
112                    }
113                }
114            }
115            false
116        }
117        _ => false,
118    }
119}
120
121/// Generate the variant fn body for a type.
122///
123/// Calls `Problem::variant()` on the concrete type.
124/// Errors if the type uses any type generics — all `ReduceTo` impls must be concrete.
125fn make_variant_fn_body(ty: &Type, type_generics: &HashSet<String>) -> syn::Result<TokenStream2> {
126    if type_uses_type_generics(ty, type_generics) {
127        let used: Vec<_> = type_generics.iter().cloned().collect();
128        return Err(syn::Error::new_spanned(
129            ty,
130            format!(
131                "#[reduction] does not support type generics (found: {}). \
132                 Make the ReduceTo impl concrete by specifying explicit types.",
133                used.join(", ")
134            ),
135        ));
136    }
137    Ok(quote! { <#ty as crate::traits::Problem>::variant() })
138}
139
140/// Generate the reduction entry code
141fn generate_reduction_entry(
142    attrs: &ReductionAttrs,
143    impl_block: &ItemImpl,
144) -> syn::Result<TokenStream2> {
145    // Extract the trait path (should be ReduceTo<Target>)
146    let trait_path = impl_block
147        .trait_
148        .as_ref()
149        .map(|(_, path, _)| path)
150        .ok_or_else(|| syn::Error::new_spanned(impl_block, "Expected impl ReduceTo<T> for S"))?;
151
152    // Extract target type from ReduceTo<Target>
153    let target_type = extract_target_from_trait(trait_path)?;
154
155    // Extract source type (Self type)
156    let source_type = &impl_block.self_ty;
157
158    // Get type names
159    let source_name = extract_type_name(source_type)
160        .ok_or_else(|| syn::Error::new_spanned(source_type, "Cannot extract source type name"))?;
161    let target_name = extract_type_name(&target_type)
162        .ok_or_else(|| syn::Error::new_spanned(&target_type, "Cannot extract target type name"))?;
163
164    // Collect generic parameter info from the impl block
165    let type_generics = collect_type_generic_names(&impl_block.generics);
166
167    // Generate variant fn bodies
168    let source_variant_body = make_variant_fn_body(source_type, &type_generics)?;
169    let target_variant_body = make_variant_fn_body(&target_type, &type_generics)?;
170
171    // Generate overhead or use default
172    let overhead = attrs.overhead.clone().unwrap_or_else(|| {
173        quote! {
174            crate::rules::registry::ReductionOverhead::default()
175        }
176    });
177
178    // Generate the combined output
179    let output = quote! {
180        #impl_block
181
182        inventory::submit! {
183            crate::rules::registry::ReductionEntry {
184                source_name: #source_name,
185                target_name: #target_name,
186                source_variant_fn: || { #source_variant_body },
187                target_variant_fn: || { #target_variant_body },
188                overhead_fn: || { #overhead },
189                module_path: module_path!(),
190                source_size_names_fn: || { <#source_type as crate::traits::Problem>::problem_size_names() },
191                target_size_names_fn: || { <#target_type as crate::traits::Problem>::problem_size_names() },
192                reduce_fn: |src: &dyn std::any::Any| -> Box<dyn crate::rules::traits::DynReductionResult> {
193                    let src = src.downcast_ref::<#source_type>().unwrap_or_else(|| {
194                        panic!(
195                            "DynReductionResult: source type mismatch: expected `{}`, got `{}`",
196                            std::any::type_name::<#source_type>(),
197                            std::any::type_name_of_val(src),
198                        )
199                    });
200                    Box::new(<#source_type as crate::rules::ReduceTo<#target_type>>::reduce_to(src))
201                },
202            }
203        }
204    };
205
206    Ok(output)
207}
208
209/// Extract the target type from ReduceTo<Target> trait path
210fn extract_target_from_trait(path: &Path) -> syn::Result<Type> {
211    let segment = path
212        .segments
213        .last()
214        .ok_or_else(|| syn::Error::new_spanned(path, "Empty trait path"))?;
215
216    if segment.ident != "ReduceTo" {
217        return Err(syn::Error::new_spanned(segment, "Expected ReduceTo trait"));
218    }
219
220    if let PathArguments::AngleBracketed(args) = &segment.arguments {
221        if let Some(GenericArgument::Type(ty)) = args.args.first() {
222            return Ok(ty.clone());
223        }
224    }
225
226    Err(syn::Error::new_spanned(
227        segment,
228        "Expected ReduceTo<Target> with type parameter",
229    ))
230}