Skip to main content

problemreductions_macros/
lib.rs

1//! Procedural macros for problemreductions.
2//!
3//! This crate provides the `#[reduction]` attribute macro that automatically
4//! generates `ReductionEntry` registrations from `ReduceTo` impl blocks.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::quote;
9use syn::{parse_macro_input, GenericArgument, ItemImpl, Path, PathArguments, Type};
10
11/// Attribute macro for automatic reduction registration.
12///
13/// This macro parses a `ReduceTo` impl block and automatically generates
14/// the corresponding `inventory::submit!` call with the correct metadata.
15///
16/// # Type Parameter Convention
17///
18/// The macro extracts graph and weight type information from type parameters:
19/// - `Problem<G>` where `G` is a graph type - extracts graph type name
20/// - `Problem<G, W>` where `W` is a weight type - weighted if W != Unweighted
21///
22/// # Example
23///
24/// ```ignore
25/// #[reduction(
26///     source_graph = "SimpleGraph",
27///     target_graph = "GridGraph",
28///     source_weighted = false,
29///     target_weighted = true,
30/// )]
31/// impl ReduceTo<IndependentSet<i32, GridGraph>> for IndependentSet<Unweighted, SimpleGraph> {
32///     type Result = ReductionISToGridIS;
33///     fn reduce_to(&self) -> Self::Result { ... }
34/// }
35/// ```
36///
37/// The macro also supports inferring from type names when explicit attributes aren't provided.
38#[proc_macro_attribute]
39pub fn reduction(attr: TokenStream, item: TokenStream) -> TokenStream {
40    let attrs = parse_macro_input!(attr as ReductionAttrs);
41    let impl_block = parse_macro_input!(item as ItemImpl);
42
43    match generate_reduction_entry(&attrs, &impl_block) {
44        Ok(tokens) => tokens.into(),
45        Err(e) => e.to_compile_error().into(),
46    }
47}
48
49/// Parsed attributes from #[reduction(...)]
50struct ReductionAttrs {
51    source_graph: Option<String>,
52    target_graph: Option<String>,
53    source_weighted: Option<bool>,
54    target_weighted: Option<bool>,
55    overhead: Option<TokenStream2>,
56}
57
58impl syn::parse::Parse for ReductionAttrs {
59    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
60        let mut attrs = ReductionAttrs {
61            source_graph: None,
62            target_graph: None,
63            source_weighted: None,
64            target_weighted: None,
65            overhead: None,
66        };
67
68        while !input.is_empty() {
69            let ident: syn::Ident = input.parse()?;
70            input.parse::<syn::Token![=]>()?;
71
72            match ident.to_string().as_str() {
73                "source_graph" => {
74                    let lit: syn::LitStr = input.parse()?;
75                    attrs.source_graph = Some(lit.value());
76                }
77                "target_graph" => {
78                    let lit: syn::LitStr = input.parse()?;
79                    attrs.target_graph = Some(lit.value());
80                }
81                "source_weighted" => {
82                    let lit: syn::LitBool = input.parse()?;
83                    attrs.source_weighted = Some(lit.value());
84                }
85                "target_weighted" => {
86                    let lit: syn::LitBool = input.parse()?;
87                    attrs.target_weighted = Some(lit.value());
88                }
89                "overhead" => {
90                    let content;
91                    syn::braced!(content in input);
92                    attrs.overhead = Some(content.parse()?);
93                }
94                _ => {
95                    return Err(syn::Error::new(
96                        ident.span(),
97                        format!("unknown attribute: {}", ident),
98                    ));
99                }
100            }
101
102            if input.peek(syn::Token![,]) {
103                input.parse::<syn::Token![,]>()?;
104            }
105        }
106
107        Ok(attrs)
108    }
109}
110
111/// Extract the base type name from a Type (e.g., "IndependentSet" from "IndependentSet<i32>")
112fn extract_type_name(ty: &Type) -> Option<String> {
113    match ty {
114        Type::Path(type_path) => {
115            let segment = type_path.path.segments.last()?;
116            Some(segment.ident.to_string())
117        }
118        _ => None,
119    }
120}
121
122/// Extract graph type from type parameters (first parameter in `Problem<G, W>` order)
123fn extract_graph_type(ty: &Type) -> Option<String> {
124    match ty {
125        Type::Path(type_path) => {
126            let segment = type_path.path.segments.last()?;
127            if let PathArguments::AngleBracketed(args) = &segment.arguments {
128                // Get the first type argument which is the graph type
129                for arg in args.args.iter() {
130                    if let GenericArgument::Type(Type::Path(inner_path)) = arg {
131                        let name = inner_path
132                            .path
133                            .segments
134                            .last()
135                            .map(|s| s.ident.to_string())?;
136                        // Skip generic params (single uppercase letter)
137                        if name.len() == 1
138                            && name
139                                .chars()
140                                .next()
141                                .map(|c| c.is_ascii_uppercase())
142                                .unwrap_or(false)
143                        {
144                            return None; // Generic param, let it default
145                        }
146                        // Skip known weight types - for single-param problems like QUBO<W>
147                        if is_weight_type(&name) {
148                            return None; // Weight type in first position, not a graph type
149                        }
150                        return Some(name);
151                    }
152                }
153            }
154            None
155        }
156        _ => None,
157    }
158}
159
160/// Check if a type name is a known weight type
161fn is_weight_type(name: &str) -> bool {
162    ["i32", "i64", "f32", "f64", "Unweighted"].contains(&name)
163}
164
165/// Extract weight type from type parameters.
166/// For `Problem<G, W>` (two params): returns W (second param).
167/// For `Problem<W>` (single weight param): returns W (first param).
168fn extract_weight_type(ty: &Type) -> Option<Type> {
169    match ty {
170        Type::Path(type_path) => {
171            let segment = type_path.path.segments.last()?;
172            if let PathArguments::AngleBracketed(args) = &segment.arguments {
173                let type_args: Vec<_> = args
174                    .args
175                    .iter()
176                    .filter_map(|arg| {
177                        if let GenericArgument::Type(t) = arg {
178                            Some(t)
179                        } else {
180                            None
181                        }
182                    })
183                    .collect();
184
185                match type_args.len() {
186                    1 => {
187                        // Single param - check if it's a weight type
188                        let first = type_args[0];
189                        if let Type::Path(inner_path) = first {
190                            let name = inner_path.path.segments.last()?.ident.to_string();
191                            if is_weight_type(&name) {
192                                return Some(first.clone());
193                            }
194                        }
195                        None
196                    }
197                    2 => {
198                        // Two params: Problem<G, W> - return second
199                        Some(type_args[1].clone())
200                    }
201                    _ => None,
202                }
203            } else {
204                None
205            }
206        }
207        _ => None,
208    }
209}
210
211/// Get weight type name as a string for the variant.
212/// Single-letter uppercase names are treated as generic type parameters
213/// and default to "Unweighted" since they're not concrete types.
214fn get_weight_name(ty: &Type) -> String {
215    match ty {
216        Type::Path(type_path) => {
217            let name = type_path
218                .path
219                .segments
220                .last()
221                .map(|s| s.ident.to_string())
222                .unwrap_or_else(|| "Unweighted".to_string());
223            // Treat single uppercase letters as generic params, default to Unweighted
224            if name.len() == 1
225                && name
226                    .chars()
227                    .next()
228                    .map(|c| c.is_ascii_uppercase())
229                    .unwrap_or(false)
230            {
231                "Unweighted".to_string()
232            } else {
233                name
234            }
235        }
236        _ => "Unweighted".to_string(),
237    }
238}
239
240/// Generate the reduction entry code
241fn generate_reduction_entry(
242    attrs: &ReductionAttrs,
243    impl_block: &ItemImpl,
244) -> syn::Result<TokenStream2> {
245    // Extract the trait path (should be ReduceTo<Target>)
246    let trait_path = impl_block
247        .trait_
248        .as_ref()
249        .map(|(_, path, _)| path)
250        .ok_or_else(|| syn::Error::new_spanned(impl_block, "Expected impl ReduceTo<T> for S"))?;
251
252    // Extract target type from ReduceTo<Target>
253    let target_type = extract_target_from_trait(trait_path)?;
254
255    // Extract source type (Self type)
256    let source_type = &impl_block.self_ty;
257
258    // Get type names
259    let source_name = extract_type_name(source_type)
260        .ok_or_else(|| syn::Error::new_spanned(source_type, "Cannot extract source type name"))?;
261    let target_name = extract_type_name(&target_type)
262        .ok_or_else(|| syn::Error::new_spanned(&target_type, "Cannot extract target type name"))?;
263
264    // Determine weight type names
265    let source_weight_name = attrs
266        .source_weighted
267        .map(|w| {
268            if w {
269                "i32".to_string()
270            } else {
271                "Unweighted".to_string()
272            }
273        })
274        .unwrap_or_else(|| {
275            extract_weight_type(source_type)
276                .map(|t| get_weight_name(&t))
277                .unwrap_or_else(|| "Unweighted".to_string())
278        });
279    let target_weight_name = attrs
280        .target_weighted
281        .map(|w| {
282            if w {
283                "i32".to_string()
284            } else {
285                "Unweighted".to_string()
286            }
287        })
288        .unwrap_or_else(|| {
289            extract_weight_type(&target_type)
290                .map(|t| get_weight_name(&t))
291                .unwrap_or_else(|| "Unweighted".to_string())
292        });
293
294    // Determine graph types
295    let source_graph = attrs
296        .source_graph
297        .clone()
298        .or_else(|| extract_graph_type(source_type))
299        .unwrap_or_else(|| "SimpleGraph".to_string());
300    let target_graph = attrs
301        .target_graph
302        .clone()
303        .or_else(|| extract_graph_type(&target_type))
304        .unwrap_or_else(|| "SimpleGraph".to_string());
305
306    // Generate overhead or use default
307    let overhead = attrs.overhead.clone().unwrap_or_else(|| {
308        quote! {
309            crate::rules::registry::ReductionOverhead::default()
310        }
311    });
312
313    // Generate the combined output
314    let output = quote! {
315        #impl_block
316
317        inventory::submit! {
318            crate::rules::registry::ReductionEntry {
319                source_name: #source_name,
320                target_name: #target_name,
321                source_variant: &[("graph", #source_graph), ("weight", #source_weight_name)],
322                target_variant: &[("graph", #target_graph), ("weight", #target_weight_name)],
323                overhead_fn: || { #overhead },
324                module_path: module_path!(),
325            }
326        }
327    };
328
329    Ok(output)
330}
331
332/// Extract the target type from ReduceTo<Target> trait path
333fn extract_target_from_trait(path: &Path) -> syn::Result<Type> {
334    let segment = path
335        .segments
336        .last()
337        .ok_or_else(|| syn::Error::new_spanned(path, "Empty trait path"))?;
338
339    if segment.ident != "ReduceTo" {
340        return Err(syn::Error::new_spanned(segment, "Expected ReduceTo trait"));
341    }
342
343    if let PathArguments::AngleBracketed(args) = &segment.arguments {
344        if let Some(GenericArgument::Type(ty)) = args.args.first() {
345            return Ok(ty.clone());
346        }
347    }
348
349    Err(syn::Error::new_spanned(
350        segment,
351        "Expected ReduceTo<Target> with type parameter",
352    ))
353}