care_macro/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro2::{Ident, Span, TokenStream};
4use quote::quote;
5use syn::{spanned::Spanned, Block, Expr, ItemFn, ItemStatic, Stmt};
6
7const STATE_VAR_SEPARATOR: &str = "\n\n\n";
8
9#[rustfmt::skip]
10fn dereference_state_vars(expr: &mut Expr, vars: &HashSet<String>) {
11    match expr {
12        Expr::Await(syn::ExprAwait { base: expr, .. }) |
13        Expr::Cast(syn::ExprCast { expr, .. }) |
14        Expr::Field(syn::ExprField { base: expr, .. }) |
15        Expr::Group(syn::ExprGroup { expr, .. }) |
16        Expr::Let(syn::ExprLet { expr, .. }) |
17        Expr::Paren(syn::ExprParen { expr, .. }) |
18        Expr::Range(syn::ExprRange { start: Some(expr), end: None, .. } |
19                    syn::ExprRange { start: None, end: Some(expr), .. }) |
20        Expr::Reference(syn::ExprReference { expr, .. }) |
21        Expr::Return(syn::ExprReturn { expr: Some(expr), .. }) |
22        Expr::Try(syn::ExprTry { expr, .. }) |
23        Expr::Unary(syn::ExprUnary { expr, .. }) |
24        Expr::Break(syn::ExprBreak { expr: Some(expr), .. }) |
25        Expr::Yield(syn::ExprYield { expr: Some(expr), .. }) => {
26            dereference_state_vars(expr, vars);
27        }
28        Expr::Assign(syn::ExprAssign { left: expr, right: expr2, .. }) |
29        Expr::Index(syn::ExprIndex { expr, index: expr2, .. }) |
30        Expr::Range(syn::ExprRange { start: Some(expr), end: Some(expr2), .. }) |
31        Expr::Binary(syn::ExprBinary { left: expr, right: expr2, .. }) |
32        Expr::Repeat(syn::ExprRepeat { expr, len: expr2, .. }) => {
33            dereference_state_vars(expr, vars);
34            dereference_state_vars(expr2, vars);
35        }
36        Expr::Array(syn::ExprArray { elems: exprs, .. }) |
37        Expr::Tuple(syn::ExprTuple { elems: exprs, .. }) => {
38            for expr in exprs {
39                dereference_state_vars(expr, vars);
40            }
41        }
42        Expr::Call(syn::ExprCall { func: expr, args: exprs, .. }) |
43        Expr::MethodCall(syn::ExprMethodCall { receiver: expr, args: exprs, .. }) => {
44            dereference_state_vars(expr, vars);
45            for expr in exprs {
46                dereference_state_vars(expr, vars);
47            }
48        }
49        Expr::Async(syn::ExprAsync { block: Block { stmts, .. }, ..}) |
50        Expr::Block(syn::ExprBlock { block: Block { stmts, .. }, ..}) |
51        Expr::Loop(syn::ExprLoop { body: Block { stmts, .. }, .. }) |
52        Expr::TryBlock(syn::ExprTryBlock { block: Block { stmts, .. }, .. }) |
53        Expr::Unsafe(syn::ExprUnsafe { block: Block { stmts, .. }, ..}) => {
54            for stmt in stmts {
55                dereference_state_vars_stmt(stmt, vars);
56            }
57        }
58        Expr::ForLoop(syn::ExprForLoop { expr, body: Block { stmts, .. }, .. }) |
59        Expr::While(syn::ExprWhile { cond: expr, body: Block { stmts, ..}, .. }) => {
60            dereference_state_vars(expr, vars);
61            for stmt in stmts {
62                dereference_state_vars_stmt(stmt, vars);
63            }
64        }
65        Expr::Match(syn::ExprMatch { expr, arms, .. }) => {
66            dereference_state_vars(expr, vars);
67            for arm in arms {
68                dereference_state_vars(&mut arm.body, vars);
69            }
70        },
71        Expr::Struct(syn::ExprStruct { fields, .. }) => {
72            for field in fields {
73                dereference_state_vars(&mut field.expr, vars);
74            }
75        }
76        Expr::If(syn::ExprIf { cond: expr, then_branch: block, else_branch, .. }) => {
77            dereference_state_vars(expr, vars);
78            for stmt in &mut block.stmts {
79                dereference_state_vars_stmt(stmt, vars)
80            }
81            if let Some(else_branch) = else_branch {
82                dereference_state_vars(&mut else_branch.1, vars);
83            }
84        }
85        Expr::Path(syn::ExprPath { path: syn::Path { leading_colon: None, segments }, .. }) => {
86            if segments.len() == 1 {
87                if let Some(seg) = segments.first_mut() {
88                    if vars.contains(&seg.ident.to_string()) {
89                        *expr = Expr::Paren(syn::ExprParen {
90                            attrs: Vec::new(),
91                            paren_token: syn::token::Paren(seg.span()),
92                            expr: Box::new(Expr::Unary(syn::ExprUnary {
93                                attrs: Vec::new(),
94                                op: syn::UnOp::Deref(syn::token::Star(seg.span())),
95                                expr: Box::new(expr.clone()),
96                            })),
97                        });
98                    }
99                }
100            }
101        }
102        _ => {},
103    }
104}
105
106#[rustfmt::skip]
107fn dereference_state_vars_stmt(stmt: &mut Stmt, vars: &HashSet<String>) {
108    match stmt {
109            syn::Stmt::Local(syn::Local {
110                init: Some(init), ..
111            }) => dereference_state_vars(&mut init.expr, vars),
112            syn::Stmt::Expr(expr, _) => dereference_state_vars(expr, vars),
113            _ => {}
114    }
115}
116
117fn care_macro_shared(func: proc_macro::TokenStream, name: &str) -> proc_macro::TokenStream {
118    let func = TokenStream::from(func);
119    let input: ItemFn = match syn::parse2(func.clone()) {
120        Ok(i) => i,
121        Err(e) => return token_stream_with_error(func, e),
122    };
123    let state_params = std::env::var("_CARE_INTERNAL_STATE_PARAMS")
124        .ok()
125        .unwrap_or_default();
126    let func_name = format!("care_{}", input.sig.ident);
127    let var_name = format!("_CARE_INTERNAL_{name}");
128    if std::env::var(&var_name).is_ok() {
129        return func.into();
130    }
131    std::env::set_var(&var_name, func_name.clone());
132
133    let state_vars: HashSet<_> = state_params
134        .split(STATE_VAR_SEPARATOR)
135        .filter(|s| !s.is_empty())
136        .map(|p| p.split_once(':').unwrap().0.trim().to_string())
137        .collect();
138
139    let state_params = if input.sig.inputs.is_empty() {
140        state_params.trim_start_matches(STATE_VAR_SEPARATOR)
141    } else {
142        &state_params
143    };
144    let new_params: TokenStream = state_params
145        .replace(STATE_VAR_SEPARATOR, ",")
146        .parse()
147        .unwrap();
148    let asyncness = input.sig.asyncness;
149    let ident = Ident::new(&func_name, input.sig.ident.span());
150    let generics = input.sig.generics;
151    let inputs = input.sig.inputs;
152    let output = input.sig.output;
153    let mut block = input.block;
154    for stmt in &mut block.stmts {
155        dereference_state_vars_stmt(stmt, &state_vars);
156    }
157    let result = quote! {
158        #asyncness fn #ident #generics (#inputs #new_params) #output
159        #block
160    };
161
162    result.into()
163}
164
165#[proc_macro_attribute]
166pub fn care_state(
167    _attr: proc_macro::TokenStream,
168    def: proc_macro::TokenStream,
169) -> proc_macro::TokenStream {
170    let def = TokenStream::from(def);
171    let item: ItemStatic = match syn::parse2::<ItemStatic>(def.clone()) {
172        Ok(i) => i,
173        Err(e) => return token_stream_with_error(def, e),
174    };
175    let ident = item.ident.clone();
176    let ident_state = Ident::new(&(item.ident.to_string() + "_state"), item.ident.span());
177    let ty = item.ty;
178    let expr = item.expr;
179    std::env::set_var(
180        "_CARE_INTERNAL_STATE_DEFS",
181        std::env::var("_CARE_INTERNAL_STATE_DEFS")
182            .ok()
183            .unwrap_or_default()
184            + &quote! { let mut #ident_state: #ty = #expr; }.to_string(),
185    );
186    std::env::set_var(
187        "_CARE_INTERNAL_STATE_PARAMS",
188        std::env::var("_CARE_INTERNAL_STATE_PARAMS")
189            .ok()
190            .unwrap_or_default()
191            + STATE_VAR_SEPARATOR
192            + &quote! { #ident: &mut #ty }.to_string(),
193    );
194    std::env::set_var(
195        "_CARE_INTERNAL_STATE_ITEMS",
196        std::env::var("_CARE_INTERNAL_STATE_ITEMS")
197            .ok()
198            .unwrap_or_default()
199            + STATE_VAR_SEPARATOR
200            + &quote! { #ident_state }.to_string(),
201    );
202    proc_macro::TokenStream::new()
203}
204
205#[proc_macro_attribute]
206pub fn care_init(
207    _attr: proc_macro::TokenStream,
208    func: proc_macro::TokenStream,
209) -> proc_macro::TokenStream {
210    care_macro_shared(func, "INIT")
211}
212
213#[proc_macro_attribute]
214pub fn care_update(
215    _attr: proc_macro::TokenStream,
216    func: proc_macro::TokenStream,
217) -> proc_macro::TokenStream {
218    care_macro_shared(func, "UPDATE")
219}
220
221#[proc_macro_attribute]
222pub fn care_draw(
223    _attr: proc_macro::TokenStream,
224    func: proc_macro::TokenStream,
225) -> proc_macro::TokenStream {
226    care_macro_shared(func, "DRAW")
227}
228
229#[proc_macro_attribute]
230pub fn care_async_main(
231    _attr: proc_macro::TokenStream,
232    func: proc_macro::TokenStream,
233) -> proc_macro::TokenStream {
234    care_macro_shared(func, "ASYNC_MAIN")
235}
236
237#[proc_macro]
238pub fn care_main(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
239    // TODO: Config
240    let attr = TokenStream::from(attr);
241
242    let conf: Expr = match syn::parse2(attr.clone()) {
243        Ok(i) => i,
244        Err(e) => return token_stream_with_error(attr, e),
245    };
246
247    let init_fn = std::env::var("_CARE_INTERNAL_INIT").ok();
248    let update_fn = std::env::var("_CARE_INTERNAL_UPDATE").ok();
249    let draw_fn = std::env::var("_CARE_INTERNAL_DRAW").ok();
250    let async_main_fn = std::env::var("_CARE_INTERNAL_ASYNC_MAIN").ok();
251
252    let state_lets: TokenStream = std::env::var("_CARE_INTERNAL_STATE_DEFS")
253        .ok()
254        .map(|st| st.parse().unwrap())
255        .unwrap_or_default();
256
257    let additional_params: TokenStream = std::env::var("_CARE_INTERNAL_STATE_ITEMS")
258        .ok()
259        .map(|st| {
260            st.trim_start_matches(STATE_VAR_SEPARATOR)
261                .replace(STATE_VAR_SEPARATOR, ",")
262                .parse()
263                .unwrap()
264        })
265        .unwrap_or_default();
266
267    if (init_fn.is_some() || update_fn.is_some() || draw_fn.is_some()) && async_main_fn.is_some() {
268        panic!("You cannot define a #[care::async] function along with any other #[care::init], #[care::update] or #[care::draw] function.");
269    }
270    if let Some(async_main_fn) = async_main_fn {
271        let fn_ident = Ident::new(&async_main_fn, Span::call_site());
272        return quote! {
273            fn main() {
274                let config = { #conf };
275                ::care::window::open(env!("CARGO_CRATE_NAME"));
276                #state_lets
277                ::care::event::main_async(#fn_ident(#additional_params));
278            }
279        }
280        .into();
281    }
282
283    let init_call = maybe_call_function(init_fn, quote! {app_args, #additional_params});
284    let update_call = maybe_call_function(update_fn, quote! {delta_time, #additional_params});
285    let draw_call = maybe_call_function(draw_fn, quote! {#additional_params});
286
287    let result = quote! {
288        fn main() {
289            let config = { #conf };
290            ::care::window::open(env!("CARGO_CRATE_NAME"));
291            ::care::event::main_loop(move || {
292                #state_lets
293                let app_args: Vec<_> = ::std::env::args().collect();
294                #init_call
295                (::std::time::Instant::now(), (#additional_params))
296            }, move |(last_time, (#additional_params))| {
297                let next_time = ::std::time::Instant::now();
298                let delta_time = next_time.duration_since(*last_time).as_secs_f64() as ::care::math::Fl;
299                *last_time = next_time;
300                #update_call
301                #draw_call
302            });
303        }
304    };
305
306    result.into()
307}
308
309fn maybe_call_function(fn_name: Option<String>, params: TokenStream) -> TokenStream {
310    if let Some(fn_name) = fn_name {
311        let fn_ident = Ident::new(&fn_name, Span::call_site());
312        quote! {
313            #fn_ident(#params);
314        }
315    } else {
316        quote! {}
317    }
318}
319
320// From tokio (https://github.com/tokio-rs/tokio/blob/tokio-1.35.1/tokio-macros/src/entry.rs#L416)
321fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> proc_macro::TokenStream {
322    tokens.extend(error.into_compile_error());
323    tokens.into()
324}