1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, DeriveInput, Data, Fields};
4
5#[proc_macro_derive(Model, attributes(validate))]
6pub fn derive_model(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 let name = input.ident;
9
10 let table_name = name.to_string().to_lowercase() + "s";
12
13 let fields = match input.data {
15 Data::Struct(ref data) => {
16 match data.fields {
17 Fields::Named(ref fields) => {
18 fields.named.iter().collect::<Vec<_>>()
19 },
20 _ => Vec::new(),
21 }
22 },
23 _ => Vec::new(),
24 };
25
26 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
27 let field_names_str: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
28
29 let has_created_at = field_names_str.contains(&"created_at".to_string());
31 let has_updated_at = field_names_str.contains(&"updated_at".to_string());
32
33 let has_deleted_at = field_names_str.contains(&"deleted_at".to_string());
35
36 let non_id_fields: Vec<_> = fields.iter()
39 .filter(|f| {
40 let name = f.ident.as_ref().unwrap().to_string();
41 name != "id" && name != "created_at" && name != "updated_at" && name != "deleted_at"
42 })
43 .collect();
44
45 let non_id_names: Vec<_> = non_id_fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
46 let non_id_names_str: Vec<_> = non_id_names.iter().map(|f| f.to_string()).collect();
47
48 let mut create_cols_list = non_id_names_str.clone();
50 if has_created_at { create_cols_list.push("created_at".to_string()); }
51 if has_updated_at { create_cols_list.push("updated_at".to_string()); }
52 let create_cols = create_cols_list.join(", ");
55 let create_placeholders: Vec<_> = (1..=create_cols_list.len()).map(|i| format!("${}", i)).collect();
56 let create_placeholders_str = create_placeholders.join(", ");
57 let create_query = format!("INSERT INTO {} ({}) VALUES ({})", table_name, create_cols, create_placeholders_str);
58
59 let mut update_sets_list = Vec::new();
61 for (i, name) in non_id_names_str.iter().enumerate() {
62 update_sets_list.push(format!("{} = ${}", name, i + 1));
63 }
64
65 let mut param_count = non_id_names_str.len();
66 if has_updated_at {
67 param_count += 1;
68 update_sets_list.push(format!("updated_at = ${}", param_count));
69 }
70
71 let update_sets_str = update_sets_list.join(", ");
72 let update_where = format!("WHERE id = ${}", param_count + 1);
73 let update_query = format!("UPDATE {} SET {} {}", table_name, update_sets_str, update_where);
74
75 let hard_delete_query = format!("DELETE FROM {} WHERE id = $1", table_name);
77
78 let delete_impl = if has_deleted_at {
79 let soft_delete_query = format!("UPDATE {} SET deleted_at = $1 WHERE id = $2", table_name);
80 quote! {
81 async fn delete(&self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
82 let now = oxidite_db::chrono::Utc::now().timestamp();
83 let query = oxidite_db::sqlx::query(#soft_delete_query)
84 .bind(now)
85 .bind(&self.id);
86 db.execute_query(query).await?;
87 Ok(())
88 }
89 }
90 } else {
91 quote! {
92 async fn delete(&self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
93 let query = oxidite_db::sqlx::query(#hard_delete_query)
94 .bind(&self.id);
95 db.execute_query(query).await?;
96 Ok(())
97 }
98 }
99 };
100
101 let created_at_logic = if has_created_at {
103 quote! {
104 let now = oxidite_db::chrono::Utc::now().timestamp();
105 self.created_at = now;
106 let query = query.bind(now);
107 }
108 } else {
109 quote! {}
110 };
111
112 let updated_at_create_logic = if has_updated_at {
113 quote! {
114 let now = oxidite_db::chrono::Utc::now().timestamp();
115 self.updated_at = now;
116 let query = query.bind(now);
117 }
118 } else {
119 quote! {}
120 };
121
122 let updated_at_update_logic = if has_updated_at {
123 quote! {
124 let now = oxidite_db::chrono::Utc::now().timestamp();
125 self.updated_at = now;
126 let query = query.bind(now);
127 }
128 } else {
129 quote! {}
130 };
131
132 let mut validation_checks = Vec::new();
134 for field in &fields {
135 let field_name = field.ident.as_ref().unwrap();
136 for attr in &field.attrs {
137 if attr.path().is_ident("validate") {
138 let attr_str = attr.to_token_stream().to_string();
139 if attr_str.contains("email") {
140 validation_checks.push(quote! {
141 {
142 static EMAIL_REGEX: oxidite_db::once_cell::sync::Lazy<oxidite_db::regex::Regex> =
143 oxidite_db::once_cell::sync::Lazy::new(|| oxidite_db::regex::Regex::new(r"^[^@\s]+@[^@\s]+\.[^@\s]+$").unwrap());
144 if !EMAIL_REGEX.is_match(&self.#field_name) {
145 return Err(format!("Invalid email format for field {}", stringify!(#field_name)));
146 }
147 }
148 });
149 }
150 }
151 }
152 }
153
154 let expanded = quote! {
155 #[oxidite_db::async_trait]
156 impl oxidite_db::Model for #name {
157 fn table_name() -> &'static str {
158 #table_name
159 }
160
161 fn fields() -> &'static [&'static str] {
162 &[#(#field_names_str),*]
163 }
164
165 fn has_soft_delete() -> bool {
166 #has_deleted_at
167 }
168
169 async fn create(&mut self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
170 let query = oxidite_db::sqlx::query(#create_query);
171 #(
172 let query = query.bind(&self.#non_id_names);
173 )*
174 #created_at_logic
175 #updated_at_create_logic
176
177 db.execute_query(query).await?;
178 Ok(())
179 }
180
181 async fn update(&mut self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
182 let query = oxidite_db::sqlx::query(#update_query);
183 #(
184 let query = query.bind(&self.#non_id_names);
185 )*
186 #updated_at_update_logic
187
188 let query = query.bind(&self.id);
189 db.execute_query(query).await?;
190 Ok(())
191 }
192
193 #delete_impl
194
195 async fn force_delete(&self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
196 let query = oxidite_db::sqlx::query(#hard_delete_query)
197 .bind(&self.id);
198 db.execute_query(query).await?;
199 Ok(())
200 }
201
202 fn validate(&self) -> std::result::Result<(), String> {
203 #(#validation_checks)*
204 Ok(())
205 }
206 }
207 };
208
209 TokenStream::from(expanded)
210}