1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 Attribute, Data, DeriveInput, Field, Fields, Ident, Token, parse_macro_input,
5 punctuated::Punctuated,
6};
7
8mod relations;
9
10#[proc_macro_derive(Model, attributes(has_many, belongs_to, premix))]
11pub fn derive_model(input: TokenStream) -> TokenStream {
12 let input = parse_macro_input!(input as DeriveInput);
13 match derive_model_impl(&input) {
14 Ok(tokens) => TokenStream::from(tokens),
15 Err(err) => TokenStream::from(err.to_compile_error()),
16 }
17}
18
19fn derive_model_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
20 let impl_block = generate_generic_impl(input)?;
21 let rel_block = relations::impl_relations(input)?;
22 Ok(quote! {
23 #impl_block
24 #rel_block
25 })
26}
27
28#[cfg(test)]
29mod tests {
30 use syn::parse_quote;
31
32 use super::*;
33
34 #[test]
35 fn generate_generic_impl_includes_table_and_columns() {
36 let input: DeriveInput = parse_quote! {
37 struct User {
38 id: i32,
39 name: String,
40 version: i32,
41 deleted_at: Option<String>,
42 }
43 };
44 let tokens = generate_generic_impl(&input).unwrap().to_string();
45 assert!(tokens.contains("CREATE TABLE IF NOT EXISTS"));
46 assert!(tokens.contains("users"));
47 assert!(tokens.contains("deleted_at"));
48 assert!(tokens.contains("version"));
49 }
50
51 #[test]
52 fn generate_generic_impl_rejects_tuple_struct() {
53 let input: DeriveInput = parse_quote! {
54 struct User(i32, String);
55 };
56 let err = generate_generic_impl(&input).unwrap_err();
57 assert!(err.to_string().contains("named fields"));
58 }
59
60 #[test]
61 fn generate_generic_impl_rejects_non_struct() {
62 let input: DeriveInput = parse_quote! {
63 enum User {
64 A,
65 B,
66 }
67 };
68 let err = generate_generic_impl(&input).unwrap_err();
69 assert!(err.to_string().contains("only supports structs"));
70 }
71
72 #[test]
73 fn generate_generic_impl_version_update_branch() {
74 let input: DeriveInput = parse_quote! {
75 struct User {
76 id: i32,
77 version: i32,
78 name: String,
79 }
80 };
81 let tokens = generate_generic_impl(&input).unwrap().to_string();
82 assert!(tokens.contains("version = version + 1"));
83 }
84
85 #[test]
86 fn generate_generic_impl_no_version_branch() {
87 let input: DeriveInput = parse_quote! {
88 struct User {
89 id: i32,
90 name: String,
91 }
92 };
93 let tokens = generate_generic_impl(&input).unwrap().to_string();
94 assert!(!tokens.contains("version = version + 1"));
95 }
96
97 #[test]
98 fn generate_generic_impl_includes_default_hooks_and_validation() {
99 let input: DeriveInput = parse_quote! {
100 struct User {
101 id: i32,
102 name: String,
103 }
104 };
105 let tokens = generate_generic_impl(&input).unwrap().to_string();
106 assert!(tokens.contains("ModelHooks"));
107 assert!(tokens.contains("ModelValidation"));
108 }
109
110 #[test]
111 fn generate_generic_impl_skips_custom_hooks_and_validation() {
112 let input: DeriveInput = parse_quote! {
113 #[premix(custom_hooks, custom_validation)]
114 struct User {
115 id: i32,
116 name: String,
117 }
118 };
119 let tokens = generate_generic_impl(&input).unwrap().to_string();
120 assert!(!tokens.contains("impl premix_orm :: ModelHooks"));
121 assert!(!tokens.contains("impl premix_orm :: ModelValidation"));
122 }
123
124 #[test]
125 fn is_ignored_detects_attribute() {
126 let field: Field = parse_quote! {
127 #[premix(ignore)]
128 ignored: Option<String>
129 };
130 assert!(is_ignored(&field));
131 }
132
133 #[test]
134 fn is_ignored_false_for_other_attrs() {
135 let field: Field = parse_quote! {
136 #[serde(skip)]
137 name: String
138 };
139 assert!(!is_ignored(&field));
140 }
141
142 #[test]
143 fn is_ignored_false_for_premix_other_arg() {
144 let field: Field = parse_quote! {
145 #[premix(skip)]
146 name: String
147 };
148 assert!(!is_ignored(&field));
149 }
150
151 #[test]
152 fn is_ignored_false_when_premix_has_no_args() {
153 let field: Field = parse_quote! {
154 #[premix]
155 name: String
156 };
157 assert!(!is_ignored(&field));
158 }
159
160 #[test]
161 fn derive_model_impl_emits_tokens() {
162 let input: DeriveInput = parse_quote! {
163 struct User {
164 id: i32,
165 name: String,
166 }
167 };
168 let tokens = derive_model_impl(&input).unwrap().to_string();
169 assert!(tokens.contains("impl"));
170 }
171
172 #[test]
173 fn derive_model_impl_propagates_error() {
174 let input: DeriveInput = parse_quote! {
175 enum User {
176 A,
177 }
178 };
179 let err = derive_model_impl(&input).unwrap_err();
180 assert!(err.to_string().contains("only supports structs"));
181 }
182
183 #[test]
184 fn generate_generic_impl_includes_soft_delete_delete_impl() {
185 let input: DeriveInput = parse_quote! {
186 struct AuditLog {
187 id: i32,
188 deleted_at: Option<String>,
189 }
190 };
191 let tokens = generate_generic_impl(&input).unwrap().to_string();
192 assert!(tokens.contains("deleted_at ="));
193 assert!(tokens.contains("has_soft_delete"));
194 }
195
196 #[test]
197 fn generate_generic_impl_ignores_marked_fields() {
198 let input: DeriveInput = parse_quote! {
199 struct User {
200 id: i32,
201 name: String,
202 #[premix(ignore)]
203 temp: Option<String>,
204 }
205 };
206 let tokens = generate_generic_impl(&input).unwrap().to_string();
207 assert!(tokens.contains("temp : None"));
208 assert!(!tokens.contains("\"temp\""));
209 }
210
211 #[test]
212 fn generate_generic_impl_adds_relation_bounds() {
213 let input: DeriveInput = parse_quote! {
214 struct User {
215 id: i32,
216 #[has_many(Post)]
217 posts: Vec<Post>,
218 }
219 };
220 let tokens = generate_generic_impl(&input).unwrap().to_string();
221 assert!(tokens.contains("Post : premix_orm :: Model < DB >"));
222 }
223
224 #[test]
225 fn generate_generic_impl_records_field_names() {
226 let input: DeriveInput = parse_quote! {
227 struct Account {
228 id: i32,
229 user_id: i32,
230 is_active: bool,
231 }
232 };
233 let tokens = generate_generic_impl(&input).unwrap().to_string();
234 assert!(tokens.contains("\"user_id\""));
235 assert!(tokens.contains("\"is_active\""));
236 }
237}
238
239fn generate_generic_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
240 let struct_name = &input.ident;
241 let table_name = struct_name.to_string().to_lowercase() + "s";
242 let custom_hooks = has_premix_flag(&input.attrs, "custom_hooks");
243 let custom_validation = has_premix_flag(&input.attrs, "custom_validation");
244
245 let all_fields = if let Data::Struct(data) = &input.data {
246 if let Fields::Named(fields) = &data.fields {
247 &fields.named
248 } else {
249 return Err(syn::Error::new_spanned(
250 &data.fields,
251 "Premix Model only supports structs with named fields",
252 ));
253 }
254 } else {
255 return Err(syn::Error::new_spanned(
256 input,
257 "Premix Model only supports structs",
258 ));
259 };
260
261 let mut db_fields = Vec::new();
262 let mut ignored_field_idents = Vec::new();
263
264 for field in all_fields {
265 if is_ignored(field) {
266 ignored_field_idents.push(field.ident.as_ref().unwrap());
267 } else {
268 db_fields.push(field);
269 }
270 }
271
272 let field_idents: Vec<_> = db_fields
273 .iter()
274 .map(|f| f.ident.as_ref().unwrap())
275 .collect();
276 let field_types: Vec<_> = db_fields.iter().map(|f| &f.ty).collect();
277 let field_indices: Vec<_> = (0..db_fields.len()).collect();
278 let field_names: Vec<_> = field_idents.iter().map(|id| id.to_string()).collect();
279 let field_idents_len = field_idents.len();
280
281 let eager_load_body = relations::generate_eager_load_body(input)?;
282 let has_version = field_names.contains(&"version".to_string());
283 let has_soft_delete = field_names.contains(&"deleted_at".to_string());
284
285 let update_impl = if has_version {
286 quote! {
287 async fn update<'a, E>(&mut self, executor: E) -> Result<premix_orm::UpdateResult, premix_orm::sqlx::Error>
288 where
289 E: premix_orm::IntoExecutor<'a, DB = DB>
290 {
291 let mut executor = executor.into_executor();
292 let table_name = Self::table_name();
293 let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_orm::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
294 let id_p = <DB as premix_orm::SqlDialect>::placeholder(1 + #field_idents_len);
295 let ver_p = <DB as premix_orm::SqlDialect>::placeholder(2 + #field_idents_len);
296 let sql = format!(
297 "UPDATE {} SET {}, version = version + 1 WHERE id = {} AND version = {}",
298 table_name, set_clause, id_p, ver_p
299 );
300
301 let mut query = premix_orm::sqlx::query::<DB>(&sql)
302 #( .bind(&self.#field_idents) )*
303 .bind(&self.id)
304 .bind(&self.version);
305
306 let result = executor.execute(query).await?;
307
308 if <DB as premix_orm::SqlDialect>::rows_affected(&result) == 0 {
309 let exists_p = <DB as premix_orm::SqlDialect>::placeholder(1);
310 let exists_sql = format!("SELECT id FROM {} WHERE id = {}", table_name, exists_p);
311 let exists_query = premix_orm::sqlx::query_as::<DB, (i32,)>(&exists_sql).bind(&self.id);
312 let exists = executor.fetch_optional(exists_query).await?;
313
314 if exists.is_none() {
315 Ok(premix_orm::UpdateResult::NotFound)
316 } else {
317 Ok(premix_orm::UpdateResult::VersionConflict)
318 }
319 } else {
320 self.version += 1;
321 Ok(premix_orm::UpdateResult::Success)
322 }
323 }
324 }
325 } else {
326 quote! {
327 async fn update<'a, E>(&mut self, executor: E) -> Result<premix_orm::UpdateResult, premix_orm::sqlx::Error>
328 where
329 E: premix_orm::IntoExecutor<'a, DB = DB>
330 {
331 let mut executor = executor.into_executor();
332 let table_name = Self::table_name();
333 let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_orm::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
334 let id_p = <DB as premix_orm::SqlDialect>::placeholder(1 + #field_idents_len);
335 let sql = format!("UPDATE {} SET {} WHERE id = {}", table_name, set_clause, id_p);
336
337 let mut query = premix_orm::sqlx::query::<DB>(&sql)
338 #( .bind(&self.#field_idents) )*
339 .bind(&self.id);
340
341 let result = executor.execute(query).await?;
342
343 if <DB as premix_orm::SqlDialect>::rows_affected(&result) == 0 {
344 Ok(premix_orm::UpdateResult::NotFound)
345 } else {
346 Ok(premix_orm::UpdateResult::Success)
347 }
348 }
349 }
350 };
351
352 let delete_impl = if has_soft_delete {
353 quote! {
354 async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_orm::sqlx::Error>
355 where
356 E: premix_orm::IntoExecutor<'a, DB = DB>
357 {
358 let mut executor = executor.into_executor();
359 let table_name = Self::table_name();
360 let id_p = <DB as premix_orm::SqlDialect>::placeholder(1);
361 let sql = format!("UPDATE {} SET deleted_at = {} WHERE id = {}", table_name, <DB as premix_orm::SqlDialect>::current_timestamp_fn(), id_p);
362
363 let query = premix_orm::sqlx::query::<DB>(&sql).bind(&self.id);
364 executor.execute(query).await?;
365
366 self.deleted_at = Some("DELETED".to_string());
367 Ok(())
368 }
369 fn has_soft_delete() -> bool { true }
370 }
371 } else {
372 quote! {
373 async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_orm::sqlx::Error>
374 where
375 E: premix_orm::IntoExecutor<'a, DB = DB>
376 {
377 let mut executor = executor.into_executor();
378 let table_name = Self::table_name();
379 let id_p = <DB as premix_orm::SqlDialect>::placeholder(1);
380 let sql = format!("DELETE FROM {} WHERE id = {}", table_name, id_p);
381
382 let query = premix_orm::sqlx::query::<DB>(&sql).bind(&self.id);
383 executor.execute(query).await?;
384
385 Ok(())
386 }
387 fn has_soft_delete() -> bool { false }
388 }
389 };
390
391 let mut related_model_bounds = Vec::new();
392 for field in all_fields {
393 for attr in &field.attrs {
394 if (attr.path().is_ident("has_many") || attr.path().is_ident("belongs_to"))
395 && let Ok(related_ident) = attr.parse_args::<syn::Ident>()
396 {
397 related_model_bounds.push(quote! { #related_ident: premix_orm::Model<DB> });
398 }
399 }
400 }
401
402 let hooks_impl = if custom_hooks {
403 quote! {}
404 } else {
405 quote! {
406 #[premix_orm::async_trait::async_trait]
407 impl premix_orm::ModelHooks for #struct_name {}
408 }
409 };
410
411 let validation_impl = if custom_validation {
412 quote! {}
413 } else {
414 quote! {
415 impl premix_orm::ModelValidation for #struct_name {}
416 }
417 };
418
419 Ok(quote! {
421 impl<'r, R> premix_orm::sqlx::FromRow<'r, R> for #struct_name
422 where
423 R: premix_orm::sqlx::Row,
424 R::Database: premix_orm::sqlx::Database,
425 #(
426 #field_types: premix_orm::sqlx::Type<R::Database> + premix_orm::sqlx::Decode<'r, R::Database>,
427 )*
428 for<'c> &'c str: premix_orm::sqlx::ColumnIndex<R>,
429 {
430 fn from_row(row: &'r R) -> Result<Self, premix_orm::sqlx::Error> {
431 use premix_orm::sqlx::Row;
432 Ok(Self {
433 #(
434 #field_idents: row.try_get(#field_names)?,
435 )*
436 #(
437 #ignored_field_idents: None,
438 )*
439 })
440 }
441 }
442
443 #[premix_orm::async_trait::async_trait]
444 impl<DB> premix_orm::Model<DB> for #struct_name
445 where
446 DB: premix_orm::SqlDialect,
447 for<'c> &'c str: premix_orm::sqlx::ColumnIndex<DB::Row>,
448 usize: premix_orm::sqlx::ColumnIndex<DB::Row>,
449 for<'q> <DB as premix_orm::sqlx::Database>::Arguments<'q>: premix_orm::sqlx::IntoArguments<'q, DB>,
450 for<'c> &'c mut <DB as premix_orm::sqlx::Database>::Connection: premix_orm::sqlx::Executor<'c, Database = DB>,
451 i32: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
452 i64: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
453 String: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
454 bool: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
455 Option<String>: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
456 #( #related_model_bounds, )*
457 {
458 fn table_name() -> &'static str {
459 #table_name
460 }
461
462 fn create_table_sql() -> String {
463 let mut cols = vec!["id ".to_string() + <DB as premix_orm::SqlDialect>::auto_increment_pk()];
464 #(
465 if #field_names != "id" {
466 let field_name: &str = #field_names;
467 let sql_type = if field_name.ends_with("_id") {
468 <DB as premix_orm::SqlDialect>::int_type()
469 } else {
470 match field_name {
471 "name" | "title" | "status" | "email" | "role" => <DB as premix_orm::SqlDialect>::text_type(),
472 "age" | "version" | "price" | "balance" => <DB as premix_orm::SqlDialect>::int_type(),
473 "is_active" => <DB as premix_orm::SqlDialect>::bool_type(),
474 "deleted_at" => <DB as premix_orm::SqlDialect>::text_type(),
475 _ => <DB as premix_orm::SqlDialect>::text_type(),
476 }
477 };
478 cols.push(format!("{} {}", #field_names, sql_type));
479 }
480 )*
481 format!("CREATE TABLE IF NOT EXISTS {} ({})", #table_name, cols.join(", "))
482 }
483
484 fn list_columns() -> Vec<String> {
485 vec![ #( #field_names.to_string() ),* ]
486 }
487
488 async fn save<'a, E>(&mut self, executor: E) -> Result<(), premix_orm::sqlx::Error>
489 where
490 E: premix_orm::IntoExecutor<'a, DB = DB>
491 {
492 let mut executor = executor.into_executor();
493 use premix_orm::ModelHooks;
494 self.before_save().await?;
495
496 let columns: Vec<&str> = vec![ #( #field_names ),* ]
498 .into_iter()
499 .filter(|&c| {
500 if c == "id" { return self.id != 0; }
501 true
502 })
503 .collect();
504
505 let placeholders = (1..=columns.len())
506 .map(|i| <DB as premix_orm::SqlDialect>::placeholder(i))
507 .collect::<Vec<_>>()
508 .join(", ");
509
510 let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, columns.join(", "), placeholders);
511
512 let mut query = premix_orm::sqlx::query::<DB>(&sql);
513
514 #(
516 if #field_names != "id" {
517 query = query.bind(&self.#field_idents);
518 } else {
519 if self.id != 0 {
520 query = query.bind(&self.id);
521 }
522 }
523 )*
524
525 let result = executor.execute(query).await?;
526
527 let last_id = <DB as premix_orm::SqlDialect>::last_insert_id(&result);
529 if last_id > 0 {
530 self.id = last_id as i32;
531 }
532
533 self.after_save().await?;
534 Ok(())
535 }
536
537 #update_impl
538 #delete_impl
539
540 async fn find_by_id<'a, E>(executor: E, id: i32) -> Result<Option<Self>, premix_orm::sqlx::Error>
541 where
542 E: premix_orm::IntoExecutor<'a, DB = DB>
543 {
544 let mut executor = executor.into_executor();
545 let p = <DB as premix_orm::SqlDialect>::placeholder(1);
546 let mut where_clause = format!("WHERE id = {}", p);
547 if Self::has_soft_delete() {
548 where_clause.push_str(" AND deleted_at IS NULL");
549 }
550 let sql = format!("SELECT * FROM {} {} LIMIT 1", #table_name, where_clause);
551 let query = premix_orm::sqlx::query_as::<DB, Self>(&sql).bind(id);
552
553 executor.fetch_optional(query).await
554 }
555
556 async fn eager_load<'a, E>(models: &mut [Self], relation: &str, executor: E) -> Result<(), premix_orm::sqlx::Error>
557 where
558 E: premix_orm::IntoExecutor<'a, DB = DB>
559 {
560 let mut executor = executor.into_executor();
561 #eager_load_body
562 }
563 }
564
565 #hooks_impl
566 #validation_impl
567 })
568}
569
570fn is_ignored(field: &Field) -> bool {
571 for attr in &field.attrs {
572 if attr.path().is_ident("premix")
573 && let Ok(meta) = attr.parse_args::<syn::Ident>()
574 && meta == "ignore"
575 {
576 return true;
577 }
578 }
579 false
580}
581
582fn has_premix_flag(attrs: &[Attribute], flag: &str) -> bool {
583 for attr in attrs {
584 if attr.path().is_ident("premix") {
585 let args = attr.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated);
586 if let Ok(args) = args {
587 if args.iter().any(|ident| ident == flag) {
588 return true;
589 }
590 }
591 }
592 }
593 false
594}