1pub(crate) mod parser;
9
10use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use quote::quote;
13use std::collections::{HashMap, HashSet};
14use syn::{parse_macro_input, GenericArgument, ItemImpl, Path, PathArguments, Type};
15
16#[proc_macro_attribute]
42pub fn reduction(attr: TokenStream, item: TokenStream) -> TokenStream {
43 let attrs = parse_macro_input!(attr as ReductionAttrs);
44 let impl_block = parse_macro_input!(item as ItemImpl);
45
46 match generate_reduction_entry(&attrs, &impl_block) {
47 Ok(tokens) => tokens.into(),
48 Err(e) => e.to_compile_error().into(),
49 }
50}
51
52enum OverheadSpec {
54 Legacy(TokenStream2),
56 Parsed(Vec<(String, String)>),
58}
59
60struct ReductionAttrs {
62 overhead: Option<OverheadSpec>,
63}
64
65impl syn::parse::Parse for ReductionAttrs {
66 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
67 let mut attrs = ReductionAttrs { overhead: None };
68
69 while !input.is_empty() {
70 let ident: syn::Ident = input.parse()?;
71 input.parse::<syn::Token![=]>()?;
72
73 match ident.to_string().as_str() {
74 "overhead" => {
75 let content;
76 syn::braced!(content in input);
77 attrs.overhead = Some(parse_overhead_content(&content)?);
78 }
79 _ => {
80 return Err(syn::Error::new(
81 ident.span(),
82 format!("unknown attribute: {}", ident),
83 ));
84 }
85 }
86
87 if input.peek(syn::Token![,]) {
88 input.parse::<syn::Token![,]>()?;
89 }
90 }
91
92 Ok(attrs)
93 }
94}
95
96fn parse_overhead_content(content: syn::parse::ParseStream) -> syn::Result<OverheadSpec> {
101 let fork = content.fork();
103
104 let is_new_syntax = fork.parse::<syn::Ident>().is_ok()
106 && fork.parse::<syn::Token![=]>().is_ok()
107 && fork.parse::<syn::LitStr>().is_ok();
108
109 if is_new_syntax {
110 let mut fields = Vec::new();
112 while !content.is_empty() {
113 let field_name: syn::Ident = content.parse()?;
114 content.parse::<syn::Token![=]>()?;
115 let expr_str: syn::LitStr = content.parse()?;
116 fields.push((field_name.to_string(), expr_str.value()));
117
118 if content.peek(syn::Token![,]) {
119 content.parse::<syn::Token![,]>()?;
120 }
121 }
122 Ok(OverheadSpec::Parsed(fields))
123 } else {
124 let tokens: TokenStream2 = content.parse()?;
126 Ok(OverheadSpec::Legacy(tokens))
127 }
128}
129
130fn extract_type_name(ty: &Type) -> Option<String> {
132 match ty {
133 Type::Path(type_path) => {
134 let segment = type_path.path.segments.last()?;
135 Some(segment.ident.to_string())
136 }
137 _ => None,
138 }
139}
140
141fn collect_type_generic_names(generics: &syn::Generics) -> HashSet<String> {
144 generics
145 .params
146 .iter()
147 .filter_map(|p| {
148 if let syn::GenericParam::Type(t) = p {
149 Some(t.ident.to_string())
150 } else {
151 None
152 }
153 })
154 .collect()
155}
156
157fn type_uses_type_generics(ty: &Type, type_generics: &HashSet<String>) -> bool {
159 match ty {
160 Type::Path(type_path) => {
161 if let Some(segment) = type_path.path.segments.last() {
162 if let PathArguments::AngleBracketed(args) = &segment.arguments {
163 for arg in args.args.iter() {
164 if let GenericArgument::Type(Type::Path(inner)) = arg {
165 if let Some(ident) = inner.path.get_ident() {
166 if type_generics.contains(&ident.to_string()) {
167 return true;
168 }
169 }
170 }
171 }
172 }
173 }
174 false
175 }
176 _ => false,
177 }
178}
179
180fn make_variant_fn_body(ty: &Type, type_generics: &HashSet<String>) -> syn::Result<TokenStream2> {
185 if type_uses_type_generics(ty, type_generics) {
186 let used: Vec<_> = type_generics.iter().cloned().collect();
187 return Err(syn::Error::new_spanned(
188 ty,
189 format!(
190 "#[reduction] does not support type generics (found: {}). \
191 Make the ReduceTo impl concrete by specifying explicit types.",
192 used.join(", ")
193 ),
194 ));
195 }
196 Ok(quote! { <#ty as crate::traits::Problem>::variant() })
197}
198
199fn generate_parsed_overhead(fields: &[(String, String)]) -> syn::Result<TokenStream2> {
203 let mut field_tokens = Vec::new();
204
205 for (field_name, expr_str) in fields {
206 let parsed = parser::parse_expr(expr_str).map_err(|e| {
207 syn::Error::new(
208 proc_macro2::Span::call_site(),
209 format!("error parsing overhead expression \"{expr_str}\": {e}"),
210 )
211 })?;
212
213 let expr_ast = parsed.to_expr_tokens();
214 let name_lit = field_name.as_str();
215 field_tokens.push(quote! { (#name_lit, #expr_ast) });
216 }
217
218 Ok(quote! {
219 crate::rules::registry::ReductionOverhead::new(vec![#(#field_tokens),*])
220 })
221}
222
223fn generate_overhead_eval_fn(
228 fields: &[(String, String)],
229 source_type: &Type,
230) -> syn::Result<TokenStream2> {
231 let src_ident = syn::Ident::new("__src", proc_macro2::Span::call_site());
232
233 let mut field_eval_tokens = Vec::new();
234 for (field_name, expr_str) in fields {
235 let parsed = parser::parse_expr(expr_str).map_err(|e| {
236 syn::Error::new(
237 proc_macro2::Span::call_site(),
238 format!("error parsing overhead expression \"{expr_str}\": {e}"),
239 )
240 })?;
241
242 let eval_tokens = parsed.to_eval_tokens(&src_ident);
243 let name_lit = field_name.as_str();
244 field_eval_tokens.push(quote! { (#name_lit, (#eval_tokens).round() as usize) });
245 }
246
247 Ok(quote! {
248 |__any_src: &dyn std::any::Any| -> crate::types::ProblemSize {
249 let #src_ident = __any_src.downcast_ref::<#source_type>().unwrap();
250 crate::types::ProblemSize::new(vec![#(#field_eval_tokens),*])
251 }
252 })
253}
254
255fn generate_reduction_entry(
257 attrs: &ReductionAttrs,
258 impl_block: &ItemImpl,
259) -> syn::Result<TokenStream2> {
260 let trait_path = impl_block
262 .trait_
263 .as_ref()
264 .map(|(_, path, _)| path)
265 .ok_or_else(|| syn::Error::new_spanned(impl_block, "Expected impl ReduceTo<T> for S"))?;
266
267 let target_type = extract_target_from_trait(trait_path)?;
269
270 let source_type = &impl_block.self_ty;
272
273 let source_name = extract_type_name(source_type)
275 .ok_or_else(|| syn::Error::new_spanned(source_type, "Cannot extract source type name"))?;
276 let target_name = extract_type_name(&target_type)
277 .ok_or_else(|| syn::Error::new_spanned(&target_type, "Cannot extract target type name"))?;
278
279 let type_generics = collect_type_generic_names(&impl_block.generics);
281
282 let source_variant_body = make_variant_fn_body(source_type, &type_generics)?;
284 let target_variant_body = make_variant_fn_body(&target_type, &type_generics)?;
285
286 let (overhead, overhead_eval_fn) = match &attrs.overhead {
288 Some(OverheadSpec::Legacy(tokens)) => {
289 let eval_fn = quote! {
290 |_: &dyn std::any::Any| -> crate::types::ProblemSize {
291 panic!("overhead_eval_fn not available for legacy overhead syntax; \
292 migrate to parsed syntax: field = \"expression\"")
293 }
294 };
295 (tokens.clone(), eval_fn)
296 }
297 Some(OverheadSpec::Parsed(fields)) => {
298 let overhead_tokens = generate_parsed_overhead(fields)?;
299 let eval_fn = generate_overhead_eval_fn(fields, source_type)?;
300 (overhead_tokens, eval_fn)
301 }
302 None => {
303 return Err(syn::Error::new(
304 proc_macro2::Span::call_site(),
305 "Missing overhead specification. Use #[reduction(overhead = { ... })] and specify overhead expressions for all target problem size fields.",
306 ));
307 }
308 };
309
310 let output = quote! {
312 #impl_block
313
314 inventory::submit! {
315 crate::rules::registry::ReductionEntry {
316 source_name: #source_name,
317 target_name: #target_name,
318 source_variant_fn: || { #source_variant_body },
319 target_variant_fn: || { #target_variant_body },
320 overhead_fn: || { #overhead },
321 module_path: module_path!(),
322 reduce_fn: |src: &dyn std::any::Any| -> Box<dyn crate::rules::traits::DynReductionResult> {
323 let src = src.downcast_ref::<#source_type>().unwrap_or_else(|| {
324 panic!(
325 "DynReductionResult: source type mismatch: expected `{}`, got `{}`",
326 std::any::type_name::<#source_type>(),
327 std::any::type_name_of_val(src),
328 )
329 });
330 Box::new(<#source_type as crate::rules::ReduceTo<#target_type>>::reduce_to(src))
331 },
332 overhead_eval_fn: #overhead_eval_fn,
333 }
334 }
335
336 const _: () = {
337 fn _assert_declared_variant<T: crate::traits::DeclaredVariant>() {}
338 fn _check() {
339 _assert_declared_variant::<#source_type>();
340 _assert_declared_variant::<#target_type>();
341 }
342 };
343 };
344
345 Ok(output)
346}
347
348fn extract_target_from_trait(path: &Path) -> syn::Result<Type> {
350 let segment = path
351 .segments
352 .last()
353 .ok_or_else(|| syn::Error::new_spanned(path, "Empty trait path"))?;
354
355 if segment.ident != "ReduceTo" {
356 return Err(syn::Error::new_spanned(segment, "Expected ReduceTo trait"));
357 }
358
359 if let PathArguments::AngleBracketed(args) = &segment.arguments {
360 if let Some(GenericArgument::Type(ty)) = args.args.first() {
361 return Ok(ty.clone());
362 }
363 }
364
365 Err(syn::Error::new_spanned(
366 segment,
367 "Expected ReduceTo<Target> with type parameter",
368 ))
369}
370
371#[derive(Debug, Clone, Copy)]
375enum SolverKind {
376 Opt,
378 Sat,
380}
381
382struct DeclareVariantsInput {
384 entries: Vec<DeclareVariantEntry>,
385}
386
387struct DeclareVariantEntry {
389 is_default: bool,
390 solver_kind: SolverKind,
391 ty: Type,
392 complexity: syn::LitStr,
393}
394
395impl syn::parse::Parse for DeclareVariantsInput {
396 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
397 let mut entries = Vec::new();
398 while !input.is_empty() {
399 let is_default = input.peek(syn::Token![default]);
401 if is_default {
402 input.parse::<syn::Token![default]>()?;
403 }
404
405 let solver_kind = if input.peek(syn::Ident) {
407 let fork = input.fork();
408 if let Ok(ident) = fork.parse::<syn::Ident>() {
409 match ident.to_string().as_str() {
410 "opt" => {
411 input.parse::<syn::Ident>()?; SolverKind::Opt
413 }
414 "sat" => {
415 input.parse::<syn::Ident>()?; SolverKind::Sat
417 }
418 _ => {
419 return Err(syn::Error::new(
420 ident.span(),
421 "expected `opt` or `sat` before type name",
422 ));
423 }
424 }
425 } else {
426 return Err(input.error("expected `opt` or `sat` before type name"));
427 }
428 } else {
429 return Err(input.error("expected `opt` or `sat` before type name"));
430 };
431
432 let ty: Type = input.parse()?;
433 input.parse::<syn::Token![=>]>()?;
434 let complexity: syn::LitStr = input.parse()?;
435 entries.push(DeclareVariantEntry {
436 is_default,
437 solver_kind,
438 ty,
439 complexity,
440 });
441
442 if input.peek(syn::Token![,]) {
443 input.parse::<syn::Token![,]>()?;
444 }
445 }
446 Ok(DeclareVariantsInput { entries })
447 }
448}
449
450#[proc_macro]
471pub fn declare_variants(input: TokenStream) -> TokenStream {
472 let input = parse_macro_input!(input as DeclareVariantsInput);
473 match generate_declare_variants(&input) {
474 Ok(tokens) => tokens.into(),
475 Err(e) => e.to_compile_error().into(),
476 }
477}
478
479fn generate_declare_variants(input: &DeclareVariantsInput) -> syn::Result<TokenStream2> {
481 let mut defaults_per_problem: HashMap<String, Vec<usize>> = HashMap::new();
484 let mut problem_names = HashSet::new();
485 for (i, entry) in input.entries.iter().enumerate() {
486 let base_name = extract_type_name(&entry.ty).unwrap_or_default();
487 problem_names.insert(base_name.clone());
488 if entry.is_default {
489 defaults_per_problem.entry(base_name).or_default().push(i);
490 }
491 }
492
493 for (name, indices) in &defaults_per_problem {
495 if indices.len() > 1 {
496 return Err(syn::Error::new(
497 proc_macro2::Span::call_site(),
498 format!(
499 "`{name}` has more than one default variant; \
500 only one entry per problem may be marked `default`"
501 ),
502 ));
503 }
504 }
505
506 for name in problem_names {
507 if !defaults_per_problem.contains_key(&name) {
508 return Err(syn::Error::new(
509 proc_macro2::Span::call_site(),
510 format!(
511 "`{name}` must declare exactly one default variant; \
512 mark one entry with `default`"
513 ),
514 ));
515 }
516 }
517
518 let mut output = TokenStream2::new();
519
520 for entry in &input.entries {
521 let ty = &entry.ty;
522 let complexity_str = entry.complexity.value();
523 let is_default = entry.is_default;
524
525 let parsed = parser::parse_expr(&complexity_str).map_err(|e| {
527 syn::Error::new(
528 entry.complexity.span(),
529 format!("invalid complexity expression \"{complexity_str}\": {e}"),
530 )
531 })?;
532
533 let vars = parsed.variables();
535 let validation = if vars.is_empty() {
536 quote! {}
537 } else {
538 let src_ident = syn::Ident::new("__src", proc_macro2::Span::call_site());
539 let getter_checks: Vec<_> = vars
540 .iter()
541 .map(|var| {
542 let getter = syn::Ident::new(var, proc_macro2::Span::call_site());
543 quote! { let _ = #src_ident.#getter(); }
544 })
545 .collect();
546
547 quote! {
548 const _: () = {
549 #[allow(unused)]
550 fn _validate_complexity(#src_ident: &#ty) {
551 #(#getter_checks)*
552 }
553 };
554 }
555 };
556
557 let complexity_eval_fn = generate_complexity_eval_fn(&parsed, ty)?;
559
560 let solve_body = match entry.solver_kind {
562 SolverKind::Opt => quote! {
563 let config = <crate::solvers::BruteForce as crate::solvers::Solver>::find_best(&solver, p)?;
564 },
565 SolverKind::Sat => quote! {
566 let config = <crate::solvers::BruteForce as crate::solvers::Solver>::find_satisfying(&solver, p)?;
567 },
568 };
569
570 let dispatch_fields = quote! {
571 factory: |data: serde_json::Value| -> Result<Box<dyn crate::registry::DynProblem>, serde_json::Error> {
572 let p: #ty = serde_json::from_value(data)?;
573 Ok(Box::new(p))
574 },
575 serialize_fn: |any: &dyn std::any::Any| -> Option<serde_json::Value> {
576 let p = any.downcast_ref::<#ty>()?;
577 Some(serde_json::to_value(p).expect("serialize failed"))
578 },
579 solve_fn: |any: &dyn std::any::Any| -> Option<(Vec<usize>, String)> {
580 let p = any.downcast_ref::<#ty>()?;
581 let solver = crate::solvers::BruteForce::new();
582 #solve_body
583 let evaluation = format!("{:?}", crate::traits::Problem::evaluate(p, &config));
584 Some((config, evaluation))
585 },
586 };
587
588 output.extend(quote! {
589 impl crate::traits::DeclaredVariant for #ty {}
590
591 crate::inventory::submit! {
592 crate::registry::VariantEntry {
593 name: <#ty as crate::traits::Problem>::NAME,
594 variant_fn: || <#ty as crate::traits::Problem>::variant(),
595 complexity: #complexity_str,
596 complexity_eval_fn: #complexity_eval_fn,
597 is_default: #is_default,
598 #dispatch_fields
599 }
600 }
601
602 #validation
603 });
604 }
605
606 Ok(output)
607}
608
609fn generate_complexity_eval_fn(
614 parsed: &parser::ParsedExpr,
615 ty: &Type,
616) -> syn::Result<TokenStream2> {
617 let src_ident = syn::Ident::new("__src", proc_macro2::Span::call_site());
618 let eval_tokens = parsed.to_eval_tokens(&src_ident);
619
620 Ok(quote! {
621 |__any_src: &dyn std::any::Any| -> f64 {
622 let #src_ident = __any_src.downcast_ref::<#ty>().unwrap();
623 #eval_tokens
624 }
625 })
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn declare_variants_accepts_single_default() {
634 let input: DeclareVariantsInput = syn::parse_quote! {
635 default opt Foo => "1",
636 };
637 assert!(generate_declare_variants(&input).is_ok());
638 }
639
640 #[test]
641 fn declare_variants_requires_one_default_per_problem() {
642 let input: DeclareVariantsInput = syn::parse_quote! {
643 opt Foo => "1",
644 };
645 let err = generate_declare_variants(&input).unwrap_err();
646 assert!(
647 err.to_string().contains("exactly one default"),
648 "expected 'exactly one default' in error, got: {}",
649 err
650 );
651 }
652
653 #[test]
654 fn declare_variants_rejects_multiple_defaults_for_one_problem() {
655 let input: DeclareVariantsInput = syn::parse_quote! {
656 default opt Foo => "1",
657 default opt Foo => "2",
658 };
659 let err = generate_declare_variants(&input).unwrap_err();
660 assert!(
661 err.to_string().contains("more than one default"),
662 "expected 'more than one default' in error, got: {}",
663 err
664 );
665 }
666
667 #[test]
668 fn declare_variants_rejects_missing_default_marker() {
669 let input: DeclareVariantsInput = syn::parse_quote! {
670 opt Foo => "1",
671 };
672 let err = generate_declare_variants(&input).unwrap_err();
673 assert!(
674 err.to_string().contains("exactly one default"),
675 "expected 'exactly one default' in error, got: {}",
676 err
677 );
678 }
679
680 #[test]
681 fn declare_variants_marks_only_explicit_default() {
682 let input: DeclareVariantsInput = syn::parse_quote! {
683 opt Foo => "1",
684 default opt Foo => "2",
685 };
686 let result = generate_declare_variants(&input);
687 assert!(result.is_ok());
688 let tokens = result.unwrap().to_string();
689 let true_count = tokens.matches("is_default : true").count();
690 let false_count = tokens.matches("is_default : false").count();
691 assert_eq!(true_count, 1, "should have exactly one default");
692 assert_eq!(false_count, 1, "should have exactly one non-default");
693 }
694
695 #[test]
696 fn declare_variants_accepts_solver_kind_markers() {
697 let input: DeclareVariantsInput = syn::parse_quote! {
698 default opt Foo => "1",
699 default sat Bar => "2",
700 };
701 assert!(generate_declare_variants(&input).is_ok());
702 }
703
704 #[test]
705 fn declare_variants_rejects_missing_solver_kind() {
706 let result = syn::parse_str::<DeclareVariantsInput>("Foo => \"1\"");
707 assert!(
708 result.is_err(),
709 "expected parse error for missing solver kind"
710 );
711 }
712
713 #[test]
714 fn declare_variants_generates_find_best_for_opt_entries() {
715 let input: DeclareVariantsInput = syn::parse_quote! {
716 default opt Foo => "1",
717 };
718 let tokens = generate_declare_variants(&input).unwrap().to_string();
719 assert!(tokens.contains("factory :"), "expected factory field");
720 assert!(
721 tokens.contains("serialize_fn :"),
722 "expected serialize_fn field"
723 );
724 assert!(tokens.contains("solve_fn :"), "expected solve_fn field");
725 assert!(
726 !tokens.contains("factory : None"),
727 "factory should not be None"
728 );
729 assert!(
730 !tokens.contains("serialize_fn : None"),
731 "serialize_fn should not be None"
732 );
733 assert!(
734 !tokens.contains("solve_fn : None"),
735 "solve_fn should not be None"
736 );
737 assert!(tokens.contains("find_best"), "expected find_best in tokens");
738 }
739
740 #[test]
741 fn declare_variants_generates_find_satisfying_for_sat_entries() {
742 let input: DeclareVariantsInput = syn::parse_quote! {
743 default sat Foo => "1",
744 };
745 let tokens = generate_declare_variants(&input).unwrap().to_string();
746 assert!(tokens.contains("factory :"), "expected factory field");
747 assert!(
748 tokens.contains("serialize_fn :"),
749 "expected serialize_fn field"
750 );
751 assert!(tokens.contains("solve_fn :"), "expected solve_fn field");
752 assert!(
753 !tokens.contains("factory : None"),
754 "factory should not be None"
755 );
756 assert!(
757 !tokens.contains("serialize_fn : None"),
758 "serialize_fn should not be None"
759 );
760 assert!(
761 !tokens.contains("solve_fn : None"),
762 "solve_fn should not be None"
763 );
764 assert!(
765 tokens.contains("find_satisfying"),
766 "expected find_satisfying in tokens"
767 );
768 }
769
770 #[test]
771 fn reduction_rejects_unexpected_attribute() {
772 let extra_attr = syn::Ident::new("extra", proc_macro2::Span::call_site());
773 let parse_result = syn::parse2::<ReductionAttrs>(quote! {
774 #extra_attr = "unexpected", overhead = { num_vertices = "num_vertices" }
775 });
776 let err = match parse_result {
777 Ok(_) => panic!("unexpected reduction attribute should be rejected"),
778 Err(err) => err,
779 };
780 assert!(err.to_string().contains("unknown attribute: extra"));
781 }
782
783 #[test]
784 fn reduction_accepts_overhead_attribute() {
785 let attrs: ReductionAttrs = syn::parse_quote! {
786 overhead = { n = "n" }
787 };
788 assert!(attrs.overhead.is_some());
789 }
790
791 #[test]
792 fn declare_variants_codegen_uses_required_dispatch_fields() {
793 let input: DeclareVariantsInput = syn::parse_quote! {
794 default opt Foo => "1",
795 };
796 let tokens = generate_declare_variants(&input).unwrap().to_string();
797 assert!(tokens.contains("factory :"));
798 assert!(tokens.contains("serialize_fn :"));
799 assert!(tokens.contains("solve_fn :"));
800 assert!(!tokens.contains("factory : None"));
801 assert!(!tokens.contains("serialize_fn : None"));
802 assert!(!tokens.contains("solve_fn : None"));
803 }
804}