calculate_macro/
lib.rs

1use std::ops::Deref;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input, BinOp, Expr, Lit, Token, Type, UnOp,
9};
10
11struct CalculateExpr {
12    expr: Expr,
13    result_type: Type,
14}
15
16impl Parse for CalculateExpr {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        let expr: Expr = input.parse()?;
19        input.parse::<Token![;]>()?;
20        let result_type: Type = input.parse()?;
21        Ok(Self { expr, result_type })
22    }
23}
24
25#[proc_macro]
26pub fn calc(input: TokenStream) -> TokenStream {
27    let CalculateExpr { expr, result_type } = parse_macro_input!(input as CalculateExpr);
28    let evaluated_expr = evaluate_expression(&expr, &result_type);
29    let result = quote! {
30        {
31            let result: #result_type = #evaluated_expr;
32            result
33        }
34    };
35    result.into()
36}
37
38fn evaluate_expression(expr: &Expr, result_type: &Type) -> TokenStream2 {
39    match expr {
40        Expr::Binary(op) => {
41            let left = evaluate_expression(&op.left, result_type);
42            let right = evaluate_expression(&op.right, result_type);
43            match op.op {
44                BinOp::Add(_) => quote! { #left + #right },
45                BinOp::Sub(_) => quote! { #left - #right },
46                BinOp::Mul(_) => quote! { #left * #right },
47                BinOp::Div(_) => quote! { #left / #right },
48                BinOp::Rem(_) => quote! { #left % #right },
49                BinOp::BitXor(_) => quote! { #left ^ #right },
50                BinOp::BitAnd(_) => quote! { #left & #right },
51                BinOp::BitOr(_) => quote! { #left | #right },
52                BinOp::Shl(_) => quote! { #left << #right },
53                BinOp::Shr(_) => quote! { #left >> #right },
54                // 可以添加更多二元操作符
55                _ => panic!("Unsupported binary operation"),
56            }
57        }
58        Expr::Unary(op) => {
59            let expr = evaluate_expression(&op.expr, result_type);
60            match op.op {
61                UnOp::Neg(_) => quote! { - #expr },
62                // 可以添加更多一元操作符
63                _ => panic!("Unsupported unary operation"),
64            }
65        }
66        Expr::Lit(lit) => match &lit.lit {
67            Lit::Int(int_lit) => {
68                let int_value = int_lit.base10_parse::<i128>().unwrap();
69                let int_value: TokenStream2 = match result_type {
70                    Type::Path(path) if path.path.is_ident("u16") => quote! { #int_value as u16 },
71                    Type::Path(path) if path.path.is_ident("u32") => quote! { #int_value as u32 },
72                    Type::Path(path) if path.path.is_ident("u64") => quote! { #int_value as u64 },
73                    Type::Path(path) if path.path.is_ident("u128") => quote! { #int_value as u128 },
74                    Type::Path(path) if path.path.is_ident("i8") => quote! { #int_value as i8 },
75                    Type::Path(path) if path.path.is_ident("i16") => quote! { #int_value as i16 },
76                    Type::Path(path) if path.path.is_ident("i32") => quote! { #int_value as i32 },
77                    Type::Path(path) if path.path.is_ident("i64") => quote! { #int_value as i64 },
78                    Type::Path(path) if path.path.is_ident("i128") => quote! { #int_value as i128 },
79                    // 可以添加更多整数类型处理
80                    _ => panic!("Unsupported result type for integer literal"),
81                };
82                int_value
83            }
84            Lit::Float(float_lit) => {
85                let float_value = float_lit.base10_parse::<f64>().unwrap();
86                let float_value: TokenStream2 = match result_type {
87                    Type::Path(path) if path.path.is_ident("f32") => quote! { #float_value as f32 },
88                    Type::Path(path) if path.path.is_ident("f64") => quote! { #float_value as f64 },
89                    // 可以添加更多浮点数类型处理
90                    _ => panic!("Unsupported result type for float literal"),
91                };
92                float_value
93            }
94            Lit::Bool(bool_lit) => quote! { #bool_lit },
95            // 可以添加更多字面量类型
96            _ => panic!("Unsupported literal type"),
97        },
98        Expr::Paren(paren) => {
99            let paren = evaluate_expression(&paren.expr, result_type);
100            quote! { (#paren) }
101        }
102        Expr::Call(call) => {
103            let func = match call.func.deref() {
104                Expr::Path(path) => path.path.get_ident().unwrap().to_string(),
105                _ => panic!("Unsupported function call"),
106            };
107            let args = call
108                .args
109                .iter()
110                .map(|arg| evaluate_expression(arg, result_type))
111                .collect::<Vec<_>>();
112            match func.as_str() {
113                "max" => {
114                    let [ref arg0, ref arg1] = args[..] else {
115                        panic!("Unsupported function call");
116                    };
117                    quote! { ::core::cmp::max(#arg0, #arg1) }
118                }
119                "min" => {
120                    let [ref arg0, ref arg1] = args[..] else {
121                        panic!("Unsupported function call");
122                    };
123                    quote! { ::core::cmp::min(#arg0, #arg1) }
124                }
125                // 可以添加更多函数支持
126                _ => panic!("Unsupported function call"),
127            }
128        }
129        Expr::Path(path) => {
130            // 将变量转换为目标类型
131            match result_type {
132                Type::Path(_) => quote! { #path as #result_type },
133                _ => panic!("Unsupported result type for variable"),
134            }
135        }
136        // 处理其他可能的表达式类型,如索引访问、字段访问等
137        _ => panic!("Unsupported expression type"),
138    }
139}