1use std::collections::HashMap;
4use sqlx::{Pool, Postgres, Row};
5
6use crate::error::ModelResult;
7use crate::model::Model;
8use crate::query::QueryBuilder;
9use crate::relationships::eager_loading::EagerLoader;
10use crate::relationships::constraints::RelationshipConstraintBuilder;
11use crate::loading::{OptimizedEagerLoader, EagerLoadConfig};
12
13pub trait QueryBuilderWithMethods<M> {
15 fn with(self, relation: &str) -> QueryBuilderWithEagerLoading<M>;
17
18 fn with_where<F>(self, relation: &str, constraint: F) -> QueryBuilderWithEagerLoading<M>
20 where
21 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static;
22
23 fn with_when(self, condition: bool, relation: &str) -> QueryBuilderWithEagerLoading<M>;
25
26 fn with_count(self, relation: &str) -> QueryBuilderWithEagerLoading<M>;
28
29 fn with_count_where<F>(self, alias: &str, relation: &str, constraint: F) -> QueryBuilderWithEagerLoading<M>
31 where
32 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static;
33}
34
35pub struct QueryBuilderWithEagerLoading<M> {
37 query: QueryBuilder<M>,
39 eager_loader: EagerLoader,
41 count_relations: HashMap<String, String>, optimization_enabled: bool,
45 optimized_loader: Option<OptimizedEagerLoader>,
47 batch_size: Option<usize>,
49}
50
51impl<M> QueryBuilderWithEagerLoading<M> {
52 pub fn new(query: QueryBuilder<M>) -> Self {
54 Self {
55 query,
56 eager_loader: EagerLoader::new(),
57 count_relations: HashMap::new(),
58 optimization_enabled: false,
59 optimized_loader: None,
60 batch_size: None,
61 }
62 }
63
64 pub fn with(mut self, relation: &str) -> Self {
66 self.eager_loader = self.eager_loader.with(relation);
67 self
68 }
69
70 pub fn with_where<F>(mut self, relation: &str, constraint: F) -> Self
72 where
73 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
74 {
75 self.eager_loader = self.eager_loader.with_constraint(relation, constraint);
76 self
77 }
78
79 pub fn with_when(self, condition: bool, relation: &str) -> Self {
81 if condition {
82 self.with(relation)
83 } else {
84 self
85 }
86 }
87
88 pub fn with_count(mut self, relation: &str) -> Self {
90 self.count_relations.insert(format!("{}_count", relation), relation.to_string());
91 self
92 }
93
94 pub fn with_count_where<F>(mut self, alias: &str, relation: &str, _constraint: F) -> Self
96 where
97 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
98 {
99 self.count_relations.insert(alias.to_string(), relation.to_string());
101 self
102 }
103
104 pub async fn get(mut self, pool: &Pool<Postgres>) -> ModelResult<Vec<M>>
106 where
107 M: Model + Send + Sync,
108 {
109 let mut models = self.query.clone().get(pool).await?;
111
112 if models.is_empty() {
113 return Ok(models);
114 }
115
116 if self.optimization_enabled && self.optimized_loader.is_some() {
118 let loaded_relations = self.eager_loader.loaded_relations();
120 let relationship_names = loaded_relations
121 .iter()
122 .map(|s| s.as_str())
123 .collect::<Vec<&str>>()
124 .join(",");
125 if !relationship_names.is_empty() {
126 let root_ids: Vec<serde_json::Value> = models
127 .iter()
128 .filter_map(|m| m.primary_key())
129 .map(|pk| serde_json::Value::String(pk.to_string()))
130 .collect();
131
132 if let Some(ref mut loader) = self.optimized_loader {
133 let _result = loader.load_with_relationships(
134 M::table_name(),
135 root_ids,
136 &relationship_names,
137 pool,
138 ).await.map_err(|e| crate::error::ModelError::Database(e.to_string()))?;
139
140 }
143 }
144 } else {
145 self.eager_loader.load_for_models(pool, &models).await?;
147 }
148
149 if !self.count_relations.is_empty() {
151 self.load_relationship_counts(pool, &mut models).await?;
152 }
153
154 self.attach_relationships_to_models(&mut models)?;
156
157 Ok(models)
158 }
159
160 pub async fn first(self, pool: &Pool<Postgres>) -> ModelResult<Option<M>>
162 where
163 M: Model + Send + Sync,
164 {
165 let models = self.get(pool).await?;
166 Ok(models.into_iter().next())
167 }
168
169 pub async fn first_or_fail(self, pool: &Pool<Postgres>) -> ModelResult<M>
171 where
172 M: Model + Send + Sync,
173 {
174 self.first(pool).await?
175 .ok_or_else(|| crate::error::ModelError::NotFound(
176 format!("No {} found", M::table_name())
177 ))
178 }
179
180 pub fn where_eq<V>(mut self, field: &str, value: V) -> Self
182 where
183 V: ToString + Send + Sync + 'static,
184 {
185 self.query = self.query.where_eq(field, value.to_string());
186 self
187 }
188
189 pub fn where_condition<V>(mut self, field: &str, operator: &str, value: V) -> Self
191 where
192 V: ToString + Send + Sync + 'static,
193 {
194 self.query = self.query.where_condition(field, operator, value.to_string());
197 self
198 }
199
200 pub fn order_by(mut self, field: &str) -> Self {
202 self.query = self.query.order_by(field);
203 self
204 }
205
206 pub fn order_by_desc(mut self, field: &str) -> Self {
208 self.query = self.query.order_by_desc(field);
209 self
210 }
211
212 pub fn limit(mut self, count: i64) -> Self {
214 self.query = self.query.limit(count);
215 self
216 }
217
218 pub fn offset(mut self, count: i64) -> Self {
220 self.query = self.query.offset(count);
221 self
222 }
223
224 pub fn optimize_loading(mut self) -> Self {
226 self.optimization_enabled = true;
227 self.optimized_loader = Some(OptimizedEagerLoader::new());
228 self
229 }
230
231 pub fn optimize_loading_with_config(mut self, config: EagerLoadConfig) -> Self {
233 self.optimization_enabled = true;
234 let batch_loader = crate::loading::BatchLoader::with_config(
235 crate::loading::BatchConfig::default()
236 );
237 self.optimized_loader = Some(OptimizedEagerLoader::with_config(config, batch_loader));
238 self
239 }
240
241 pub fn batch_size(mut self, size: usize) -> Self {
243 self.batch_size = Some(size);
244
245 if let Some(ref mut loader) = self.optimized_loader {
247 let mut config = loader.config().clone();
248 config.max_batch_size = size;
249 loader.update_config(config);
250 }
251
252 self
253 }
254
255 pub fn parallel_loading(mut self, enabled: bool) -> Self {
257 if let Some(ref mut loader) = self.optimized_loader {
259 let mut config = loader.config().clone();
260 config.enable_parallelism = enabled;
261 loader.update_config(config);
262 } else if enabled {
263 let mut config = EagerLoadConfig::default();
265 config.enable_parallelism = true;
266 self = self.optimize_loading_with_config(config);
267 }
268
269 self
270 }
271
272 pub fn max_depth(mut self, depth: usize) -> Self {
274 if let Some(ref mut loader) = self.optimized_loader {
276 let mut config = loader.config().clone();
277 config.max_depth = depth;
278 loader.update_config(config);
279 }
280
281 self
282 }
283
284 async fn load_relationship_counts(&self, pool: &Pool<Postgres>, models: &mut [M]) -> ModelResult<()>
286 where
287 M: Model + Send + Sync,
288 {
289 for (_alias, relation) in &self.count_relations {
290 let model_ids: Vec<String> = models
292 .iter()
293 .filter_map(|m| m.primary_key().map(|pk| pk.to_string()))
294 .collect();
295
296 if model_ids.is_empty() {
297 continue;
298 }
299
300 let (count_query, params) = self.build_secure_count_query(relation, &model_ids)?;
302
303 let mut query = sqlx::query(&count_query);
305 for param in params {
306 query = query.bind(param);
307 }
308
309 let rows = query.fetch_all(pool).await
310 .map_err(|e| crate::error::ModelError::Database(e.to_string()))?;
311
312 let mut counts: HashMap<String, i64> = HashMap::new();
314 for row in rows {
315 let parent_id: String = row.get("parent_id");
316 let count: i64 = row.get("count");
317 counts.insert(parent_id, count);
318 }
319
320 }
324
325 Ok(())
326 }
327
328 fn build_secure_count_query(&self, relation: &str, parent_ids: &[String]) -> ModelResult<(String, Vec<String>)> {
330 use crate::security::{escape_identifier, validate_identifier};
331
332 validate_identifier(relation).map_err(|_|
334 crate::error::ModelError::Validation(
335 format!("Invalid relationship name: {}", relation)
336 )
337 )?;
338
339 let (table_name, foreign_key) = match relation {
341 "posts" => ("posts", "user_id"),
342 "comments" => ("comments", "post_id"),
343 "profile" => ("profiles", "user_id"),
344 _ => {
345 validate_identifier(relation).map_err(|_|
348 crate::error::ModelError::Validation(
349 format!("Invalid table name derived from relation: {}", relation)
350 )
351 )?;
352 (relation, "parent_id")
353 }
354 };
355
356 validate_identifier(table_name)?;
358 validate_identifier(foreign_key)?;
359
360 let escaped_table = escape_identifier(table_name);
362 let escaped_foreign_key = escape_identifier(foreign_key);
363
364 let placeholders: Vec<String> = (1..=parent_ids.len())
366 .map(|i| format!("${}", i))
367 .collect();
368 let placeholders_str = placeholders.join(", ");
369
370 let query = format!(
371 "SELECT {} as parent_id, COUNT(*) as count FROM {} WHERE {} IN ({}) GROUP BY {}",
372 escaped_foreign_key, escaped_table, escaped_foreign_key, placeholders_str, escaped_foreign_key
373 );
374
375 Ok((query, parent_ids.to_vec()))
377 }
378
379 fn attach_relationships_to_models(&self, models: &mut [M]) -> ModelResult<()>
381 where
382 M: Model + Send + Sync,
383 {
384 for model in models {
395 if let Some(pk) = model.primary_key() {
396 let pk_str = pk.to_string();
397
398 for relation in self.eager_loader.loaded_relations() {
400 if let Some(_data) = self.eager_loader.get_loaded_data(relation, &pk_str) {
401 }
404 }
405 }
406 }
407
408 Ok(())
409 }
410}
411
412impl<M> QueryBuilderWithMethods<M> for QueryBuilder<M>
414where
415 M: Model + Send + Sync,
416{
417 fn with(self, relation: &str) -> QueryBuilderWithEagerLoading<M> {
418 QueryBuilderWithEagerLoading::new(self).with(relation)
419 }
420
421 fn with_where<F>(self, relation: &str, constraint: F) -> QueryBuilderWithEagerLoading<M>
422 where
423 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
424 {
425 QueryBuilderWithEagerLoading::new(self).with_where(relation, constraint)
426 }
427
428 fn with_when(self, condition: bool, relation: &str) -> QueryBuilderWithEagerLoading<M> {
429 QueryBuilderWithEagerLoading::new(self).with_when(condition, relation)
430 }
431
432 fn with_count(self, relation: &str) -> QueryBuilderWithEagerLoading<M> {
433 QueryBuilderWithEagerLoading::new(self).with_count(relation)
434 }
435
436 fn with_count_where<F>(self, alias: &str, relation: &str, constraint: F) -> QueryBuilderWithEagerLoading<M>
437 where
438 F: FnOnce(RelationshipConstraintBuilder) -> RelationshipConstraintBuilder + 'static,
439 {
440 QueryBuilderWithEagerLoading::new(self).with_count_where(alias, relation, constraint)
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::query::QueryBuilder;
448 use crate::relationships::eager_loading::EagerLoadSpec;
449
450 #[test]
451 fn test_query_builder_with_trait_exists() {
452 let _query = QueryBuilder::<()>::new();
455
456 assert!(true); }
464
465 #[test]
466 fn test_query_builder_with_eager_loading_struct() {
467 let base_query = QueryBuilder::<()>::new();
469 let _with_query = QueryBuilderWithEagerLoading::new(base_query);
470
471 assert!(true); }
473
474 #[test]
475 fn test_eager_loader_creation() {
476 let loader = EagerLoader::new();
478 let _loader_with_relation = loader.with("posts");
479
480 assert!(true); }
482
483 #[test]
484 fn test_relationship_constraint_builder_creation() {
485 let _builder = RelationshipConstraintBuilder::new()
487 .where_eq("status", "published")
488 .where_gt("views", 1000)
489 .order_by_desc("created_at")
490 .limit(5);
491
492 assert!(true); }
494
495 #[test]
496 fn test_eager_loading_spec_creation() {
497 let spec = EagerLoadSpec {
499 relation: "posts".to_string(),
500 constraints: None,
501 };
502
503 assert_eq!(spec.relation, "posts");
504 assert!(spec.constraints.is_none());
505 }
506
507 #[test]
508 fn test_api_compatibility() {
509 let _query = QueryBuilder::<()>::new();
514
515 let _loader = EagerLoader::new();
517 let _constraint_builder = RelationshipConstraintBuilder::new();
518 let _with_eager_loading = QueryBuilderWithEagerLoading::new(QueryBuilder::<()>::new());
519
520 assert!(true);
522 }
523}