1use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8pub struct RelationInfo<'a> {
10 pub field: &'a Field,
11 pub related_model: &'a Model,
12 pub relation_type: RelationType,
13 pub fk_column: String,
15 pub ref_column: String,
17}
18
19#[must_use]
21pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
22 let mut relations = Vec::new();
23
24 for field in &model.fields {
25 if let Some(rel) = &field.relation {
26 let related = schema.models.iter().find(|m| m.name == rel.related_model);
27 if let Some(related_model) = related {
28 let (fk_column, ref_column) = if rel.fields.is_empty() {
29 find_back_reference(model, related_model)
31 .unwrap_or_else(|| ("id".into(), "id".into()))
32 } else {
33 (rel.fields[0].clone(), rel.references[0].clone())
35 };
36
37 relations.push(RelationInfo {
38 field,
39 related_model,
40 relation_type: rel.relation_type,
41 fk_column: to_snake_case(&fk_column),
42 ref_column: to_snake_case(&ref_column),
43 });
44 }
45 } else if field.is_list {
46 if let FieldKind::Model(related_name) = &field.field_type {
48 let related = schema.models.iter().find(|m| m.name == *related_name);
49 if let Some(related_model) = related {
50 let (fk_column, ref_column) = find_back_reference(model, related_model)
51 .unwrap_or_else(|| ("id".into(), "id".into()));
52
53 relations.push(RelationInfo {
54 field,
55 related_model,
56 relation_type: RelationType::OneToMany,
57 fk_column: to_snake_case(&fk_column),
58 ref_column: to_snake_case(&ref_column),
59 });
60 }
61 }
62 }
63 }
64
65 relations
66}
67
68fn find_back_reference(parent: &Model, child: &Model) -> Option<(String, String)> {
71 for field in &child.fields {
72 if let Some(rel) = &field.relation
73 && rel.related_model == parent.name
74 && !rel.fields.is_empty()
75 {
76 return Some((rel.fields[0].clone(), rel.references[0].clone()));
77 }
78 }
79 None
80}
81
82#[must_use]
84pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
85 let relations = collect_relations(model, schema);
86
87 if relations.is_empty() {
88 return quote! {};
89 }
90
91 let model_ident = format_ident!("{}", model.name);
92 let include_name = format_ident!("{}Include", model.name);
93 let with_relations_name = format_ident!("{}WithRelations", model.name);
94
95 let include_fields: Vec<TokenStream> = relations
97 .iter()
98 .map(|r| {
99 let name = format_ident!("{}", to_snake_case(&r.field.name));
100 quote! { pub #name: bool }
101 })
102 .collect();
103
104 let with_rel_fields: Vec<TokenStream> = relations
106 .iter()
107 .map(|r| {
108 let name = format_ident!("{}", to_snake_case(&r.field.name));
109 let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
110 let related_struct = format_ident!("{}", r.related_model.name);
111
112 match r.relation_type {
113 RelationType::OneToMany | RelationType::ManyToMany => {
114 quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
115 }
116 RelationType::OneToOne | RelationType::ManyToOne => {
117 quote! { pub #name: Option<super::#related_mod::#related_struct> }
118 }
119 }
120 })
121 .collect();
122
123 let load_arms = gen_load_arms(&relations, model);
125
126 quote! {
127 #[derive(Debug, Clone, Default)]
128 pub struct #include_name {
129 #(#include_fields,)*
130 }
131
132 #[derive(Debug, Clone, Serialize, Deserialize)]
133 pub struct #with_relations_name {
134 #[serde(flatten)]
135 pub data: #model_ident,
136 #(#with_rel_fields,)*
137 }
138
139 impl #model_ident {
140 pub(crate) async fn load_relations(
142 records: Vec<#model_ident>,
143 include: &#include_name,
144 client: &DatabaseClient,
145 ) -> Result<Vec<#with_relations_name>, FerriormError> {
146 #load_arms
147 }
148 }
149 }
150}
151
152fn gen_batched_load_many(
155 rel: &RelationInfo<'_>,
156 load_var: &proc_macro2::Ident,
157 field_name: &proc_macro2::Ident,
158 id_source_ident: &proc_macro2::Ident,
159 lookup_col_str: &str,
160 insert_key_ident: &proc_macro2::Ident,
161 fk_optional: bool,
162) -> TokenStream {
163 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
164 let related_struct = format_ident!("{}", rel.related_model.name);
165 let related_table = &rel.related_model.db_name;
166
167 let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
168
169 let insert_row_code = if fk_optional {
171 quote! {
172 if let Some(key) = row.#insert_key_ident.clone() {
173 #load_var.entry(key).or_default().push(row);
174 }
175 }
176 } else {
177 quote! {
178 #load_var.entry(row.#insert_key_ident.clone()).or_default().push(row);
179 }
180 };
181
182 quote! {
183 let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
184 if include.#field_name {
185 let ids: Vec<String> = records.iter()
186 .map(|r| r.#id_source_ident.clone())
187 .collect();
188
189 if !ids.is_empty() {
190 macro_rules! build_in_query {
191 ($db:ty) => {{
192 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
193 let mut sep = qb.separated(", ");
194 for id in &ids {
195 sep.push_bind(id.clone());
196 }
197 qb.push(")");
198 qb
199 }};
200 }
201
202 macro_rules! insert_rows {
203 ($rows:expr) => {
204 for row in $rows {
205 #insert_row_code
206 }
207 };
208 }
209
210 match client {
211 DatabaseClient::Postgres(pool) => {
212 let mut qb = build_in_query!(sqlx::Postgres);
213 let related_rows: Vec<super::#related_mod::#related_struct> =
214 qb.build_query_as().fetch_all(pool).await
215 .map_err(FerriormError::from)?;
216 insert_rows!(related_rows);
217 }
218 DatabaseClient::Sqlite(pool) => {
219 let mut qb = build_in_query!(sqlx::Sqlite);
220 let related_rows: Vec<super::#related_mod::#related_struct> =
221 qb.build_query_as().fetch_all(pool).await
222 .map_err(FerriormError::from)?;
223 insert_rows!(related_rows);
224 }
225 }
226 }
227 }
228 }
229}
230
231fn gen_batched_load_one(
233 rel: &RelationInfo<'_>,
234 load_var: &proc_macro2::Ident,
235 field_name: &proc_macro2::Ident,
236 id_source_ident: &proc_macro2::Ident,
237 lookup_col_str: &str,
238 insert_key_ident: &proc_macro2::Ident,
239 fk_is_optional: bool,
240) -> TokenStream {
241 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
242 let related_struct = format_ident!("{}", rel.related_model.name);
243 let related_table = &rel.related_model.db_name;
244
245 let select_base = format!(r#"SELECT * FROM "{related_table}" WHERE "{lookup_col_str}" IN ("#);
246
247 let ids_collect = if fk_is_optional {
249 quote! {
250 let ids: Vec<String> = records.iter()
251 .filter_map(|r| r.#id_source_ident.clone())
252 .collect();
253 }
254 } else {
255 quote! {
256 let ids: Vec<String> = records.iter()
257 .map(|r| r.#id_source_ident.clone())
258 .collect();
259 }
260 };
261
262 quote! {
263 let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
264 if include.#field_name {
265 #ids_collect
266
267 if !ids.is_empty() {
268 macro_rules! build_in_query {
269 ($db:ty) => {{
270 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
271 let mut sep = qb.separated(", ");
272 for id in &ids {
273 sep.push_bind(id.clone());
274 }
275 qb.push(")");
276 qb
277 }};
278 }
279
280 match client {
281 DatabaseClient::Postgres(pool) => {
282 let mut qb = build_in_query!(sqlx::Postgres);
283 let related_rows: Vec<super::#related_mod::#related_struct> =
284 qb.build_query_as().fetch_all(pool).await
285 .map_err(FerriormError::from)?;
286 for row in related_rows {
287 #load_var.insert(row.#insert_key_ident.clone(), row);
288 }
289 }
290 DatabaseClient::Sqlite(pool) => {
291 let mut qb = build_in_query!(sqlx::Sqlite);
292 let related_rows: Vec<super::#related_mod::#related_struct> =
293 qb.build_query_as().fetch_all(pool).await
294 .map_err(FerriormError::from)?;
295 for row in related_rows {
296 #load_var.insert(row.#insert_key_ident.clone(), row);
297 }
298 }
299 }
300 }
301 }
302 }
303}
304
305#[allow(clippy::too_many_lines)]
306fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
307 let _model_ident = format_ident!("{}", model.name);
308 let with_relations_name = format_ident!("{}WithRelations", model.name);
309
310 let mut relation_loads = vec![];
311 let mut field_inits = vec![];
312
313 for rel in relations {
314 let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
315 let fk_col_str = &rel.fk_column;
316 let ref_col_str = &rel.ref_column;
317 let fk_col_ident = format_ident!("{}", rel.fk_column);
318 let ref_col_ident = format_ident!("{}", rel.ref_column);
319
320 match rel.relation_type {
321 RelationType::OneToMany | RelationType::ManyToMany => {
322 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
324
325 let child_fk_optional = rel
327 .related_model
328 .fields
329 .iter()
330 .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_optional);
331
332 relation_loads.push(gen_batched_load_many(
333 rel,
334 &load_var,
335 &field_name,
336 &ref_col_ident,
337 fk_col_str,
338 &fk_col_ident,
339 child_fk_optional,
340 ));
341
342 let ref_col_ident = format_ident!("{}", ref_col_str);
343 field_inits.push(quote! {
344 #field_name: if include.#field_name {
345 Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
346 } else {
347 None
348 }
349 });
350 }
351 RelationType::OneToOne | RelationType::ManyToOne => {
352 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
354 let fk_field = format_ident!("{}", fk_col_str);
355
356 let fk_model_field = model
358 .fields
359 .iter()
360 .find(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
361 let has_fk = fk_model_field.is_some();
362 let fk_is_optional = fk_model_field.is_some_and(|f| f.is_optional);
363
364 if has_fk {
365 relation_loads.push(gen_batched_load_one(
366 rel,
367 &load_var,
368 &field_name,
369 &fk_field,
370 ref_col_str,
371 &ref_col_ident,
372 fk_is_optional,
373 ));
374
375 if fk_is_optional {
376 field_inits.push(quote! {
377 #field_name: if include.#field_name {
378 r.#fk_field.as_ref().and_then(|fk| #load_var.remove(fk))
379 } else {
380 None
381 }
382 });
383 } else {
384 field_inits.push(quote! {
385 #field_name: if include.#field_name {
386 #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
387 } else {
388 None
389 }
390 });
391 }
392 } else {
393 let ref_col_ident = format_ident!("{}", ref_col_str);
396
397 relation_loads.push(gen_batched_load_one(
398 rel,
399 &load_var,
400 &field_name,
401 &ref_col_ident,
402 fk_col_str,
403 &fk_col_ident,
404 false,
405 ));
406
407 field_inits.push(quote! {
408 #field_name: if include.#field_name {
409 #load_var.remove(&r.#ref_col_ident)
410 } else {
411 None
412 }
413 });
414 }
415 }
416 }
417 }
418
419 quote! {
420 #(#relation_loads)*
421
422 let mut results = Vec::with_capacity(records.len());
423 for r in records {
424 results.push(#with_relations_name {
425 #(#field_inits,)*
426 data: r,
427 });
428 }
429 Ok(results)
430 }
431}
432
433#[must_use]
435pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
436 let relations = collect_relations(model, schema);
437 if relations.is_empty() {
438 return quote! {};
439 }
440
441 let model_ident = format_ident!("{}", model.name);
442 let include_name = format_ident!("{}Include", model.name);
443 let with_relations_name = format_ident!("{}WithRelations", model.name);
444
445 quote! {
446 impl<'a> FindManyQuery<'a> {
447 pub fn include(self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
448 FindManyWithIncludeQuery {
449 inner: self,
450 include,
451 }
452 }
453 }
454
455 pub struct FindManyWithIncludeQuery<'a> {
456 inner: FindManyQuery<'a>,
457 include: #include_name,
458 }
459
460 impl<'a> FindManyWithIncludeQuery<'a> {
461 pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
462 let include = self.include;
463 let client = self.inner.client;
464 let records = FindManyQuery {
465 client,
466 r#where: self.inner.r#where,
467 order_by: self.inner.order_by,
468 skip: self.inner.skip,
469 take: self.inner.take,
470 }.exec().await?;
471 #model_ident::load_relations(records, &include, client).await
472 }
473 }
474
475 impl<'a> FindUniqueQuery<'a> {
476 pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
477 FindUniqueWithIncludeQuery {
478 inner: self,
479 include,
480 }
481 }
482 }
483
484 pub struct FindUniqueWithIncludeQuery<'a> {
485 inner: FindUniqueQuery<'a>,
486 include: #include_name,
487 }
488
489 impl<'a> FindUniqueWithIncludeQuery<'a> {
490 pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
491 let include = self.include;
492 let client = self.inner.client;
493 let record = FindUniqueQuery {
494 client,
495 r#where: self.inner.r#where,
496 }.exec().await?;
497 match record {
498 Some(r) => {
499 let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
500 Ok(results.pop())
501 }
502 None => Ok(None),
503 }
504 }
505 }
506 }
507}