aoc_zen_runner_macros/
lib.rs

1use aggregate::{discover_mod_contents, AocSolutionsAggregation};
2use anyhow::Context;
3use cargo_metadata::MetadataCommand;
4use domain::{AocGeneratorData, AocSolverData};
5use parser::caseargs::AocCaseArgs;
6use parser::macroargs::AocMacroArgs;
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use quote::quote;
10use quote::ToTokens;
11use syn::parse_macro_input;
12use syn::Type;
13use syn::{Ident, ItemMod};
14use syn::{ItemConst, ItemFn};
15
16mod aggregate;
17mod domain;
18mod parser;
19mod partflag;
20
21// Flag macros ------------------------------------------------------------
22#[proc_macro_attribute]
23pub fn generator(_attr: TokenStream, item: TokenStream) -> TokenStream {
24    item
25}
26
27#[proc_macro_attribute]
28pub fn solver(_attr: TokenStream, item: TokenStream) -> TokenStream {
29    item
30}
31
32#[proc_macro_attribute]
33pub fn solution(_attr: TokenStream, item: TokenStream) -> TokenStream {
34    item
35}
36
37#[proc_macro_attribute]
38pub fn flag(_attr: TokenStream, item: TokenStream) -> TokenStream {
39    let item = parse_macro_input!(item as ItemFn);
40    println!("*** Flagged item in {}:\n{:#?}", file!(), &item);
41    proc_macro::TokenStream::from(item.into_token_stream())
42}
43
44// Test-related macros ----------------------------------------------------
45#[proc_macro_attribute]
46pub fn aoc_case(attr: TokenStream, item: TokenStream) -> TokenStream {
47    let args = parse_macro_input!(attr as AocCaseArgs);
48    let exp_p1 = args.expected_p1;
49    let p2 = args.expected_p2;
50    let input = parse_macro_input!(item as ItemConst);
51    let in_name = &input.ident;
52    let slug_str: String = format!("aoc_test_{}", &input.ident.to_string().to_lowercase());
53    let slug = Ident::new(&slug_str, input.ident.span());
54
55    if let Some(exp_p2) = p2 {
56        quote! {
57            #input
58
59            #[test]
60            fn #slug() {
61                let expected_p1 = #exp_p1;
62                let expected_p2 = #exp_p2;
63
64                for (idx, p1) in super::_gen_lists::P1_SOLUTIONS.iter().enumerate() {
65                    let test_label = super::_gen_lists::P1_LABELS[idx];
66                    assert_eq!(expected_p1, p1(#in_name), "Part 1 Test failed solution: {}", test_label);
67                }
68                for (idx, p2) in super::_gen_lists::P2_SOLUTIONS.iter().enumerate() {
69                    let test_label = super::_gen_lists::P1_LABELS[idx];
70                    assert_eq!(expected_p2, p2(#in_name), "Part 2 Test failed solution: {}", test_label);
71                }
72            }
73        }
74        .into()
75    } else {
76        quote! {
77            #input
78
79            #[test]
80            fn #slug() {
81                let expected_p1 = #exp_p1;
82
83                for p1 in _gen_lists::P1_SOLUTIONS {
84                    assert_eq!(expected_p1, p1(#in_name));
85                }
86            }
87        }
88        .into()
89    }
90}
91
92// AOC --------------------------------------------------------------------
93
94#[proc_macro_attribute]
95pub fn aoc(args: TokenStream, item: TokenStream) -> TokenStream {
96    let item = parse_macro_input!(item as ItemMod);
97    let mod_name = &item.ident;
98
99    let macro_args = parse_macro_input!(args as AocMacroArgs);
100
101    let agg_result = match discover_mod_contents(&item) {
102        Ok(data) => data,
103        Err(e) => {
104            return e.into_compile_error().into();
105        }
106    };
107
108    let mod_extension = gen_solution_lists_mod(&agg_result, mod_name);
109
110    let mut item_ts = item.into_token_stream();
111
112    item_ts.extend(mod_extension);
113    item_ts.extend(gen_quick_microbench());
114    item_ts.extend(gen_slow_microbench());
115    item_ts.extend(gen_main(macro_args.year_num, macro_args.day_num));
116
117    item_ts.into()
118}
119
120fn gen_idents_from_solns<'a>(
121    part_indicator: &str,
122    solns: impl Iterator<Item = (&'a AocGeneratorData<'a>, &'a AocSolverData<'a>)>,
123) -> Vec<(&'a Ident, &'a Ident, Ident)> {
124    solns
125        .map(|(gen, sol)| {
126            let g_ident = &gen.source.sig.ident;
127            let g_slug = &gen.display_slug;
128            let s_ident = &sol.source.sig.ident;
129            let s_slug = &sol.display_slug;
130            let f_ident = Ident::new(
131                format!("f_{}_{}_{}", part_indicator, g_slug, s_slug).as_str(),
132                Span::call_site(),
133            );
134            (g_ident, s_ident, f_ident)
135        })
136        .collect()
137}
138
139fn gen_composed_labels<'a>(
140    solns: impl Iterator<Item = (&'a AocGeneratorData<'a>, &'a AocSolverData<'a>)>,
141) -> Vec<String> {
142    solns
143        .map(|(gen, sol)| {
144            let g_slug = &gen.display_slug.to_string();
145            let s_slug = &sol.display_slug.to_string();
146            let label = format!("{} / {}", g_slug, s_slug);
147            label
148        })
149        .collect()
150}
151
152fn gen_solution_lists_mod(agg_result: &AocSolutionsAggregation, mod_name: &Ident) -> proc_macro2::TokenStream {
153    let p1_composed_data: Vec<(&Ident, &Ident, Ident)> = gen_idents_from_solns("p1", agg_result.p1_composed_solns());
154
155    let p1_fn_idents: Vec<&Ident> = p1_composed_data.iter().map(|(_, _, f)| f).collect();
156    let p1_gen_idents: Vec<&Ident> = p1_composed_data.iter().map(|(g, _, _)| *g).collect();
157    let p1_solver_idents: Vec<&Ident> = p1_composed_data.iter().map(|(_, s, _)| *s).collect();
158
159    let mut p1_labels = gen_composed_labels(agg_result.p1_composed_solns());
160    let mut p1_impls = p1_fn_idents.clone();
161    p1_impls.extend(agg_result.p1_user_solns().map(|sln| &sln.source.sig.ident));
162    p1_labels.extend(agg_result.p1_user_solns().map(|sln| sln.display_slug.to_string()));
163    let p1_ret = agg_result
164        .p1_result_type
165        .unwrap_or(&Type::Verbatim(quote!(String)))
166        .to_owned();
167    let p1_len = p1_impls.len();
168
169    let p2_data: Vec<(&Ident, &Ident, Ident)> = gen_idents_from_solns("p2", agg_result.p2_composed_solns());
170
171    let p2_fn_idents: Vec<&Ident> = p2_data.iter().map(|(_, _, f)| f).collect();
172    let p2_gen_idents: Vec<&Ident> = p2_data.iter().map(|(g, _, _)| *g).collect();
173    let p2_solver_idents: Vec<&Ident> = p2_data.iter().map(|(_, s, _)| *s).collect();
174
175    let mut p2_labels = gen_composed_labels(agg_result.p2_composed_solns());
176    let mut p2_impls = p2_fn_idents.clone();
177    p2_labels.extend(agg_result.p2_user_solns().map(|sln| sln.display_slug.to_string()));
178    p2_impls.extend(agg_result.p2_user_solns().map(|sln| &sln.source.sig.ident));
179    let p2_ret = agg_result
180        .p2_result_type
181        .unwrap_or(&Type::Verbatim(quote!(String)))
182        .to_owned();
183    let p2_len = p2_impls.len();
184
185    quote! {
186        mod _gen_lists {
187            use super::#mod_name::*;
188            use std::fmt::Display;
189
190            pub const P1_LABELS: [&str; #p1_len] = [ #(#p1_labels),* ];
191            pub const P2_LABELS: [&str; #p2_len] = [ #(#p2_labels),* ];
192
193            #(pub fn #p1_fn_idents(input: &str) -> #p1_ret { #p1_solver_idents(#p1_gen_idents(input)) })*
194            #(pub fn #p2_fn_idents(input: &str) -> #p2_ret { #p2_solver_idents(#p2_gen_idents(input)) })*
195            pub const P1_SOLUTIONS: [for<'r> fn(&'r str) -> #p1_ret; #p1_len] = [ #(#p1_impls),* ];
196            pub const P2_SOLUTIONS: [for<'r> fn(&'r str) -> #p2_ret; #p2_len] = [ #(#p2_impls),* ];
197        }
198    }
199}
200
201fn gen_main(year_num: u32, day_num: u32) -> proc_macro2::TokenStream {
202    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
203    let meta_res = MetadataCommand::new()
204        .current_dir(manifest_dir)
205        .exec()
206        .context("Could not use cargo metadata to find inputs directory");
207    if meta_res.is_err() {
208        let err_str = meta_res.err().unwrap().to_string();
209        return quote! { compile_error!(#err_str) };
210    }
211    let meta = meta_res.unwrap();
212    let mut input_path = meta.workspace_root;
213    input_path.push("input");
214    input_path.push(year_num.to_string());
215    input_path.push(format!("{}.txt", day_num));
216
217    let input_file = input_path.as_str();
218
219    let input_blob = if input_path.exists() {
220        quote! { include_str!(#input_file) }
221    } else {
222        quote! { "" }
223    };
224
225    quote! {
226        const AOC_RAW_INPUT: &str = #input_blob;
227
228        #[cfg(not(test))]
229        fn main() {
230            println!("## AOC {}, Day {} ----------", #year_num, #day_num);
231            if AOC_RAW_INPUT.len() == 0 {
232                println!("No input found.");
233                return;
234            }
235            let p1len = _gen_lists::P1_SOLUTIONS.len();
236            let p2len = _gen_lists::P2_SOLUTIONS.len();
237            if p1len > 0 {
238                let solution_p1 = _gen_lists::P1_SOLUTIONS[0](AOC_RAW_INPUT);
239                let label = _gen_lists::P1_LABELS[0];
240                println!("Part 1, {} Solution: {}", label, solution_p1);
241                if p1len > 1 {
242                    println!("Checking alternative Part 1 solutions...");
243                    for (idx, solver) in _gen_lists::P1_SOLUTIONS.iter().enumerate().skip(1) {
244                        let solution = solver(AOC_RAW_INPUT);
245                        if solution == solution_p1 {
246                            print!("✅");
247                        } else {
248                            println!("\nSolver {} found {}", _gen_lists::P1_LABELS[idx], solution);
249                        }
250                    }
251                    println!("\n");
252                }
253            }
254            if p2len > 0 {
255                let solution_p2 = _gen_lists::P2_SOLUTIONS[0](AOC_RAW_INPUT);
256                let label = _gen_lists::P2_LABELS[0];
257                println!("Part 2, {} Solution: {}", label, solution_p2);
258                if p2len > 1 {
259                    println!("Checking alternative Part 2 solutions...");
260                    for (idx, solver) in _gen_lists::P2_SOLUTIONS.iter().enumerate().skip(1) {
261                        let solution = solver(AOC_RAW_INPUT);
262                        if solution == solution_p2 {
263                            print!("✅");
264                        } else {
265                            println!("\nSolver {} found {}", _gen_lists::P2_LABELS[idx], solution);
266                        }
267                    }
268                    println!("\n");
269                }
270            }
271
272            println!(" ---- Quick Benches ----- ");
273            bench_quick::run_benches();
274        }
275    }
276}
277
278fn gen_quick_microbench() -> proc_macro2::TokenStream {
279    quote! {
280        mod bench_quick {
281            use std::time::Duration;
282            use microbench as mb;
283
284            pub fn run_benches() {
285                let mb_opts = mb::Options::default().time(Duration::from_secs(1));
286
287                for (idx, solver) in super::_gen_lists::P1_SOLUTIONS.iter().enumerate() {
288                    let label = format!("Part 1 - {}", super::_gen_lists::P1_LABELS[idx]);
289                    mb::bench(&mb_opts, &label, || solver(mb::retain(super::AOC_RAW_INPUT)))
290                }
291                for (idx, solver) in super::_gen_lists::P2_SOLUTIONS.iter().enumerate() {
292                    let label = format!("Part 2 - {}", super::_gen_lists::P2_LABELS[idx]);
293                    mb::bench(&mb_opts, &label, || solver( mb::retain(super::AOC_RAW_INPUT)))
294                }
295            }
296        }
297    }
298}
299
300fn gen_slow_microbench() -> proc_macro2::TokenStream {
301    quote! {
302        use pprof::criterion::{PProfProfiler, Output};
303        use pprof::flamegraph::Options as FGOptions;
304        use criterion::{Criterion, criterion_group, criterion_main, black_box};
305
306        fn bench(c: &mut Criterion) {
307            let mut group1 = c.benchmark_group("Part 1");
308            for (idx, solver_fn) in _gen_lists::P1_SOLUTIONS.iter().enumerate() {
309                let label = _gen_lists::P1_LABELS[idx];
310                group1.bench_function(label, |b| b.iter(|| solver_fn(black_box(AOC_RAW_INPUT))));
311            }
312            group1.finish();
313            let mut group2 = c.benchmark_group("Part 2");
314            for (idx, solver_fn) in _gen_lists::P2_SOLUTIONS.iter().enumerate() {
315                let label = _gen_lists::P2_LABELS[idx];
316                group2.bench_function(label, |b| b.iter(|| solver_fn(black_box(AOC_RAW_INPUT))));
317            }
318            group2.finish();
319        }
320
321        criterion_group! {
322            name = benches;
323            config = Criterion::default()
324                .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)))
325                .with_output_color(true)
326                .with_plots();
327            targets = bench
328        }
329
330        // We need this call to happen only when benchmarking. This is the closest we can get.
331        #[cfg(test)]
332        criterion_main!(benches);
333    }
334}