1use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8pub struct RelationInfo<'a> {
10 pub field: &'a Field,
11 pub related_model: &'a Model,
12 pub relation_type: RelationType,
13 pub fk_column: String,
15 pub ref_column: String,
17}
18
19#[must_use]
21pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
22 let mut relations = Vec::new();
23
24 for field in &model.fields {
25 if let Some(rel) = &field.relation {
26 let related = schema.models.iter().find(|m| m.name == rel.related_model);
27 if let Some(related_model) = related {
28 let (fk_column, ref_column) = if rel.fields.is_empty() {
29 find_back_reference(model, related_model, rel.name.as_deref())
33 .unwrap_or_else(|| ("id".into(), "id".into()))
34 } else {
35 (rel.fields[0].clone(), rel.references[0].clone())
37 };
38
39 relations.push(RelationInfo {
40 field,
41 related_model,
42 relation_type: rel.relation_type,
43 fk_column: to_snake_case(&fk_column),
44 ref_column: to_snake_case(&ref_column),
45 });
46 }
47 } else if field.is_list {
48 if let FieldKind::Model(related_name) = &field.field_type {
53 let related = schema.models.iter().find(|m| m.name == *related_name);
54 if let Some(related_model) = related {
55 let (fk_column, ref_column) = find_back_reference(model, related_model, None)
56 .unwrap_or_else(|| ("id".into(), "id".into()));
57
58 relations.push(RelationInfo {
59 field,
60 related_model,
61 relation_type: RelationType::OneToMany,
62 fk_column: to_snake_case(&fk_column),
63 ref_column: to_snake_case(&ref_column),
64 });
65 }
66 }
67 }
68 }
69
70 relations
71}
72
73fn find_back_reference(
79 parent: &Model,
80 child: &Model,
81 name: Option<&str>,
82) -> Option<(String, String)> {
83 for field in &child.fields {
84 if let Some(rel) = &field.relation
85 && rel.related_model == parent.name
86 && !rel.fields.is_empty()
87 && (name.is_none() || rel.name.as_deref() == name)
88 {
89 return Some((rel.fields[0].clone(), rel.references[0].clone()));
90 }
91 }
92 None
93}
94
95#[must_use]
97pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
98 let relations = collect_relations(model, schema);
99
100 if relations.is_empty() {
101 return quote! {};
102 }
103
104 let model_ident = format_ident!("{}", model.name);
105 let include_name = format_ident!("{}Include", model.name);
106 let with_relations_name = format_ident!("{}WithRelations", model.name);
107
108 let include_fields: Vec<TokenStream> = relations
110 .iter()
111 .map(|r| {
112 let name = format_ident!("{}", to_snake_case(&r.field.name));
113 quote! { pub #name: bool }
114 })
115 .collect();
116
117 let with_rel_fields: Vec<TokenStream> = relations
119 .iter()
120 .map(|r| {
121 let name = format_ident!("{}", to_snake_case(&r.field.name));
122 let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
123 let related_struct = format_ident!("{}", r.related_model.name);
124
125 match r.relation_type {
126 RelationType::OneToMany | RelationType::ManyToMany => {
127 quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
128 }
129 RelationType::OneToOne | RelationType::ManyToOne => {
130 quote! { pub #name: Option<super::#related_mod::#related_struct> }
131 }
132 }
133 })
134 .collect();
135
136 let load_arms = gen_load_arms(&relations, model);
138
139 quote! {
140 #[derive(Debug, Clone, Default)]
141 pub struct #include_name {
142 #(#include_fields,)*
143 }
144
145 #[derive(Debug, Clone, Serialize, Deserialize)]
146 pub struct #with_relations_name {
147 #[serde(flatten)]
148 pub data: #model_ident,
149 #(#with_rel_fields,)*
150 }
151
152 impl #model_ident {
153 pub(crate) async fn load_relations(
155 records: Vec<#model_ident>,
156 include: &#include_name,
157 client: &DatabaseClient,
158 ) -> Result<Vec<#with_relations_name>, FerriormError> {
159 #load_arms
160 }
161 }
162 }
163}
164
165fn gen_batched_load_many(
168 rel: &RelationInfo<'_>,
169 load_var: &proc_macro2::Ident,
170 field_name: &proc_macro2::Ident,
171 id_source_ident: &proc_macro2::Ident,
172 lookup_col_str: &str,
173 insert_key_ident: &proc_macro2::Ident,
174 fk_optional: bool,
175) -> TokenStream {
176 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
177 let related_struct = format_ident!("{}", rel.related_model.name);
178 let related_table = &rel.related_model.db_name;
179
180 let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
181
182 let insert_row_code = if fk_optional {
184 quote! {
185 if let Some(key) = row.#insert_key_ident.clone() {
186 #load_var.entry(key).or_default().push(row);
187 }
188 }
189 } else {
190 quote! {
191 #load_var.entry(row.#insert_key_ident.clone()).or_default().push(row);
192 }
193 };
194
195 quote! {
196 let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
197 if include.#field_name {
198 let ids: Vec<String> = records.iter()
199 .map(|r| r.#id_source_ident.clone())
200 .collect();
201
202 if !ids.is_empty() {
203 macro_rules! build_in_query {
204 ($db:ty) => {{
205 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
206 let mut sep = qb.separated(", ");
207 for id in &ids {
208 sep.push_bind(id.clone());
209 }
210 qb.push(")");
211 qb
212 }};
213 }
214
215 macro_rules! insert_rows {
216 ($rows:expr) => {
217 for row in $rows {
218 #insert_row_code
219 }
220 };
221 }
222
223 match client {
224 DatabaseClient::Postgres(pool) => {
225 let mut qb = build_in_query!(sqlx::Postgres);
226 let related_rows: Vec<super::#related_mod::#related_struct> =
227 qb.build_query_as().fetch_all(pool).await
228 .map_err(FerriormError::from)?;
229 insert_rows!(related_rows);
230 }
231 DatabaseClient::Sqlite(pool) => {
232 let mut qb = build_in_query!(sqlx::Sqlite);
233 let related_rows: Vec<super::#related_mod::#related_struct> =
234 qb.build_query_as().fetch_all(pool).await
235 .map_err(FerriormError::from)?;
236 insert_rows!(related_rows);
237 }
238 }
239 }
240 }
241 }
242}
243
244fn gen_batched_load_one(
246 rel: &RelationInfo<'_>,
247 load_var: &proc_macro2::Ident,
248 field_name: &proc_macro2::Ident,
249 id_source_ident: &proc_macro2::Ident,
250 lookup_col_str: &str,
251 insert_key_ident: &proc_macro2::Ident,
252 fk_is_optional: bool,
253) -> TokenStream {
254 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
255 let related_struct = format_ident!("{}", rel.related_model.name);
256 let related_table = &rel.related_model.db_name;
257
258 let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
259
260 let ids_collect = if fk_is_optional {
262 quote! {
263 let ids: Vec<String> = records.iter()
264 .filter_map(|r| r.#id_source_ident.clone())
265 .collect();
266 }
267 } else {
268 quote! {
269 let ids: Vec<String> = records.iter()
270 .map(|r| r.#id_source_ident.clone())
271 .collect();
272 }
273 };
274
275 quote! {
276 let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
277 if include.#field_name {
278 #ids_collect
279
280 if !ids.is_empty() {
281 macro_rules! build_in_query {
282 ($db:ty) => {{
283 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
284 let mut sep = qb.separated(", ");
285 for id in &ids {
286 sep.push_bind(id.clone());
287 }
288 qb.push(")");
289 qb
290 }};
291 }
292
293 match client {
294 DatabaseClient::Postgres(pool) => {
295 let mut qb = build_in_query!(sqlx::Postgres);
296 let related_rows: Vec<super::#related_mod::#related_struct> =
297 qb.build_query_as().fetch_all(pool).await
298 .map_err(FerriormError::from)?;
299 for row in related_rows {
300 #load_var.insert(row.#insert_key_ident.clone(), row);
301 }
302 }
303 DatabaseClient::Sqlite(pool) => {
304 let mut qb = build_in_query!(sqlx::Sqlite);
305 let related_rows: Vec<super::#related_mod::#related_struct> =
306 qb.build_query_as().fetch_all(pool).await
307 .map_err(FerriormError::from)?;
308 for row in related_rows {
309 #load_var.insert(row.#insert_key_ident.clone(), row);
310 }
311 }
312 }
313 }
314 }
315 }
316}
317
318#[allow(clippy::too_many_lines)]
319fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
320 let _model_ident = format_ident!("{}", model.name);
321 let with_relations_name = format_ident!("{}WithRelations", model.name);
322
323 let mut relation_loads = vec![];
324 let mut field_inits = vec![];
325
326 for rel in relations {
327 let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
328 let fk_col_str = &rel.fk_column;
329 let ref_col_str = &rel.ref_column;
330 let fk_col_ident = format_ident!("{}", rel.fk_column);
331 let ref_col_ident = format_ident!("{}", rel.ref_column);
332
333 match rel.relation_type {
334 RelationType::OneToMany | RelationType::ManyToMany => {
335 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
337
338 let child_fk_optional = rel
340 .related_model
341 .fields
342 .iter()
343 .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_optional);
344
345 relation_loads.push(gen_batched_load_many(
346 rel,
347 &load_var,
348 &field_name,
349 &ref_col_ident,
350 fk_col_str,
351 &fk_col_ident,
352 child_fk_optional,
353 ));
354
355 let ref_col_ident = format_ident!("{}", ref_col_str);
356 field_inits.push(quote! {
357 #field_name: if include.#field_name {
358 Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
359 } else {
360 None
361 }
362 });
363 }
364 RelationType::OneToOne | RelationType::ManyToOne => {
365 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
367 let fk_field = format_ident!("{}", fk_col_str);
368
369 let fk_model_field = model
371 .fields
372 .iter()
373 .find(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
374 let has_fk = fk_model_field.is_some();
375 let fk_is_optional = fk_model_field.is_some_and(|f| f.is_optional);
376
377 if has_fk {
378 relation_loads.push(gen_batched_load_one(
379 rel,
380 &load_var,
381 &field_name,
382 &fk_field,
383 ref_col_str,
384 &ref_col_ident,
385 fk_is_optional,
386 ));
387
388 if fk_is_optional {
389 field_inits.push(quote! {
390 #field_name: if include.#field_name {
391 r.#fk_field.as_ref().and_then(|fk| #load_var.remove(fk))
392 } else {
393 None
394 }
395 });
396 } else {
397 field_inits.push(quote! {
398 #field_name: if include.#field_name {
399 #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
400 } else {
401 None
402 }
403 });
404 }
405 } else {
406 let ref_col_ident = format_ident!("{}", ref_col_str);
409
410 relation_loads.push(gen_batched_load_one(
411 rel,
412 &load_var,
413 &field_name,
414 &ref_col_ident,
415 fk_col_str,
416 &fk_col_ident,
417 false,
418 ));
419
420 field_inits.push(quote! {
421 #field_name: if include.#field_name {
422 #load_var.remove(&r.#ref_col_ident)
423 } else {
424 None
425 }
426 });
427 }
428 }
429 }
430 }
431
432 quote! {
433 #(#relation_loads)*
434
435 let mut results = Vec::with_capacity(records.len());
436 for r in records {
437 results.push(#with_relations_name {
438 #(#field_inits,)*
439 data: r,
440 });
441 }
442 Ok(results)
443 }
444}
445
446#[must_use]
448pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
449 let relations = collect_relations(model, schema);
450 if relations.is_empty() {
451 return quote! {};
452 }
453
454 let model_ident = format_ident!("{}", model.name);
455 let include_name = format_ident!("{}Include", model.name);
456 let with_relations_name = format_ident!("{}WithRelations", model.name);
457
458 quote! {
459 impl<'a> FindManyQuery<'a> {
460 pub fn include(self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
461 FindManyWithIncludeQuery {
462 inner: self,
463 include,
464 }
465 }
466 }
467
468 pub struct FindManyWithIncludeQuery<'a> {
469 inner: FindManyQuery<'a>,
470 include: #include_name,
471 }
472
473 impl<'a> FindManyWithIncludeQuery<'a> {
474 pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
475 let include = self.include;
476 let client = self.inner.client;
477 let records = FindManyQuery {
478 client,
479 r#where: self.inner.r#where,
480 order_by: self.inner.order_by,
481 skip: self.inner.skip,
482 take: self.inner.take,
483 }.exec().await?;
484 #model_ident::load_relations(records, &include, client).await
485 }
486 }
487
488 impl<'a> FindUniqueQuery<'a> {
489 pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
490 FindUniqueWithIncludeQuery {
491 inner: self,
492 include,
493 }
494 }
495 }
496
497 pub struct FindUniqueWithIncludeQuery<'a> {
498 inner: FindUniqueQuery<'a>,
499 include: #include_name,
500 }
501
502 impl<'a> FindUniqueWithIncludeQuery<'a> {
503 pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
504 let include = self.include;
505 let client = self.inner.client;
506 let record = FindUniqueQuery {
507 client,
508 r#where: self.inner.r#where,
509 }.exec().await?;
510 match record {
511 Some(r) => {
512 let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
513 Ok(results.pop())
514 }
515 None => Ok(None),
516 }
517 }
518 }
519 }
520}