checked_rs_macro_impl/params/
number_arg.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use rhai::{plugin::*, Engine};
4use syn::{parse::Parse, parse_quote, spanned::Spanned};
5
6use super::{MinOrMax, NumberKind, NumberValue};
7
8/// Represents a numerical argument obtained from parsing proc macro input tokens.
9///
10/// `NumberArg` can represent three kinds of numerical data:
11/// - A literal value directly specified in the input.
12/// - A primitive number constant, which is a predefined constant value.
13/// - An expression that is evaluated at compile time using the Rhai scripting engine.
14///
15/// This enum is designed to work within the context of procedural macros, where it helps
16/// in interpreting and manipulating numerical inputs. It provides flexibility in handling
17/// both simple numerical values and complex expressions that need to be evaluated.
18///
19/// The `NumberArg` can be converted into a `NumberValue` type through the `into_value` method.
20/// This conversion allows for direct operations on the underlying numerical value, supporting
21/// various primitive number types (e.g., `u8`, `u16`, `i32`, `usize`, etc.).
22///
23/// Example usage:
24/// ```
25/// let num_arg: NumberArg = parse_macro_input!(input as NumberArg);
26/// let num_value: NumberValue = num_arg.into_value(NumberKind::U32);
27/// // Now `num_value` can be used for further numerical operations.
28/// ```
29///
30/// Note: The `into_value` method panics if the conversion fails, so it should be used
31/// within a context where panicking is acceptable or handled appropriately.
32#[derive(Clone)]
33pub enum NumberArg {
34    Literal(syn::LitInt),
35    ConstExpr {
36        const_token: syn::Token![const],
37        kind: NumberKind,
38        block: syn::Block,
39    },
40    Constant {
41        kind: NumberKind,
42        dbl_colon: syn::Token![::],
43        ident: MinOrMax,
44    },
45}
46
47impl Parse for NumberArg {
48    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
49        if input.peek(syn::LitInt) {
50            Ok(Self::Literal(input.parse()?))
51        } else if input.peek(syn::Token![const]) {
52            Ok(Self::ConstExpr {
53                const_token: input.parse()?,
54                kind: input.parse()?,
55                block: input.parse()?,
56            })
57        } else {
58            let kind = input.parse()?;
59            let dbl_colon = input.parse()?;
60            let ident: MinOrMax = input.parse()?;
61
62            Ok(Self::Constant {
63                kind,
64                dbl_colon,
65                ident,
66            })
67        }
68    }
69}
70
71impl ToTokens for NumberArg {
72    fn to_tokens(&self, tokens: &mut TokenStream) {
73        match self {
74            Self::Literal(lit) => lit.to_tokens(tokens),
75            Self::ConstExpr { kind, .. } => tokens.extend(self.into_literal_as_tokens(*kind)),
76            Self::Constant {
77                kind,
78                dbl_colon,
79                ident,
80            } => {
81                let kind = kind.to_token_stream();
82                let dbl_colon = dbl_colon.to_token_stream();
83                let ident = ident.to_token_stream();
84
85                tokens.extend(quote! {
86                    #kind #dbl_colon #ident
87                });
88            }
89        }
90    }
91}
92
93impl std::fmt::Debug for NumberArg {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Self::Literal(lit) => write!(f, "{}", lit.to_token_stream().to_string()),
97            Self::ConstExpr { kind, block, .. } => {
98                write!(f, "const {} {}", kind, block.to_token_stream().to_string())
99            }
100            Self::Constant { kind, ident, .. } => {
101                write!(f, "{}::{}", kind, ident.to_token_stream().to_string())
102            }
103        }
104    }
105}
106
107impl std::fmt::Display for NumberArg {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        std::fmt::Debug::fmt(self, f)
110    }
111}
112
113impl NumberArg {
114    pub const LIMITS_INIT: (Option<Self>, Option<Self>) = (None, None);
115
116    /// Create a new `NumberArg` instance for a `NumberKind`'s minimum value.
117    pub fn new_min_constant(kind: NumberKind) -> Self {
118        Self::Constant {
119            kind,
120            dbl_colon: parse_quote!(::),
121            ident: MinOrMax::Min(parse_quote!(MIN)),
122        }
123    }
124
125    /// Create a new `NumberArg` instance for a `NumberKind`'s maximum value.
126    pub fn new_max_constant(kind: NumberKind) -> Self {
127        Self::Constant {
128            kind,
129            dbl_colon: parse_quote!(::),
130            ident: MinOrMax::Max(parse_quote!(MAX)),
131        }
132    }
133
134    /// Helper method to create a `NumberArg` instance from a `syn::Expr`. Mostly useful for testing.
135    pub fn from_expr(expr: &syn::Expr) -> Self {
136        parse_quote!(#expr)
137    }
138
139    /// Helper method to create a `NumberArg` instance from a `syn::LitInt`. Mostly useful for testing.
140    pub fn from_lit(lit: &syn::LitInt) -> Self {
141        Self::Literal(lit.clone())
142    }
143
144    /// Helper method to create a pair of `NumberArg` instances from a `syn::ExprRange`. Mostly useful for testing.
145    pub fn from_range_expr(kind: NumberKind, expr: &syn::ExprRange) -> (Self, Self) {
146        let start: Option<NumberArg> = expr.start.as_ref().map(|expr| parse_quote!(#expr));
147        let end: Option<NumberArg> = expr.end.as_ref().map(|expr| parse_quote!(#expr));
148
149        (
150            start.unwrap_or_else(|| NumberArg::new_min_constant(kind)),
151            end.unwrap_or_else(|| NumberArg::new_max_constant(kind)),
152        )
153    }
154
155    /// Create a new `NumberArg` instance from the lesser of two inputs.
156    pub fn min(&self, other: &Self, kind: NumberKind) -> Self {
157        let a = self.into_value(kind);
158        let b = other.into_value(kind);
159
160        if a <= b {
161            self.clone()
162        } else {
163            other.clone()
164        }
165    }
166
167    /// Create a new `NumberArg` instance from the greater of two inputs.
168    pub fn max(&self, other: &Self, kind: NumberKind) -> Self {
169        let a = self.into_value(kind);
170        let b = other.into_value(kind);
171
172        if a >= b {
173            self.clone()
174        } else {
175            other.clone()
176        }
177    }
178
179    /// Convert the `NumberArg` into a `NumberValue` instance.
180    pub fn into_value(&self, kind: NumberKind) -> NumberValue {
181        match kind {
182            NumberKind::U8 => NumberValue::U8(match self.base10_parse() {
183                Ok(n) => n,
184                Err(e) => panic!("{}", e.to_string()),
185            }),
186            NumberKind::U16 => NumberValue::U16(match self.base10_parse() {
187                Ok(n) => n,
188                Err(e) => panic!("{}", e.to_string()),
189            }),
190            NumberKind::U32 => NumberValue::U32(match self.base10_parse() {
191                Ok(n) => n,
192                Err(e) => panic!("{}", e.to_string()),
193            }),
194            NumberKind::U64 => NumberValue::U64(match self.base10_parse() {
195                Ok(n) => n,
196                Err(e) => panic!("{}", e.to_string()),
197            }),
198            NumberKind::U128 => NumberValue::U128(match self.base10_parse() {
199                Ok(n) => n,
200                Err(e) => panic!("{}", e.to_string()),
201            }),
202            NumberKind::USize => NumberValue::USize(match self.base10_parse() {
203                Ok(n) => n,
204                Err(e) => panic!("{}", e.to_string()),
205            }),
206            NumberKind::I8 => NumberValue::I8(match self.base10_parse() {
207                Ok(n) => n,
208                Err(e) => panic!("{}", e.to_string()),
209            }),
210            NumberKind::I16 => NumberValue::I16(match self.base10_parse() {
211                Ok(n) => n,
212                Err(e) => panic!("{}", e.to_string()),
213            }),
214            NumberKind::I32 => NumberValue::I32(match self.base10_parse() {
215                Ok(n) => n,
216                Err(e) => panic!("{}", e.to_string()),
217            }),
218            NumberKind::I64 => NumberValue::I64(match self.base10_parse() {
219                Ok(n) => n,
220                Err(e) => panic!("{}", e.to_string()),
221            }),
222            NumberKind::I128 => NumberValue::I128(match self.base10_parse() {
223                Ok(n) => n,
224                Err(e) => panic!("{}", e.to_string()),
225            }),
226            NumberKind::ISize => NumberValue::ISize(match self.base10_parse() {
227                Ok(n) => n,
228                Err(e) => panic!("{}", e.to_string()),
229            }),
230        }
231    }
232
233    /// Output the value as a bare literal number in a token stream.
234    pub fn into_literal_as_tokens(&self, kind: NumberKind) -> TokenStream {
235        self.into_value(kind).into_token_stream()
236    }
237
238    /// Parse the value as a base 10 number.
239    pub fn base10_parse<N>(&self) -> syn::Result<N>
240    where
241        N: std::str::FromStr,
242        N::Err: std::fmt::Display,
243    {
244        match self {
245            Self::Literal(lit) => lit.base10_parse::<N>(),
246            Self::ConstExpr { kind, block, .. } => {
247                match eval_const_expr(kind, block)?.to_string().parse() {
248                    Ok(n) => Ok(n),
249                    Err(e) => Err(syn::Error::new(block.span(), e)),
250                }
251            }
252            Self::Constant {
253                kind,
254                dbl_colon: _,
255                ident,
256            } => {
257                let n = match ident {
258                    MinOrMax::Min(..) => match kind {
259                        NumberKind::U8 => u8::MIN.to_string(),
260                        NumberKind::U16 => u16::MIN.to_string(),
261                        NumberKind::U32 => u32::MIN.to_string(),
262                        NumberKind::U64 => u64::MIN.to_string(),
263                        NumberKind::U128 => u128::MIN.to_string(),
264                        NumberKind::USize => usize::MIN.to_string(),
265                        NumberKind::I8 => i8::MIN.to_string(),
266                        NumberKind::I16 => i16::MIN.to_string(),
267                        NumberKind::I32 => i32::MIN.to_string(),
268                        NumberKind::I64 => i64::MIN.to_string(),
269                        NumberKind::I128 => i128::MIN.to_string(),
270                        NumberKind::ISize => isize::MIN.to_string(),
271                    },
272                    MinOrMax::Max(..) => match kind {
273                        NumberKind::U8 => u8::MAX.to_string(),
274                        NumberKind::U16 => u16::MAX.to_string(),
275                        NumberKind::U32 => u32::MAX.to_string(),
276                        NumberKind::U64 => u64::MAX.to_string(),
277                        NumberKind::U128 => u128::MAX.to_string(),
278                        NumberKind::USize => usize::MAX.to_string(),
279                        NumberKind::I8 => i8::MAX.to_string(),
280                        NumberKind::I16 => i16::MAX.to_string(),
281                        NumberKind::I32 => i32::MAX.to_string(),
282                        NumberKind::I64 => i64::MAX.to_string(),
283                        NumberKind::I128 => i128::MAX.to_string(),
284                        NumberKind::ISize => isize::MAX.to_string(),
285                    },
286                };
287
288                match str::parse(&n) {
289                    Ok(n) => Ok(n),
290                    Err(e) => Err(syn::Error::new(ident.span(), e)),
291                }
292            }
293        }
294    }
295}
296
297macro_rules! use_rhai_int {
298    (
299        declare {$($ty:ident),* $(,)?}
300    ) => {
301        paste::paste! {
302            $(
303                #[allow(dead_code)]
304                #[export_module]
305                mod [<rhai_ $ty>] {
306                    #[allow(unused_imports)]
307                    pub use std::$ty::*;
308                }
309            )*
310        }
311    };
312    (
313        register[$engine:ident] {$($ty:ident),* $(,)?}
314    ) => {
315        paste::paste! {
316            $(
317                let [< $ty _module >] = exported_module!([< rhai_ $ty >]);
318                $engine.register_static_module(stringify!($ty), [< $ty _module >].into());
319            )*
320        }
321    };
322}
323
324use_rhai_int! {
325    declare {
326        u8, u16, u32, u64, u128, usize,
327        i8, i16, i32, i64, i128, isize,
328    }
329}
330
331fn eval_const_expr(kind: &NumberKind, expr: &syn::Block) -> syn::Result<NumberValue> {
332    let mut engine = Engine::new();
333
334    use_rhai_int! {
335        register[engine] {
336            u8, u16, u32, u64, u128, usize,
337            i8, i16, i32, i64, i128, isize,
338        }
339    }
340
341    let stmts = &expr.stmts;
342
343    if stmts.len() != 1 {
344        return Err(syn::Error::new(expr.span(), "expected a single expression"));
345    }
346
347    let script = stmts[0].to_token_stream().to_string();
348
349    Ok(match kind {
350        NumberKind::U8 => match engine.eval_expression::<u8>(&script) {
351            Ok(n) => n.into(),
352            Err(err) => {
353                return Err(syn::Error::new(
354                    expr.span(),
355                    format!("failed to evaluate expression: {}", err),
356                ))
357            }
358        },
359        NumberKind::U16 => match engine.eval_expression::<u16>(&script) {
360            Ok(n) => n.into(),
361            Err(err) => {
362                return Err(syn::Error::new(
363                    expr.span(),
364                    format!("failed to evaluate expression: {}", err),
365                ))
366            }
367        },
368        NumberKind::U32 => match engine.eval_expression::<u32>(&script) {
369            Ok(n) => n.into(),
370            Err(err) => {
371                return Err(syn::Error::new(
372                    expr.span(),
373                    format!("failed to evaluate expression: {}", err),
374                ))
375            }
376        },
377        NumberKind::U64 => match engine.eval_expression::<u64>(&script) {
378            Ok(n) => n.into(),
379            Err(err) => {
380                return Err(syn::Error::new(
381                    expr.span(),
382                    format!("failed to evaluate expression: {}", err),
383                ))
384            }
385        },
386        NumberKind::U128 => match engine.eval_expression::<u128>(&script) {
387            Ok(n) => n.into(),
388            Err(err) => {
389                return Err(syn::Error::new(
390                    expr.span(),
391                    format!("failed to evaluate expression: {}", err),
392                ))
393            }
394        },
395        NumberKind::USize => match engine.eval_expression::<usize>(&script) {
396            Ok(n) => n.into(),
397            Err(err) => {
398                return Err(syn::Error::new(
399                    expr.span(),
400                    format!("failed to evaluate expression: {}", err),
401                ))
402            }
403        },
404        NumberKind::I8 => match engine.eval_expression::<i8>(&script) {
405            Ok(n) => n.into(),
406            Err(err) => {
407                return Err(syn::Error::new(
408                    expr.span(),
409                    format!("failed to evaluate expression: {}", err),
410                ))
411            }
412        },
413        NumberKind::I16 => match engine.eval_expression::<i16>(&script) {
414            Ok(n) => n.into(),
415            Err(err) => {
416                return Err(syn::Error::new(
417                    expr.span(),
418                    format!("failed to evaluate expression: {}", err),
419                ))
420            }
421        },
422        NumberKind::I32 => match engine.eval_expression::<i32>(&script) {
423            Ok(n) => n.into(),
424            Err(err) => {
425                return Err(syn::Error::new(
426                    expr.span(),
427                    format!("failed to evaluate expression: {}", err),
428                ))
429            }
430        },
431        NumberKind::I64 => match engine.eval_expression::<i64>(&script) {
432            Ok(n) => n.into(),
433            Err(err) => {
434                return Err(syn::Error::new(
435                    expr.span(),
436                    format!("failed to evaluate expression: {}", err),
437                ))
438            }
439        },
440        NumberKind::I128 => match engine.eval_expression::<i128>(&script) {
441            Ok(n) => n.into(),
442            Err(err) => {
443                return Err(syn::Error::new(
444                    expr.span(),
445                    format!("failed to evaluate expression: {}", err),
446                ))
447            }
448        },
449        NumberKind::ISize => match engine.eval_expression::<isize>(&script) {
450            Ok(n) => n.into(),
451            Err(err) => {
452                return Err(syn::Error::new(
453                    expr.span(),
454                    format!("failed to evaluate expression: {}", err),
455                ))
456            }
457        },
458    })
459}