aoc_toolbox_derive/
lib.rs

1use std::cell::RefCell;
2use std::collections::hash_map::Entry;
3use std::collections::HashMap;
4
5use proc_macro2::{Ident, TokenStream};
6use quote::{format_ident, quote, ToTokens};
7use syn::{parse_macro_input, AttributeArgs, ItemFn, Lit, NestedMeta};
8
9use inflector::Inflector;
10
11thread_local! {
12    static AOC_SOLVERS: RefCell<HashMap<String, HashMap<String, Option<String>>>> = RefCell::new(HashMap::new());
13}
14
15fn format_trait(day: &String, part: &String) -> Ident {
16    format_ident!("{}{}", day.to_title_case(), part.to_title_case())
17}
18
19fn format_module(day: &String, part: &String) -> Ident {
20    format_ident!("Mod{}{}", day.to_title_case(), part.to_title_case())
21}
22
23#[proc_macro_attribute]
24pub fn aoc_solver(
25    input: proc_macro::TokenStream,
26    annotated_item: proc_macro::TokenStream,
27) -> proc_macro::TokenStream {
28    let attributes = parse_macro_input!(input as AttributeArgs);
29    let attributes: Vec<String> = attributes
30        .iter()
31        .map(|a| match a {
32            NestedMeta::Lit(Lit::Str(s)) => s.value(),
33            _ => panic!("Attribute is not a string"),
34        })
35        .collect();
36    if attributes.len() != 2 && attributes.len() != 3 {
37        panic!("Number of attributes must be two or three");
38    }
39    let day = attributes[0].clone();
40    let part = attributes[1].clone();
41    let result = if attributes.len() == 3 {
42        Some(attributes[2].clone())
43    } else {
44        None
45    };
46
47    let trait_name = format_trait(&day, &part);
48    let module_name = format_module(&day, &part);
49    let func = parse_macro_input!(annotated_item as ItemFn);
50    let function = format_ident!("{}", func.sig.ident.to_string());
51    let solve_impl = quote! {
52        mod #module_name {
53            use super::*;
54            use crate::#trait_name;
55
56            impl<'a> #trait_name for aoc_toolbox::Aoc<'a> {
57                fn solve() -> String {
58                    #function(aoc_toolbox::utils::load_input(#day))
59                }
60            }
61        }
62    };
63
64    AOC_SOLVERS.with(|solvers| {
65        match solvers.borrow_mut().entry(day) {
66            Entry::Occupied(mut e) => {
67                if e.get().contains_key(&part) {
68                    panic!("Part \"{}\" for day \"{}\" exists already", part, e.key());
69                }
70                e.get_mut().insert(part, result);
71            }
72            Entry::Vacant(e) => {
73                e.insert(HashMap::new()).insert(part, result);
74            }
75        };
76    });
77
78    quote! {
79        #func
80
81        #solve_impl
82    }
83    .into_token_stream()
84    .into()
85}
86
87#[proc_macro]
88pub fn aoc_main(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
89    let year = parse_macro_input!(input as Lit);
90
91    AOC_SOLVERS.with(|solvers| {
92        let traits: Vec<TokenStream> = solvers
93            .borrow_mut()
94            .iter()
95            .flat_map(|(day, parts)| {
96                let ret: Vec<TokenStream> = parts
97                    .iter()
98                    .map(|(part, _result)| {
99                        let trait_name = format_trait(day, part);
100
101                        quote! {
102                            trait #trait_name {
103                                fn solve() -> String;
104                            }
105                        }
106                    })
107                    .collect();
108                ret
109            })
110            .collect();
111
112        let adders: Vec<TokenStream> = solvers
113            .borrow_mut()
114            .iter()
115            .flat_map(|(day, parts)| {
116                let ret: Vec<TokenStream> = parts
117                    .iter()
118                    .map(|(part, result)| {
119                        let trait_name = format_trait(day, part);
120
121                        if result.is_none() {
122                            return quote! {
123                                aoc.add_solver(#day, #part, None, <aoc_toolbox::Aoc as #trait_name>::solve);
124                            };
125                        }
126
127                        quote! {
128                            aoc.add_solver(#day, #part, Some(#result), <aoc_toolbox::Aoc as #trait_name>::solve);
129                        }
130                    })
131                    .collect();
132                ret
133            })
134            .collect();
135
136        quote! {
137            #( #traits )*
138
139            fn main() -> Result<(), Box<dyn std::error::Error>> {
140                use aoc_toolbox::{clap, Parser};
141
142                #[derive(aoc_toolbox::Parser, Debug)]
143                #[clap(about, long_about = None)]
144                struct Args {
145                    /// Name of the solver to run
146                    #[clap(index = 1, value_parser, default_value = "all")]
147                    solver: String,
148
149                    /// List all available solvers
150                    #[clap(short, long, value_parser, exclusive = true)]
151                    list: bool,
152
153                    /// List all available solvers
154                    #[clap(short = 'd', long, value_parser)]
155                    with_duration: bool,
156                }
157
158                let args = Args::parse();
159
160                let mut aoc = aoc_toolbox::Aoc::new( #year , args.with_duration);
161                #( #adders )*
162
163                if args.list {
164                    aoc.list();
165                    return Ok(());
166                }
167
168                aoc.run(args.solver)?;
169                Ok(())
170            }
171        }
172        .into()
173    })
174}