kagou/
lib.rs

1mod seaorm;
2// 导入过程宏相关的核心库。proc_macro 是编译器提供的库,用于处理令牌流
3use proc_macro::TokenStream;
4// quote 库允许我们将 Rust 语法转换为 TokenStream
5use quote::quote;
6// syn 库用于解析 Rust 代码,提供语法树结构
7use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, ItemFn, PathArguments, Type};
8// 数据库查询的 增删改查功能
9#[proc_macro]
10pub fn seaorm_curd(args: TokenStream) -> TokenStream {
11    seaorm::curd::page(args)
12    // if cfg!(feature = "seaorm_page") {
13    //     //开启指定 features 特性
14    //     // 调用其他函数
15    //     seaorm::curd::page(args)
16    // } else {
17    //     seaorm::curd::page(args)
18    // }
19}
20// 过程宏   只能包含以下三种宏之一:
21// 函数式的过程宏 (proc_macro)
22#[proc_macro]
23pub fn my_macro(input: TokenStream) -> TokenStream {
24    let input = input.to_string();
25    let output = format!("println!(\"Hello from macro! Input was: {}\");", input);
26    output.parse().unwrap()
27}
28
29// derive过程宏派生宏 (proc_macro_derive)
30#[proc_macro_derive(MyDerive)]
31pub fn my_derive(input: TokenStream) -> TokenStream {
32    let input = parse_macro_input!(input as DeriveInput);
33    let name = input.ident;
34
35    let output = quote! {
36        impl #name {
37            fn hello(&self) {
38                println!("Hello from derived impl for {}!", stringify!(#name));
39            }
40        }
41    };
42
43    output.into()
44}
45
46// 属性宏 (proc_macro_attribute)
47#[proc_macro_attribute]
48pub fn my_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
49    let attr = attr.to_string();
50    let input = parse_macro_input!(item as ItemFn);
51    let fn_name = &input.sig.ident;
52
53    let output = quote! {
54        #input
55
56        fn #fn_name() {
57            println!("This function was modified by my_attr with attribute: {}", #attr);
58        }
59    };
60
61    output.into()
62}
63
64/// 为 SeaORM ActiveModel 生成从 JSON 值更新字段的方法
65/// 这是一个派生宏(derive macro),通过 #[derive(UpdateFromJson)] 使用
66/// 输入是应用该宏的结构体的 TokenStream,输出是生成的代码的 TokenStream
67/// 核心功能:根据提供的 JSON 数据(通常是 Option<T> 字段)更新 SeaORM ActiveModel
68/// 只有 JSON 数据中存在的 Some(...) 字段才会被更新,未提供的字段保持原状
69#[proc_macro_derive(UpdateFromJson)]
70pub fn run(input: TokenStream) -> TokenStream {
71    // 将输入的 TokenStream 解析为 DeriveInput 语法树结构
72    // DeriveInput 包含了所有关于结构体、枚举或联合体的信息
73    let input = parse_macro_input!(input as DeriveInput);
74
75    // 获取结构体的名称(标识符),用于生成 impl 块
76    let struct_name = &input.ident;
77
78    // 处理结构体的数据(字段)
79    // 检查输入数据是否是一个结构体
80    let fields = if let Data::Struct(data_struct) = &input.data {
81        &data_struct.fields
82    } else {
83        // 如果不是结构体,返回编译错误
84        // 过程宏应该返回 TokenStream,这里使用 to_compile_error 创建错误信息
85        return syn::Error::new_spanned(input, "`UpdateFromJson` 此宏仅支持应用于结构体中")
86            .to_compile_error()
87            .into();
88    };
89
90    // 为每个字段生成更新语句
91    // 检查字段是否是有名字段(相对于元组结构体的无名字段)
92    let update_arms = if let Fields::Named(fields_named) = fields {
93        // 遍历所有有名字段,为每个字段生成更新代码
94        fields_named.named.iter().map(|field| {
95            // 获取字段的名称
96            let field_name = &field.ident;
97            // 将字段名转换为字符串,用于匹配
98            let field_name_str = field_name.as_ref().unwrap().to_string();
99            // 获取字段的类型,用于类型检查和转换
100            let field_type = &field.ty;
101
102            // 根据字段名和类型生成不同的处理逻辑
103            // 使用字段名匹配来识别需要特殊处理的字段
104            match field_name_str.as_str() {
105                // 时间字段特殊处理:String -> DateTime
106                // 假设数据库模型中对应字段是 DateTime 类型,但 JSON 数据中是字符串
107                "time_line" => {
108                    quote! {
109                        #field_name_str => {
110                            if let Some(time_str) = &data.#field_name {
111                                // 将 RFC3339 格式的时间字符串解析为 DateTime
112                                // 使用 with_timezone(&chrono::Utc) 转换为 UTC 时区
113                                match chrono::DateTime::parse_from_rfc3339(time_str)
114                                    .map(|dt| dt.with_timezone(&chrono::Utc))
115                                {
116                                    Ok(dt) => {
117                                        // 解析成功,设置 ActiveModel 的对应字段
118                                        active_model.#field_name = sea_orm::ActiveValue::Set(Some(dt));
119                                    }
120                                    Err(e) => {
121                                        // 解析失败,记录警告日志但不中断程序
122                                        log::warn!(
123                                            "Failed to parse datetime string for field '{}': {}",
124                                            #field_name_str,
125                                            e
126                                        );
127                                    }
128                                }
129                            }
130                        }
131                    }
132                }
133                // 金额字段特殊处理:f64 -> Decimal (如果需要)
134                // 假设数据库模型中对应字段是 Decimal 类型
135                "price" | "fee" => {
136                    quote! {
137                        #field_name_str => {
138                            if let Some(num) = data.#field_name {
139                                // 这里假设 ActiveModel 的对应字段已经是 Decimal 类型
140                                // 如果需要转换,可以在这里添加 f64 到 Decimal 的转换逻辑
141                                active_model.#field_name = sea_orm::ActiveValue::Set(Some(num));
142                            }
143                        }
144                    }
145                }
146                // JSON 字段特殊处理
147                // 假设数据库模型中对应字段是 sea_orm::sea_query::types::Json 类型
148                "rules" | "store_quanxain" => {
149                    quote! {
150                        #field_name_str => {
151                            if let Some(json_value) = &data.#field_name {
152                                // 将 fache::json::Value 转换为 sea_orm 的 Json 类型
153                                active_model.#field_name = sea_orm::ActiveValue::Set(
154                                    Some(sea_orm::sea_query::types::Json(json_value.clone()))
155                                );
156                            }
157                        }
158                    }
159                }
160                // 默认处理:直接赋值
161                // 对于大多数字段,直接进行值赋值
162                _ => {
163                    // 检查字段类型,确保兼容性
164                    if is_type_compatible(field_type) {
165                        quote! {
166                            #field_name_str => {
167                                if let Some(value) = data.#field_name {
168                                    // 直接设置字段值,假设类型兼容
169                                    active_model.#field_name = sea_orm::ActiveValue::Set(value);
170                                }
171                            }
172                        }
173                    } else {
174                        // 对于不兼容的类型,生成错误或跳过
175                        quote! {
176                            #field_name_str => {
177                                println!("Field '{}' has incompatible type and will be skipped", #field_name_str);
178                            }
179                        }
180                    }
181                }
182            }
183        })
184    } else {
185        // 处理没有命名字段的结构体(如元组结构体)
186        return syn::Error::new_spanned(input, "`UpdateFromJson` 仅支持有名字的结构体")
187            .to_compile_error()
188            .into();
189    };
190
191    // 生成最终的 impl 块代码
192    // 使用 quote! 宏生成完整的实现代码
193    let expanded = quote! {
194        impl #struct_name {
195            /// 根据提供的 JSON 数据更新 SeaORM ActiveModel 的字段
196            /// 只有 JSON 数据中存在的 `Some(...)` 字段才会被更新
197            /// 未提供的字段将保持其在 ActiveModel 中的原状(NotSet 或原有值)
198            ///
199            /// # 参数
200            /// - `active_model`: 要更新的 SeaORM ActiveModel 的可变引用
201            /// - `data`: 包含更新数据的结构体实例(&Self)
202            ///
203            /// # 注意
204            /// 此方法只会更新提供的 Some(...) 字段,None 字段会被忽略
205            pub fn update_active_model_from_json(
206                &self,
207                active_model: &mut sea_orm::ActiveModel,
208            ) {
209                let data = self;
210                // 这里会生成一个匹配字段名的 match 语句(或类似逻辑)
211                // 为了简化示例,我们直接遍历结构体的字段
212                #(#update_arms)*
213            }
214        }
215    };
216
217    // 将生成的代码转换为 TokenStream 并返回
218    // 编译器会将这个 TokenStream 插入到原始代码中
219    TokenStream::from(expanded)
220}
221
222/// 检查字段类型是否与 SeaORM ActiveValue 兼容
223/// 这是一个简化版的实现,实际应用中应该根据具体需求进行更详细的检查
224fn is_type_compatible(field_type: &Type) -> bool {
225    println!("当前的类型 '{:#?}' ", field_type);
226    // 处理路径类型(最常见的情况)
227    if let Type::Path(type_path) = field_type {
228        // 获取最后一段路径(通常是类型名称)
229        if let Some(segment) = type_path.path.segments.last() {
230            let type_name = segment.ident.to_string();
231            // 检查基本类型
232            match type_name.as_str() {
233                // 基本数值类型
234                "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64"
235                | "u128" | "usize" | "f32" | "f64" => return true,
236
237                // 布尔类型
238                "bool" => return true,
239
240                // 字符串类型
241                "String" | "str" => return true,
242
243                // SeaORM 和常用类型
244                "DateTime"
245                | "DateTimeWithTimeZone"
246                | "DateTimeUtc"
247                | "DateTimeLocal"
248                | "NaiveDate"
249                | "NaiveDateTime"
250                | "Decimal"
251                | "Json" => return true,
252
253                // JSON 类型 (serde 和 sea_orm)
254                "Value" => {
255                    // 检查是否是 serde_json::Value 或 sea_orm::sea_query::types::Json
256                    if let Some(path) = type_path.path.segments.first() {
257                        return path.ident == "serde_json" || path.ident == "sea_orm";
258                    }
259                    return false;
260                }
261
262                // Option<T> 类型 - 需要检查内部类型
263                "Option" => {
264                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
265                        if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
266                            // 递归检查内部的类型
267                            return is_type_compatible(inner_type);
268                        }
269                    }
270                    return false;
271                }
272
273                // 其他类型需要进一步检查
274                _ => {}
275            }
276        }
277    }
278
279    // 处理引用类型 (&T, &mut T)
280    if let Type::Reference(type_ref) = field_type {
281        return is_type_compatible(&type_ref.elem);
282    }
283
284    // 处理数组类型 [T; N]
285    if let Type::Array(type_array) = field_type {
286        return is_type_compatible(&type_array.elem);
287    }
288
289    // 处理元组类型 (T, U, ...)
290    if let Type::Tuple(type_tuple) = field_type {
291        return type_tuple.elems.iter().all(is_type_compatible);
292    }
293
294    // 对于未知或复杂类型,默认返回 false
295    // 在生产环境中,可以在这里添加更多特定的检查或记录警告
296    false
297}