Skip to main content

ferriorm_codegen/
relations.rs

1//! Code generation for relation support: Include, `WithRelations`, batched loading.
2
3use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8/// Information about a relation from one model to another.
9pub struct RelationInfo<'a> {
10    pub field: &'a Field,
11    pub related_model: &'a Model,
12    pub relation_type: RelationType,
13    /// The FK column on the "many" side (e.g., "`author_id`" on Post for User.posts)
14    pub fk_column: String,
15    /// The referenced column (e.g., "id" on User)
16    pub ref_column: String,
17}
18
19/// Collect relation fields for a model, resolving against the full schema.
20#[must_use]
21pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
22    let mut relations = Vec::new();
23
24    for field in &model.fields {
25        if let Some(rel) = &field.relation {
26            let related = schema.models.iter().find(|m| m.name == rel.related_model);
27            if let Some(related_model) = related {
28                let (fk_column, ref_column) = if rel.fields.is_empty() {
29                    // The other side has the FK (OneToMany) — find the back-reference
30                    find_back_reference(model, related_model)
31                        .unwrap_or_else(|| ("id".into(), "id".into()))
32                } else {
33                    // This side has the FK (ManyToOne)
34                    (rel.fields[0].clone(), rel.references[0].clone())
35                };
36
37                relations.push(RelationInfo {
38                    field,
39                    related_model,
40                    relation_type: rel.relation_type,
41                    fk_column: to_snake_case(&fk_column),
42                    ref_column: to_snake_case(&ref_column),
43                });
44            }
45        } else if field.is_list {
46            // Implicit relation (e.g., posts Post[])
47            if let FieldKind::Model(related_name) = &field.field_type {
48                let related = schema.models.iter().find(|m| m.name == *related_name);
49                if let Some(related_model) = related {
50                    let (fk_column, ref_column) = find_back_reference(model, related_model)
51                        .unwrap_or_else(|| ("id".into(), "id".into()));
52
53                    relations.push(RelationInfo {
54                        field,
55                        related_model,
56                        relation_type: RelationType::OneToMany,
57                        fk_column: to_snake_case(&fk_column),
58                        ref_column: to_snake_case(&ref_column),
59                    });
60                }
61            }
62        }
63    }
64
65    relations
66}
67
68/// Find the back-reference from the related model to this model.
69/// E.g., for User.posts (Post[]), find Post.authorId @relation(fields: [authorId], references: [id])
70fn find_back_reference(parent: &Model, child: &Model) -> Option<(String, String)> {
71    for field in &child.fields {
72        if let Some(rel) = &field.relation
73            && rel.related_model == parent.name
74            && !rel.fields.is_empty()
75        {
76            return Some((rel.fields[0].clone(), rel.references[0].clone()));
77        }
78    }
79    None
80}
81
82/// Generate the Include struct, `WithRelations` struct, and include-aware query methods.
83#[must_use]
84pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
85    let relations = collect_relations(model, schema);
86
87    if relations.is_empty() {
88        return quote! {};
89    }
90
91    let model_ident = format_ident!("{}", model.name);
92    let include_name = format_ident!("{}Include", model.name);
93    let with_relations_name = format_ident!("{}WithRelations", model.name);
94
95    // Include struct fields
96    let include_fields: Vec<TokenStream> = relations
97        .iter()
98        .map(|r| {
99            let name = format_ident!("{}", to_snake_case(&r.field.name));
100            quote! { pub #name: bool }
101        })
102        .collect();
103
104    // WithRelations struct fields
105    let with_rel_fields: Vec<TokenStream> = relations
106        .iter()
107        .map(|r| {
108            let name = format_ident!("{}", to_snake_case(&r.field.name));
109            let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
110            let related_struct = format_ident!("{}", r.related_model.name);
111
112            match r.relation_type {
113                RelationType::OneToMany | RelationType::ManyToMany => {
114                    quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
115                }
116                RelationType::OneToOne | RelationType::ManyToOne => {
117                    quote! { pub #name: Option<super::#related_mod::#related_struct> }
118                }
119            }
120        })
121        .collect();
122
123    // Generate the batched loading logic
124    let load_arms = gen_load_arms(&relations, model);
125
126    quote! {
127        #[derive(Debug, Clone, Default)]
128        pub struct #include_name {
129            #(#include_fields,)*
130        }
131
132        #[derive(Debug, Clone, Serialize, Deserialize)]
133        pub struct #with_relations_name {
134            #[serde(flatten)]
135            pub data: #model_ident,
136            #(#with_rel_fields,)*
137        }
138
139        impl #model_ident {
140            /// Load relations for a batch of records.
141            pub(crate) async fn load_relations(
142                records: Vec<#model_ident>,
143                include: &#include_name,
144                client: &DatabaseClient,
145            ) -> Result<Vec<#with_relations_name>, FerriormError> {
146                #load_arms
147            }
148        }
149    }
150}
151
152/// Helper: generate code to load related rows using `QueryBuilder` and dispatch
153/// to both Postgres and Sqlite.
154fn gen_batched_load_many(
155    rel: &RelationInfo<'_>,
156    load_var: &proc_macro2::Ident,
157    field_name: &proc_macro2::Ident,
158    id_source_ident: &proc_macro2::Ident,
159    lookup_col_str: &str,
160    insert_key_ident: &proc_macro2::Ident,
161    fk_optional: bool,
162) -> TokenStream {
163    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
164    let related_struct = format_ident!("{}", rel.related_model.name);
165    let related_table = &rel.related_model.db_name;
166
167    let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
168
169    // Generate the row insertion code based on whether the FK is optional
170    let insert_row_code = if fk_optional {
171        quote! {
172            if let Some(key) = row.#insert_key_ident.clone() {
173                #load_var.entry(key).or_default().push(row);
174            }
175        }
176    } else {
177        quote! {
178            #load_var.entry(row.#insert_key_ident.clone()).or_default().push(row);
179        }
180    };
181
182    quote! {
183        let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
184        if include.#field_name {
185            let ids: Vec<String> = records.iter()
186                .map(|r| r.#id_source_ident.clone())
187                .collect();
188
189            if !ids.is_empty() {
190                macro_rules! build_in_query {
191                    ($db:ty) => {{
192                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
193                        let mut sep = qb.separated(", ");
194                        for id in &ids {
195                            sep.push_bind(id.clone());
196                        }
197                        qb.push(")");
198                        qb
199                    }};
200                }
201
202                macro_rules! insert_rows {
203                    ($rows:expr) => {
204                        for row in $rows {
205                            #insert_row_code
206                        }
207                    };
208                }
209
210                match client {
211                    DatabaseClient::Postgres(pool) => {
212                        let mut qb = build_in_query!(sqlx::Postgres);
213                        let related_rows: Vec<super::#related_mod::#related_struct> =
214                            qb.build_query_as().fetch_all(pool).await
215                                .map_err(FerriormError::from)?;
216                        insert_rows!(related_rows);
217                    }
218                    DatabaseClient::Sqlite(pool) => {
219                        let mut qb = build_in_query!(sqlx::Sqlite);
220                        let related_rows: Vec<super::#related_mod::#related_struct> =
221                            qb.build_query_as().fetch_all(pool).await
222                                .map_err(FerriormError::from)?;
223                        insert_rows!(related_rows);
224                    }
225                }
226            }
227        }
228    }
229}
230
231/// Helper: generate code to load related rows for a single-value (`OneToOne` / `ManyToOne`) relation.
232fn gen_batched_load_one(
233    rel: &RelationInfo<'_>,
234    load_var: &proc_macro2::Ident,
235    field_name: &proc_macro2::Ident,
236    id_source_ident: &proc_macro2::Ident,
237    lookup_col_str: &str,
238    insert_key_ident: &proc_macro2::Ident,
239    fk_is_optional: bool,
240) -> TokenStream {
241    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
242    let related_struct = format_ident!("{}", rel.related_model.name);
243    let related_table = &rel.related_model.db_name;
244
245    let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
246
247    // When the FK field is optional (Option<String>), use filter_map to skip None values.
248    let ids_collect = if fk_is_optional {
249        quote! {
250            let ids: Vec<String> = records.iter()
251                .filter_map(|r| r.#id_source_ident.clone())
252                .collect();
253        }
254    } else {
255        quote! {
256            let ids: Vec<String> = records.iter()
257                .map(|r| r.#id_source_ident.clone())
258                .collect();
259        }
260    };
261
262    quote! {
263        let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
264        if include.#field_name {
265            #ids_collect
266
267            if !ids.is_empty() {
268                macro_rules! build_in_query {
269                    ($db:ty) => {{
270                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
271                        let mut sep = qb.separated(", ");
272                        for id in &ids {
273                            sep.push_bind(id.clone());
274                        }
275                        qb.push(")");
276                        qb
277                    }};
278                }
279
280                match client {
281                    DatabaseClient::Postgres(pool) => {
282                        let mut qb = build_in_query!(sqlx::Postgres);
283                        let related_rows: Vec<super::#related_mod::#related_struct> =
284                            qb.build_query_as().fetch_all(pool).await
285                                .map_err(FerriormError::from)?;
286                        for row in related_rows {
287                            #load_var.insert(row.#insert_key_ident.clone(), row);
288                        }
289                    }
290                    DatabaseClient::Sqlite(pool) => {
291                        let mut qb = build_in_query!(sqlx::Sqlite);
292                        let related_rows: Vec<super::#related_mod::#related_struct> =
293                            qb.build_query_as().fetch_all(pool).await
294                                .map_err(FerriormError::from)?;
295                        for row in related_rows {
296                            #load_var.insert(row.#insert_key_ident.clone(), row);
297                        }
298                    }
299                }
300            }
301        }
302    }
303}
304
305#[allow(clippy::too_many_lines)]
306fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
307    let _model_ident = format_ident!("{}", model.name);
308    let with_relations_name = format_ident!("{}WithRelations", model.name);
309
310    let mut relation_loads = vec![];
311    let mut field_inits = vec![];
312
313    for rel in relations {
314        let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
315        let fk_col_str = &rel.fk_column;
316        let ref_col_str = &rel.ref_column;
317        let fk_col_ident = format_ident!("{}", rel.fk_column);
318        let ref_col_ident = format_ident!("{}", rel.ref_column);
319
320        match rel.relation_type {
321            RelationType::OneToMany | RelationType::ManyToMany => {
322                // Batched loading: SELECT * FROM related WHERE fk IN (parent_ids)
323                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
324
325                // Check if the FK column on the child (related) model is optional
326                let child_fk_optional = rel
327                    .related_model
328                    .fields
329                    .iter()
330                    .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_optional);
331
332                relation_loads.push(gen_batched_load_many(
333                    rel,
334                    &load_var,
335                    &field_name,
336                    &ref_col_ident,
337                    fk_col_str,
338                    &fk_col_ident,
339                    child_fk_optional,
340                ));
341
342                let ref_col_ident = format_ident!("{}", ref_col_str);
343                field_inits.push(quote! {
344                    #field_name: if include.#field_name {
345                        Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
346                    } else {
347                        None
348                    }
349                });
350            }
351            RelationType::OneToOne | RelationType::ManyToOne => {
352                // For ManyToOne (this model has the FK), we can batch load the parent
353                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
354                let fk_field = format_ident!("{}", fk_col_str);
355
356                // Check if this model has the FK field as a scalar
357                let fk_model_field = model
358                    .fields
359                    .iter()
360                    .find(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
361                let has_fk = fk_model_field.is_some();
362                let fk_is_optional = fk_model_field.is_some_and(|f| f.is_optional);
363
364                if has_fk {
365                    relation_loads.push(gen_batched_load_one(
366                        rel,
367                        &load_var,
368                        &field_name,
369                        &fk_field,
370                        ref_col_str,
371                        &ref_col_ident,
372                        fk_is_optional,
373                    ));
374
375                    if fk_is_optional {
376                        field_inits.push(quote! {
377                            #field_name: if include.#field_name {
378                                r.#fk_field.as_ref().and_then(|fk| #load_var.remove(fk))
379                            } else {
380                                None
381                            }
382                        });
383                    } else {
384                        field_inits.push(quote! {
385                            #field_name: if include.#field_name {
386                                #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
387                            } else {
388                                None
389                            }
390                        });
391                    }
392                } else {
393                    // The FK is on the other side (e.g., User.profile where Profile has userId)
394                    // Batch load: SELECT * FROM profiles WHERE user_id IN (user_ids)
395                    let ref_col_ident = format_ident!("{}", ref_col_str);
396
397                    relation_loads.push(gen_batched_load_one(
398                        rel,
399                        &load_var,
400                        &field_name,
401                        &ref_col_ident,
402                        fk_col_str,
403                        &fk_col_ident,
404                        false,
405                    ));
406
407                    field_inits.push(quote! {
408                        #field_name: if include.#field_name {
409                            #load_var.remove(&r.#ref_col_ident)
410                        } else {
411                            None
412                        }
413                    });
414                }
415            }
416        }
417    }
418
419    quote! {
420        #(#relation_loads)*
421
422        let mut results = Vec::with_capacity(records.len());
423        for r in records {
424            results.push(#with_relations_name {
425                #(#field_inits,)*
426                data: r,
427            });
428        }
429        Ok(results)
430    }
431}
432
433/// Generate `include()` and `exec_with_relations()` methods for `FindMany`.
434#[must_use]
435pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
436    let relations = collect_relations(model, schema);
437    if relations.is_empty() {
438        return quote! {};
439    }
440
441    let model_ident = format_ident!("{}", model.name);
442    let include_name = format_ident!("{}Include", model.name);
443    let with_relations_name = format_ident!("{}WithRelations", model.name);
444
445    quote! {
446        impl<'a> FindManyQuery<'a> {
447            pub fn include(self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
448                FindManyWithIncludeQuery {
449                    inner: self,
450                    include,
451                }
452            }
453        }
454
455        pub struct FindManyWithIncludeQuery<'a> {
456            inner: FindManyQuery<'a>,
457            include: #include_name,
458        }
459
460        impl<'a> FindManyWithIncludeQuery<'a> {
461            pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
462                let include = self.include;
463                let client = self.inner.client;
464                let records = FindManyQuery {
465                    client,
466                    r#where: self.inner.r#where,
467                    order_by: self.inner.order_by,
468                    skip: self.inner.skip,
469                    take: self.inner.take,
470                }.exec().await?;
471                #model_ident::load_relations(records, &include, client).await
472            }
473        }
474
475        impl<'a> FindUniqueQuery<'a> {
476            pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
477                FindUniqueWithIncludeQuery {
478                    inner: self,
479                    include,
480                }
481            }
482        }
483
484        pub struct FindUniqueWithIncludeQuery<'a> {
485            inner: FindUniqueQuery<'a>,
486            include: #include_name,
487        }
488
489        impl<'a> FindUniqueWithIncludeQuery<'a> {
490            pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
491                let include = self.include;
492                let client = self.inner.client;
493                let record = FindUniqueQuery {
494                    client,
495                    r#where: self.inner.r#where,
496                }.exec().await?;
497                match record {
498                    Some(r) => {
499                        let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
500                        Ok(results.pop())
501                    }
502                    None => Ok(None),
503                }
504            }
505        }
506    }
507}