Skip to main content

ferriorm_codegen/
relations.rs

1//! Code generation for relation support: Include, WithRelations, batched loading.
2
3use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8/// Information about a relation from one model to another.
9pub struct RelationInfo<'a> {
10    pub field: &'a Field,
11    pub related_model: &'a Model,
12    pub relation_type: RelationType,
13    /// The FK column on the "many" side (e.g., "author_id" on Post for User.posts)
14    pub fk_column: String,
15    /// The referenced column (e.g., "id" on User)
16    pub ref_column: String,
17}
18
19/// Collect relation fields for a model, resolving against the full schema.
20pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
21    let mut relations = Vec::new();
22
23    for field in &model.fields {
24        if let Some(rel) = &field.relation {
25            let related = schema.models.iter().find(|m| m.name == rel.related_model);
26            if let Some(related_model) = related {
27                let (fk_column, ref_column) = if !rel.fields.is_empty() {
28                    // This side has the FK (ManyToOne)
29                    (rel.fields[0].clone(), rel.references[0].clone())
30                } else {
31                    // The other side has the FK (OneToMany) — find the back-reference
32                    find_back_reference(model, related_model)
33                        .unwrap_or_else(|| ("id".into(), "id".into()))
34                };
35
36                relations.push(RelationInfo {
37                    field,
38                    related_model,
39                    relation_type: rel.relation_type,
40                    fk_column: to_snake_case(&fk_column),
41                    ref_column: to_snake_case(&ref_column),
42                });
43            }
44        } else if field.is_list {
45            // Implicit relation (e.g., posts Post[])
46            if let FieldKind::Model(related_name) = &field.field_type {
47                let related = schema.models.iter().find(|m| m.name == *related_name);
48                if let Some(related_model) = related {
49                    let (fk_column, ref_column) = find_back_reference(model, related_model)
50                        .unwrap_or_else(|| ("id".into(), "id".into()));
51
52                    relations.push(RelationInfo {
53                        field,
54                        related_model,
55                        relation_type: RelationType::OneToMany,
56                        fk_column: to_snake_case(&fk_column),
57                        ref_column: to_snake_case(&ref_column),
58                    });
59                }
60            }
61        }
62    }
63
64    relations
65}
66
67/// Find the back-reference from the related model to this model.
68/// E.g., for User.posts (Post[]), find Post.authorId @relation(fields: [authorId], references: [id])
69fn find_back_reference(parent: &Model, child: &Model) -> Option<(String, String)> {
70    for field in &child.fields {
71        if let Some(rel) = &field.relation
72            && rel.related_model == parent.name
73            && !rel.fields.is_empty()
74        {
75            return Some((rel.fields[0].clone(), rel.references[0].clone()));
76        }
77    }
78    None
79}
80
81/// Generate the Include struct, WithRelations struct, and include-aware query methods.
82pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
83    let relations = collect_relations(model, schema);
84
85    if relations.is_empty() {
86        return quote! {};
87    }
88
89    let model_ident = format_ident!("{}", model.name);
90    let include_name = format_ident!("{}Include", model.name);
91    let with_relations_name = format_ident!("{}WithRelations", model.name);
92
93    // Include struct fields
94    let include_fields: Vec<TokenStream> = relations
95        .iter()
96        .map(|r| {
97            let name = format_ident!("{}", to_snake_case(&r.field.name));
98            quote! { pub #name: bool }
99        })
100        .collect();
101
102    // WithRelations struct fields
103    let with_rel_fields: Vec<TokenStream> = relations
104        .iter()
105        .map(|r| {
106            let name = format_ident!("{}", to_snake_case(&r.field.name));
107            let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
108            let related_struct = format_ident!("{}", r.related_model.name);
109
110            match r.relation_type {
111                RelationType::OneToMany | RelationType::ManyToMany => {
112                    quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
113                }
114                RelationType::OneToOne | RelationType::ManyToOne => {
115                    quote! { pub #name: Option<super::#related_mod::#related_struct> }
116                }
117            }
118        })
119        .collect();
120
121    // Generate the batched loading logic
122    let load_arms = gen_load_arms(&relations, model);
123
124    quote! {
125        #[derive(Debug, Clone, Default)]
126        pub struct #include_name {
127            #(#include_fields,)*
128        }
129
130        #[derive(Debug, Clone, Serialize, Deserialize)]
131        pub struct #with_relations_name {
132            #[serde(flatten)]
133            pub data: #model_ident,
134            #(#with_rel_fields,)*
135        }
136
137        impl #model_ident {
138            /// Load relations for a batch of records.
139            pub(crate) async fn load_relations(
140                records: Vec<#model_ident>,
141                include: &#include_name,
142                client: &DatabaseClient,
143            ) -> Result<Vec<#with_relations_name>, FerriormError> {
144                #load_arms
145            }
146        }
147    }
148}
149
150/// Helper: generate code to load related rows using QueryBuilder and dispatch
151/// to both Postgres and Sqlite.
152fn gen_batched_load_many(
153    rel: &RelationInfo<'_>,
154    load_var: &proc_macro2::Ident,
155    field_name: &proc_macro2::Ident,
156    id_source_ident: &proc_macro2::Ident,
157    lookup_col_str: &str,
158    insert_key_ident: &proc_macro2::Ident,
159) -> TokenStream {
160    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
161    let related_struct = format_ident!("{}", rel.related_model.name);
162    let related_table = &rel.related_model.db_name;
163
164    let select_base = format!(
165        r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
166        related_table, lookup_col_str
167    );
168
169    quote! {
170        let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
171        if include.#field_name {
172            let ids: Vec<String> = records.iter()
173                .map(|r| r.#id_source_ident.clone())
174                .collect();
175
176            if !ids.is_empty() {
177                macro_rules! build_in_query {
178                    ($db:ty) => {{
179                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
180                        let mut sep = qb.separated(", ");
181                        for id in &ids {
182                            sep.push_bind(id.clone());
183                        }
184                        qb.push(")");
185                        qb
186                    }};
187                }
188
189                match client {
190                    DatabaseClient::Postgres(pool) => {
191                        let mut qb = build_in_query!(sqlx::Postgres);
192                        let related_rows: Vec<super::#related_mod::#related_struct> =
193                            qb.build_query_as().fetch_all(pool).await
194                                .map_err(FerriormError::from)?;
195                        for row in related_rows {
196                            #load_var.entry(row.#insert_key_ident.clone())
197                                .or_default()
198                                .push(row);
199                        }
200                    }
201                    DatabaseClient::Sqlite(pool) => {
202                        let mut qb = build_in_query!(sqlx::Sqlite);
203                        let related_rows: Vec<super::#related_mod::#related_struct> =
204                            qb.build_query_as().fetch_all(pool).await
205                                .map_err(FerriormError::from)?;
206                        for row in related_rows {
207                            #load_var.entry(row.#insert_key_ident.clone())
208                                .or_default()
209                                .push(row);
210                        }
211                    }
212                }
213            }
214        }
215    }
216}
217
218/// Helper: generate code to load related rows for a single-value (OneToOne / ManyToOne) relation.
219fn gen_batched_load_one(
220    rel: &RelationInfo<'_>,
221    load_var: &proc_macro2::Ident,
222    field_name: &proc_macro2::Ident,
223    id_source_ident: &proc_macro2::Ident,
224    lookup_col_str: &str,
225    insert_key_ident: &proc_macro2::Ident,
226) -> TokenStream {
227    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
228    let related_struct = format_ident!("{}", rel.related_model.name);
229    let related_table = &rel.related_model.db_name;
230
231    let select_base = format!(
232        r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
233        related_table, lookup_col_str
234    );
235
236    quote! {
237        let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
238        if include.#field_name {
239            let ids: Vec<String> = records.iter()
240                .map(|r| r.#id_source_ident.clone())
241                .collect();
242
243            if !ids.is_empty() {
244                macro_rules! build_in_query {
245                    ($db:ty) => {{
246                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
247                        let mut sep = qb.separated(", ");
248                        for id in &ids {
249                            sep.push_bind(id.clone());
250                        }
251                        qb.push(")");
252                        qb
253                    }};
254                }
255
256                match client {
257                    DatabaseClient::Postgres(pool) => {
258                        let mut qb = build_in_query!(sqlx::Postgres);
259                        let related_rows: Vec<super::#related_mod::#related_struct> =
260                            qb.build_query_as().fetch_all(pool).await
261                                .map_err(FerriormError::from)?;
262                        for row in related_rows {
263                            #load_var.insert(row.#insert_key_ident.clone(), row);
264                        }
265                    }
266                    DatabaseClient::Sqlite(pool) => {
267                        let mut qb = build_in_query!(sqlx::Sqlite);
268                        let related_rows: Vec<super::#related_mod::#related_struct> =
269                            qb.build_query_as().fetch_all(pool).await
270                                .map_err(FerriormError::from)?;
271                        for row in related_rows {
272                            #load_var.insert(row.#insert_key_ident.clone(), row);
273                        }
274                    }
275                }
276            }
277        }
278    }
279}
280
281fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
282    let _model_ident = format_ident!("{}", model.name);
283    let with_relations_name = format_ident!("{}WithRelations", model.name);
284
285    let mut relation_loads = vec![];
286    let mut field_inits = vec![];
287
288    for rel in relations {
289        let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
290        let fk_col_str = &rel.fk_column;
291        let ref_col_str = &rel.ref_column;
292        let fk_col_ident = format_ident!("{}", rel.fk_column);
293        let ref_col_ident = format_ident!("{}", rel.ref_column);
294
295        match rel.relation_type {
296            RelationType::OneToMany | RelationType::ManyToMany => {
297                // Batched loading: SELECT * FROM related WHERE fk IN (parent_ids)
298                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
299
300                relation_loads.push(gen_batched_load_many(
301                    rel,
302                    &load_var,
303                    &field_name,
304                    &ref_col_ident,
305                    fk_col_str,
306                    &fk_col_ident,
307                ));
308
309                let ref_col_ident = format_ident!("{}", ref_col_str);
310                field_inits.push(quote! {
311                    #field_name: if include.#field_name {
312                        Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
313                    } else {
314                        None
315                    }
316                });
317            }
318            RelationType::OneToOne | RelationType::ManyToOne => {
319                // For ManyToOne (this model has the FK), we can batch load the parent
320                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
321                let fk_field = format_ident!("{}", fk_col_str);
322
323                // Check if this model has the FK field as a scalar
324                let has_fk = model
325                    .fields
326                    .iter()
327                    .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
328
329                if has_fk {
330                    relation_loads.push(gen_batched_load_one(
331                        rel,
332                        &load_var,
333                        &field_name,
334                        &fk_field,
335                        ref_col_str,
336                        &ref_col_ident,
337                    ));
338
339                    field_inits.push(quote! {
340                        #field_name: if include.#field_name {
341                            #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
342                        } else {
343                            None
344                        }
345                    });
346                } else {
347                    // The FK is on the other side (e.g., User.profile where Profile has userId)
348                    // Batch load: SELECT * FROM profiles WHERE user_id IN (user_ids)
349                    let ref_col_ident = format_ident!("{}", ref_col_str);
350
351                    relation_loads.push(gen_batched_load_one(
352                        rel,
353                        &load_var,
354                        &field_name,
355                        &ref_col_ident,
356                        fk_col_str,
357                        &fk_col_ident,
358                    ));
359
360                    field_inits.push(quote! {
361                        #field_name: if include.#field_name {
362                            #load_var.remove(&r.#ref_col_ident)
363                        } else {
364                            None
365                        }
366                    });
367                }
368            }
369        }
370    }
371
372    quote! {
373        #(#relation_loads)*
374
375        let mut results = Vec::with_capacity(records.len());
376        for mut r in records {
377            results.push(#with_relations_name {
378                #(#field_inits,)*
379                data: r,
380            });
381        }
382        Ok(results)
383    }
384}
385
386/// Generate `include()` and `exec_with_relations()` methods for FindMany.
387pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
388    let relations = collect_relations(model, schema);
389    if relations.is_empty() {
390        return quote! {};
391    }
392
393    let model_ident = format_ident!("{}", model.name);
394    let include_name = format_ident!("{}Include", model.name);
395    let with_relations_name = format_ident!("{}WithRelations", model.name);
396
397    quote! {
398        impl<'a> FindManyQuery<'a> {
399            pub fn include(mut self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
400                FindManyWithIncludeQuery {
401                    inner: self,
402                    include,
403                }
404            }
405        }
406
407        pub struct FindManyWithIncludeQuery<'a> {
408            inner: FindManyQuery<'a>,
409            include: #include_name,
410        }
411
412        impl<'a> FindManyWithIncludeQuery<'a> {
413            pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
414                let include = self.include;
415                let client = self.inner.client;
416                let records = FindManyQuery {
417                    client,
418                    r#where: self.inner.r#where,
419                    order_by: self.inner.order_by,
420                    skip: self.inner.skip,
421                    take: self.inner.take,
422                }.exec().await?;
423                #model_ident::load_relations(records, &include, client).await
424            }
425        }
426
427        impl<'a> FindUniqueQuery<'a> {
428            pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
429                FindUniqueWithIncludeQuery {
430                    inner: self,
431                    include,
432                }
433            }
434        }
435
436        pub struct FindUniqueWithIncludeQuery<'a> {
437            inner: FindUniqueQuery<'a>,
438            include: #include_name,
439        }
440
441        impl<'a> FindUniqueWithIncludeQuery<'a> {
442            pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
443                let include = self.include;
444                let client = self.inner.client;
445                let record = FindUniqueQuery {
446                    client,
447                    r#where: self.inner.r#where,
448                }.exec().await?;
449                match record {
450                    Some(r) => {
451                        let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
452                        Ok(results.pop())
453                    }
454                    None => Ok(None),
455                }
456            }
457        }
458    }
459}