1use sqlx::{Pool, Postgres, Row};
4use std::collections::HashMap;
5
6use crate::error::ModelResult;
7use crate::loading::{EagerLoadConfig, OptimizedEagerLoader};
8use crate::model::Model;
9use crate::query::QueryBuilder;
10use crate::relationships::constraints::RelationshipConstraintBuilder;
11use crate::relationships::eager_loading::EagerLoader;
12
13pub trait QueryBuilderWithMethods<M> {
15 fn with(self, relation: &str) -> QueryBuilderWithEagerLoading<M>;
17
18 fn with_where<F>(self, relation: &str, constraint: F) -> QueryBuilderWithEagerLoading<M>
20 where
21 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static;
22
23 fn with_when(self, condition: bool, relation: &str) -> QueryBuilderWithEagerLoading<M>;
25
26 fn with_count(self, relation: &str) -> QueryBuilderWithEagerLoading<M>;
28
29 fn with_count_where<F>(
31 self,
32 alias: &str,
33 relation: &str,
34 constraint: F,
35 ) -> QueryBuilderWithEagerLoading<M>
36 where
37 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static;
38}
39
40pub struct QueryBuilderWithEagerLoading<M> {
42 query: QueryBuilder<M>,
44 eager_loader: EagerLoader,
46 count_relations: HashMap<String, String>, optimization_enabled: bool,
50 optimized_loader: Option<OptimizedEagerLoader>,
52 batch_size: Option<usize>,
54}
55
56impl<M> QueryBuilderWithEagerLoading<M> {
57 pub fn new(query: QueryBuilder<M>) -> Self {
59 Self {
60 query,
61 eager_loader: EagerLoader::new(),
62 count_relations: HashMap::new(),
63 optimization_enabled: false,
64 optimized_loader: None,
65 batch_size: None,
66 }
67 }
68
69 pub fn with(mut self, relation: &str) -> Self {
71 self.eager_loader = self.eager_loader.with(relation);
72 self
73 }
74
75 pub fn with_where<F>(mut self, relation: &str, constraint: F) -> Self
77 where
78 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
79 {
80 self.eager_loader = self.eager_loader.with_constraint(relation, constraint);
81 self
82 }
83
84 pub fn with_when(self, condition: bool, relation: &str) -> Self {
86 if condition {
87 self.with(relation)
88 } else {
89 self
90 }
91 }
92
93 pub fn with_count(mut self, relation: &str) -> Self {
95 self.count_relations
96 .insert(format!("{}_count", relation), relation.to_string());
97 self
98 }
99
100 pub fn with_count_where<F>(mut self, alias: &str, relation: &str, _constraint: F) -> Self
102 where
103 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
104 {
105 self.count_relations
107 .insert(alias.to_string(), relation.to_string());
108 self
109 }
110
111 pub async fn get(mut self, pool: &Pool<Postgres>) -> ModelResult<Vec<M>>
113 where
114 M: Model + Send + Sync,
115 {
116 let mut models = self.query.clone().get(pool).await?;
118
119 if models.is_empty() {
120 return Ok(models);
121 }
122
123 if self.optimization_enabled && self.optimized_loader.is_some() {
125 let loaded_relations = self.eager_loader.loaded_relations();
127 let relationship_names = loaded_relations
128 .iter()
129 .map(|s| s.as_str())
130 .collect::<Vec<&str>>()
131 .join(",");
132 if !relationship_names.is_empty() {
133 let root_ids: Vec<serde_json::Value> = models
134 .iter()
135 .filter_map(|m| m.primary_key())
136 .map(|pk| serde_json::Value::String(pk.to_string()))
137 .collect();
138
139 if let Some(ref mut loader) = self.optimized_loader {
140 let _result = loader
141 .load_with_relationships(
142 M::table_name(),
143 root_ids,
144 &relationship_names,
145 pool,
146 )
147 .await
148 .map_err(|e| crate::error::ModelError::Database(e.to_string()))?;
149
150 }
153 }
154 } else {
155 self.eager_loader.load_for_models(pool, &models).await?;
157 }
158
159 if !self.count_relations.is_empty() {
161 self.load_relationship_counts(pool, &mut models).await?;
162 }
163
164 self.attach_relationships_to_models(&mut models)?;
166
167 Ok(models)
168 }
169
170 pub async fn first(self, pool: &Pool<Postgres>) -> ModelResult<Option<M>>
172 where
173 M: Model + Send + Sync,
174 {
175 let models = self.get(pool).await?;
176 Ok(models.into_iter().next())
177 }
178
179 pub async fn first_or_fail(self, pool: &Pool<Postgres>) -> ModelResult<M>
181 where
182 M: Model + Send + Sync,
183 {
184 self.first(pool).await?.ok_or_else(|| {
185 crate::error::ModelError::NotFound(format!("No {} found", M::table_name()))
186 })
187 }
188
189 pub fn where_eq<V>(mut self, field: &str, value: V) -> Self
191 where
192 V: ToString + Send + Sync + 'static,
193 {
194 self.query = self.query.where_eq(field, value.to_string());
195 self
196 }
197
198 pub fn where_condition<V>(mut self, field: &str, operator: &str, value: V) -> Self
200 where
201 V: ToString + Send + Sync + 'static,
202 {
203 self.query = self
206 .query
207 .where_condition(field, operator, value.to_string());
208 self
209 }
210
211 pub fn order_by(mut self, field: &str) -> Self {
213 self.query = self.query.order_by(field);
214 self
215 }
216
217 pub fn order_by_desc(mut self, field: &str) -> Self {
219 self.query = self.query.order_by_desc(field);
220 self
221 }
222
223 pub fn limit(mut self, count: i64) -> Self {
225 self.query = self.query.limit(count);
226 self
227 }
228
229 pub fn offset(mut self, count: i64) -> Self {
231 self.query = self.query.offset(count);
232 self
233 }
234
235 pub fn optimize_loading(mut self) -> Self {
237 self.optimization_enabled = true;
238 self.optimized_loader = Some(OptimizedEagerLoader::new());
239 self
240 }
241
242 pub fn optimize_loading_with_config(mut self, config: EagerLoadConfig) -> Self {
244 self.optimization_enabled = true;
245 let batch_loader =
246 crate::loading::BatchLoader::with_config(crate::loading::BatchConfig::default());
247 self.optimized_loader = Some(OptimizedEagerLoader::with_config(config, batch_loader));
248 self
249 }
250
251 pub fn batch_size(mut self, size: usize) -> Self {
253 self.batch_size = Some(size);
254
255 if let Some(ref mut loader) = self.optimized_loader {
257 let mut config = loader.config().clone();
258 config.max_batch_size = size;
259 loader.update_config(config);
260 }
261
262 self
263 }
264
265 pub fn parallel_loading(mut self, enabled: bool) -> Self {
267 if let Some(ref mut loader) = self.optimized_loader {
269 let mut config = loader.config().clone();
270 config.enable_parallelism = enabled;
271 loader.update_config(config);
272 } else if enabled {
273 let mut config = EagerLoadConfig::default();
275 config.enable_parallelism = true;
276 self = self.optimize_loading_with_config(config);
277 }
278
279 self
280 }
281
282 pub fn max_depth(mut self, depth: usize) -> Self {
284 if let Some(ref mut loader) = self.optimized_loader {
286 let mut config = loader.config().clone();
287 config.max_depth = depth;
288 loader.update_config(config);
289 }
290
291 self
292 }
293
294 async fn load_relationship_counts(
296 &self,
297 pool: &Pool<Postgres>,
298 models: &mut [M],
299 ) -> ModelResult<()>
300 where
301 M: Model + Send + Sync,
302 {
303 for relation in self.count_relations.values() {
304 let model_ids: Vec<String> = models
306 .iter()
307 .filter_map(|m| m.primary_key().map(|pk| pk.to_string()))
308 .collect();
309
310 if model_ids.is_empty() {
311 continue;
312 }
313
314 let (count_query, params) = self.build_secure_count_query(relation, &model_ids)?;
316
317 let mut query = sqlx::query(&count_query);
319 for param in params {
320 query = query.bind(param);
321 }
322
323 let rows = query
324 .fetch_all(pool)
325 .await
326 .map_err(|e| crate::error::ModelError::Database(e.to_string()))?;
327
328 let mut counts: HashMap<String, i64> = HashMap::new();
330 for row in rows {
331 let parent_id: String = row.get("parent_id");
332 let count: i64 = row.get("count");
333 counts.insert(parent_id, count);
334 }
335
336 }
340
341 Ok(())
342 }
343
344 fn build_secure_count_query(
346 &self,
347 relation: &str,
348 parent_ids: &[String],
349 ) -> ModelResult<(String, Vec<String>)> {
350 use crate::security::{escape_identifier, validate_identifier};
351
352 validate_identifier(relation).map_err(|_| {
354 crate::error::ModelError::Validation(format!("Invalid relationship name: {}", relation))
355 })?;
356
357 let (table_name, foreign_key) = match relation {
359 "posts" => ("posts", "user_id"),
360 "comments" => ("comments", "post_id"),
361 "profile" => ("profiles", "user_id"),
362 _ => {
363 validate_identifier(relation).map_err(|_| {
366 crate::error::ModelError::Validation(format!(
367 "Invalid table name derived from relation: {}",
368 relation
369 ))
370 })?;
371 (relation, "parent_id")
372 }
373 };
374
375 validate_identifier(table_name)?;
377 validate_identifier(foreign_key)?;
378
379 let escaped_table = escape_identifier(table_name);
381 let escaped_foreign_key = escape_identifier(foreign_key);
382
383 let placeholders: Vec<String> = (1..=parent_ids.len()).map(|i| format!("${}", i)).collect();
385 let placeholders_str = placeholders.join(", ");
386
387 let query = format!(
388 "SELECT {} as parent_id, COUNT(*) as count FROM {} WHERE {} IN ({}) GROUP BY {}",
389 escaped_foreign_key,
390 escaped_table,
391 escaped_foreign_key,
392 placeholders_str,
393 escaped_foreign_key
394 );
395
396 Ok((query, parent_ids.to_vec()))
398 }
399
400 fn attach_relationships_to_models(&self, models: &mut [M]) -> ModelResult<()>
402 where
403 M: Model + Send + Sync,
404 {
405 for model in models {
416 if let Some(pk) = model.primary_key() {
417 let pk_str = pk.to_string();
418
419 for relation in self.eager_loader.loaded_relations() {
421 if let Some(_data) = self.eager_loader.get_loaded_data(relation, &pk_str) {
422 }
425 }
426 }
427 }
428
429 Ok(())
430 }
431}
432
433impl<M> QueryBuilderWithMethods<M> for QueryBuilder<M>
435where
436 M: Model + Send + Sync,
437{
438 fn with(self, relation: &str) -> QueryBuilderWithEagerLoading<M> {
439 QueryBuilderWithEagerLoading::new(self).with(relation)
440 }
441
442 fn with_where<F>(self, relation: &str, constraint: F) -> QueryBuilderWithEagerLoading<M>
443 where
444 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
445 {
446 QueryBuilderWithEagerLoading::new(self).with_where(relation, constraint)
447 }
448
449 fn with_when(self, condition: bool, relation: &str) -> QueryBuilderWithEagerLoading<M> {
450 QueryBuilderWithEagerLoading::new(self).with_when(condition, relation)
451 }
452
453 fn with_count(self, relation: &str) -> QueryBuilderWithEagerLoading<M> {
454 QueryBuilderWithEagerLoading::new(self).with_count(relation)
455 }
456
457 fn with_count_where<F>(
458 self,
459 alias: &str,
460 relation: &str,
461 constraint: F,
462 ) -> QueryBuilderWithEagerLoading<M>
463 where
464 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
465 {
466 QueryBuilderWithEagerLoading::new(self).with_count_where(alias, relation, constraint)
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::query::QueryBuilder;
474 use crate::relationships::eager_loading::EagerLoadSpec;
475
476 #[test]
477 fn test_query_builder_with_trait_exists() {
478 let _query = QueryBuilder::<()>::new();
481
482 assert!(true); }
490
491 #[test]
492 fn test_query_builder_with_eager_loading_struct() {
493 let base_query = QueryBuilder::<()>::new();
495 let _with_query = QueryBuilderWithEagerLoading::new(base_query);
496
497 assert!(true); }
499
500 #[test]
501 fn test_eager_loader_creation() {
502 let loader = EagerLoader::new();
504 let _loader_with_relation = loader.with("posts");
505
506 assert!(true); }
508
509 #[test]
510 fn test_relationship_constraint_builder_creation() {
511 let _builder = RelationshipConstraintBuilder::new()
513 .where_eq("status", "published")
514 .where_gt("views", 1000)
515 .order_by_desc("created_at")
516 .limit(5);
517
518 assert!(true); }
520
521 #[test]
522 fn test_eager_loading_spec_creation() {
523 let spec = EagerLoadSpec {
525 relation: "posts".to_string(),
526 constraints: None,
527 };
528
529 assert_eq!(spec.relation, "posts");
530 assert!(spec.constraints.is_none());
531 }
532
533 #[test]
534 fn test_api_compatibility() {
535 let _query = QueryBuilder::<()>::new();
540
541 let _loader = EagerLoader::new();
543 let _constraint_builder = RelationshipConstraintBuilder::new();
544 let _with_eager_loading = QueryBuilderWithEagerLoading::new(QueryBuilder::<()>::new());
545
546 assert!(true);
548 }
549}