Skip to main content

ferriorm_codegen/
relations.rs

1//! Code generation for relation support: Include, `WithRelations`, batched loading.
2
3use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8/// Information about a relation from one model to another.
9pub struct RelationInfo<'a> {
10    pub field: &'a Field,
11    pub related_model: &'a Model,
12    pub relation_type: RelationType,
13    /// The FK column on the "many" side (e.g., "`author_id`" on Post for User.posts)
14    pub fk_column: String,
15    /// The referenced column (e.g., "id" on User)
16    pub ref_column: String,
17}
18
19/// Collect relation fields for a model, resolving against the full schema.
20#[must_use]
21pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
22    let mut relations = Vec::new();
23
24    for field in &model.fields {
25        if let Some(rel) = &field.relation {
26            let related = schema.models.iter().find(|m| m.name == rel.related_model);
27            if let Some(related_model) = related {
28                let (fk_column, ref_column) = if rel.fields.is_empty() {
29                    // The other side has the FK (OneToMany) — find the back-reference,
30                    // filtering by relation name when one is set so multi-relations
31                    // pair correctly.
32                    find_back_reference(model, related_model, rel.name.as_deref())
33                        .unwrap_or_else(|| ("id".into(), "id".into()))
34                } else {
35                    // This side has the FK (ManyToOne)
36                    (rel.fields[0].clone(), rel.references[0].clone())
37                };
38
39                relations.push(RelationInfo {
40                    field,
41                    related_model,
42                    relation_type: rel.relation_type,
43                    fk_column: to_snake_case(&fk_column),
44                    ref_column: to_snake_case(&ref_column),
45                });
46            }
47        } else if field.is_list {
48            // Implicit relation (e.g., posts Post[]) — only valid when there's a
49            // single relation between this model and the target. Multi-relations
50            // are rejected by the validator unless every side has a name, so
51            // there's at most one matching back-reference here.
52            if let FieldKind::Model(related_name) = &field.field_type {
53                let related = schema.models.iter().find(|m| m.name == *related_name);
54                if let Some(related_model) = related {
55                    let (fk_column, ref_column) = find_back_reference(model, related_model, None)
56                        .unwrap_or_else(|| ("id".into(), "id".into()));
57
58                    relations.push(RelationInfo {
59                        field,
60                        related_model,
61                        relation_type: RelationType::OneToMany,
62                        fk_column: to_snake_case(&fk_column),
63                        ref_column: to_snake_case(&ref_column),
64                    });
65                }
66            }
67        }
68    }
69
70    relations
71}
72
73/// Find the back-reference from the related model to this model.
74/// E.g., for User.posts (Post[]), find Post.authorId @relation(fields: [authorId], references: [id]).
75///
76/// When `name` is `Some`, only matches relations with the same name, so
77/// multi-relations between the same two models pair correctly.
78fn find_back_reference(
79    parent: &Model,
80    child: &Model,
81    name: Option<&str>,
82) -> Option<(String, String)> {
83    for field in &child.fields {
84        if let Some(rel) = &field.relation
85            && rel.related_model == parent.name
86            && !rel.fields.is_empty()
87            && (name.is_none() || rel.name.as_deref() == name)
88        {
89            return Some((rel.fields[0].clone(), rel.references[0].clone()));
90        }
91    }
92    None
93}
94
95/// Generate the Include struct, `WithRelations` struct, and include-aware query methods.
96#[must_use]
97pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
98    let relations = collect_relations(model, schema);
99
100    if relations.is_empty() {
101        return quote! {};
102    }
103
104    let model_ident = format_ident!("{}", model.name);
105    let include_name = format_ident!("{}Include", model.name);
106    let with_relations_name = format_ident!("{}WithRelations", model.name);
107
108    // Include struct fields
109    let include_fields: Vec<TokenStream> = relations
110        .iter()
111        .map(|r| {
112            let name = format_ident!("{}", to_snake_case(&r.field.name));
113            quote! { pub #name: bool }
114        })
115        .collect();
116
117    // WithRelations struct fields
118    let with_rel_fields: Vec<TokenStream> = relations
119        .iter()
120        .map(|r| {
121            let name = format_ident!("{}", to_snake_case(&r.field.name));
122            let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
123            let related_struct = format_ident!("{}", r.related_model.name);
124
125            match r.relation_type {
126                RelationType::OneToMany | RelationType::ManyToMany => {
127                    quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
128                }
129                RelationType::OneToOne | RelationType::ManyToOne => {
130                    quote! { pub #name: Option<super::#related_mod::#related_struct> }
131                }
132            }
133        })
134        .collect();
135
136    // Generate the batched loading logic
137    let load_arms = gen_load_arms(&relations, model);
138
139    quote! {
140        #[derive(Debug, Clone, Default)]
141        pub struct #include_name {
142            #(#include_fields,)*
143        }
144
145        #[derive(Debug, Clone, Serialize, Deserialize)]
146        pub struct #with_relations_name {
147            #[serde(flatten)]
148            pub data: #model_ident,
149            #(#with_rel_fields,)*
150        }
151
152        impl #model_ident {
153            /// Load relations for a batch of records.
154            pub(crate) async fn load_relations(
155                records: Vec<#model_ident>,
156                include: &#include_name,
157                client: &DatabaseClient,
158            ) -> Result<Vec<#with_relations_name>, FerriormError> {
159                #load_arms
160            }
161        }
162    }
163}
164
165/// Helper: generate code to load related rows using `QueryBuilder` and dispatch
166/// to both Postgres and Sqlite.
167fn gen_batched_load_many(
168    rel: &RelationInfo<'_>,
169    load_var: &proc_macro2::Ident,
170    field_name: &proc_macro2::Ident,
171    id_source_ident: &proc_macro2::Ident,
172    lookup_col_str: &str,
173    insert_key_ident: &proc_macro2::Ident,
174    fk_optional: bool,
175) -> TokenStream {
176    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
177    let related_struct = format_ident!("{}", rel.related_model.name);
178    let related_table = &rel.related_model.db_name;
179
180    let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
181
182    // Generate the row insertion code based on whether the FK is optional
183    let insert_row_code = if fk_optional {
184        quote! {
185            if let Some(key) = row.#insert_key_ident.clone() {
186                #load_var.entry(key).or_default().push(row);
187            }
188        }
189    } else {
190        quote! {
191            #load_var.entry(row.#insert_key_ident.clone()).or_default().push(row);
192        }
193    };
194
195    quote! {
196        let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
197        if include.#field_name {
198            let ids: Vec<String> = records.iter()
199                .map(|r| r.#id_source_ident.clone())
200                .collect();
201
202            if !ids.is_empty() {
203                macro_rules! build_in_query {
204                    ($db:ty) => {{
205                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
206                        let mut sep = qb.separated(", ");
207                        for id in &ids {
208                            sep.push_bind(id.clone());
209                        }
210                        qb.push(")");
211                        qb
212                    }};
213                }
214
215                macro_rules! insert_rows {
216                    ($rows:expr) => {
217                        for row in $rows {
218                            #insert_row_code
219                        }
220                    };
221                }
222
223                match client {
224                    DatabaseClient::Postgres(pool) => {
225                        let mut qb = build_in_query!(sqlx::Postgres);
226                        let related_rows: Vec<super::#related_mod::#related_struct> =
227                            qb.build_query_as().fetch_all(pool).await
228                                .map_err(FerriormError::from)?;
229                        insert_rows!(related_rows);
230                    }
231                    DatabaseClient::Sqlite(pool) => {
232                        let mut qb = build_in_query!(sqlx::Sqlite);
233                        let related_rows: Vec<super::#related_mod::#related_struct> =
234                            qb.build_query_as().fetch_all(pool).await
235                                .map_err(FerriormError::from)?;
236                        insert_rows!(related_rows);
237                    }
238                }
239            }
240        }
241    }
242}
243
244/// Helper: generate code to load related rows for a single-value (`OneToOne` / `ManyToOne`) relation.
245fn gen_batched_load_one(
246    rel: &RelationInfo<'_>,
247    load_var: &proc_macro2::Ident,
248    field_name: &proc_macro2::Ident,
249    id_source_ident: &proc_macro2::Ident,
250    lookup_col_str: &str,
251    insert_key_ident: &proc_macro2::Ident,
252    fk_is_optional: bool,
253) -> TokenStream {
254    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
255    let related_struct = format_ident!("{}", rel.related_model.name);
256    let related_table = &rel.related_model.db_name;
257
258    let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
259
260    // When the FK field is optional (Option<String>), use filter_map to skip None values.
261    let ids_collect = if fk_is_optional {
262        quote! {
263            let ids: Vec<String> = records.iter()
264                .filter_map(|r| r.#id_source_ident.clone())
265                .collect();
266        }
267    } else {
268        quote! {
269            let ids: Vec<String> = records.iter()
270                .map(|r| r.#id_source_ident.clone())
271                .collect();
272        }
273    };
274
275    quote! {
276        let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
277        if include.#field_name {
278            #ids_collect
279
280            if !ids.is_empty() {
281                macro_rules! build_in_query {
282                    ($db:ty) => {{
283                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
284                        let mut sep = qb.separated(", ");
285                        for id in &ids {
286                            sep.push_bind(id.clone());
287                        }
288                        qb.push(")");
289                        qb
290                    }};
291                }
292
293                match client {
294                    DatabaseClient::Postgres(pool) => {
295                        let mut qb = build_in_query!(sqlx::Postgres);
296                        let related_rows: Vec<super::#related_mod::#related_struct> =
297                            qb.build_query_as().fetch_all(pool).await
298                                .map_err(FerriormError::from)?;
299                        for row in related_rows {
300                            #load_var.insert(row.#insert_key_ident.clone(), row);
301                        }
302                    }
303                    DatabaseClient::Sqlite(pool) => {
304                        let mut qb = build_in_query!(sqlx::Sqlite);
305                        let related_rows: Vec<super::#related_mod::#related_struct> =
306                            qb.build_query_as().fetch_all(pool).await
307                                .map_err(FerriormError::from)?;
308                        for row in related_rows {
309                            #load_var.insert(row.#insert_key_ident.clone(), row);
310                        }
311                    }
312                }
313            }
314        }
315    }
316}
317
318#[allow(clippy::too_many_lines)]
319fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
320    let _model_ident = format_ident!("{}", model.name);
321    let with_relations_name = format_ident!("{}WithRelations", model.name);
322
323    let mut relation_loads = vec![];
324    let mut field_inits = vec![];
325
326    for rel in relations {
327        let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
328        let fk_col_str = &rel.fk_column;
329        let ref_col_str = &rel.ref_column;
330        let fk_col_ident = format_ident!("{}", rel.fk_column);
331        let ref_col_ident = format_ident!("{}", rel.ref_column);
332
333        match rel.relation_type {
334            RelationType::OneToMany | RelationType::ManyToMany => {
335                // Batched loading: SELECT * FROM related WHERE fk IN (parent_ids)
336                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
337
338                // Check if the FK column on the child (related) model is optional
339                let child_fk_optional = rel
340                    .related_model
341                    .fields
342                    .iter()
343                    .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_optional);
344
345                relation_loads.push(gen_batched_load_many(
346                    rel,
347                    &load_var,
348                    &field_name,
349                    &ref_col_ident,
350                    fk_col_str,
351                    &fk_col_ident,
352                    child_fk_optional,
353                ));
354
355                let ref_col_ident = format_ident!("{}", ref_col_str);
356                field_inits.push(quote! {
357                    #field_name: if include.#field_name {
358                        Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
359                    } else {
360                        None
361                    }
362                });
363            }
364            RelationType::OneToOne | RelationType::ManyToOne => {
365                // For ManyToOne (this model has the FK), we can batch load the parent
366                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
367                let fk_field = format_ident!("{}", fk_col_str);
368
369                // Check if this model has the FK field as a scalar
370                let fk_model_field = model
371                    .fields
372                    .iter()
373                    .find(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
374                let has_fk = fk_model_field.is_some();
375                let fk_is_optional = fk_model_field.is_some_and(|f| f.is_optional);
376
377                if has_fk {
378                    relation_loads.push(gen_batched_load_one(
379                        rel,
380                        &load_var,
381                        &field_name,
382                        &fk_field,
383                        ref_col_str,
384                        &ref_col_ident,
385                        fk_is_optional,
386                    ));
387
388                    if fk_is_optional {
389                        field_inits.push(quote! {
390                            #field_name: if include.#field_name {
391                                r.#fk_field.as_ref().and_then(|fk| #load_var.remove(fk))
392                            } else {
393                                None
394                            }
395                        });
396                    } else {
397                        field_inits.push(quote! {
398                            #field_name: if include.#field_name {
399                                #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
400                            } else {
401                                None
402                            }
403                        });
404                    }
405                } else {
406                    // The FK is on the other side (e.g., User.profile where Profile has userId)
407                    // Batch load: SELECT * FROM profiles WHERE user_id IN (user_ids)
408                    let ref_col_ident = format_ident!("{}", ref_col_str);
409
410                    relation_loads.push(gen_batched_load_one(
411                        rel,
412                        &load_var,
413                        &field_name,
414                        &ref_col_ident,
415                        fk_col_str,
416                        &fk_col_ident,
417                        false,
418                    ));
419
420                    field_inits.push(quote! {
421                        #field_name: if include.#field_name {
422                            #load_var.remove(&r.#ref_col_ident)
423                        } else {
424                            None
425                        }
426                    });
427                }
428            }
429        }
430    }
431
432    quote! {
433        #(#relation_loads)*
434
435        let mut results = Vec::with_capacity(records.len());
436        for r in records {
437            results.push(#with_relations_name {
438                #(#field_inits,)*
439                data: r,
440            });
441        }
442        Ok(results)
443    }
444}
445
446/// Generate `include()` and `exec_with_relations()` methods for `FindMany`.
447#[must_use]
448pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
449    let relations = collect_relations(model, schema);
450    if relations.is_empty() {
451        return quote! {};
452    }
453
454    let model_ident = format_ident!("{}", model.name);
455    let include_name = format_ident!("{}Include", model.name);
456    let with_relations_name = format_ident!("{}WithRelations", model.name);
457
458    quote! {
459        impl<'a> FindManyQuery<'a> {
460            pub fn include(self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
461                FindManyWithIncludeQuery {
462                    inner: self,
463                    include,
464                }
465            }
466        }
467
468        pub struct FindManyWithIncludeQuery<'a> {
469            inner: FindManyQuery<'a>,
470            include: #include_name,
471        }
472
473        impl<'a> FindManyWithIncludeQuery<'a> {
474            pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
475                let include = self.include;
476                let client = self.inner.client;
477                let records = FindManyQuery {
478                    client,
479                    r#where: self.inner.r#where,
480                    order_by: self.inner.order_by,
481                    skip: self.inner.skip,
482                    take: self.inner.take,
483                }.exec().await?;
484                #model_ident::load_relations(records, &include, client).await
485            }
486        }
487
488        impl<'a> FindUniqueQuery<'a> {
489            pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
490                FindUniqueWithIncludeQuery {
491                    inner: self,
492                    include,
493                }
494            }
495        }
496
497        pub struct FindUniqueWithIncludeQuery<'a> {
498            inner: FindUniqueQuery<'a>,
499            include: #include_name,
500        }
501
502        impl<'a> FindUniqueWithIncludeQuery<'a> {
503            pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
504                let include = self.include;
505                let client = self.inner.client;
506                let record = FindUniqueQuery {
507                    client,
508                    r#where: self.inner.r#where,
509                }.exec().await?;
510                match record {
511                    Some(r) => {
512                        let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
513                        Ok(results.pop())
514                    }
515                    None => Ok(None),
516                }
517            }
518        }
519    }
520}