1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
use proc_macro::TokenStream; // proc_macro 提供的 TokenStream 类型,是编译器传给宏的语法片段
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Fields}; // syn 用于解析 Rust 语法树 // quote 用于生成 Rust 代码
use crate::create_crate_ident;
/// 真实宏逻辑实现
/// 输入: TokenStream(Rust 编译器传入的枚举语法片段)
/// 输出: TokenStream(生成的 impl 代码)
/// 作用: 自动为枚举生成 FromStr、Serialize/Deserialize、sqlx Decode/Type 等实现
pub fn derive_to_sqlx_enum(input: TokenStream) -> TokenStream {
// 将 TokenStream 解析成 AST(抽象语法树),方便操作
// DeriveInput 是 syn 提供的类型,代表一个可派生宏的 Rust item(struct、enum、union)
let input = parse_macro_input!(input as DeriveInput);
// 获取枚举名字,例如 enum Status {...} 的名字就是 "Status"
let name = &input.ident;
let crate_ident = create_crate_ident("db_cores");
let serde_path = if crate_ident == "crate" {
quote!(crate::serde)
} else {
quote!(::#crate_ident::serde)
};
let sqlx_path = if crate_ident == "crate" {
quote!(crate::sqlx) // 如果是db_common本身调 crate::sqlx,则直接使用 crate::sqlx,不能写成::crate::sqlx
} else {
quote!(::#crate_ident::sqlx) // 如果是其它组件调用则是 ::db_common::sqlx,则使用 ::db_common::sqlx
};
// 检查枚举类型,并收集所有 unit variant
// unit variant = 没有字段的枚举,如 Pending、Approved、Rejected
let variants = if let syn::Data::Enum(data_enum) = &input.data {
// 遍历枚举的每个 variant
data_enum
.variants
.iter()
.map(|v| {
let ident = &v.ident; // variant 名字,例如 Pending
match &v.fields {
Fields::Unit => {
// 只支持无字段的枚举
let value = ident.to_string(); // 将 variant 名字转换为字符串,用于 FromStr/Serialize
(ident.clone(), value) // 返回元组 (variant 标识符, 对应字符串)
}
_ => panic!("SqlEnum only supports unit variants"), // 如果有字段,宏直接报错
}
})
.collect::<Vec<_>>() // 收集到 Vec,避免迭代器被移动
} else {
panic!("SqlEnum can only be derived on enums"); // 如果不是枚举,宏报错
};
// 构建 FromStr match 分支
// 用于把字符串解析成枚举,例如 "Pending" => Ok(Status::Pending)
let from_str_match: Vec<_> = variants
.iter()
.map(|(ident, value)| {
quote! {
x if x.eq_ignore_ascii_case(#value) => Ok(#name::#ident), //忽略 大小
}
})
.collect();
let lowercases: Vec<_> = variants
.iter()
.map(|(ident, value)| {
quote! {
#name::#ident => #value.to_ascii_lowercase(), //忽略 大小
}
})
.collect();
// 为 Deserialize 单独构建 match 分支,使用正确的变量名
let deserialize_match: Vec<_> = variants
.iter()
.map(|(ident, value)| {
quote! {
x if x.eq_ignore_ascii_case(#value) => Ok(#name::#ident),
}
}).collect();
// 使用 quote! 生成最终代码
// quote! 可以把 Rust 代码写成模板,变量用 #var 插入
let expanded = quote! {
// 实现 std::str::FromStr
impl std::str::FromStr for #name {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
#(#from_str_match)* // 插入所有 match 分支
_ => Err(format!("Invalid value for {}: {}", stringify!(#name), s)),
}
}
}
impl #name {
pub fn to_lower(&self) -> String {
match self {
#(#lowercases)* // 插入所有 match 分支
}
}
}
// 实现 Deserialize 当sqlx 解析son内容时会调用
// impl<'de> ::#crate_ident::serde::Deserialize<'de> for #name {
// fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
// where
// D: ::#crate_ident::serde::Deserializer<'de>,
// {
// // let s = String::deserialize(deserializer)?;
// let s = <String as ::#crate_ident::serde::Deserialize>::deserialize(deserializer)?;
// match s.as_str() {
// #(#deserialize_match)*
// _ => Err(::#crate_ident::serde::de::Error::custom(
// format!(concat!("Invalid value for ", stringify!($name), ": {}"), s))
// ),
// }
// }
// }
impl<'de> #serde_path::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: #serde_path::Deserializer<'de>,
{
let s = <String as #serde_path::Deserialize>::deserialize(deserializer)?;
match s.as_str() {
#(#deserialize_match)*
_ => Err(<D::Error as #serde_path::de::Error>::custom(
format!(
concat!("Invalid value for ", stringify!(#name), ": {}"),
s
)
)),
}
}
}
// SQLite
impl<'r> #sqlx_path::Decode<'r, #sqlx_path::Sqlite> for #name {
fn decode(value: #sqlx_path::sqlite::SqliteValueRef<'r>) -> Result<Self,#sqlx_path::error::BoxDynError> {
// 先把数据库字段解析成 String
let value_str = <String as #sqlx_path::Decode<#sqlx_path::Sqlite>>::decode(value)?;
// 再用 FromStr 转成枚举
#name::from_str(&value_str).map_err(|e| e.into())
}
}
// MySQL
impl<'r>#sqlx_path::Decode<'r,::sqlx::MySql> for #name {
fn decode(value: #sqlx_path::mysql::MySqlValueRef<'r>) -> Result<Self,#sqlx_path::error::BoxDynError> {
let value_str = <String as #sqlx_path::Decode<#sqlx_path::MySql>>::decode(value)?;
#name::from_str(&value_str).map_err(|e| e.into())
}
}
// PostgreSQL
impl<'r>#sqlx_path::Decode<'r,#sqlx_path::Postgres> for #name {
fn decode(value: #sqlx_path::postgres::PgValueRef<'r>) -> Result<Self,#sqlx_path::error::BoxDynError> {
let value_str = <String as #sqlx_path::Decode<#sqlx_path::Postgres>>::decode(value)?;
#name::from_str(&value_str).map_err(|e| e.into())
}
}
// ==================::sqlx::Type impls ==================
impl #sqlx_path::Type<#sqlx_path::Sqlite> for #name {
fn type_info() ->#sqlx_path::sqlite::SqliteTypeInfo {
<String as #sqlx_path::Type<#sqlx_path::Sqlite>>::type_info() // 用 String 类型信息
}
}
impl #sqlx_path::Type<#sqlx_path::MySql> for #name {
fn type_info() ->#sqlx_path::mysql::MySqlTypeInfo {
<String as #sqlx_path::Type<#sqlx_path::MySql>>::type_info()
}
}
impl #sqlx_path::Type<#sqlx_path::Postgres> for #name {
fn type_info() ->#sqlx_path::postgres::PgTypeInfo {
<String as #sqlx_path::Type<#sqlx_path::Postgres>>::type_info()
}
}
};
// 将 quote! 生成的 TokenStream 返回给编译器
TokenStream::from(expanded)
}