1use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8pub struct RelationInfo<'a> {
10 pub field: &'a Field,
11 pub related_model: &'a Model,
12 pub relation_type: RelationType,
13 pub fk_column: String,
15 pub ref_column: String,
17}
18
19pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
21 let mut relations = Vec::new();
22
23 for field in &model.fields {
24 if let Some(rel) = &field.relation {
25 let related = schema.models.iter().find(|m| m.name == rel.related_model);
26 if let Some(related_model) = related {
27 let (fk_column, ref_column) = if !rel.fields.is_empty() {
28 (rel.fields[0].clone(), rel.references[0].clone())
30 } else {
31 find_back_reference(model, related_model)
33 .unwrap_or_else(|| ("id".into(), "id".into()))
34 };
35
36 relations.push(RelationInfo {
37 field,
38 related_model,
39 relation_type: rel.relation_type,
40 fk_column: to_snake_case(&fk_column),
41 ref_column: to_snake_case(&ref_column),
42 });
43 }
44 } else if field.is_list {
45 if let FieldKind::Model(related_name) = &field.field_type {
47 let related = schema.models.iter().find(|m| m.name == *related_name);
48 if let Some(related_model) = related {
49 let (fk_column, ref_column) = find_back_reference(model, related_model)
50 .unwrap_or_else(|| ("id".into(), "id".into()));
51
52 relations.push(RelationInfo {
53 field,
54 related_model,
55 relation_type: RelationType::OneToMany,
56 fk_column: to_snake_case(&fk_column),
57 ref_column: to_snake_case(&ref_column),
58 });
59 }
60 }
61 }
62 }
63
64 relations
65}
66
67fn find_back_reference(parent: &Model, child: &Model) -> Option<(String, String)> {
70 for field in &child.fields {
71 if let Some(rel) = &field.relation
72 && rel.related_model == parent.name
73 && !rel.fields.is_empty()
74 {
75 return Some((rel.fields[0].clone(), rel.references[0].clone()));
76 }
77 }
78 None
79}
80
81pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
83 let relations = collect_relations(model, schema);
84
85 if relations.is_empty() {
86 return quote! {};
87 }
88
89 let model_ident = format_ident!("{}", model.name);
90 let include_name = format_ident!("{}Include", model.name);
91 let with_relations_name = format_ident!("{}WithRelations", model.name);
92
93 let include_fields: Vec<TokenStream> = relations
95 .iter()
96 .map(|r| {
97 let name = format_ident!("{}", to_snake_case(&r.field.name));
98 quote! { pub #name: bool }
99 })
100 .collect();
101
102 let with_rel_fields: Vec<TokenStream> = relations
104 .iter()
105 .map(|r| {
106 let name = format_ident!("{}", to_snake_case(&r.field.name));
107 let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
108 let related_struct = format_ident!("{}", r.related_model.name);
109
110 match r.relation_type {
111 RelationType::OneToMany | RelationType::ManyToMany => {
112 quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
113 }
114 RelationType::OneToOne | RelationType::ManyToOne => {
115 quote! { pub #name: Option<super::#related_mod::#related_struct> }
116 }
117 }
118 })
119 .collect();
120
121 let load_arms = gen_load_arms(&relations, model);
123
124 quote! {
125 #[derive(Debug, Clone, Default)]
126 pub struct #include_name {
127 #(#include_fields,)*
128 }
129
130 #[derive(Debug, Clone, Serialize, Deserialize)]
131 pub struct #with_relations_name {
132 #[serde(flatten)]
133 pub data: #model_ident,
134 #(#with_rel_fields,)*
135 }
136
137 impl #model_ident {
138 pub(crate) async fn load_relations(
140 records: Vec<#model_ident>,
141 include: &#include_name,
142 client: &DatabaseClient,
143 ) -> Result<Vec<#with_relations_name>, FerriormError> {
144 #load_arms
145 }
146 }
147 }
148}
149
150fn gen_batched_load_many(
153 rel: &RelationInfo<'_>,
154 load_var: &proc_macro2::Ident,
155 field_name: &proc_macro2::Ident,
156 id_source_ident: &proc_macro2::Ident,
157 lookup_col_str: &str,
158 insert_key_ident: &proc_macro2::Ident,
159 fk_optional: bool,
160) -> TokenStream {
161 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
162 let related_struct = format_ident!("{}", rel.related_model.name);
163 let related_table = &rel.related_model.db_name;
164
165 let select_base = format!(
166 r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
167 related_table, lookup_col_str
168 );
169
170 let insert_row_code = if fk_optional {
172 quote! {
173 if let Some(key) = row.#insert_key_ident.clone() {
174 #load_var.entry(key).or_default().push(row);
175 }
176 }
177 } else {
178 quote! {
179 #load_var.entry(row.#insert_key_ident.clone()).or_default().push(row);
180 }
181 };
182
183 quote! {
184 let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
185 if include.#field_name {
186 let ids: Vec<String> = records.iter()
187 .map(|r| r.#id_source_ident.clone())
188 .collect();
189
190 if !ids.is_empty() {
191 macro_rules! build_in_query {
192 ($db:ty) => {{
193 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
194 let mut sep = qb.separated(", ");
195 for id in &ids {
196 sep.push_bind(id.clone());
197 }
198 qb.push(")");
199 qb
200 }};
201 }
202
203 macro_rules! insert_rows {
204 ($rows:expr) => {
205 for row in $rows {
206 #insert_row_code
207 }
208 };
209 }
210
211 match client {
212 DatabaseClient::Postgres(pool) => {
213 let mut qb = build_in_query!(sqlx::Postgres);
214 let related_rows: Vec<super::#related_mod::#related_struct> =
215 qb.build_query_as().fetch_all(pool).await
216 .map_err(FerriormError::from)?;
217 insert_rows!(related_rows);
218 }
219 DatabaseClient::Sqlite(pool) => {
220 let mut qb = build_in_query!(sqlx::Sqlite);
221 let related_rows: Vec<super::#related_mod::#related_struct> =
222 qb.build_query_as().fetch_all(pool).await
223 .map_err(FerriormError::from)?;
224 insert_rows!(related_rows);
225 }
226 }
227 }
228 }
229 }
230}
231
232fn gen_batched_load_one(
234 rel: &RelationInfo<'_>,
235 load_var: &proc_macro2::Ident,
236 field_name: &proc_macro2::Ident,
237 id_source_ident: &proc_macro2::Ident,
238 lookup_col_str: &str,
239 insert_key_ident: &proc_macro2::Ident,
240 fk_is_optional: bool,
241) -> TokenStream {
242 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
243 let related_struct = format_ident!("{}", rel.related_model.name);
244 let related_table = &rel.related_model.db_name;
245
246 let select_base = format!(
247 r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
248 related_table, lookup_col_str
249 );
250
251 let ids_collect = if fk_is_optional {
253 quote! {
254 let ids: Vec<String> = records.iter()
255 .filter_map(|r| r.#id_source_ident.clone())
256 .collect();
257 }
258 } else {
259 quote! {
260 let ids: Vec<String> = records.iter()
261 .map(|r| r.#id_source_ident.clone())
262 .collect();
263 }
264 };
265
266 quote! {
267 let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
268 if include.#field_name {
269 #ids_collect
270
271 if !ids.is_empty() {
272 macro_rules! build_in_query {
273 ($db:ty) => {{
274 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
275 let mut sep = qb.separated(", ");
276 for id in &ids {
277 sep.push_bind(id.clone());
278 }
279 qb.push(")");
280 qb
281 }};
282 }
283
284 match client {
285 DatabaseClient::Postgres(pool) => {
286 let mut qb = build_in_query!(sqlx::Postgres);
287 let related_rows: Vec<super::#related_mod::#related_struct> =
288 qb.build_query_as().fetch_all(pool).await
289 .map_err(FerriormError::from)?;
290 for row in related_rows {
291 #load_var.insert(row.#insert_key_ident.clone(), row);
292 }
293 }
294 DatabaseClient::Sqlite(pool) => {
295 let mut qb = build_in_query!(sqlx::Sqlite);
296 let related_rows: Vec<super::#related_mod::#related_struct> =
297 qb.build_query_as().fetch_all(pool).await
298 .map_err(FerriormError::from)?;
299 for row in related_rows {
300 #load_var.insert(row.#insert_key_ident.clone(), row);
301 }
302 }
303 }
304 }
305 }
306 }
307}
308
309fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
310 let _model_ident = format_ident!("{}", model.name);
311 let with_relations_name = format_ident!("{}WithRelations", model.name);
312
313 let mut relation_loads = vec![];
314 let mut field_inits = vec![];
315
316 for rel in relations {
317 let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
318 let fk_col_str = &rel.fk_column;
319 let ref_col_str = &rel.ref_column;
320 let fk_col_ident = format_ident!("{}", rel.fk_column);
321 let ref_col_ident = format_ident!("{}", rel.ref_column);
322
323 match rel.relation_type {
324 RelationType::OneToMany | RelationType::ManyToMany => {
325 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
327
328 let child_fk_optional = rel
330 .related_model
331 .fields
332 .iter()
333 .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_optional);
334
335 relation_loads.push(gen_batched_load_many(
336 rel,
337 &load_var,
338 &field_name,
339 &ref_col_ident,
340 fk_col_str,
341 &fk_col_ident,
342 child_fk_optional,
343 ));
344
345 let ref_col_ident = format_ident!("{}", ref_col_str);
346 field_inits.push(quote! {
347 #field_name: if include.#field_name {
348 Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
349 } else {
350 None
351 }
352 });
353 }
354 RelationType::OneToOne | RelationType::ManyToOne => {
355 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
357 let fk_field = format_ident!("{}", fk_col_str);
358
359 let fk_model_field = model
361 .fields
362 .iter()
363 .find(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
364 let has_fk = fk_model_field.is_some();
365 let fk_is_optional = fk_model_field.map(|f| f.is_optional).unwrap_or(false);
366
367 if has_fk {
368 relation_loads.push(gen_batched_load_one(
369 rel,
370 &load_var,
371 &field_name,
372 &fk_field,
373 ref_col_str,
374 &ref_col_ident,
375 fk_is_optional,
376 ));
377
378 if fk_is_optional {
379 field_inits.push(quote! {
380 #field_name: if include.#field_name {
381 r.#fk_field.as_ref().and_then(|fk| #load_var.remove(fk))
382 } else {
383 None
384 }
385 });
386 } else {
387 field_inits.push(quote! {
388 #field_name: if include.#field_name {
389 #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
390 } else {
391 None
392 }
393 });
394 }
395 } else {
396 let ref_col_ident = format_ident!("{}", ref_col_str);
399
400 relation_loads.push(gen_batched_load_one(
401 rel,
402 &load_var,
403 &field_name,
404 &ref_col_ident,
405 fk_col_str,
406 &fk_col_ident,
407 false,
408 ));
409
410 field_inits.push(quote! {
411 #field_name: if include.#field_name {
412 #load_var.remove(&r.#ref_col_ident)
413 } else {
414 None
415 }
416 });
417 }
418 }
419 }
420 }
421
422 quote! {
423 #(#relation_loads)*
424
425 let mut results = Vec::with_capacity(records.len());
426 for r in records {
427 results.push(#with_relations_name {
428 #(#field_inits,)*
429 data: r,
430 });
431 }
432 Ok(results)
433 }
434}
435
436pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
438 let relations = collect_relations(model, schema);
439 if relations.is_empty() {
440 return quote! {};
441 }
442
443 let model_ident = format_ident!("{}", model.name);
444 let include_name = format_ident!("{}Include", model.name);
445 let with_relations_name = format_ident!("{}WithRelations", model.name);
446
447 quote! {
448 impl<'a> FindManyQuery<'a> {
449 pub fn include(self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
450 FindManyWithIncludeQuery {
451 inner: self,
452 include,
453 }
454 }
455 }
456
457 pub struct FindManyWithIncludeQuery<'a> {
458 inner: FindManyQuery<'a>,
459 include: #include_name,
460 }
461
462 impl<'a> FindManyWithIncludeQuery<'a> {
463 pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
464 let include = self.include;
465 let client = self.inner.client;
466 let records = FindManyQuery {
467 client,
468 r#where: self.inner.r#where,
469 order_by: self.inner.order_by,
470 skip: self.inner.skip,
471 take: self.inner.take,
472 }.exec().await?;
473 #model_ident::load_relations(records, &include, client).await
474 }
475 }
476
477 impl<'a> FindUniqueQuery<'a> {
478 pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
479 FindUniqueWithIncludeQuery {
480 inner: self,
481 include,
482 }
483 }
484 }
485
486 pub struct FindUniqueWithIncludeQuery<'a> {
487 inner: FindUniqueQuery<'a>,
488 include: #include_name,
489 }
490
491 impl<'a> FindUniqueWithIncludeQuery<'a> {
492 pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
493 let include = self.include;
494 let client = self.inner.client;
495 let record = FindUniqueQuery {
496 client,
497 r#where: self.inner.r#where,
498 }.exec().await?;
499 match record {
500 Some(r) => {
501 let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
502 Ok(results.pop())
503 }
504 None => Ok(None),
505 }
506 }
507 }
508 }
509}