Skip to main content

problemreductions_macros/
lib.rs

1//! Procedural macros for problemreductions.
2//!
3//! This crate provides the `#[reduction]` attribute macro that automatically
4//! generates `ReductionEntry` registrations from `ReduceTo` impl blocks,
5//! and the `declare_variants!` proc macro for compile-time validated variant
6//! registration.
7
8pub(crate) mod parser;
9
10use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use quote::quote;
13use std::collections::{HashMap, HashSet};
14use syn::{parse_macro_input, GenericArgument, ItemImpl, Path, PathArguments, Type};
15
16/// Attribute macro for automatic reduction registration.
17///
18/// Parses a `ReduceTo` impl block and generates the corresponding `inventory::submit!`
19/// call. Variant fields are derived from `Problem::variant()`.
20///
21/// **Type generics are not supported** — all `ReduceTo` impls must use concrete types.
22/// If you need a reduction for a generic problem, write separate impls for each concrete
23/// type combination.
24///
25/// # Attributes
26///
27/// - `overhead = { expr }` — overhead specification
28///
29/// ## New syntax (preferred):
30/// ```ignore
31/// #[reduction(overhead = {
32///     num_vars = "num_vertices^2",
33///     num_constraints = "num_edges",
34/// })]
35/// ```
36///
37/// ## Legacy syntax (still supported):
38/// ```ignore
39/// #[reduction(overhead = { ReductionOverhead::new(vec![...]) })]
40/// ```
41#[proc_macro_attribute]
42pub fn reduction(attr: TokenStream, item: TokenStream) -> TokenStream {
43    let attrs = parse_macro_input!(attr as ReductionAttrs);
44    let impl_block = parse_macro_input!(item as ItemImpl);
45
46    match generate_reduction_entry(&attrs, &impl_block) {
47        Ok(tokens) => tokens.into(),
48        Err(e) => e.to_compile_error().into(),
49    }
50}
51
52/// Overhead specification: either new parsed syntax or legacy raw tokens.
53enum OverheadSpec {
54    /// Legacy syntax: raw token stream (e.g., `ReductionOverhead::new(...)`)
55    Legacy(TokenStream2),
56    /// New syntax: list of (field_name, expression_string) pairs
57    Parsed(Vec<(String, String)>),
58}
59
60/// Parsed attributes from #[reduction(...)]
61struct ReductionAttrs {
62    overhead: Option<OverheadSpec>,
63}
64
65impl syn::parse::Parse for ReductionAttrs {
66    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
67        let mut attrs = ReductionAttrs { overhead: None };
68
69        while !input.is_empty() {
70            let ident: syn::Ident = input.parse()?;
71            input.parse::<syn::Token![=]>()?;
72
73            match ident.to_string().as_str() {
74                "overhead" => {
75                    let content;
76                    syn::braced!(content in input);
77                    attrs.overhead = Some(parse_overhead_content(&content)?);
78                }
79                _ => {
80                    return Err(syn::Error::new(
81                        ident.span(),
82                        format!("unknown attribute: {}", ident),
83                    ));
84                }
85            }
86
87            if input.peek(syn::Token![,]) {
88                input.parse::<syn::Token![,]>()?;
89            }
90        }
91
92        Ok(attrs)
93    }
94}
95
96/// Detect and parse the overhead content as either new or legacy syntax.
97///
98/// New syntax detection: the first tokens are `ident = "string_literal"`.
99/// Legacy syntax: everything else (starts with a path like `ReductionOverhead::...`).
100fn parse_overhead_content(content: syn::parse::ParseStream) -> syn::Result<OverheadSpec> {
101    // Fork to peek ahead without consuming
102    let fork = content.fork();
103
104    // Try to detect new syntax: ident = "string"
105    let is_new_syntax = fork.parse::<syn::Ident>().is_ok()
106        && fork.parse::<syn::Token![=]>().is_ok()
107        && fork.parse::<syn::LitStr>().is_ok();
108
109    if is_new_syntax {
110        // Parse new syntax: field_name = "expression", ...
111        let mut fields = Vec::new();
112        while !content.is_empty() {
113            let field_name: syn::Ident = content.parse()?;
114            content.parse::<syn::Token![=]>()?;
115            let expr_str: syn::LitStr = content.parse()?;
116            fields.push((field_name.to_string(), expr_str.value()));
117
118            if content.peek(syn::Token![,]) {
119                content.parse::<syn::Token![,]>()?;
120            }
121        }
122        Ok(OverheadSpec::Parsed(fields))
123    } else {
124        // Legacy syntax: parse as raw token stream
125        let tokens: TokenStream2 = content.parse()?;
126        Ok(OverheadSpec::Legacy(tokens))
127    }
128}
129
130/// Extract the base type name from a Type (e.g., "IndependentSet" from "IndependentSet<i32>")
131fn extract_type_name(ty: &Type) -> Option<String> {
132    match ty {
133        Type::Path(type_path) => {
134            let segment = type_path.path.segments.last()?;
135            Some(segment.ident.to_string())
136        }
137        _ => None,
138    }
139}
140
141/// Collect type generic parameter names from impl generics.
142/// e.g., `impl<G: Graph, W: NumericSize>` → {"G", "W"}
143fn collect_type_generic_names(generics: &syn::Generics) -> HashSet<String> {
144    generics
145        .params
146        .iter()
147        .filter_map(|p| {
148            if let syn::GenericParam::Type(t) = p {
149                Some(t.ident.to_string())
150            } else {
151                None
152            }
153        })
154        .collect()
155}
156
157/// Check if a type uses any of the given type generic parameters.
158fn type_uses_type_generics(ty: &Type, type_generics: &HashSet<String>) -> bool {
159    match ty {
160        Type::Path(type_path) => {
161            if let Some(segment) = type_path.path.segments.last() {
162                if let PathArguments::AngleBracketed(args) = &segment.arguments {
163                    for arg in args.args.iter() {
164                        if let GenericArgument::Type(Type::Path(inner)) = arg {
165                            if let Some(ident) = inner.path.get_ident() {
166                                if type_generics.contains(&ident.to_string()) {
167                                    return true;
168                                }
169                            }
170                        }
171                    }
172                }
173            }
174            false
175        }
176        _ => false,
177    }
178}
179
180/// Generate the variant fn body for a type.
181///
182/// Calls `Problem::variant()` on the concrete type.
183/// Errors if the type uses any type generics — all `ReduceTo` impls must be concrete.
184fn make_variant_fn_body(ty: &Type, type_generics: &HashSet<String>) -> syn::Result<TokenStream2> {
185    if type_uses_type_generics(ty, type_generics) {
186        let used: Vec<_> = type_generics.iter().cloned().collect();
187        return Err(syn::Error::new_spanned(
188            ty,
189            format!(
190                "#[reduction] does not support type generics (found: {}). \
191                 Make the ReduceTo impl concrete by specifying explicit types.",
192                used.join(", ")
193            ),
194        ));
195    }
196    Ok(quote! { <#ty as crate::traits::Problem>::variant() })
197}
198
199/// Generate overhead code from the new parsed syntax.
200///
201/// Produces a `ReductionOverhead` constructor that uses `Expr` AST values.
202fn generate_parsed_overhead(fields: &[(String, String)]) -> syn::Result<TokenStream2> {
203    let mut field_tokens = Vec::new();
204
205    for (field_name, expr_str) in fields {
206        let parsed = parser::parse_expr(expr_str).map_err(|e| {
207            syn::Error::new(
208                proc_macro2::Span::call_site(),
209                format!("error parsing overhead expression \"{expr_str}\": {e}"),
210            )
211        })?;
212
213        let expr_ast = parsed.to_expr_tokens();
214        let name_lit = field_name.as_str();
215        field_tokens.push(quote! { (#name_lit, #expr_ast) });
216    }
217
218    Ok(quote! {
219        crate::rules::registry::ReductionOverhead::new(vec![#(#field_tokens),*])
220    })
221}
222
223/// Generate a compiled overhead evaluation function from parsed overhead fields.
224///
225/// Produces a closure that downcasts `&dyn Any` to `&SourceType`, calls getter methods
226/// for each variable in the expressions, and returns a `ProblemSize`.
227fn generate_overhead_eval_fn(
228    fields: &[(String, String)],
229    source_type: &Type,
230) -> syn::Result<TokenStream2> {
231    let src_ident = syn::Ident::new("__src", proc_macro2::Span::call_site());
232
233    let mut field_eval_tokens = Vec::new();
234    for (field_name, expr_str) in fields {
235        let parsed = parser::parse_expr(expr_str).map_err(|e| {
236            syn::Error::new(
237                proc_macro2::Span::call_site(),
238                format!("error parsing overhead expression \"{expr_str}\": {e}"),
239            )
240        })?;
241
242        let eval_tokens = parsed.to_eval_tokens(&src_ident);
243        let name_lit = field_name.as_str();
244        field_eval_tokens.push(quote! { (#name_lit, (#eval_tokens).round() as usize) });
245    }
246
247    Ok(quote! {
248        |__any_src: &dyn std::any::Any| -> crate::types::ProblemSize {
249            let #src_ident = __any_src.downcast_ref::<#source_type>().unwrap();
250            crate::types::ProblemSize::new(vec![#(#field_eval_tokens),*])
251        }
252    })
253}
254
255/// Generate the reduction entry code
256fn generate_reduction_entry(
257    attrs: &ReductionAttrs,
258    impl_block: &ItemImpl,
259) -> syn::Result<TokenStream2> {
260    // Extract the trait path (should be ReduceTo<Target>)
261    let trait_path = impl_block
262        .trait_
263        .as_ref()
264        .map(|(_, path, _)| path)
265        .ok_or_else(|| syn::Error::new_spanned(impl_block, "Expected impl ReduceTo<T> for S"))?;
266
267    // Extract target type from ReduceTo<Target>
268    let target_type = extract_target_from_trait(trait_path)?;
269
270    // Extract source type (Self type)
271    let source_type = &impl_block.self_ty;
272
273    // Get type names
274    let source_name = extract_type_name(source_type)
275        .ok_or_else(|| syn::Error::new_spanned(source_type, "Cannot extract source type name"))?;
276    let target_name = extract_type_name(&target_type)
277        .ok_or_else(|| syn::Error::new_spanned(&target_type, "Cannot extract target type name"))?;
278
279    // Collect generic parameter info from the impl block
280    let type_generics = collect_type_generic_names(&impl_block.generics);
281
282    // Generate variant fn bodies
283    let source_variant_body = make_variant_fn_body(source_type, &type_generics)?;
284    let target_variant_body = make_variant_fn_body(&target_type, &type_generics)?;
285
286    // Generate overhead and eval fn
287    let (overhead, overhead_eval_fn) = match &attrs.overhead {
288        Some(OverheadSpec::Legacy(tokens)) => {
289            let eval_fn = quote! {
290                |_: &dyn std::any::Any| -> crate::types::ProblemSize {
291                    panic!("overhead_eval_fn not available for legacy overhead syntax; \
292                            migrate to parsed syntax: field = \"expression\"")
293                }
294            };
295            (tokens.clone(), eval_fn)
296        }
297        Some(OverheadSpec::Parsed(fields)) => {
298            let overhead_tokens = generate_parsed_overhead(fields)?;
299            let eval_fn = generate_overhead_eval_fn(fields, source_type)?;
300            (overhead_tokens, eval_fn)
301        }
302        None => {
303            return Err(syn::Error::new(
304                proc_macro2::Span::call_site(),
305                "Missing overhead specification. Use #[reduction(overhead = { ... })] and specify overhead expressions for all target problem size fields.",
306            ));
307        }
308    };
309
310    // Generate the combined output
311    let output = quote! {
312        #impl_block
313
314        inventory::submit! {
315            crate::rules::registry::ReductionEntry {
316                source_name: #source_name,
317                target_name: #target_name,
318                source_variant_fn: || { #source_variant_body },
319                target_variant_fn: || { #target_variant_body },
320                overhead_fn: || { #overhead },
321                module_path: module_path!(),
322                reduce_fn: |src: &dyn std::any::Any| -> Box<dyn crate::rules::traits::DynReductionResult> {
323                    let src = src.downcast_ref::<#source_type>().unwrap_or_else(|| {
324                        panic!(
325                            "DynReductionResult: source type mismatch: expected `{}`, got `{}`",
326                            std::any::type_name::<#source_type>(),
327                            std::any::type_name_of_val(src),
328                        )
329                    });
330                    Box::new(<#source_type as crate::rules::ReduceTo<#target_type>>::reduce_to(src))
331                },
332                overhead_eval_fn: #overhead_eval_fn,
333            }
334        }
335
336        const _: () = {
337            fn _assert_declared_variant<T: crate::traits::DeclaredVariant>() {}
338            fn _check() {
339                _assert_declared_variant::<#source_type>();
340                _assert_declared_variant::<#target_type>();
341            }
342        };
343    };
344
345    Ok(output)
346}
347
348/// Extract the target type from ReduceTo<Target> trait path
349fn extract_target_from_trait(path: &Path) -> syn::Result<Type> {
350    let segment = path
351        .segments
352        .last()
353        .ok_or_else(|| syn::Error::new_spanned(path, "Empty trait path"))?;
354
355    if segment.ident != "ReduceTo" {
356        return Err(syn::Error::new_spanned(segment, "Expected ReduceTo trait"));
357    }
358
359    if let PathArguments::AngleBracketed(args) = &segment.arguments {
360        if let Some(GenericArgument::Type(ty)) = args.args.first() {
361            return Ok(ty.clone());
362        }
363    }
364
365    Err(syn::Error::new_spanned(
366        segment,
367        "Expected ReduceTo<Target> with type parameter",
368    ))
369}
370
371// --- declare_variants! proc macro ---
372
373/// Solver kind for dispatch generation.
374#[derive(Debug, Clone, Copy)]
375enum SolverKind {
376    /// Optimization problem — uses `find_best`.
377    Opt,
378    /// Satisfaction problem — uses `find_satisfying`.
379    Sat,
380}
381
382/// Input for the `declare_variants!` proc macro.
383struct DeclareVariantsInput {
384    entries: Vec<DeclareVariantEntry>,
385}
386
387/// A single entry: `[default] opt|sat Type => "complexity_string"`.
388struct DeclareVariantEntry {
389    is_default: bool,
390    solver_kind: SolverKind,
391    ty: Type,
392    complexity: syn::LitStr,
393}
394
395impl syn::parse::Parse for DeclareVariantsInput {
396    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
397        let mut entries = Vec::new();
398        while !input.is_empty() {
399            // Optionally accept a `default` keyword before the type
400            let is_default = input.peek(syn::Token![default]);
401            if is_default {
402                input.parse::<syn::Token![default]>()?;
403            }
404
405            // Require `opt` or `sat` keyword
406            let solver_kind = if input.peek(syn::Ident) {
407                let fork = input.fork();
408                if let Ok(ident) = fork.parse::<syn::Ident>() {
409                    match ident.to_string().as_str() {
410                        "opt" => {
411                            input.parse::<syn::Ident>()?; // consume
412                            SolverKind::Opt
413                        }
414                        "sat" => {
415                            input.parse::<syn::Ident>()?; // consume
416                            SolverKind::Sat
417                        }
418                        _ => {
419                            return Err(syn::Error::new(
420                                ident.span(),
421                                "expected `opt` or `sat` before type name",
422                            ));
423                        }
424                    }
425                } else {
426                    return Err(input.error("expected `opt` or `sat` before type name"));
427                }
428            } else {
429                return Err(input.error("expected `opt` or `sat` before type name"));
430            };
431
432            let ty: Type = input.parse()?;
433            input.parse::<syn::Token![=>]>()?;
434            let complexity: syn::LitStr = input.parse()?;
435            entries.push(DeclareVariantEntry {
436                is_default,
437                solver_kind,
438                ty,
439                complexity,
440            });
441
442            if input.peek(syn::Token![,]) {
443                input.parse::<syn::Token![,]>()?;
444            }
445        }
446        Ok(DeclareVariantsInput { entries })
447    }
448}
449
450/// Declare explicit problem variants with per-variant complexity metadata.
451///
452/// Each entry generates:
453/// 1. A `DeclaredVariant` trait impl for compile-time checking
454/// 2. A `VariantEntry` inventory submission for runtime graph building
455/// 3. A compiled `complexity_eval_fn` that calls getter methods
456/// 4. A const validation block verifying all variable names are valid getters
457///
458/// Complexity strings must use only numeric literals and getter method names.
459/// Mathematical constants (epsilon, omega, etc.) should be inlined as numbers
460/// and documented in comments or docstrings.
461///
462/// # Example
463///
464/// ```ignore
465/// declare_variants! {
466///     MaximumIndependentSet<SimpleGraph, i32>   => "1.1996^num_vertices",
467///     MaximumIndependentSet<KingsSubgraph, i32> => "2^sqrt(num_vertices)",
468/// }
469/// ```
470#[proc_macro]
471pub fn declare_variants(input: TokenStream) -> TokenStream {
472    let input = parse_macro_input!(input as DeclareVariantsInput);
473    match generate_declare_variants(&input) {
474        Ok(tokens) => tokens.into(),
475        Err(e) => e.to_compile_error().into(),
476    }
477}
478
479/// Generate code for all `declare_variants!` entries.
480fn generate_declare_variants(input: &DeclareVariantsInput) -> syn::Result<TokenStream2> {
481    // Validate default markers per problem name.
482    // Group entries by their base type name (e.g., "MaximumIndependentSet").
483    let mut defaults_per_problem: HashMap<String, Vec<usize>> = HashMap::new();
484    let mut problem_names = HashSet::new();
485    for (i, entry) in input.entries.iter().enumerate() {
486        let base_name = extract_type_name(&entry.ty).unwrap_or_default();
487        problem_names.insert(base_name.clone());
488        if entry.is_default {
489            defaults_per_problem.entry(base_name).or_default().push(i);
490        }
491    }
492
493    // Check for multiple defaults for the same problem
494    for (name, indices) in &defaults_per_problem {
495        if indices.len() > 1 {
496            return Err(syn::Error::new(
497                proc_macro2::Span::call_site(),
498                format!(
499                    "`{name}` has more than one default variant; \
500                     only one entry per problem may be marked `default`"
501                ),
502            ));
503        }
504    }
505
506    for name in problem_names {
507        if !defaults_per_problem.contains_key(&name) {
508            return Err(syn::Error::new(
509                proc_macro2::Span::call_site(),
510                format!(
511                    "`{name}` must declare exactly one default variant; \
512                     mark one entry with `default`"
513                ),
514            ));
515        }
516    }
517
518    let mut output = TokenStream2::new();
519
520    for entry in &input.entries {
521        let ty = &entry.ty;
522        let complexity_str = entry.complexity.value();
523        let is_default = entry.is_default;
524
525        // Parse the complexity expression to validate syntax
526        let parsed = parser::parse_expr(&complexity_str).map_err(|e| {
527            syn::Error::new(
528                entry.complexity.span(),
529                format!("invalid complexity expression \"{complexity_str}\": {e}"),
530            )
531        })?;
532
533        // Generate getter validation for all variables
534        let vars = parsed.variables();
535        let validation = if vars.is_empty() {
536            quote! {}
537        } else {
538            let src_ident = syn::Ident::new("__src", proc_macro2::Span::call_site());
539            let getter_checks: Vec<_> = vars
540                .iter()
541                .map(|var| {
542                    let getter = syn::Ident::new(var, proc_macro2::Span::call_site());
543                    quote! { let _ = #src_ident.#getter(); }
544                })
545                .collect();
546
547            quote! {
548                const _: () = {
549                    #[allow(unused)]
550                    fn _validate_complexity(#src_ident: &#ty) {
551                        #(#getter_checks)*
552                    }
553                };
554            }
555        };
556
557        // Generate compiled complexity eval fn
558        let complexity_eval_fn = generate_complexity_eval_fn(&parsed, ty)?;
559
560        // Generate dispatch fields based on solver kind
561        let solve_body = match entry.solver_kind {
562            SolverKind::Opt => quote! {
563                let config = <crate::solvers::BruteForce as crate::solvers::Solver>::find_best(&solver, p)?;
564            },
565            SolverKind::Sat => quote! {
566                let config = <crate::solvers::BruteForce as crate::solvers::Solver>::find_satisfying(&solver, p)?;
567            },
568        };
569
570        let dispatch_fields = quote! {
571            factory: |data: serde_json::Value| -> Result<Box<dyn crate::registry::DynProblem>, serde_json::Error> {
572                let p: #ty = serde_json::from_value(data)?;
573                Ok(Box::new(p))
574            },
575            serialize_fn: |any: &dyn std::any::Any| -> Option<serde_json::Value> {
576                let p = any.downcast_ref::<#ty>()?;
577                Some(serde_json::to_value(p).expect("serialize failed"))
578            },
579            solve_fn: |any: &dyn std::any::Any| -> Option<(Vec<usize>, String)> {
580                let p = any.downcast_ref::<#ty>()?;
581                let solver = crate::solvers::BruteForce::new();
582                #solve_body
583                let evaluation = format!("{:?}", crate::traits::Problem::evaluate(p, &config));
584                Some((config, evaluation))
585            },
586        };
587
588        output.extend(quote! {
589            impl crate::traits::DeclaredVariant for #ty {}
590
591            crate::inventory::submit! {
592                crate::registry::VariantEntry {
593                    name: <#ty as crate::traits::Problem>::NAME,
594                    variant_fn: || <#ty as crate::traits::Problem>::variant(),
595                    complexity: #complexity_str,
596                    complexity_eval_fn: #complexity_eval_fn,
597                    is_default: #is_default,
598                    #dispatch_fields
599                }
600            }
601
602            #validation
603        });
604    }
605
606    Ok(output)
607}
608
609/// Generate a compiled complexity evaluation function.
610///
611/// Produces a closure that downcasts `&dyn Any` to the problem type, calls getter
612/// methods for all variables, and returns the worst-case time complexity as f64.
613fn generate_complexity_eval_fn(
614    parsed: &parser::ParsedExpr,
615    ty: &Type,
616) -> syn::Result<TokenStream2> {
617    let src_ident = syn::Ident::new("__src", proc_macro2::Span::call_site());
618    let eval_tokens = parsed.to_eval_tokens(&src_ident);
619
620    Ok(quote! {
621        |__any_src: &dyn std::any::Any| -> f64 {
622            let #src_ident = __any_src.downcast_ref::<#ty>().unwrap();
623            #eval_tokens
624        }
625    })
626}
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    #[test]
633    fn declare_variants_accepts_single_default() {
634        let input: DeclareVariantsInput = syn::parse_quote! {
635            default opt Foo => "1",
636        };
637        assert!(generate_declare_variants(&input).is_ok());
638    }
639
640    #[test]
641    fn declare_variants_requires_one_default_per_problem() {
642        let input: DeclareVariantsInput = syn::parse_quote! {
643            opt Foo => "1",
644        };
645        let err = generate_declare_variants(&input).unwrap_err();
646        assert!(
647            err.to_string().contains("exactly one default"),
648            "expected 'exactly one default' in error, got: {}",
649            err
650        );
651    }
652
653    #[test]
654    fn declare_variants_rejects_multiple_defaults_for_one_problem() {
655        let input: DeclareVariantsInput = syn::parse_quote! {
656            default opt Foo => "1",
657            default opt Foo => "2",
658        };
659        let err = generate_declare_variants(&input).unwrap_err();
660        assert!(
661            err.to_string().contains("more than one default"),
662            "expected 'more than one default' in error, got: {}",
663            err
664        );
665    }
666
667    #[test]
668    fn declare_variants_rejects_missing_default_marker() {
669        let input: DeclareVariantsInput = syn::parse_quote! {
670            opt Foo => "1",
671        };
672        let err = generate_declare_variants(&input).unwrap_err();
673        assert!(
674            err.to_string().contains("exactly one default"),
675            "expected 'exactly one default' in error, got: {}",
676            err
677        );
678    }
679
680    #[test]
681    fn declare_variants_marks_only_explicit_default() {
682        let input: DeclareVariantsInput = syn::parse_quote! {
683            opt Foo => "1",
684            default opt Foo => "2",
685        };
686        let result = generate_declare_variants(&input);
687        assert!(result.is_ok());
688        let tokens = result.unwrap().to_string();
689        let true_count = tokens.matches("is_default : true").count();
690        let false_count = tokens.matches("is_default : false").count();
691        assert_eq!(true_count, 1, "should have exactly one default");
692        assert_eq!(false_count, 1, "should have exactly one non-default");
693    }
694
695    #[test]
696    fn declare_variants_accepts_solver_kind_markers() {
697        let input: DeclareVariantsInput = syn::parse_quote! {
698            default opt Foo => "1",
699            default sat Bar => "2",
700        };
701        assert!(generate_declare_variants(&input).is_ok());
702    }
703
704    #[test]
705    fn declare_variants_rejects_missing_solver_kind() {
706        let result = syn::parse_str::<DeclareVariantsInput>("Foo => \"1\"");
707        assert!(
708            result.is_err(),
709            "expected parse error for missing solver kind"
710        );
711    }
712
713    #[test]
714    fn declare_variants_generates_find_best_for_opt_entries() {
715        let input: DeclareVariantsInput = syn::parse_quote! {
716            default opt Foo => "1",
717        };
718        let tokens = generate_declare_variants(&input).unwrap().to_string();
719        assert!(tokens.contains("factory :"), "expected factory field");
720        assert!(
721            tokens.contains("serialize_fn :"),
722            "expected serialize_fn field"
723        );
724        assert!(tokens.contains("solve_fn :"), "expected solve_fn field");
725        assert!(
726            !tokens.contains("factory : None"),
727            "factory should not be None"
728        );
729        assert!(
730            !tokens.contains("serialize_fn : None"),
731            "serialize_fn should not be None"
732        );
733        assert!(
734            !tokens.contains("solve_fn : None"),
735            "solve_fn should not be None"
736        );
737        assert!(tokens.contains("find_best"), "expected find_best in tokens");
738    }
739
740    #[test]
741    fn declare_variants_generates_find_satisfying_for_sat_entries() {
742        let input: DeclareVariantsInput = syn::parse_quote! {
743            default sat Foo => "1",
744        };
745        let tokens = generate_declare_variants(&input).unwrap().to_string();
746        assert!(tokens.contains("factory :"), "expected factory field");
747        assert!(
748            tokens.contains("serialize_fn :"),
749            "expected serialize_fn field"
750        );
751        assert!(tokens.contains("solve_fn :"), "expected solve_fn field");
752        assert!(
753            !tokens.contains("factory : None"),
754            "factory should not be None"
755        );
756        assert!(
757            !tokens.contains("serialize_fn : None"),
758            "serialize_fn should not be None"
759        );
760        assert!(
761            !tokens.contains("solve_fn : None"),
762            "solve_fn should not be None"
763        );
764        assert!(
765            tokens.contains("find_satisfying"),
766            "expected find_satisfying in tokens"
767        );
768    }
769
770    #[test]
771    fn reduction_rejects_unexpected_attribute() {
772        let extra_attr = syn::Ident::new("extra", proc_macro2::Span::call_site());
773        let parse_result = syn::parse2::<ReductionAttrs>(quote! {
774            #extra_attr = "unexpected", overhead = { num_vertices = "num_vertices" }
775        });
776        let err = match parse_result {
777            Ok(_) => panic!("unexpected reduction attribute should be rejected"),
778            Err(err) => err,
779        };
780        assert!(err.to_string().contains("unknown attribute: extra"));
781    }
782
783    #[test]
784    fn reduction_accepts_overhead_attribute() {
785        let attrs: ReductionAttrs = syn::parse_quote! {
786            overhead = { n = "n" }
787        };
788        assert!(attrs.overhead.is_some());
789    }
790
791    #[test]
792    fn declare_variants_codegen_uses_required_dispatch_fields() {
793        let input: DeclareVariantsInput = syn::parse_quote! {
794            default opt Foo => "1",
795        };
796        let tokens = generate_declare_variants(&input).unwrap().to_string();
797        assert!(tokens.contains("factory :"));
798        assert!(tokens.contains("serialize_fn :"));
799        assert!(tokens.contains("solve_fn :"));
800        assert!(!tokens.contains("factory : None"));
801        assert!(!tokens.contains("serialize_fn : None"));
802        assert!(!tokens.contains("solve_fn : None"));
803    }
804}