1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Field, Fields, parse_macro_input};
4
5mod relations;
6
7#[proc_macro_derive(Model, attributes(has_many, belongs_to, premix))]
8pub fn derive_model(input: TokenStream) -> TokenStream {
9 let input = parse_macro_input!(input as DeriveInput);
10 match derive_model_impl(&input) {
11 Ok(tokens) => TokenStream::from(tokens),
12 Err(err) => TokenStream::from(err.to_compile_error()),
13 }
14}
15
16fn derive_model_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
17 let impl_block = generate_generic_impl(input)?;
18 let rel_block = relations::impl_relations(input)?;
19 Ok(quote! {
20 #impl_block
21 #rel_block
22 })
23}
24
25#[cfg(test)]
26mod tests {
27 use syn::parse_quote;
28
29 use super::*;
30
31 #[test]
32 fn generate_generic_impl_includes_table_and_columns() {
33 let input: DeriveInput = parse_quote! {
34 struct User {
35 id: i32,
36 name: String,
37 version: i32,
38 deleted_at: Option<String>,
39 }
40 };
41 let tokens = generate_generic_impl(&input).unwrap().to_string();
42 assert!(tokens.contains("CREATE TABLE IF NOT EXISTS"));
43 assert!(tokens.contains("users"));
44 assert!(tokens.contains("deleted_at"));
45 assert!(tokens.contains("version"));
46 }
47
48 #[test]
49 fn generate_generic_impl_rejects_tuple_struct() {
50 let input: DeriveInput = parse_quote! {
51 struct User(i32, String);
52 };
53 let err = generate_generic_impl(&input).unwrap_err();
54 assert!(err.to_string().contains("named fields"));
55 }
56
57 #[test]
58 fn generate_generic_impl_rejects_non_struct() {
59 let input: DeriveInput = parse_quote! {
60 enum User {
61 A,
62 B,
63 }
64 };
65 let err = generate_generic_impl(&input).unwrap_err();
66 assert!(err.to_string().contains("only supports structs"));
67 }
68
69 #[test]
70 fn generate_generic_impl_version_update_branch() {
71 let input: DeriveInput = parse_quote! {
72 struct User {
73 id: i32,
74 version: i32,
75 name: String,
76 }
77 };
78 let tokens = generate_generic_impl(&input).unwrap().to_string();
79 assert!(tokens.contains("version = version + 1"));
80 }
81
82 #[test]
83 fn generate_generic_impl_no_version_branch() {
84 let input: DeriveInput = parse_quote! {
85 struct User {
86 id: i32,
87 name: String,
88 }
89 };
90 let tokens = generate_generic_impl(&input).unwrap().to_string();
91 assert!(!tokens.contains("version = version + 1"));
92 }
93
94 #[test]
95 fn is_ignored_detects_attribute() {
96 let field: Field = parse_quote! {
97 #[premix(ignore)]
98 ignored: Option<String>
99 };
100 assert!(is_ignored(&field));
101 }
102
103 #[test]
104 fn is_ignored_false_for_other_attrs() {
105 let field: Field = parse_quote! {
106 #[serde(skip)]
107 name: String
108 };
109 assert!(!is_ignored(&field));
110 }
111
112 #[test]
113 fn is_ignored_false_for_premix_other_arg() {
114 let field: Field = parse_quote! {
115 #[premix(skip)]
116 name: String
117 };
118 assert!(!is_ignored(&field));
119 }
120
121 #[test]
122 fn is_ignored_false_when_premix_has_no_args() {
123 let field: Field = parse_quote! {
124 #[premix]
125 name: String
126 };
127 assert!(!is_ignored(&field));
128 }
129
130 #[test]
131 fn derive_model_impl_emits_tokens() {
132 let input: DeriveInput = parse_quote! {
133 struct User {
134 id: i32,
135 name: String,
136 }
137 };
138 let tokens = derive_model_impl(&input).unwrap().to_string();
139 assert!(tokens.contains("impl"));
140 }
141
142 #[test]
143 fn derive_model_impl_propagates_error() {
144 let input: DeriveInput = parse_quote! {
145 enum User {
146 A,
147 }
148 };
149 let err = derive_model_impl(&input).unwrap_err();
150 assert!(err.to_string().contains("only supports structs"));
151 }
152
153 #[test]
154 fn generate_generic_impl_includes_soft_delete_delete_impl() {
155 let input: DeriveInput = parse_quote! {
156 struct AuditLog {
157 id: i32,
158 deleted_at: Option<String>,
159 }
160 };
161 let tokens = generate_generic_impl(&input).unwrap().to_string();
162 assert!(tokens.contains("deleted_at ="));
163 assert!(tokens.contains("has_soft_delete"));
164 }
165
166 #[test]
167 fn generate_generic_impl_ignores_marked_fields() {
168 let input: DeriveInput = parse_quote! {
169 struct User {
170 id: i32,
171 name: String,
172 #[premix(ignore)]
173 temp: Option<String>,
174 }
175 };
176 let tokens = generate_generic_impl(&input).unwrap().to_string();
177 assert!(tokens.contains("temp : None"));
178 assert!(!tokens.contains("\"temp\""));
179 }
180
181 #[test]
182 fn generate_generic_impl_adds_relation_bounds() {
183 let input: DeriveInput = parse_quote! {
184 struct User {
185 id: i32,
186 #[has_many(Post)]
187 posts: Vec<Post>,
188 }
189 };
190 let tokens = generate_generic_impl(&input).unwrap().to_string();
191 assert!(tokens.contains("Post : premix_core :: Model < DB >"));
192 }
193
194 #[test]
195 fn generate_generic_impl_records_field_names() {
196 let input: DeriveInput = parse_quote! {
197 struct Account {
198 id: i32,
199 user_id: i32,
200 is_active: bool,
201 }
202 };
203 let tokens = generate_generic_impl(&input).unwrap().to_string();
204 assert!(tokens.contains("\"user_id\""));
205 assert!(tokens.contains("\"is_active\""));
206 }
207}
208
209fn generate_generic_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
210 let struct_name = &input.ident;
211 let table_name = struct_name.to_string().to_lowercase() + "s";
212
213 let all_fields = if let Data::Struct(data) = &input.data {
214 if let Fields::Named(fields) = &data.fields {
215 &fields.named
216 } else {
217 return Err(syn::Error::new_spanned(
218 &data.fields,
219 "Premix Model only supports structs with named fields",
220 ));
221 }
222 } else {
223 return Err(syn::Error::new_spanned(
224 input,
225 "Premix Model only supports structs",
226 ));
227 };
228
229 let mut db_fields = Vec::new();
230 let mut ignored_field_idents = Vec::new();
231
232 for field in all_fields {
233 if is_ignored(field) {
234 ignored_field_idents.push(field.ident.as_ref().unwrap());
235 } else {
236 db_fields.push(field);
237 }
238 }
239
240 let field_idents: Vec<_> = db_fields
241 .iter()
242 .map(|f| f.ident.as_ref().unwrap())
243 .collect();
244 let field_types: Vec<_> = db_fields.iter().map(|f| &f.ty).collect();
245 let field_indices: Vec<_> = (0..db_fields.len()).collect();
246 let field_names: Vec<_> = field_idents.iter().map(|id| id.to_string()).collect();
247 let field_idents_len = field_idents.len();
248
249 let eager_load_body = relations::generate_eager_load_body(input)?;
250 let has_version = field_names.contains(&"version".to_string());
251 let has_soft_delete = field_names.contains(&"deleted_at".to_string());
252
253 let update_impl = if has_version {
254 quote! {
255 async fn update<'a, E>(&mut self, executor: E) -> Result<premix_core::UpdateResult, premix_core::sqlx::Error>
256 where
257 E: premix_core::IntoExecutor<'a, DB = DB>
258 {
259 let mut executor = executor.into_executor();
260 let table_name = Self::table_name();
261 let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_core::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
262 let id_p = <DB as premix_core::SqlDialect>::placeholder(1 + #field_idents_len);
263 let ver_p = <DB as premix_core::SqlDialect>::placeholder(2 + #field_idents_len);
264 let sql = format!(
265 "UPDATE {} SET {}, version = version + 1 WHERE id = {} AND version = {}",
266 table_name, set_clause, id_p, ver_p
267 );
268
269 let mut query = premix_core::sqlx::query::<DB>(&sql)
270 #( .bind(&self.#field_idents) )*
271 .bind(&self.id)
272 .bind(&self.version);
273
274 let result = executor.execute(query).await?;
275
276 if <DB as premix_core::SqlDialect>::rows_affected(&result) == 0 {
277 let exists_p = <DB as premix_core::SqlDialect>::placeholder(1);
278 let exists_sql = format!("SELECT id FROM {} WHERE id = {}", table_name, exists_p);
279 let exists_query = premix_core::sqlx::query_as::<DB, (i32,)>(&exists_sql).bind(&self.id);
280 let exists = executor.fetch_optional(exists_query).await?;
281
282 if exists.is_none() {
283 Ok(premix_core::UpdateResult::NotFound)
284 } else {
285 Ok(premix_core::UpdateResult::VersionConflict)
286 }
287 } else {
288 self.version += 1;
289 Ok(premix_core::UpdateResult::Success)
290 }
291 }
292 }
293 } else {
294 quote! {
295 async fn update<'a, E>(&mut self, executor: E) -> Result<premix_core::UpdateResult, premix_core::sqlx::Error>
296 where
297 E: premix_core::IntoExecutor<'a, DB = DB>
298 {
299 let mut executor = executor.into_executor();
300 let table_name = Self::table_name();
301 let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_core::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
302 let id_p = <DB as premix_core::SqlDialect>::placeholder(1 + #field_idents_len);
303 let sql = format!("UPDATE {} SET {} WHERE id = {}", table_name, set_clause, id_p);
304
305 let mut query = premix_core::sqlx::query::<DB>(&sql)
306 #( .bind(&self.#field_idents) )*
307 .bind(&self.id);
308
309 let result = executor.execute(query).await?;
310
311 if <DB as premix_core::SqlDialect>::rows_affected(&result) == 0 {
312 Ok(premix_core::UpdateResult::NotFound)
313 } else {
314 Ok(premix_core::UpdateResult::Success)
315 }
316 }
317 }
318 };
319
320 let delete_impl = if has_soft_delete {
321 quote! {
322 async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_core::sqlx::Error>
323 where
324 E: premix_core::IntoExecutor<'a, DB = DB>
325 {
326 let mut executor = executor.into_executor();
327 let table_name = Self::table_name();
328 let id_p = <DB as premix_core::SqlDialect>::placeholder(1);
329 let sql = format!("UPDATE {} SET deleted_at = {} WHERE id = {}", table_name, <DB as premix_core::SqlDialect>::current_timestamp_fn(), id_p);
330
331 let query = premix_core::sqlx::query::<DB>(&sql).bind(&self.id);
332 executor.execute(query).await?;
333
334 self.deleted_at = Some("DELETED".to_string());
335 Ok(())
336 }
337 fn has_soft_delete() -> bool { true }
338 }
339 } else {
340 quote! {
341 async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_core::sqlx::Error>
342 where
343 E: premix_core::IntoExecutor<'a, DB = DB>
344 {
345 let mut executor = executor.into_executor();
346 let table_name = Self::table_name();
347 let id_p = <DB as premix_core::SqlDialect>::placeholder(1);
348 let sql = format!("DELETE FROM {} WHERE id = {}", table_name, id_p);
349
350 let query = premix_core::sqlx::query::<DB>(&sql).bind(&self.id);
351 executor.execute(query).await?;
352
353 Ok(())
354 }
355 fn has_soft_delete() -> bool { false }
356 }
357 };
358
359 let mut related_model_bounds = Vec::new();
360 for field in all_fields {
361 for attr in &field.attrs {
362 if (attr.path().is_ident("has_many") || attr.path().is_ident("belongs_to"))
363 && let Ok(related_ident) = attr.parse_args::<syn::Ident>()
364 {
365 related_model_bounds.push(quote! { #related_ident: premix_core::Model<DB> });
366 }
367 }
368 }
369
370 Ok(quote! {
372 impl<'r, R> premix_core::sqlx::FromRow<'r, R> for #struct_name
373 where
374 R: premix_core::sqlx::Row,
375 R::Database: premix_core::sqlx::Database,
376 #(
377 #field_types: premix_core::sqlx::Type<R::Database> + premix_core::sqlx::Decode<'r, R::Database>,
378 )*
379 for<'c> &'c str: premix_core::sqlx::ColumnIndex<R>,
380 {
381 fn from_row(row: &'r R) -> Result<Self, premix_core::sqlx::Error> {
382 use premix_core::sqlx::Row;
383 Ok(Self {
384 #(
385 #field_idents: row.try_get(#field_names)?,
386 )*
387 #(
388 #ignored_field_idents: None,
389 )*
390 })
391 }
392 }
393
394 #[premix_core::async_trait::async_trait]
395 impl<DB> premix_core::Model<DB> for #struct_name
396 where
397 DB: premix_core::SqlDialect,
398 for<'c> &'c str: premix_core::sqlx::ColumnIndex<DB::Row>,
399 usize: premix_core::sqlx::ColumnIndex<DB::Row>,
400 for<'q> <DB as premix_core::sqlx::Database>::Arguments<'q>: premix_core::sqlx::IntoArguments<'q, DB>,
401 for<'c> &'c mut <DB as premix_core::sqlx::Database>::Connection: premix_core::sqlx::Executor<'c, Database = DB>,
402 i32: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
403 i64: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
404 String: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
405 bool: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
406 Option<String>: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
407 #( #related_model_bounds, )*
408 {
409 fn table_name() -> &'static str {
410 #table_name
411 }
412
413 fn create_table_sql() -> String {
414 let mut cols = vec!["id ".to_string() + <DB as premix_core::SqlDialect>::auto_increment_pk()];
415 #(
416 if #field_names != "id" {
417 let field_name: &str = #field_names;
418 let sql_type = if field_name.ends_with("_id") {
419 <DB as premix_core::SqlDialect>::int_type()
420 } else {
421 match field_name {
422 "name" | "title" | "status" | "email" | "role" => <DB as premix_core::SqlDialect>::text_type(),
423 "age" | "version" | "price" | "balance" => <DB as premix_core::SqlDialect>::int_type(),
424 "is_active" => <DB as premix_core::SqlDialect>::bool_type(),
425 "deleted_at" => <DB as premix_core::SqlDialect>::text_type(),
426 _ => <DB as premix_core::SqlDialect>::text_type(),
427 }
428 };
429 cols.push(format!("{} {}", #field_names, sql_type));
430 }
431 )*
432 format!("CREATE TABLE IF NOT EXISTS {} ({})", #table_name, cols.join(", "))
433 }
434
435 fn list_columns() -> Vec<String> {
436 vec![ #( #field_names.to_string() ),* ]
437 }
438
439 async fn save<'a, E>(&mut self, executor: E) -> Result<(), premix_core::sqlx::Error>
440 where
441 E: premix_core::IntoExecutor<'a, DB = DB>
442 {
443 let mut executor = executor.into_executor();
444 use premix_core::ModelHooks;
445 self.before_save().await?;
446
447 let columns: Vec<&str> = vec![ #( #field_names ),* ]
449 .into_iter()
450 .filter(|&c| {
451 if c == "id" { return self.id != 0; }
452 true
453 })
454 .collect();
455
456 let placeholders = (1..=columns.len())
457 .map(|i| <DB as premix_core::SqlDialect>::placeholder(i))
458 .collect::<Vec<_>>()
459 .join(", ");
460
461 let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, columns.join(", "), placeholders);
462
463 let mut query = premix_core::sqlx::query::<DB>(&sql);
464
465 #(
467 if #field_names != "id" {
468 query = query.bind(&self.#field_idents);
469 } else {
470 if self.id != 0 {
471 query = query.bind(&self.id);
472 }
473 }
474 )*
475
476 let result = executor.execute(query).await?;
477
478 let last_id = <DB as premix_core::SqlDialect>::last_insert_id(&result);
480 if last_id > 0 {
481 self.id = last_id as i32;
482 }
483
484 self.after_save().await?;
485 Ok(())
486 }
487
488 #update_impl
489 #delete_impl
490
491 async fn find_by_id<'a, E>(executor: E, id: i32) -> Result<Option<Self>, premix_core::sqlx::Error>
492 where
493 E: premix_core::IntoExecutor<'a, DB = DB>
494 {
495 let mut executor = executor.into_executor();
496 let p = <DB as premix_core::SqlDialect>::placeholder(1);
497 let mut where_clause = format!("WHERE id = {}", p);
498 if Self::has_soft_delete() {
499 where_clause.push_str(" AND deleted_at IS NULL");
500 }
501 let sql = format!("SELECT * FROM {} {} LIMIT 1", #table_name, where_clause);
502 let query = premix_core::sqlx::query_as::<DB, Self>(&sql).bind(id);
503
504 executor.fetch_optional(query).await
505 }
506
507 async fn eager_load<'a, E>(models: &mut [Self], relation: &str, executor: E) -> Result<(), premix_core::sqlx::Error>
508 where
509 E: premix_core::IntoExecutor<'a, DB = DB>
510 {
511 let mut executor = executor.into_executor();
512 #eager_load_body
513 }
514 }
515 })
516}
517
518fn is_ignored(field: &Field) -> bool {
519 for attr in &field.attrs {
520 if attr.path().is_ident("premix")
521 && let Ok(meta) = attr.parse_args::<syn::Ident>()
522 && meta == "ignore"
523 {
524 return true;
525 }
526 }
527 false
528}