1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6fn strip_raw_identifier_prefix(ident: &str) -> String {
9 if ident.starts_with("r#") {
10 ident[2..].to_string()
11 } else {
12 ident.to_string()
13 }
14}
15
16fn parse_column_name(attrs: &[syn::Attribute], field_name: &str) -> String {
19 for attr in attrs {
20 if attr.path().is_ident("column") {
21 if let syn::Meta::List(list) = &attr.meta {
22 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
23 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
24 for meta in metas {
25 if let Meta::NameValue(nv) = meta {
26 if nv.path.is_ident("name") {
27 if let syn::Expr::Lit(syn::ExprLit {
28 lit: syn::Lit::Str(s),
29 ..
30 }) = nv.value
31 {
32 return s.value();
33 }
34 }
35 }
36 }
37 }
38 } else if let syn::Meta::NameValue(nv) = &attr.meta {
39 if nv.path.is_ident("name") {
40 if let syn::Expr::Lit(syn::ExprLit {
41 lit: syn::Lit::Str(s),
42 ..
43 }) = &nv.value
44 {
45 return s.value();
46 }
47 }
48 }
49 }
50 }
51 strip_raw_identifier_prefix(field_name)
53}
54
55#[proc_macro_derive(ModelMeta, attributes(model, column))]
79pub fn derive_model_meta(input: TokenStream) -> TokenStream {
80 let input = parse_macro_input!(input as DeriveInput);
81 let name = &input.ident;
82
83 let mut table_name = None;
85 let mut pk_field = None;
86 let mut soft_delete_field = None;
87
88 for attr in &input.attrs {
89 if attr.path().is_ident("model") {
90 if let syn::Meta::List(list) = &attr.meta {
92 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
94 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
95 for meta in metas {
96 if let Meta::NameValue(nv) = meta {
97 if nv.path.is_ident("table") {
98 if let syn::Expr::Lit(syn::ExprLit {
99 lit: syn::Lit::Str(s),
100 ..
101 }) = nv.value
102 {
103 table_name = Some(s.value());
104 }
105 } else if nv.path.is_ident("pk") {
106 if let syn::Expr::Lit(syn::ExprLit {
107 lit: syn::Lit::Str(s),
108 ..
109 }) = nv.value
110 {
111 pk_field = Some(s.value());
112 }
113 } else if nv.path.is_ident("soft_delete") {
114 if let syn::Expr::Lit(syn::ExprLit {
115 lit: syn::Lit::Str(s),
116 ..
117 }) = nv.value
118 {
119 soft_delete_field = Some(s.value());
120 }
121 }
122 }
123 }
124 }
125 } else if let syn::Meta::NameValue(nv) = &attr.meta {
126 if nv.path.is_ident("table") {
128 if let syn::Expr::Lit(syn::ExprLit {
129 lit: syn::Lit::Str(s),
130 ..
131 }) = &nv.value
132 {
133 table_name = Some(s.value());
134 }
135 } else if nv.path.is_ident("pk") {
136 if let syn::Expr::Lit(syn::ExprLit {
137 lit: syn::Lit::Str(s),
138 ..
139 }) = &nv.value
140 {
141 pk_field = Some(s.value());
142 }
143 } else if nv.path.is_ident("soft_delete") {
144 if let syn::Expr::Lit(syn::ExprLit {
145 lit: syn::Lit::Str(s),
146 ..
147 }) = &nv.value
148 {
149 soft_delete_field = Some(s.value());
150 }
151 }
152 }
153 }
154 }
155
156 let table = table_name.unwrap_or_else(|| {
158 let s = name.to_string();
159 let mut result = String::new();
161 for (i, c) in s.chars().enumerate() {
162 if c.is_uppercase() && i > 0 {
163 result.push('_');
164 }
165 result.push(c.to_ascii_lowercase());
166 }
167 result
168 });
169
170 let pk = pk_field.unwrap_or_else(|| "id".to_string());
172
173 let expanded = if let Some(soft_delete) = soft_delete_field {
175 let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
177 quote! {
178 impl sqlxplus::Model for #name {
179 const TABLE: &'static str = #table;
180 const PK: &'static str = #pk;
181 const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
182 }
183 }
184 } else {
185 quote! {
187 impl sqlxplus::Model for #name {
188 const TABLE: &'static str = #table;
189 const PK: &'static str = #pk;
190 const SOFT_DELETE_FIELD: Option<&'static str> = None;
191 }
192 }
193 };
194
195 TokenStream::from(expanded)
196}
197
198#[proc_macro_derive(CRUD, attributes(model, skip, column))]
224pub fn derive_crud(input: TokenStream) -> TokenStream {
225 let input = parse_macro_input!(input as DeriveInput);
226 let name = &input.ident;
227
228 let mut pk_field = None;
230 for attr in &input.attrs {
231 if attr.path().is_ident("model") {
232 if let syn::Meta::List(list) = &attr.meta {
233 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
234 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
235 for meta in metas {
236 if let Meta::NameValue(nv) = meta {
237 if nv.path.is_ident("pk") {
238 if let syn::Expr::Lit(syn::ExprLit {
239 lit: syn::Lit::Str(s),
240 ..
241 }) = nv.value
242 {
243 pk_field = Some(s.value());
244 }
245 }
246 }
247 }
248 }
249 } else if let syn::Meta::NameValue(nv) = &attr.meta {
250 if nv.path.is_ident("pk") {
251 if let syn::Expr::Lit(syn::ExprLit {
252 lit: syn::Lit::Str(s),
253 ..
254 }) = &nv.value
255 {
256 pk_field = Some(s.value());
257 }
258 }
259 }
260 }
261 }
262 let pk = pk_field.unwrap_or_else(|| "id".to_string());
264
265 let fields = match &input.data {
267 Data::Struct(DataStruct {
268 fields: Fields::Named(fields),
269 ..
270 }) => &fields.named,
271 _ => {
272 return syn::Error::new_spanned(
273 name,
274 "CRUD derive only supports structs with named fields",
275 )
276 .to_compile_error()
277 .into();
278 }
279 };
280
281 let mut pk_ident_opt: Option<&syn::Ident> = None;
285
286 let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
288 let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
289 let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
290 let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
291
292 let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
294 let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
295 let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
296 let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
297
298 let mut update_fields_normal_field_names: Vec<&syn::Ident> = Vec::new();
300 let mut update_fields_normal_field_columns: Vec<syn::LitStr> = Vec::new();
301 let mut update_fields_option_field_names: Vec<&syn::Ident> = Vec::new();
302 let mut update_fields_option_field_columns: Vec<syn::LitStr> = Vec::new();
303 let mut update_fields_normal_field_name_strs: Vec<syn::LitStr> = Vec::new();
305 let mut update_fields_option_field_name_strs: Vec<syn::LitStr> = Vec::new();
306
307 for field in fields {
308 let field_name = field.ident.as_ref().unwrap();
309 let field_name_str = field_name.to_string();
310 let column_name = parse_column_name(&field.attrs, &field_name_str);
312
313 let mut skip = false;
315 for attr in &field.attrs {
316 if attr.path().is_ident("skip") || attr.path().is_ident("model") {
317 skip = true;
318 break;
319 }
320 }
321
322 if !skip {
323 if field_name_str == pk {
324 pk_ident_opt = Some(field_name);
326 let is_opt = is_option_type(&field.ty);
328 let col_lit = syn::LitStr::new(&column_name, proc_macro2::Span::call_site());
329 let is_supported = if is_opt {
330 if let Some(inner_ty) = get_option_inner_type(&field.ty) {
331 is_bind_value_supported_type(inner_ty)
332 } else {
333 false
334 }
335 } else {
336 is_bind_value_supported_type(&field.ty)
337 };
338 if is_supported {
340 let field_name_lit =
341 syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
342 if is_opt {
343 update_fields_option_field_names.push(field_name);
344 update_fields_option_field_columns.push(col_lit);
345 update_fields_option_field_name_strs.push(field_name_lit);
346 } else {
347 update_fields_normal_field_names.push(field_name);
348 update_fields_normal_field_columns.push(col_lit);
349 update_fields_normal_field_name_strs.push(field_name_lit);
350 }
351 }
352 } else {
353 let is_opt = is_option_type(&field.ty);
355 let col_lit = syn::LitStr::new(&column_name, proc_macro2::Span::call_site());
356
357 let is_supported = if is_opt {
359 if let Some(inner_ty) = get_option_inner_type(&field.ty) {
360 is_bind_value_supported_type(inner_ty)
361 } else {
362 false
363 }
364 } else {
365 is_bind_value_supported_type(&field.ty)
366 };
367
368 if is_opt {
369 insert_option_field_names.push(field_name);
370 insert_option_field_columns.push(col_lit.clone());
371
372 update_option_field_names.push(field_name);
373 update_option_field_columns.push(col_lit.clone());
374
375 if is_supported {
377 let field_name_lit =
378 syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
379 update_fields_option_field_names.push(field_name);
380 update_fields_option_field_columns.push(col_lit);
381 update_fields_option_field_name_strs.push(field_name_lit);
382 }
383 } else {
384 insert_normal_field_names.push(field_name);
385 insert_normal_field_columns.push(col_lit.clone());
386
387 update_normal_field_names.push(field_name);
388 update_normal_field_columns.push(col_lit.clone());
389
390 if is_supported {
392 let field_name_lit =
393 syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
394 update_fields_normal_field_names.push(field_name);
395 update_fields_normal_field_columns.push(col_lit);
396 update_fields_normal_field_name_strs.push(field_name_lit);
397 }
398 }
399 }
400 }
401 }
402
403 let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
405
406 let expanded = quote! {
408 #[async_trait::async_trait]
410 impl sqlxplus::Crud for #name {
411 async fn insert<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<sqlxplus::crud::Id>
413 where
414 DB: sqlx::Database + sqlxplus::DatabaseInfo,
415 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
416 E: sqlxplus::DatabaseType<DB = DB>
417 + sqlx::Executor<'c, Database = DB>
418 + Send,
419 i64: sqlx::Type<DB> + for<'r> sqlx::Decode<'r, DB>,
420 usize: sqlx::ColumnIndex<DB::Row>,
421 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
424 i64: for<'b> sqlx::Encode<'b, DB>,
425 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
426 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
427 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
428 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
429 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
430 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
431 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
432 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
433 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
434 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
435 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
436 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
437 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
438 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
439 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
440 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
441 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
442 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
443 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
444 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
445 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
446 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
447 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
448 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
449 {
450 use sqlxplus::Model;
451 use sqlxplus::DatabaseInfo;
452 use sqlxplus::db_pool::DbDriver;
453 let table = Self::TABLE;
454 let escaped_table = DB::escape_identifier(table);
455
456 let mut columns: Vec<&str> = Vec::new();
458 let mut placeholders: Vec<String> = Vec::new();
459 let mut placeholder_index = 0;
460
461 #(
463 columns.push(#insert_normal_field_columns);
464 placeholders.push(DB::placeholder(placeholder_index));
465 placeholder_index += 1;
466 )*
467
468 #(
470 if self.#insert_option_field_names.is_some() {
471 columns.push(#insert_option_field_columns);
472 placeholders.push(DB::placeholder(placeholder_index));
473 placeholder_index += 1;
474 }
475 )*
476
477 let sql = match DB::get_driver() {
479 DbDriver::Postgres => {
480 let pk = Self::PK;
481 let escaped_pk = DB::escape_identifier(pk);
482 format!(
483 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
484 escaped_table,
485 columns.join(", "),
486 placeholders.join(", "),
487 escaped_pk
488 )
489 }
490 _ => {
491 format!(
492 "INSERT INTO {} ({}) VALUES ({})",
493 escaped_table,
494 columns.join(", "),
495 placeholders.join(", ")
496 )
497 }
498 };
499
500 match DB::get_driver() {
502 DbDriver::Postgres => {
503 let mut query = sqlx::query_scalar::<_, i64>(&sql);
504 #(
506 query = query.bind(&self.#insert_normal_field_names);
507 )*
508 #(
510 if let Some(ref val) = self.#insert_option_field_names {
511 query = query.bind(val);
512 }
513 )*
514 let id: i64 = query.fetch_one(executor).await?;
515 Ok(id)
516 }
517 DbDriver::MySql => {
518 let mut query = sqlx::query(&sql);
519 #(
521 query = query.bind(&self.#insert_normal_field_names);
522 )*
523 #(
525 if let Some(ref val) = self.#insert_option_field_names {
526 query = query.bind(val);
527 }
528 )*
529 let result = query.execute(executor).await?;
530 unsafe {
534 use sqlx::mysql::MySqlQueryResult;
535 let ptr: *const DB::QueryResult = &result;
536 let mysql_ptr = ptr as *const MySqlQueryResult;
537 Ok((*mysql_ptr).last_insert_id() as i64)
538 }
539 }
540 DbDriver::Sqlite => {
541 let mut query = sqlx::query(&sql);
542 #(
544 query = query.bind(&self.#insert_normal_field_names);
545 )*
546 #(
548 if let Some(ref val) = self.#insert_option_field_names {
549 query = query.bind(val);
550 }
551 )*
552 let result = query.execute(executor).await?;
553 unsafe {
555 use sqlx::sqlite::SqliteQueryResult;
556 let ptr: *const DB::QueryResult = &result;
557 let sqlite_ptr = ptr as *const SqliteQueryResult;
558 Ok((*sqlite_ptr).last_insert_rowid() as i64)
559 }
560 }
561 }
562 }
563
564 async fn update<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
566 where
567 DB: sqlx::Database + sqlxplus::DatabaseInfo,
568 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
569 E: sqlxplus::DatabaseType<DB = DB>
570 + sqlx::Executor<'c, Database = DB>
571 + Send,
572 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
575 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
576 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
577 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
578 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
579 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
580 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
581 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
582 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
583 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
584 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
585 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
586 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
587 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
588 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
589 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
590 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
591 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
592 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
593 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
594 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
595 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
596 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
597 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
598 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
599 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
600 {
601 use sqlxplus::Model;
602 use sqlxplus::DatabaseInfo;
603 let table = Self::TABLE;
604 let pk = Self::PK;
605 let escaped_table = DB::escape_identifier(table);
606 let escaped_pk = DB::escape_identifier(pk);
607
608 let mut set_parts: Vec<String> = Vec::new();
610 let mut placeholder_index = 0;
611
612 #(
614 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
615 placeholder_index += 1;
616 )*
617
618 #(
620 if self.#update_option_field_names.is_some() {
621 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
622 placeholder_index += 1;
623 }
624 )*
625
626 if set_parts.is_empty() {
627 return Ok(());
628 }
629
630 let sql = format!(
631 "UPDATE {} SET {} WHERE {} = {}",
632 escaped_table,
633 set_parts.join(", "),
634 escaped_pk,
635 DB::placeholder(placeholder_index)
636 );
637
638 let mut query = sqlx::query(&sql);
639 #(
641 query = query.bind(&self.#update_normal_field_names);
642 )*
643 #(
645 if let Some(ref val) = self.#update_option_field_names {
646 query = query.bind(val);
647 }
648 )*
649 query = query.bind(&self.#pk_ident);
650 query.execute(executor).await?;
651 Ok(())
652 }
653
654 async fn update_with_none<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
656 where
657 DB: sqlx::Database + sqlxplus::DatabaseInfo,
658 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
659 E: sqlxplus::DatabaseType<DB = DB>
660 + sqlx::Executor<'c, Database = DB>
661 + Send,
662 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
665 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
666 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
667 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
668 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
669 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
670 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
671 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
672 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
673 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
674 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
675 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
676 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
677 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
678 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
679 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
680 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
681 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
682 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
683 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
684 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
685 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
686 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
687 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
688 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
689 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
690 {
691 use sqlxplus::Model;
692 use sqlxplus::DatabaseInfo;
693 use sqlxplus::db_pool::DbDriver;
694 let table = Self::TABLE;
695 let pk = Self::PK;
696 let escaped_table = DB::escape_identifier(table);
697 let escaped_pk = DB::escape_identifier(pk);
698
699 let mut set_parts: Vec<String> = Vec::new();
701 let mut placeholder_index = 0;
702
703 #(
705 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
706 placeholder_index += 1;
707 )*
708
709 match DB::get_driver() {
711 DbDriver::Sqlite => {
712 #(
714 if self.#update_option_field_names.is_some() {
715 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
716 placeholder_index += 1;
717 }
718 )*
719 }
720 _ => {
721 #(
723 if self.#update_option_field_names.is_some() {
724 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
725 placeholder_index += 1;
726 } else {
727 set_parts.push(format!("{} = DEFAULT", DB::escape_identifier(#update_option_field_columns)));
728 }
729 )*
730 }
731 }
732
733 if set_parts.is_empty() {
734 return Ok(());
735 }
736
737 let sql = format!(
738 "UPDATE {} SET {} WHERE {} = {}",
739 escaped_table,
740 set_parts.join(", "),
741 escaped_pk,
742 DB::placeholder(placeholder_index)
743 );
744
745 let mut query = sqlx::query(&sql);
746 #(
748 query = query.bind(&self.#update_normal_field_names);
749 )*
750 #(
752 if let Some(ref val) = self.#update_option_field_names {
753 query = query.bind(val);
754 }
755 )*
756 query = query.bind(&self.#pk_ident);
757 query.execute(executor).await?;
758 Ok(())
759 }
760 }
761 };
762
763 let update_fields_impl = quote! {
768 impl sqlxplus::builder::update_builder::UpdateFields for #name {
769 fn get_field_value(&self, field_name: &str) -> Option<sqlxplus::builder::query_builder::BindValue> {
770 match field_name {
771 #(
773 #update_fields_normal_field_columns | #update_fields_normal_field_name_strs => {
774 Some(sqlxplus::builder::query_builder::BindValue::from(self.#update_fields_normal_field_names.clone()))
776 }
777 )*
778 #(
779 #update_fields_option_field_columns | #update_fields_option_field_name_strs => {
780 self.#update_fields_option_field_names.as_ref().map(|v| {
782 sqlxplus::builder::query_builder::BindValue::from(v.clone())
783 })
784 }
785 )*
786 _ => None, }
788 }
789
790 fn get_all_field_names() -> &'static [&'static str] {
791 &[
792 #(#update_normal_field_columns,)*
793 #(#update_option_field_columns,)*
794 ]
795 }
796
797 fn has_field(field_name: &str) -> bool {
798 #(
800 if field_name == #update_normal_field_columns || field_name == #update_fields_normal_field_name_strs {
801 return true;
802 }
803 )*
804 #(
805 if field_name == #update_option_field_columns || field_name == #update_fields_option_field_name_strs {
806 return true;
807 }
808 )*
809 false
810 }
811 }
812 };
813
814 let expanded = quote! {
815 #expanded
816 #update_fields_impl
817 };
818
819 TokenStream::from(expanded)
820}
821
822fn is_option_type(ty: &syn::Type) -> bool {
824 if let syn::Type::Path(type_path) = ty {
825 if let Some(seg) = type_path.path.segments.last() {
826 if seg.ident == "Option" {
827 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
828 return args.args.len() == 1;
829 }
830 }
831 }
832 }
833 false
834}
835
836fn is_bind_value_supported_type(ty: &syn::Type) -> bool {
839 if let syn::Type::Path(type_path) = ty {
840 if let Some(seg) = type_path.path.segments.last() {
841 let type_name = seg.ident.to_string();
842 match type_name.as_str() {
844 "String" | "i64" | "i32" | "i16" | "f64" | "f32" | "bool" => true,
845 "Vec" => {
846 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
848 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
849 if let syn::Type::Path(inner_path) = inner_ty {
850 if let Some(inner_seg) = inner_path.path.segments.last() {
851 return inner_seg.ident == "u8";
852 }
853 }
854 }
855 }
856 false
857 }
858 _ => false,
859 }
860 } else {
861 false
862 }
863 } else {
864 false
865 }
866}
867
868fn get_option_inner_type(ty: &syn::Type) -> Option<&syn::Type> {
870 if let syn::Type::Path(type_path) = ty {
871 if let Some(seg) = type_path.path.segments.last() {
872 if seg.ident == "Option" {
873 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
874 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
875 return Some(inner_ty);
876 }
877 }
878 }
879 }
880 }
881 None
882}