Skip to main content

ferriorm_codegen/
relations.rs

1//! Code generation for relation support: Include, WithRelations, batched loading.
2
3use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8/// Information about a relation from one model to another.
9pub struct RelationInfo<'a> {
10    pub field: &'a Field,
11    pub related_model: &'a Model,
12    pub relation_type: RelationType,
13    /// The FK column on the "many" side (e.g., "author_id" on Post for User.posts)
14    pub fk_column: String,
15    /// The referenced column (e.g., "id" on User)
16    pub ref_column: String,
17}
18
19/// Collect relation fields for a model, resolving against the full schema.
20pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
21    let mut relations = Vec::new();
22
23    for field in &model.fields {
24        if let Some(rel) = &field.relation {
25            let related = schema.models.iter().find(|m| m.name == rel.related_model);
26            if let Some(related_model) = related {
27                let (fk_column, ref_column) = if !rel.fields.is_empty() {
28                    // This side has the FK (ManyToOne)
29                    (rel.fields[0].clone(), rel.references[0].clone())
30                } else {
31                    // The other side has the FK (OneToMany) — find the back-reference
32                    find_back_reference(model, related_model)
33                        .unwrap_or_else(|| ("id".into(), "id".into()))
34                };
35
36                relations.push(RelationInfo {
37                    field,
38                    related_model,
39                    relation_type: rel.relation_type,
40                    fk_column: to_snake_case(&fk_column),
41                    ref_column: to_snake_case(&ref_column),
42                });
43            }
44        } else if field.is_list {
45            // Implicit relation (e.g., posts Post[])
46            if let FieldKind::Model(related_name) = &field.field_type {
47                let related = schema.models.iter().find(|m| m.name == *related_name);
48                if let Some(related_model) = related {
49                    let (fk_column, ref_column) = find_back_reference(model, related_model)
50                        .unwrap_or_else(|| ("id".into(), "id".into()));
51
52                    relations.push(RelationInfo {
53                        field,
54                        related_model,
55                        relation_type: RelationType::OneToMany,
56                        fk_column: to_snake_case(&fk_column),
57                        ref_column: to_snake_case(&ref_column),
58                    });
59                }
60            }
61        }
62    }
63
64    relations
65}
66
67/// Find the back-reference from the related model to this model.
68/// E.g., for User.posts (Post[]), find Post.authorId @relation(fields: [authorId], references: [id])
69fn find_back_reference(parent: &Model, child: &Model) -> Option<(String, String)> {
70    for field in &child.fields {
71        if let Some(rel) = &field.relation
72            && rel.related_model == parent.name
73            && !rel.fields.is_empty()
74        {
75            return Some((rel.fields[0].clone(), rel.references[0].clone()));
76        }
77    }
78    None
79}
80
81/// Generate the Include struct, WithRelations struct, and include-aware query methods.
82pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
83    let relations = collect_relations(model, schema);
84
85    if relations.is_empty() {
86        return quote! {};
87    }
88
89    let model_ident = format_ident!("{}", model.name);
90    let include_name = format_ident!("{}Include", model.name);
91    let with_relations_name = format_ident!("{}WithRelations", model.name);
92
93    // Include struct fields
94    let include_fields: Vec<TokenStream> = relations
95        .iter()
96        .map(|r| {
97            let name = format_ident!("{}", to_snake_case(&r.field.name));
98            quote! { pub #name: bool }
99        })
100        .collect();
101
102    // WithRelations struct fields
103    let with_rel_fields: Vec<TokenStream> = relations
104        .iter()
105        .map(|r| {
106            let name = format_ident!("{}", to_snake_case(&r.field.name));
107            let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
108            let related_struct = format_ident!("{}", r.related_model.name);
109
110            match r.relation_type {
111                RelationType::OneToMany | RelationType::ManyToMany => {
112                    quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
113                }
114                RelationType::OneToOne | RelationType::ManyToOne => {
115                    quote! { pub #name: Option<super::#related_mod::#related_struct> }
116                }
117            }
118        })
119        .collect();
120
121    // Generate the batched loading logic
122    let load_arms = gen_load_arms(&relations, model);
123
124    quote! {
125        #[derive(Debug, Clone, Default)]
126        pub struct #include_name {
127            #(#include_fields,)*
128        }
129
130        #[derive(Debug, Clone, Serialize, Deserialize)]
131        pub struct #with_relations_name {
132            #[serde(flatten)]
133            pub data: #model_ident,
134            #(#with_rel_fields,)*
135        }
136
137        impl #model_ident {
138            /// Load relations for a batch of records.
139            pub(crate) async fn load_relations(
140                records: Vec<#model_ident>,
141                include: &#include_name,
142                client: &DatabaseClient,
143            ) -> Result<Vec<#with_relations_name>, FerriormError> {
144                #load_arms
145            }
146        }
147    }
148}
149
150/// Helper: generate code to load related rows using QueryBuilder and dispatch
151/// to both Postgres and Sqlite.
152fn gen_batched_load_many(
153    rel: &RelationInfo<'_>,
154    load_var: &proc_macro2::Ident,
155    field_name: &proc_macro2::Ident,
156    id_source_ident: &proc_macro2::Ident,
157    lookup_col_str: &str,
158    insert_key_ident: &proc_macro2::Ident,
159    fk_optional: bool,
160) -> TokenStream {
161    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
162    let related_struct = format_ident!("{}", rel.related_model.name);
163    let related_table = &rel.related_model.db_name;
164
165    let select_base = format!(
166        r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
167        related_table, lookup_col_str
168    );
169
170    // Generate the row insertion code based on whether the FK is optional
171    let insert_row_code = if fk_optional {
172        quote! {
173            if let Some(key) = row.#insert_key_ident.clone() {
174                #load_var.entry(key).or_default().push(row);
175            }
176        }
177    } else {
178        quote! {
179            #load_var.entry(row.#insert_key_ident.clone()).or_default().push(row);
180        }
181    };
182
183    quote! {
184        let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
185        if include.#field_name {
186            let ids: Vec<String> = records.iter()
187                .map(|r| r.#id_source_ident.clone())
188                .collect();
189
190            if !ids.is_empty() {
191                macro_rules! build_in_query {
192                    ($db:ty) => {{
193                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
194                        let mut sep = qb.separated(", ");
195                        for id in &ids {
196                            sep.push_bind(id.clone());
197                        }
198                        qb.push(")");
199                        qb
200                    }};
201                }
202
203                macro_rules! insert_rows {
204                    ($rows:expr) => {
205                        for row in $rows {
206                            #insert_row_code
207                        }
208                    };
209                }
210
211                match client {
212                    DatabaseClient::Postgres(pool) => {
213                        let mut qb = build_in_query!(sqlx::Postgres);
214                        let related_rows: Vec<super::#related_mod::#related_struct> =
215                            qb.build_query_as().fetch_all(pool).await
216                                .map_err(FerriormError::from)?;
217                        insert_rows!(related_rows);
218                    }
219                    DatabaseClient::Sqlite(pool) => {
220                        let mut qb = build_in_query!(sqlx::Sqlite);
221                        let related_rows: Vec<super::#related_mod::#related_struct> =
222                            qb.build_query_as().fetch_all(pool).await
223                                .map_err(FerriormError::from)?;
224                        insert_rows!(related_rows);
225                    }
226                }
227            }
228        }
229    }
230}
231
232/// Helper: generate code to load related rows for a single-value (OneToOne / ManyToOne) relation.
233fn gen_batched_load_one(
234    rel: &RelationInfo<'_>,
235    load_var: &proc_macro2::Ident,
236    field_name: &proc_macro2::Ident,
237    id_source_ident: &proc_macro2::Ident,
238    lookup_col_str: &str,
239    insert_key_ident: &proc_macro2::Ident,
240    fk_is_optional: bool,
241) -> TokenStream {
242    let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
243    let related_struct = format_ident!("{}", rel.related_model.name);
244    let related_table = &rel.related_model.db_name;
245
246    let select_base = format!(
247        r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
248        related_table, lookup_col_str
249    );
250
251    // When the FK field is optional (Option<String>), use filter_map to skip None values.
252    let ids_collect = if fk_is_optional {
253        quote! {
254            let ids: Vec<String> = records.iter()
255                .filter_map(|r| r.#id_source_ident.clone())
256                .collect();
257        }
258    } else {
259        quote! {
260            let ids: Vec<String> = records.iter()
261                .map(|r| r.#id_source_ident.clone())
262                .collect();
263        }
264    };
265
266    quote! {
267        let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
268        if include.#field_name {
269            #ids_collect
270
271            if !ids.is_empty() {
272                macro_rules! build_in_query {
273                    ($db:ty) => {{
274                        let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
275                        let mut sep = qb.separated(", ");
276                        for id in &ids {
277                            sep.push_bind(id.clone());
278                        }
279                        qb.push(")");
280                        qb
281                    }};
282                }
283
284                match client {
285                    DatabaseClient::Postgres(pool) => {
286                        let mut qb = build_in_query!(sqlx::Postgres);
287                        let related_rows: Vec<super::#related_mod::#related_struct> =
288                            qb.build_query_as().fetch_all(pool).await
289                                .map_err(FerriormError::from)?;
290                        for row in related_rows {
291                            #load_var.insert(row.#insert_key_ident.clone(), row);
292                        }
293                    }
294                    DatabaseClient::Sqlite(pool) => {
295                        let mut qb = build_in_query!(sqlx::Sqlite);
296                        let related_rows: Vec<super::#related_mod::#related_struct> =
297                            qb.build_query_as().fetch_all(pool).await
298                                .map_err(FerriormError::from)?;
299                        for row in related_rows {
300                            #load_var.insert(row.#insert_key_ident.clone(), row);
301                        }
302                    }
303                }
304            }
305        }
306    }
307}
308
309fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
310    let _model_ident = format_ident!("{}", model.name);
311    let with_relations_name = format_ident!("{}WithRelations", model.name);
312
313    let mut relation_loads = vec![];
314    let mut field_inits = vec![];
315
316    for rel in relations {
317        let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
318        let fk_col_str = &rel.fk_column;
319        let ref_col_str = &rel.ref_column;
320        let fk_col_ident = format_ident!("{}", rel.fk_column);
321        let ref_col_ident = format_ident!("{}", rel.ref_column);
322
323        match rel.relation_type {
324            RelationType::OneToMany | RelationType::ManyToMany => {
325                // Batched loading: SELECT * FROM related WHERE fk IN (parent_ids)
326                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
327
328                // Check if the FK column on the child (related) model is optional
329                let child_fk_optional = rel
330                    .related_model
331                    .fields
332                    .iter()
333                    .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_optional);
334
335                relation_loads.push(gen_batched_load_many(
336                    rel,
337                    &load_var,
338                    &field_name,
339                    &ref_col_ident,
340                    fk_col_str,
341                    &fk_col_ident,
342                    child_fk_optional,
343                ));
344
345                let ref_col_ident = format_ident!("{}", ref_col_str);
346                field_inits.push(quote! {
347                    #field_name: if include.#field_name {
348                        Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
349                    } else {
350                        None
351                    }
352                });
353            }
354            RelationType::OneToOne | RelationType::ManyToOne => {
355                // For ManyToOne (this model has the FK), we can batch load the parent
356                let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
357                let fk_field = format_ident!("{}", fk_col_str);
358
359                // Check if this model has the FK field as a scalar
360                let fk_model_field = model
361                    .fields
362                    .iter()
363                    .find(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
364                let has_fk = fk_model_field.is_some();
365                let fk_is_optional = fk_model_field.map(|f| f.is_optional).unwrap_or(false);
366
367                if has_fk {
368                    relation_loads.push(gen_batched_load_one(
369                        rel,
370                        &load_var,
371                        &field_name,
372                        &fk_field,
373                        ref_col_str,
374                        &ref_col_ident,
375                        fk_is_optional,
376                    ));
377
378                    if fk_is_optional {
379                        field_inits.push(quote! {
380                            #field_name: if include.#field_name {
381                                r.#fk_field.as_ref().and_then(|fk| #load_var.remove(fk))
382                            } else {
383                                None
384                            }
385                        });
386                    } else {
387                        field_inits.push(quote! {
388                            #field_name: if include.#field_name {
389                                #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
390                            } else {
391                                None
392                            }
393                        });
394                    }
395                } else {
396                    // The FK is on the other side (e.g., User.profile where Profile has userId)
397                    // Batch load: SELECT * FROM profiles WHERE user_id IN (user_ids)
398                    let ref_col_ident = format_ident!("{}", ref_col_str);
399
400                    relation_loads.push(gen_batched_load_one(
401                        rel,
402                        &load_var,
403                        &field_name,
404                        &ref_col_ident,
405                        fk_col_str,
406                        &fk_col_ident,
407                        false,
408                    ));
409
410                    field_inits.push(quote! {
411                        #field_name: if include.#field_name {
412                            #load_var.remove(&r.#ref_col_ident)
413                        } else {
414                            None
415                        }
416                    });
417                }
418            }
419        }
420    }
421
422    quote! {
423        #(#relation_loads)*
424
425        let mut results = Vec::with_capacity(records.len());
426        for r in records {
427            results.push(#with_relations_name {
428                #(#field_inits,)*
429                data: r,
430            });
431        }
432        Ok(results)
433    }
434}
435
436/// Generate `include()` and `exec_with_relations()` methods for FindMany.
437pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
438    let relations = collect_relations(model, schema);
439    if relations.is_empty() {
440        return quote! {};
441    }
442
443    let model_ident = format_ident!("{}", model.name);
444    let include_name = format_ident!("{}Include", model.name);
445    let with_relations_name = format_ident!("{}WithRelations", model.name);
446
447    quote! {
448        impl<'a> FindManyQuery<'a> {
449            pub fn include(self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
450                FindManyWithIncludeQuery {
451                    inner: self,
452                    include,
453                }
454            }
455        }
456
457        pub struct FindManyWithIncludeQuery<'a> {
458            inner: FindManyQuery<'a>,
459            include: #include_name,
460        }
461
462        impl<'a> FindManyWithIncludeQuery<'a> {
463            pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
464                let include = self.include;
465                let client = self.inner.client;
466                let records = FindManyQuery {
467                    client,
468                    r#where: self.inner.r#where,
469                    order_by: self.inner.order_by,
470                    skip: self.inner.skip,
471                    take: self.inner.take,
472                }.exec().await?;
473                #model_ident::load_relations(records, &include, client).await
474            }
475        }
476
477        impl<'a> FindUniqueQuery<'a> {
478            pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
479                FindUniqueWithIncludeQuery {
480                    inner: self,
481                    include,
482                }
483            }
484        }
485
486        pub struct FindUniqueWithIncludeQuery<'a> {
487            inner: FindUniqueQuery<'a>,
488            include: #include_name,
489        }
490
491        impl<'a> FindUniqueWithIncludeQuery<'a> {
492            pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
493                let include = self.include;
494                let client = self.inner.client;
495                let record = FindUniqueQuery {
496                    client,
497                    r#where: self.inner.r#where,
498                }.exec().await?;
499                match record {
500                    Some(r) => {
501                        let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
502                        Ok(results.pop())
503                    }
504                    None => Ok(None),
505                }
506            }
507        }
508    }
509}