sqlxplus_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6/// 去除原始标识符的 r# 前缀
7/// 例如:r#type -> type
8fn strip_raw_identifier_prefix(ident: &str) -> String {
9    if ident.starts_with("r#") {
10        ident[2..].to_string()
11    } else {
12        ident.to_string()
13    }
14}
15
16/// 解析字段的 column 属性,获取列名
17/// 如果指定了 name,使用指定的列名;否则使用去除 r# 前缀后的字段名
18fn parse_column_name(attrs: &[syn::Attribute], field_name: &str) -> String {
19    for attr in attrs {
20        if attr.path().is_ident("column") {
21            if let syn::Meta::List(list) = &attr.meta {
22                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
23                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
24                    for meta in metas {
25                        if let Meta::NameValue(nv) = meta {
26                            if nv.path.is_ident("name") {
27                                if let syn::Expr::Lit(syn::ExprLit {
28                                    lit: syn::Lit::Str(s),
29                                    ..
30                                }) = nv.value
31                                {
32                                    return s.value();
33                                }
34                            }
35                        }
36                    }
37                }
38            } else if let syn::Meta::NameValue(nv) = &attr.meta {
39                if nv.path.is_ident("name") {
40                    if let syn::Expr::Lit(syn::ExprLit {
41                        lit: syn::Lit::Str(s),
42                        ..
43                    }) = &nv.value
44                    {
45                        return s.value();
46                    }
47                }
48            }
49        }
50    }
51    // 如果没有指定 name,使用去除 r# 前缀后的字段名
52    strip_raw_identifier_prefix(field_name)
53}
54
55/// 生成 Model trait 的实现
56///
57/// 自动生成 `TABLE`、`PK` 和可选的 `SOFT_DELETE_FIELD` 常量
58///
59/// 使用示例:
60/// ```ignore
61/// // 物理删除模式(默认)
62/// #[derive(ModelMeta)]
63/// #[model(table = "users", pk = "id")]
64/// struct User {
65///     id: i64,
66///     name: String,
67/// }
68///
69/// // 逻辑删除模式
70/// #[derive(ModelMeta)]
71/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
72/// struct UserWithSoftDelete {
73///     id: i64,
74///     name: String,
75///     is_deleted: i32, // 逻辑删除字段:0=未删除,1=已删除
76/// }
77/// ```
78#[proc_macro_derive(ModelMeta, attributes(model, column))]
79pub fn derive_model_meta(input: TokenStream) -> TokenStream {
80    let input = parse_macro_input!(input as DeriveInput);
81    let name = &input.ident;
82
83    // 解析属性
84    let mut table_name = None;
85    let mut pk_field = None;
86    let mut soft_delete_field = None;
87
88    for attr in &input.attrs {
89        if attr.path().is_ident("model") {
90            // 在 syn 2.0 中,使用 meta() 方法获取元数据
91            if let syn::Meta::List(list) = &attr.meta {
92                // 解析列表中的每个 Meta::NameValue
93                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
94                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
95                    for meta in metas {
96                        if let Meta::NameValue(nv) = meta {
97                            if nv.path.is_ident("table") {
98                                if let syn::Expr::Lit(syn::ExprLit {
99                                    lit: syn::Lit::Str(s),
100                                    ..
101                                }) = nv.value
102                                {
103                                    table_name = Some(s.value());
104                                }
105                            } else if nv.path.is_ident("pk") {
106                                if let syn::Expr::Lit(syn::ExprLit {
107                                    lit: syn::Lit::Str(s),
108                                    ..
109                                }) = nv.value
110                                {
111                                    pk_field = Some(s.value());
112                                }
113                            } else if nv.path.is_ident("soft_delete") {
114                                if let syn::Expr::Lit(syn::ExprLit {
115                                    lit: syn::Lit::Str(s),
116                                    ..
117                                }) = nv.value
118                                {
119                                    soft_delete_field = Some(s.value());
120                                }
121                            }
122                        }
123                    }
124                }
125            } else if let syn::Meta::NameValue(nv) = &attr.meta {
126                // 单个 NameValue 的情况
127                if nv.path.is_ident("table") {
128                    if let syn::Expr::Lit(syn::ExprLit {
129                        lit: syn::Lit::Str(s),
130                        ..
131                    }) = &nv.value
132                    {
133                        table_name = Some(s.value());
134                    }
135                } else if nv.path.is_ident("pk") {
136                    if let syn::Expr::Lit(syn::ExprLit {
137                        lit: syn::Lit::Str(s),
138                        ..
139                    }) = &nv.value
140                    {
141                        pk_field = Some(s.value());
142                    }
143                } else if nv.path.is_ident("soft_delete") {
144                    if let syn::Expr::Lit(syn::ExprLit {
145                        lit: syn::Lit::Str(s),
146                        ..
147                    }) = &nv.value
148                    {
149                        soft_delete_field = Some(s.value());
150                    }
151                }
152            }
153        }
154    }
155
156    // 如果没有指定表名,使用结构体名称的小写蛇形命名方式
157    let table = table_name.unwrap_or_else(|| {
158        let s = name.to_string();
159        // 将 PascalCase 转换为 snake_case
160        let mut result = String::new();
161        for (i, c) in s.chars().enumerate() {
162            if c.is_uppercase() && i > 0 {
163                result.push('_');
164            }
165            result.push(c.to_ascii_lowercase());
166        }
167        result
168    });
169
170    // 如果没有指定主键,默认使用 "id"
171    let pk = pk_field.unwrap_or_else(|| "id".to_string());
172
173    // 生成实现代码
174    let expanded = if let Some(soft_delete) = soft_delete_field {
175        // 如果指定了逻辑删除字段,生成包含 SOFT_DELETE_FIELD 的实现
176        let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
177        quote! {
178            impl sqlxplus::Model for #name {
179                const TABLE: &'static str = #table;
180                const PK: &'static str = #pk;
181                const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
182            }
183        }
184    } else {
185        // 如果没有指定逻辑删除字段,SOFT_DELETE_FIELD 为 None
186        quote! {
187            impl sqlxplus::Model for #name {
188                const TABLE: &'static str = #table;
189                const PK: &'static str = #pk;
190                const SOFT_DELETE_FIELD: Option<&'static str> = None;
191            }
192        }
193    };
194
195    TokenStream::from(expanded)
196}
197
198/// 生成 CRUD trait 的实现
199///
200/// 自动生成 insert 和 update 方法的实现
201///
202/// 使用示例:
203/// ```ignore
204/// // 物理删除模式
205/// #[derive(CRUD, FromRow, ModelMeta)]
206/// #[model(table = "users", pk = "id")]
207/// struct User {
208///     id: i64,
209///     name: String,
210///     email: String,
211/// }
212///
213/// // 逻辑删除模式
214/// #[derive(CRUD, FromRow, ModelMeta)]
215/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
216/// struct UserWithSoftDelete {
217///     id: i64,
218///     name: String,
219///     email: String,
220///     is_deleted: i32, // 逻辑删除字段
221/// }
222/// ```
223#[proc_macro_derive(CRUD, attributes(model, skip, column))]
224pub fn derive_crud(input: TokenStream) -> TokenStream {
225    let input = parse_macro_input!(input as DeriveInput);
226    let name = &input.ident;
227
228    // 解析 #[model(pk = "...")],获取主键字段名,默认 "id"
229    let mut pk_field = None;
230    for attr in &input.attrs {
231        if attr.path().is_ident("model") {
232            if let syn::Meta::List(list) = &attr.meta {
233                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
234                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
235                    for meta in metas {
236                        if let Meta::NameValue(nv) = meta {
237                            if nv.path.is_ident("pk") {
238                                if let syn::Expr::Lit(syn::ExprLit {
239                                    lit: syn::Lit::Str(s),
240                                    ..
241                                }) = nv.value
242                                {
243                                    pk_field = Some(s.value());
244                                }
245                            }
246                        }
247                    }
248                }
249            } else if let syn::Meta::NameValue(nv) = &attr.meta {
250                if nv.path.is_ident("pk") {
251                    if let syn::Expr::Lit(syn::ExprLit {
252                        lit: syn::Lit::Str(s),
253                        ..
254                    }) = &nv.value
255                    {
256                        pk_field = Some(s.value());
257                    }
258                }
259            }
260        }
261    }
262    // 如果没有指定主键,默认使用 "id"
263    let pk = pk_field.unwrap_or_else(|| "id".to_string());
264
265    // 获取字段列表(必须是具名字段的结构体)
266    let fields = match &input.data {
267        Data::Struct(DataStruct {
268            fields: Fields::Named(fields),
269            ..
270        }) => &fields.named,
271        _ => {
272            return syn::Error::new_spanned(
273                name,
274                "CRUD derive only supports structs with named fields",
275            )
276            .to_compile_error()
277            .into();
278        }
279    };
280
281    // 收集字段信息
282    // - pk_ident: 主键字段 Ident
283    // - insert_*/update_*: 非主键字段(INSERT / UPDATE 使用)
284    let mut pk_ident_opt: Option<&syn::Ident> = None;
285
286    // INSERT 使用的字段(非主键)
287    let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
288    let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
289    let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
290    let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
291
292    // UPDATE 使用的字段(非主键)
293    let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
294    let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
295    let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
296    let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
297
298    // 用于 UpdateFields trait:只包含 BindValue 支持的类型
299    let mut update_fields_normal_field_names: Vec<&syn::Ident> = Vec::new();
300    let mut update_fields_normal_field_columns: Vec<syn::LitStr> = Vec::new();
301    let mut update_fields_option_field_names: Vec<&syn::Ident> = Vec::new();
302    let mut update_fields_option_field_columns: Vec<syn::LitStr> = Vec::new();
303    // 字段名(用于支持 r#type 这样的原始标识符,在 match 中同时匹配字段名和列名)
304    let mut update_fields_normal_field_name_strs: Vec<syn::LitStr> = Vec::new();
305    let mut update_fields_option_field_name_strs: Vec<syn::LitStr> = Vec::new();
306
307    for field in fields {
308        let field_name = field.ident.as_ref().unwrap();
309        let field_name_str = field_name.to_string();
310        // 获取数据库列名(处理 r#type 这样的原始标识符和 #[column(name = "...")] 属性)
311        let column_name = parse_column_name(&field.attrs, &field_name_str);
312
313        // 检查属性:skip / model
314        let mut skip = false;
315        for attr in &field.attrs {
316            if attr.path().is_ident("skip") || attr.path().is_ident("model") {
317                skip = true;
318                break;
319            }
320        }
321
322        if !skip {
323            if field_name_str == pk {
324                // 记录主键字段
325                pk_ident_opt = Some(field_name);
326                // 主键字段也需要添加到 UpdateFields,因为 UpdateBuilder 需要获取主键值
327                let is_opt = is_option_type(&field.ty);
328                let col_lit = syn::LitStr::new(&column_name, proc_macro2::Span::call_site());
329                let is_supported = if is_opt {
330                    if let Some(inner_ty) = get_option_inner_type(&field.ty) {
331                        is_bind_value_supported_type(inner_ty)
332                    } else {
333                        false
334                    }
335                } else {
336                    is_bind_value_supported_type(&field.ty)
337                };
338                // 如果主键类型是支持的类型,添加到 UpdateFields
339                if is_supported {
340                    let field_name_lit =
341                        syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
342                    if is_opt {
343                        update_fields_option_field_names.push(field_name);
344                        update_fields_option_field_columns.push(col_lit);
345                        update_fields_option_field_name_strs.push(field_name_lit);
346                    } else {
347                        update_fields_normal_field_names.push(field_name);
348                        update_fields_normal_field_columns.push(col_lit);
349                        update_fields_normal_field_name_strs.push(field_name_lit);
350                    }
351                }
352            } else {
353                // 非主键字段用于 INSERT / UPDATE
354                let is_opt = is_option_type(&field.ty);
355                let col_lit = syn::LitStr::new(&column_name, proc_macro2::Span::call_site());
356
357                // 检查是否是 BindValue 支持的类型
358                let is_supported = if is_opt {
359                    if let Some(inner_ty) = get_option_inner_type(&field.ty) {
360                        is_bind_value_supported_type(inner_ty)
361                    } else {
362                        false
363                    }
364                } else {
365                    is_bind_value_supported_type(&field.ty)
366                };
367
368                if is_opt {
369                    insert_option_field_names.push(field_name);
370                    insert_option_field_columns.push(col_lit.clone());
371
372                    update_option_field_names.push(field_name);
373                    update_option_field_columns.push(col_lit.clone());
374
375                    // 只为支持的类型添加到 UpdateFields
376                    if is_supported {
377                        let field_name_lit =
378                            syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
379                        update_fields_option_field_names.push(field_name);
380                        update_fields_option_field_columns.push(col_lit);
381                        update_fields_option_field_name_strs.push(field_name_lit);
382                    }
383                } else {
384                    insert_normal_field_names.push(field_name);
385                    insert_normal_field_columns.push(col_lit.clone());
386
387                    update_normal_field_names.push(field_name);
388                    update_normal_field_columns.push(col_lit.clone());
389
390                    // 只为支持的类型添加到 UpdateFields
391                    if is_supported {
392                        let field_name_lit =
393                            syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
394                        update_fields_normal_field_names.push(field_name);
395                        update_fields_normal_field_columns.push(col_lit);
396                        update_fields_normal_field_name_strs.push(field_name_lit);
397                    }
398                }
399            }
400        }
401    }
402
403    // 编译期确保主键字段存在
404    let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
405
406    // 生成实现代码
407    let expanded = quote! {
408        // Trait 方法实现
409        #[async_trait::async_trait]
410        impl sqlxplus::Crud for #name {
411            // 泛型版本的 insert(自动类型推断)
412            async fn insert<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<sqlxplus::crud::Id>
413            where
414                DB: sqlx::Database + sqlxplus::DatabaseInfo,
415                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
416                E: sqlxplus::DatabaseType<DB = DB>
417                    + sqlx::Executor<'c, Database = DB>
418                    + Send,
419                i64: sqlx::Type<DB> + for<'r> sqlx::Decode<'r, DB>,
420                usize: sqlx::ColumnIndex<DB::Row>,
421                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
422                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
423                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
424                i64: for<'b> sqlx::Encode<'b, DB>,
425                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
426                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
427                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
428                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
429                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
430                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
431                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
432                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
433                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
434                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
435                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
436                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
437                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
438                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
439                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
440                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
441                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
442                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
443                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
444                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
445                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
446                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
447                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
448                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
449            {
450                use sqlxplus::Model;
451                use sqlxplus::DatabaseInfo;
452                use sqlxplus::db_pool::DbDriver;
453                let table = Self::TABLE;
454                let escaped_table = DB::escape_identifier(table);
455
456                // 构建列名和占位符
457                let mut columns: Vec<&str> = Vec::new();
458                let mut placeholders: Vec<String> = Vec::new();
459                let mut placeholder_index = 0;
460
461                // 非 Option 字段:始终参与 INSERT
462                #(
463                    columns.push(#insert_normal_field_columns);
464                    placeholders.push(DB::placeholder(placeholder_index));
465                    placeholder_index += 1;
466                )*
467
468                // Option 字段:仅当为 Some 时参与 INSERT
469                #(
470                    if self.#insert_option_field_names.is_some() {
471                        columns.push(#insert_option_field_columns);
472                        placeholders.push(DB::placeholder(placeholder_index));
473                        placeholder_index += 1;
474                    }
475                )*
476
477                // 根据数据库类型构建 SQL
478                let sql = match DB::get_driver() {
479                    DbDriver::Postgres => {
480                        let pk = Self::PK;
481                        let escaped_pk = DB::escape_identifier(pk);
482                        format!(
483                            "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
484                            escaped_table,
485                            columns.join(", "),
486                            placeholders.join(", "),
487                            escaped_pk
488                        )
489                    }
490                    _ => {
491                        format!(
492                            "INSERT INTO {} ({}) VALUES ({})",
493                            escaped_table,
494                            columns.join(", "),
495                            placeholders.join(", ")
496                        )
497                    }
498                };
499
500                // 根据数据库类型执行查询
501                match DB::get_driver() {
502                    DbDriver::Postgres => {
503                        let mut query = sqlx::query_scalar::<_, i64>(&sql);
504                        // 非 Option 字段:始终绑定
505                        #(
506                            query = query.bind(&self.#insert_normal_field_names);
507                        )*
508                        // Option 字段:仅当为 Some 时绑定
509                        #(
510                            if let Some(ref val) = self.#insert_option_field_names {
511                                query = query.bind(val);
512                            }
513                        )*
514                        let id: i64 = query.fetch_one(executor).await?;
515                        Ok(id)
516                    }
517                    DbDriver::MySql => {
518                        let mut query = sqlx::query(&sql);
519                        // 非 Option 字段:始终绑定
520                        #(
521                            query = query.bind(&self.#insert_normal_field_names);
522                        )*
523                        // Option 字段:仅当为 Some 时绑定
524                        #(
525                            if let Some(ref val) = self.#insert_option_field_names {
526                                query = query.bind(val);
527                            }
528                        )*
529                        let result = query.execute(executor).await?;
530                        // 在泛型上下文中,我们需要使用 unsafe 转换来访问数据库特定的方法
531                        // 这是安全的,因为我们已经通过 DB::get_driver() 确认了数据库类型
532                        // 并且我们知道 DB = MySql,所以 result 的类型是 MySqlQueryResult
533                        unsafe {
534                            use sqlx::mysql::MySqlQueryResult;
535                            let ptr: *const DB::QueryResult = &result;
536                            let mysql_ptr = ptr as *const MySqlQueryResult;
537                            Ok((*mysql_ptr).last_insert_id() as i64)
538                        }
539                    }
540                    DbDriver::Sqlite => {
541                        let mut query = sqlx::query(&sql);
542                        // 非 Option 字段:始终绑定
543                        #(
544                            query = query.bind(&self.#insert_normal_field_names);
545                        )*
546                        // Option 字段:仅当为 Some 时绑定
547                        #(
548                            if let Some(ref val) = self.#insert_option_field_names {
549                                query = query.bind(val);
550                            }
551                        )*
552                        let result = query.execute(executor).await?;
553                        // 在泛型上下文中,我们需要使用 unsafe 转换来访问数据库特定的方法
554                        unsafe {
555                            use sqlx::sqlite::SqliteQueryResult;
556                            let ptr: *const DB::QueryResult = &result;
557                            let sqlite_ptr = ptr as *const SqliteQueryResult;
558                            Ok((*sqlite_ptr).last_insert_rowid() as i64)
559                        }
560                    }
561                }
562            }
563
564            // 泛型版本的 update(自动类型推断)
565            async fn update<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
566            where
567                DB: sqlx::Database + sqlxplus::DatabaseInfo,
568                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
569                E: sqlxplus::DatabaseType<DB = DB>
570                    + sqlx::Executor<'c, Database = DB>
571                    + Send,
572                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
573                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
574                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
575                i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
576                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
577                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
578                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
579                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
580                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
581                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
582                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
583                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
584                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
585                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
586                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
587                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
588                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
589                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
590                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
591                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
592                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
593                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
594                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
595                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
596                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
597                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
598                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
599                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
600            {
601                use sqlxplus::Model;
602                use sqlxplus::DatabaseInfo;
603                let table = Self::TABLE;
604                let pk = Self::PK;
605                let escaped_table = DB::escape_identifier(table);
606                let escaped_pk = DB::escape_identifier(pk);
607
608                // 构建 UPDATE SET 子句(Patch 语义)
609                let mut set_parts: Vec<String> = Vec::new();
610                let mut placeholder_index = 0;
611
612                // 非 Option 字段
613                #(
614                    set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
615                    placeholder_index += 1;
616                )*
617
618                // Option 字段
619                #(
620                    if self.#update_option_field_names.is_some() {
621                        set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
622                        placeholder_index += 1;
623                    }
624                )*
625
626                if set_parts.is_empty() {
627                    return Ok(());
628                }
629
630                let sql = format!(
631                    "UPDATE {} SET {} WHERE {} = {}",
632                    escaped_table,
633                    set_parts.join(", "),
634                    escaped_pk,
635                    DB::placeholder(placeholder_index)
636                );
637
638                let mut query = sqlx::query(&sql);
639                // 非 Option 字段:始终绑定
640                #(
641                    query = query.bind(&self.#update_normal_field_names);
642                )*
643                // Option 字段:仅当为 Some 时绑定
644                #(
645                    if let Some(ref val) = self.#update_option_field_names {
646                        query = query.bind(val);
647                    }
648                )*
649                query = query.bind(&self.#pk_ident);
650                query.execute(executor).await?;
651                Ok(())
652            }
653
654            // 泛型版本的 update_with_none(自动类型推断)
655            async fn update_with_none<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
656            where
657                DB: sqlx::Database + sqlxplus::DatabaseInfo,
658                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
659                E: sqlxplus::DatabaseType<DB = DB>
660                    + sqlx::Executor<'c, Database = DB>
661                    + Send,
662                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
663                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
664                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
665                i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
666                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
667                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
668                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
669                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
670                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
671                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
672                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
673                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
674                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
675                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
676                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
677                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
678                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
679                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
680                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
681                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
682                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
683                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
684                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
685                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
686                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
687                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
688                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
689                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
690            {
691                use sqlxplus::Model;
692                use sqlxplus::DatabaseInfo;
693                use sqlxplus::db_pool::DbDriver;
694                let table = Self::TABLE;
695                let pk = Self::PK;
696                let escaped_table = DB::escape_identifier(table);
697                let escaped_pk = DB::escape_identifier(pk);
698
699                // 构建 UPDATE SET 子句(Reset 语义)
700                let mut set_parts: Vec<String> = Vec::new();
701                let mut placeholder_index = 0;
702
703                // 非 Option 字段:始终更新为当前值
704                #(
705                    set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
706                    placeholder_index += 1;
707                )*
708
709                // Option 字段:根据数据库类型处理
710                match DB::get_driver() {
711                    DbDriver::Sqlite => {
712                        // SQLite 不支持 DEFAULT,跳过 None 字段
713                        #(
714                            if self.#update_option_field_names.is_some() {
715                                set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
716                                placeholder_index += 1;
717                            }
718                        )*
719                    }
720                    _ => {
721                        // MySQL 和 PostgreSQL 使用 DEFAULT
722                        #(
723                            if self.#update_option_field_names.is_some() {
724                                set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
725                                placeholder_index += 1;
726                            } else {
727                                set_parts.push(format!("{} = DEFAULT", DB::escape_identifier(#update_option_field_columns)));
728                            }
729                        )*
730                    }
731                }
732
733                if set_parts.is_empty() {
734                    return Ok(());
735                }
736
737                let sql = format!(
738                    "UPDATE {} SET {} WHERE {} = {}",
739                    escaped_table,
740                    set_parts.join(", "),
741                    escaped_pk,
742                    DB::placeholder(placeholder_index)
743                );
744
745                let mut query = sqlx::query(&sql);
746                // 非 Option 字段:始终绑定
747                #(
748                    query = query.bind(&self.#update_normal_field_names);
749                )*
750                // Option 字段:仅当为 Some 时绑定(None 使用 DEFAULT 或跳过)
751                #(
752                    if let Some(ref val) = self.#update_option_field_names {
753                        query = query.bind(val);
754                    }
755                )*
756                query = query.bind(&self.#pk_ident);
757                query.execute(executor).await?;
758                Ok(())
759            }
760        }
761    };
762
763    // 生成 UpdateFields trait 实现(用于 UpdateBuilder 和 InsertBuilder)
764    // 注意:只对 BindValue 支持的基本类型生成转换代码
765    // 对于复杂类型(如 DateTime、JsonValue 等),get_field_value 返回 None
766    // InsertBuilder 和 UpdateBuilder 需要直接使用 sqlx::bind 来处理这些类型
767    let update_fields_impl = quote! {
768        impl sqlxplus::builder::update_builder::UpdateFields for #name {
769            fn get_field_value(&self, field_name: &str) -> Option<sqlxplus::builder::query_builder::BindValue> {
770                match field_name {
771                    // 支持字段名和列名两种匹配方式(处理 r#type 这样的原始标识符)
772                    #(
773                        #update_fields_normal_field_columns | #update_fields_normal_field_name_strs => {
774                            // 对于非 Option 类型,转换为 BindValue(只包含支持的类型)
775                            Some(sqlxplus::builder::query_builder::BindValue::from(self.#update_fields_normal_field_names.clone()))
776                        }
777                    )*
778                    #(
779                        #update_fields_option_field_columns | #update_fields_option_field_name_strs => {
780                            // 对于 Option 类型,如果是 Some 则转换,None 则返回 None(只包含支持的类型)
781                            self.#update_fields_option_field_names.as_ref().map(|v| {
782                                sqlxplus::builder::query_builder::BindValue::from(v.clone())
783                            })
784                        }
785                    )*
786                    _ => None, // 不支持的类型或未包含的字段返回 None
787                }
788            }
789
790            fn get_all_field_names() -> &'static [&'static str] {
791                &[
792                    #(#update_normal_field_columns,)*
793                    #(#update_option_field_columns,)*
794                ]
795            }
796
797            fn has_field(field_name: &str) -> bool {
798                // 支持字段名和列名两种匹配方式(处理 r#type 这样的原始标识符)
799                #(
800                    if field_name == #update_normal_field_columns || field_name == #update_fields_normal_field_name_strs {
801                        return true;
802                    }
803                )*
804                #(
805                    if field_name == #update_option_field_columns || field_name == #update_fields_option_field_name_strs {
806                        return true;
807                    }
808                )*
809                false
810            }
811        }
812    };
813
814    let expanded = quote! {
815        #expanded
816        #update_fields_impl
817    };
818
819    TokenStream::from(expanded)
820}
821
822/// 判断字段类型是否为 Option<T>
823fn is_option_type(ty: &syn::Type) -> bool {
824    if let syn::Type::Path(type_path) = ty {
825        if let Some(seg) = type_path.path.segments.last() {
826            if seg.ident == "Option" {
827                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
828                    return args.args.len() == 1;
829                }
830            }
831        }
832    }
833    false
834}
835
836/// 检查类型是否是 BindValue 支持的基本类型
837/// 支持的类型:String, i64, i32, i16, f64, f32, bool, Vec<u8>
838fn is_bind_value_supported_type(ty: &syn::Type) -> bool {
839    if let syn::Type::Path(type_path) = ty {
840        if let Some(seg) = type_path.path.segments.last() {
841            let type_name = seg.ident.to_string();
842            // 检查是否是支持的基本类型
843            match type_name.as_str() {
844                "String" | "i64" | "i32" | "i16" | "f64" | "f32" | "bool" => true,
845                "Vec" => {
846                    // 对于 Vec,检查是否是 Vec<u8>
847                    if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
848                        if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
849                            if let syn::Type::Path(inner_path) = inner_ty {
850                                if let Some(inner_seg) = inner_path.path.segments.last() {
851                                    return inner_seg.ident == "u8";
852                                }
853                            }
854                        }
855                    }
856                    false
857                }
858                _ => false,
859            }
860        } else {
861            false
862        }
863    } else {
864        false
865    }
866}
867
868/// 获取 Option 内部的类型
869fn get_option_inner_type(ty: &syn::Type) -> Option<&syn::Type> {
870    if let syn::Type::Path(type_path) = ty {
871        if let Some(seg) = type_path.path.segments.last() {
872            if seg.ident == "Option" {
873                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
874                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
875                        return Some(inner_ty);
876                    }
877                }
878            }
879        }
880    }
881    None
882}