1use ferriorm_core::schema::{Field, FieldKind, Model, RelationType, Schema};
4use ferriorm_core::utils::to_snake_case;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8pub struct RelationInfo<'a> {
10 pub field: &'a Field,
11 pub related_model: &'a Model,
12 pub relation_type: RelationType,
13 pub fk_column: String,
15 pub ref_column: String,
17}
18
19pub fn collect_relations<'a>(model: &'a Model, schema: &'a Schema) -> Vec<RelationInfo<'a>> {
21 let mut relations = Vec::new();
22
23 for field in &model.fields {
24 if let Some(rel) = &field.relation {
25 let related = schema.models.iter().find(|m| m.name == rel.related_model);
26 if let Some(related_model) = related {
27 let (fk_column, ref_column) = if !rel.fields.is_empty() {
28 (rel.fields[0].clone(), rel.references[0].clone())
30 } else {
31 find_back_reference(model, related_model)
33 .unwrap_or_else(|| ("id".into(), "id".into()))
34 };
35
36 relations.push(RelationInfo {
37 field,
38 related_model,
39 relation_type: rel.relation_type,
40 fk_column: to_snake_case(&fk_column),
41 ref_column: to_snake_case(&ref_column),
42 });
43 }
44 } else if field.is_list {
45 if let FieldKind::Model(related_name) = &field.field_type {
47 let related = schema.models.iter().find(|m| m.name == *related_name);
48 if let Some(related_model) = related {
49 let (fk_column, ref_column) = find_back_reference(model, related_model)
50 .unwrap_or_else(|| ("id".into(), "id".into()));
51
52 relations.push(RelationInfo {
53 field,
54 related_model,
55 relation_type: RelationType::OneToMany,
56 fk_column: to_snake_case(&fk_column),
57 ref_column: to_snake_case(&ref_column),
58 });
59 }
60 }
61 }
62 }
63
64 relations
65}
66
67fn find_back_reference(parent: &Model, child: &Model) -> Option<(String, String)> {
70 for field in &child.fields {
71 if let Some(rel) = &field.relation
72 && rel.related_model == parent.name
73 && !rel.fields.is_empty()
74 {
75 return Some((rel.fields[0].clone(), rel.references[0].clone()));
76 }
77 }
78 None
79}
80
81pub fn gen_relation_types(model: &Model, schema: &Schema) -> TokenStream {
83 let relations = collect_relations(model, schema);
84
85 if relations.is_empty() {
86 return quote! {};
87 }
88
89 let model_ident = format_ident!("{}", model.name);
90 let include_name = format_ident!("{}Include", model.name);
91 let with_relations_name = format_ident!("{}WithRelations", model.name);
92
93 let include_fields: Vec<TokenStream> = relations
95 .iter()
96 .map(|r| {
97 let name = format_ident!("{}", to_snake_case(&r.field.name));
98 quote! { pub #name: bool }
99 })
100 .collect();
101
102 let with_rel_fields: Vec<TokenStream> = relations
104 .iter()
105 .map(|r| {
106 let name = format_ident!("{}", to_snake_case(&r.field.name));
107 let related_mod = format_ident!("{}", to_snake_case(&r.related_model.name));
108 let related_struct = format_ident!("{}", r.related_model.name);
109
110 match r.relation_type {
111 RelationType::OneToMany | RelationType::ManyToMany => {
112 quote! { pub #name: Option<Vec<super::#related_mod::#related_struct>> }
113 }
114 RelationType::OneToOne | RelationType::ManyToOne => {
115 quote! { pub #name: Option<super::#related_mod::#related_struct> }
116 }
117 }
118 })
119 .collect();
120
121 let load_arms = gen_load_arms(&relations, model);
123
124 quote! {
125 #[derive(Debug, Clone, Default)]
126 pub struct #include_name {
127 #(#include_fields,)*
128 }
129
130 #[derive(Debug, Clone, Serialize, Deserialize)]
131 pub struct #with_relations_name {
132 #[serde(flatten)]
133 pub data: #model_ident,
134 #(#with_rel_fields,)*
135 }
136
137 impl #model_ident {
138 pub(crate) async fn load_relations(
140 records: Vec<#model_ident>,
141 include: &#include_name,
142 client: &DatabaseClient,
143 ) -> Result<Vec<#with_relations_name>, FerriormError> {
144 #load_arms
145 }
146 }
147 }
148}
149
150fn gen_batched_load_many(
153 rel: &RelationInfo<'_>,
154 load_var: &proc_macro2::Ident,
155 field_name: &proc_macro2::Ident,
156 id_source_ident: &proc_macro2::Ident,
157 lookup_col_str: &str,
158 insert_key_ident: &proc_macro2::Ident,
159) -> TokenStream {
160 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
161 let related_struct = format_ident!("{}", rel.related_model.name);
162 let related_table = &rel.related_model.db_name;
163
164 let select_base = format!(
165 r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
166 related_table, lookup_col_str
167 );
168
169 quote! {
170 let mut #load_var: std::collections::HashMap<String, Vec<super::#related_mod::#related_struct>> = std::collections::HashMap::new();
171 if include.#field_name {
172 let ids: Vec<String> = records.iter()
173 .map(|r| r.#id_source_ident.clone())
174 .collect();
175
176 if !ids.is_empty() {
177 macro_rules! build_in_query {
178 ($db:ty) => {{
179 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
180 let mut sep = qb.separated(", ");
181 for id in &ids {
182 sep.push_bind(id.clone());
183 }
184 qb.push(")");
185 qb
186 }};
187 }
188
189 match client {
190 DatabaseClient::Postgres(pool) => {
191 let mut qb = build_in_query!(sqlx::Postgres);
192 let related_rows: Vec<super::#related_mod::#related_struct> =
193 qb.build_query_as().fetch_all(pool).await
194 .map_err(FerriormError::from)?;
195 for row in related_rows {
196 #load_var.entry(row.#insert_key_ident.clone())
197 .or_default()
198 .push(row);
199 }
200 }
201 DatabaseClient::Sqlite(pool) => {
202 let mut qb = build_in_query!(sqlx::Sqlite);
203 let related_rows: Vec<super::#related_mod::#related_struct> =
204 qb.build_query_as().fetch_all(pool).await
205 .map_err(FerriormError::from)?;
206 for row in related_rows {
207 #load_var.entry(row.#insert_key_ident.clone())
208 .or_default()
209 .push(row);
210 }
211 }
212 }
213 }
214 }
215 }
216}
217
218fn gen_batched_load_one(
220 rel: &RelationInfo<'_>,
221 load_var: &proc_macro2::Ident,
222 field_name: &proc_macro2::Ident,
223 id_source_ident: &proc_macro2::Ident,
224 lookup_col_str: &str,
225 insert_key_ident: &proc_macro2::Ident,
226) -> TokenStream {
227 let related_mod = format_ident!("{}", to_snake_case(&rel.related_model.name));
228 let related_struct = format_ident!("{}", rel.related_model.name);
229 let related_table = &rel.related_model.db_name;
230
231 let select_base = format!(
232 r#"SELECT * FROM "{}" WHERE "{}" IN ("#,
233 related_table, lookup_col_str
234 );
235
236 quote! {
237 let mut #load_var: std::collections::HashMap<String, super::#related_mod::#related_struct> = std::collections::HashMap::new();
238 if include.#field_name {
239 let ids: Vec<String> = records.iter()
240 .map(|r| r.#id_source_ident.clone())
241 .collect();
242
243 if !ids.is_empty() {
244 macro_rules! build_in_query {
245 ($db:ty) => {{
246 let mut qb = sqlx::QueryBuilder::<$db>::new(#select_base);
247 let mut sep = qb.separated(", ");
248 for id in &ids {
249 sep.push_bind(id.clone());
250 }
251 qb.push(")");
252 qb
253 }};
254 }
255
256 match client {
257 DatabaseClient::Postgres(pool) => {
258 let mut qb = build_in_query!(sqlx::Postgres);
259 let related_rows: Vec<super::#related_mod::#related_struct> =
260 qb.build_query_as().fetch_all(pool).await
261 .map_err(FerriormError::from)?;
262 for row in related_rows {
263 #load_var.insert(row.#insert_key_ident.clone(), row);
264 }
265 }
266 DatabaseClient::Sqlite(pool) => {
267 let mut qb = build_in_query!(sqlx::Sqlite);
268 let related_rows: Vec<super::#related_mod::#related_struct> =
269 qb.build_query_as().fetch_all(pool).await
270 .map_err(FerriormError::from)?;
271 for row in related_rows {
272 #load_var.insert(row.#insert_key_ident.clone(), row);
273 }
274 }
275 }
276 }
277 }
278 }
279}
280
281fn gen_load_arms(relations: &[RelationInfo<'_>], model: &Model) -> TokenStream {
282 let _model_ident = format_ident!("{}", model.name);
283 let with_relations_name = format_ident!("{}WithRelations", model.name);
284
285 let mut relation_loads = vec![];
286 let mut field_inits = vec![];
287
288 for rel in relations {
289 let field_name = format_ident!("{}", to_snake_case(&rel.field.name));
290 let fk_col_str = &rel.fk_column;
291 let ref_col_str = &rel.ref_column;
292 let fk_col_ident = format_ident!("{}", rel.fk_column);
293 let ref_col_ident = format_ident!("{}", rel.ref_column);
294
295 match rel.relation_type {
296 RelationType::OneToMany | RelationType::ManyToMany => {
297 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
299
300 relation_loads.push(gen_batched_load_many(
301 rel,
302 &load_var,
303 &field_name,
304 &ref_col_ident,
305 fk_col_str,
306 &fk_col_ident,
307 ));
308
309 let ref_col_ident = format_ident!("{}", ref_col_str);
310 field_inits.push(quote! {
311 #field_name: if include.#field_name {
312 Some(#load_var.remove(&r.#ref_col_ident).unwrap_or_default())
313 } else {
314 None
315 }
316 });
317 }
318 RelationType::OneToOne | RelationType::ManyToOne => {
319 let load_var = format_ident!("{}_map", to_snake_case(&rel.field.name));
321 let fk_field = format_ident!("{}", fk_col_str);
322
323 let has_fk = model
325 .fields
326 .iter()
327 .any(|f| to_snake_case(&f.name) == *fk_col_str && f.is_scalar());
328
329 if has_fk {
330 relation_loads.push(gen_batched_load_one(
331 rel,
332 &load_var,
333 &field_name,
334 &fk_field,
335 ref_col_str,
336 &ref_col_ident,
337 ));
338
339 field_inits.push(quote! {
340 #field_name: if include.#field_name {
341 #load_var.remove(&r.#fk_field).map(Some).unwrap_or(None)
342 } else {
343 None
344 }
345 });
346 } else {
347 let ref_col_ident = format_ident!("{}", ref_col_str);
350
351 relation_loads.push(gen_batched_load_one(
352 rel,
353 &load_var,
354 &field_name,
355 &ref_col_ident,
356 fk_col_str,
357 &fk_col_ident,
358 ));
359
360 field_inits.push(quote! {
361 #field_name: if include.#field_name {
362 #load_var.remove(&r.#ref_col_ident)
363 } else {
364 None
365 }
366 });
367 }
368 }
369 }
370 }
371
372 quote! {
373 #(#relation_loads)*
374
375 let mut results = Vec::with_capacity(records.len());
376 for mut r in records {
377 results.push(#with_relations_name {
378 #(#field_inits,)*
379 data: r,
380 });
381 }
382 Ok(results)
383 }
384}
385
386pub fn gen_find_many_include(model: &Model, schema: &Schema) -> TokenStream {
388 let relations = collect_relations(model, schema);
389 if relations.is_empty() {
390 return quote! {};
391 }
392
393 let model_ident = format_ident!("{}", model.name);
394 let include_name = format_ident!("{}Include", model.name);
395 let with_relations_name = format_ident!("{}WithRelations", model.name);
396
397 quote! {
398 impl<'a> FindManyQuery<'a> {
399 pub fn include(mut self, include: #include_name) -> FindManyWithIncludeQuery<'a> {
400 FindManyWithIncludeQuery {
401 inner: self,
402 include,
403 }
404 }
405 }
406
407 pub struct FindManyWithIncludeQuery<'a> {
408 inner: FindManyQuery<'a>,
409 include: #include_name,
410 }
411
412 impl<'a> FindManyWithIncludeQuery<'a> {
413 pub async fn exec(self) -> Result<Vec<#with_relations_name>, FerriormError> {
414 let include = self.include;
415 let client = self.inner.client;
416 let records = FindManyQuery {
417 client,
418 r#where: self.inner.r#where,
419 order_by: self.inner.order_by,
420 skip: self.inner.skip,
421 take: self.inner.take,
422 }.exec().await?;
423 #model_ident::load_relations(records, &include, client).await
424 }
425 }
426
427 impl<'a> FindUniqueQuery<'a> {
428 pub fn include(self, include: #include_name) -> FindUniqueWithIncludeQuery<'a> {
429 FindUniqueWithIncludeQuery {
430 inner: self,
431 include,
432 }
433 }
434 }
435
436 pub struct FindUniqueWithIncludeQuery<'a> {
437 inner: FindUniqueQuery<'a>,
438 include: #include_name,
439 }
440
441 impl<'a> FindUniqueWithIncludeQuery<'a> {
442 pub async fn exec(self) -> Result<Option<#with_relations_name>, FerriormError> {
443 let include = self.include;
444 let client = self.inner.client;
445 let record = FindUniqueQuery {
446 client,
447 r#where: self.inner.r#where,
448 }.exec().await?;
449 match record {
450 Some(r) => {
451 let mut results = #model_ident::load_relations(vec![r], &include, client).await?;
452 Ok(results.pop())
453 }
454 None => Ok(None),
455 }
456 }
457 }
458 }
459}