1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, LitStr, parse_macro_input};
6
7#[proc_macro_derive(Model, attributes(model))]
8pub fn derive_model(input: TokenStream) -> TokenStream {
9 let input = parse_macro_input!(input as DeriveInput);
10 let name = &input.ident;
11
12 let mut table_name = name.to_string().to_lowercase() + "s"; let mut pk_fields = Vec::new();
14 let mut generated_fields = Vec::new();
15 let mut columns = Vec::new();
16 let mut has_many_rels = Vec::new(); for attr in &input.attrs {
20 if attr.path().is_ident("model") {
21 let _ = attr.parse_nested_meta(|meta| {
22 if meta.path.is_ident("table_name") {
23 let value = meta.value()?;
24 let s: LitStr = value.parse()?;
25 table_name = s.value();
26 }
27 if meta.path.is_ident("has_many") {
28 let mut target_ident: Option<syn::Ident> = None;
29 let mut fk_name = String::new();
30
31 let _ = meta.parse_nested_meta(|inner| {
32 if inner.path.is_ident("fk") {
33 let value = inner.value()?;
34 let s: LitStr = value.parse()?;
35 fk_name = s.value();
36 } else if target_ident.is_none() {
37 target_ident = inner.path.get_ident().cloned();
38 }
39 Ok(())
40 });
41
42 if let Some(ident) = target_ident {
43 has_many_rels.push((ident, fk_name));
44 }
45 }
46 Ok(())
47 });
48 }
49 }
50
51 let mut field_types = Vec::new();
52 let mut non_pk_fields = Vec::new();
53 let mut non_pk_types = Vec::new();
54 let mut belongs_to_fks = Vec::new(); let fields_list = if let Data::Struct(data_struct) = &input.data {
57 if let Fields::Named(syn_fields) = &data_struct.fields {
58 let mut extracted = Vec::new();
59 for f in &syn_fields.named {
60 let field_name = match &f.ident {
61 Some(ident) => ident,
62 None => {
63 return syn::Error::new_spanned(f, "All fields must have names")
64 .to_compile_error()
65 .into();
66 }
67 };
68 let field_name_str = field_name.to_string();
69 columns.push(field_name_str.clone());
70 field_types.push(f.ty.clone());
71
72 let mut is_pk = false;
73 let mut is_gen = false;
74 for attr in &f.attrs {
76 if attr.path().is_ident("model") {
77 let _ = attr.parse_nested_meta(|meta| {
78 if meta.path.is_ident("primary_key") {
79 is_pk = true;
80 }
81 if meta.path.is_ident("generated") {
82 is_gen = true;
83 }
84 if meta.path.is_ident("belongs_to") {
85 let _ = meta.parse_nested_meta(|inner| {
86 if let Some(ident) = inner.path.get_ident() {
87 belongs_to_fks.push((field_name.clone(), ident.clone()));
88 }
89 Ok(())
90 });
91 }
92 Ok(())
93 });
94 }
95 }
96
97 if is_pk {
98 pk_fields.push(field_name.clone());
99 let ty = &f.ty;
100 let ty_str = quote::quote!(#ty).to_string().replace(" ", "");
101 if ty_str == "i32" || ty_str == "i64" {
102 is_gen = true;
103 }
104 } else {
105 non_pk_fields.push(field_name.clone());
106 non_pk_types.push(f.ty.clone());
107 }
108
109 if is_gen {
110 generated_fields.push(field_name.clone());
111 }
112
113 extracted.push(field_name.clone());
114 }
115 extracted
116 } else {
117 return syn::Error::new_spanned(
118 input,
119 "Model can only be derived for structs with named fields",
120 )
121 .to_compile_error()
122 .into();
123 }
124 } else {
125 return syn::Error::new_spanned(
126 input,
127 "Model can only be derived for structs with named fields",
128 )
129 .to_compile_error()
130 .into();
131 };
132
133 if pk_fields.is_empty() {
134 if columns.contains(&"id".to_string()) {
135 pk_fields.push(syn::Ident::new("id", proc_macro2::Span::call_site()));
136 generated_fields.push(syn::Ident::new("id", proc_macro2::Span::call_site()));
137 } else {
138 return syn::Error::new_spanned(name, "Model requires at least one primary key field (e.g., #[model(primary_key)] id) or a field named 'id'").to_compile_error().into();
139 }
140 }
141
142 let field_names_str: Vec<String> = columns.clone();
143 let pk_names_str: Vec<String> = pk_fields.iter().map(|i| i.to_string()).collect();
144 let gen_names_str: Vec<String> = generated_fields.iter().map(|i| i.to_string()).collect();
145
146 let column_enum_name =
147 syn::Ident::new(&format!("{}Column", name), proc_macro2::Span::call_site());
148 let _active_model_name = syn::Ident::new(
149 &format!("{}ActiveModel", name),
150 proc_macro2::Span::call_site(),
151 );
152
153 let gen_field_names = generated_fields.clone();
154 let gen_fields_len = generated_fields.len();
155
156 let mut column_defs = Vec::new();
157 let mut col_names = Vec::new();
158 let mut col_types = Vec::new();
159 for (i, field_name) in columns.iter().enumerate() {
160 let ty = &field_types[i];
161 let is_pk = pk_names_str.contains(field_name);
162 let is_gen = gen_names_str.contains(field_name);
163
164 let mut not_null = true;
165 let mut inner_ty = ty;
166 if let syn::Type::Path(type_path) = ty
167 && let Some(segment) = type_path.path.segments.last()
168 && segment.ident == "Option"
169 {
170 not_null = false;
171 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
172 && let Some(syn::GenericArgument::Type(t)) = args.args.first()
173 {
174 inner_ty = t;
175 }
176 }
177
178 let type_str = quote::quote!(#inner_ty).to_string().replace(" ", "");
179
180 let mut sql_type = match type_str.as_str() {
181 "i32" if is_gen && is_pk && pk_fields.len() == 1 => "SERIAL PRIMARY KEY".to_string(),
182 "i32" if is_gen => "SERIAL".to_string(),
183 "i32" => "INT".to_string(),
184 "i64" if is_gen && is_pk && pk_fields.len() == 1 => "BIGSERIAL PRIMARY KEY".to_string(),
185 "i64" if is_gen => "BIGSERIAL".to_string(),
186 "i64" => "BIGINT".to_string(),
187 "String" => "TEXT".to_string(),
188 "bool" => "BOOLEAN".to_string(),
189 "f64" => "DOUBLE PRECISION".to_string(),
190 _ => "TEXT".to_string(),
191 };
192
193 if is_pk && pk_fields.len() == 1 && !sql_type.contains("PRIMARY KEY") {
194 sql_type.push_str(" PRIMARY KEY");
195 }
196
197 if not_null && !sql_type.contains("PRIMARY KEY") && !sql_type.contains("SERIAL") {
198 sql_type.push_str(" NOT NULL");
199 }
200
201 col_names.push(field_name.clone());
202 col_types.push(sql_type.clone());
203 column_defs.push(format!("{} {}", field_name, sql_type));
204 }
205
206 if pk_fields.len() > 1 {
207 let pk_csv = pk_names_str.join(", ");
208 column_defs.push(format!("PRIMARY KEY ({})", pk_csv));
209 }
210
211 let base_sql = format!(
212 "CREATE TABLE IF NOT EXISTS {} (\n {}\n)",
213 table_name,
214 column_defs.join(",\n ")
215 );
216
217 let fk_fields: Vec<_> = belongs_to_fks.iter().map(|(f, _)| f.clone()).collect();
218 let fk_models: Vec<_> = belongs_to_fks.iter().map(|(_, m)| m.clone()).collect();
219
220 let hm_targets: Vec<_> = has_many_rels.iter().map(|(m, _)| m.clone()).collect();
221 let hm_fks: Vec<_> = has_many_rels.iter().map(|(_, fk)| fk.clone()).collect();
222 let fetch_hm_names: Vec<_> = hm_targets
223 .iter()
224 .map(|m| {
225 syn::Ident::new(
226 &format!("fetch_{}s", m.to_string().to_lowercase()),
227 proc_macro2::Span::call_site(),
228 )
229 })
230 .collect();
231 let fetch_bt_names: Vec<_> = fk_fields
232 .iter()
233 .map(|f| {
234 let fname = f.to_string();
235 let base = fname.strip_suffix("_id").unwrap_or(&fname);
236 syn::Ident::new(&format!("fetch_{}", base), proc_macro2::Span::call_site())
237 })
238 .collect();
239 let first_pk = pk_fields[0].clone();
240 let field_names_join = field_names_str.join(", ");
241 let fields_indices: Vec<usize> = (0..columns.len()).collect();
242
243 let expanded = quote! {
244 impl chopin_orm::Model for #name {
245 fn table_name() -> &'static str {
246 #table_name
247 }
248
249 fn create_table_stmt() -> String {
250 let mut sql = String::from(#base_sql);
251 #(
252 sql.pop(); sql.pop(); let fk_constraint = format!(",\n FOREIGN KEY ({}) REFERENCES {} (id)\n)", stringify!(#fk_fields), <#fk_models as chopin_orm::Model>::table_name());
255 sql.push_str(&fk_constraint);
256 )*
257 sql
258 }
259
260 fn column_definitions() -> Vec<(&'static str, &'static str)> {
261 vec![
262 #( (#col_names, #col_types) ),*
263 ]
264 }
265
266 fn primary_key_columns() -> &'static [&'static str] {
267 &[#(#pk_names_str),*]
268 }
269
270 fn generated_columns() -> &'static [&'static str] {
271 &[#(#gen_names_str),*]
272 }
273
274 fn columns() -> &'static [&'static str] {
275 &[#(#field_names_str),*]
276 }
277
278 fn select_clause() -> &'static str {
279 const COLS: &[&str] = &[#(#field_names_str),*];
280 const JOINED: &str = #field_names_join;
281 JOINED
282 }
283
284 fn primary_key_values(&self) -> Vec<chopin_pg::PgValue> {
285 use chopin_pg::types::ToSql;
286 vec![
287 #(self.#pk_fields.to_sql()),*
288 ]
289 }
290
291 fn get_values(&self) -> Vec<chopin_pg::PgValue> {
292 use chopin_pg::types::ToSql;
293 vec![
294 #(self.#fields_list.to_sql()),*
295 ]
296 }
297
298 fn set_generated_values(&mut self, mut values: Vec<chopin_pg::PgValue>) -> chopin_orm::OrmResult<()> {
299 if values.len() != #gen_fields_len {
300 return Err(chopin_orm::OrmError::ModelError("Generated values length mismatch".to_string()));
301 }
302 let mut iter = values.into_iter();
303 #(
304 if let Some(val) = iter.next() {
305 self.#gen_field_names = chopin_orm::ExtractValue::from_pg_value(val)?;
306 }
307 )*
308 Ok(())
309 }
310 }
311
312 impl chopin_orm::FromRow for #name {
313 fn from_row(row: &chopin_pg::Row) -> chopin_orm::OrmResult<Self> {
314 Ok(Self {
315 #(
316 #fields_list: chopin_orm::ExtractValue::extract_at(row, #fields_indices)?,
317 )*
318 })
319 }
320 }
321
322 impl #name {
323 #(
324 pub fn #fetch_bt_names(&self, executor: &mut impl chopin_orm::Executor) -> chopin_orm::OrmResult<Option<#fk_models>> {
325 use chopin_pg::types::ToSql;
326 use chopin_pg::types::ToParam;
327 let qb = #fk_models::find().filter((
328 format!("{} = $1", <#fk_models as chopin_orm::Model>::primary_key_columns()[0]),
329 vec![self.#fk_fields.to_param()]
330 ));
331 qb.one(executor)
332 }
333 )*
334
335 #(
336 pub fn #fetch_hm_names(&self, executor: &mut impl chopin_orm::Executor) -> chopin_orm::OrmResult<Vec<#hm_targets>> {
337 use chopin_pg::types::ToSql;
338 use chopin_pg::types::ToParam;
339 let target_pk: chopin_pg::PgValue = self.#first_pk.clone().to_param();
340 let qb = #hm_targets::find().filter((
341 format!("{} = $1", #hm_fks),
342 vec![target_pk]
343 ));
344 qb.all(executor)
345 }
346 )*
347 }
348
349 #[allow(non_camel_case_types)]
350 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
351 pub enum #column_enum_name {
352 #(#fields_list),*
353 }
354
355 impl chopin_orm::builder::ColumnTrait<#name> for #column_enum_name {
356 fn column_name(&self) -> &'static str {
357 match self {
358 #(Self::#fields_list => #field_names_str),*
359 }
360 }
361 }
362 };
363
364 let active_expanded = quote! {};
365
366 let mut belongs_to_field_names = Vec::new();
367 let mut belongs_to_related_models = Vec::new();
368 for (f, r) in &belongs_to_fks {
369 belongs_to_field_names.push(f.clone());
370 belongs_to_related_models.push(r.clone());
371 }
372
373 let final_expanded = quote! {
374 #expanded
375 #active_expanded
376
377 #(
378 impl chopin_orm::HasForeignKey<#belongs_to_related_models> for #name {
379 fn foreign_key_info() -> (&'static str, Vec<(&'static str, &'static str)>) {
380 (<Self as chopin_orm::Model>::table_name(), vec![(stringify!(#belongs_to_field_names), "id")])
381 }
382 }
383 )*
384 };
385
386 TokenStream::from(final_expanded)
387}