1use crate::{
2 error::OrmResult,
3 loading::{
4 batch_loader::BatchLoader,
5 optimizer::{OptimizationStrategy, PlanExecutor, QueryNode, QueryOptimizer, QueryPlan},
6 query_deduplicator::QueryDeduplicator,
7 },
8 relationships::RelationshipType,
9};
10use serde_json::Value as JsonValue;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct EagerLoadConfig {
16 pub max_batch_size: usize,
18 pub deduplicate_queries: bool,
20 pub max_depth: usize,
22 pub enable_parallelism: bool,
24 pub query_timeout_ms: u64,
26}
27
28impl Default for EagerLoadConfig {
29 fn default() -> Self {
30 Self {
31 max_batch_size: 100,
32 deduplicate_queries: true,
33 max_depth: 10,
34 enable_parallelism: true,
35 query_timeout_ms: 30000,
36 }
37 }
38}
39
40#[derive(Debug)]
42pub struct EagerLoadResult {
43 pub data: HashMap<JsonValue, JsonValue>,
45 pub stats: EagerLoadStats,
47 pub optimizations: Vec<OptimizationStrategy>,
49}
50
51#[derive(Debug, Clone)]
53pub struct EagerLoadStats {
54 pub execution_time_ms: u64,
56 pub query_count: usize,
58 pub records_loaded: usize,
60 pub depth_loaded: usize,
62 pub cache_hit_ratio: f64,
64}
65
66impl Default for EagerLoadStats {
67 fn default() -> Self {
68 Self {
69 execution_time_ms: 0,
70 query_count: 0,
71 records_loaded: 0,
72 depth_loaded: 0,
73 cache_hit_ratio: 0.0,
74 }
75 }
76}
77
78pub struct OptimizedEagerLoader {
80 batch_loader: BatchLoader,
81 query_optimizer: QueryOptimizer,
82 plan_executor: PlanExecutor,
83 _query_deduplicator: QueryDeduplicator,
84 config: EagerLoadConfig,
85}
86
87impl OptimizedEagerLoader {
88 pub fn new() -> Self {
90 let config = EagerLoadConfig::default();
91 let batch_loader = BatchLoader::new();
92 Self::with_config(config, batch_loader)
93 }
94
95 pub fn with_config(config: EagerLoadConfig, batch_loader: BatchLoader) -> Self {
97 let query_optimizer = QueryOptimizer::new();
98 let plan_executor = PlanExecutor::with_config(
99 batch_loader.clone(),
100 if config.enable_parallelism { 10 } else { 1 },
101 std::time::Duration::from_millis(config.query_timeout_ms),
102 );
103 let query_deduplicator = QueryDeduplicator::new();
104
105 Self {
106 batch_loader,
107 query_optimizer,
108 plan_executor,
109 _query_deduplicator: query_deduplicator,
110 config,
111 }
112 }
113
114 pub async fn load_with_relationships(
116 &mut self,
117 root_table: &str,
118 root_ids: Vec<JsonValue>,
119 relationships: &str,
120 connection: &sqlx::PgPool,
121 ) -> OrmResult<EagerLoadResult> {
122 let start_time = std::time::Instant::now();
123
124 let mut plan = self.build_query_plan(root_table, &root_ids, relationships)?;
126
127 let optimization_strategies = self.query_optimizer.optimize_plan(&mut plan)?;
129
130 let execution_result = self.plan_executor.execute_plan(&plan, connection).await?;
132
133 let processed_data = self.process_execution_results(execution_result.results, &root_ids)?;
135
136 let execution_time = start_time.elapsed();
138 let stats = EagerLoadStats {
139 execution_time_ms: execution_time.as_millis() as u64,
140 query_count: execution_result.stats.query_count,
141 records_loaded: execution_result.stats.rows_fetched,
142 depth_loaded: plan.max_depth,
143 cache_hit_ratio: self.calculate_cache_hit_ratio().await,
144 };
145
146 Ok(EagerLoadResult {
147 data: processed_data,
148 stats,
149 optimizations: optimization_strategies,
150 })
151 }
152
153 pub async fn load_with_strategy(
155 &mut self,
156 root_table: &str,
157 root_ids: Vec<JsonValue>,
158 relationships: &str,
159 strategy: OptimizationStrategy,
160 connection: &sqlx::PgPool,
161 ) -> OrmResult<EagerLoadResult> {
162 let mut plan = self.build_query_plan(root_table, &root_ids, relationships)?;
164
165 match strategy {
167 OptimizationStrategy::IncreaseParallelism => {
168 self.apply_parallel_optimization(&mut plan)?;
169 }
170 OptimizationStrategy::ReduceBatchSize => {
171 self.apply_batch_size_optimization(&mut plan)?;
172 }
173 OptimizationStrategy::ReorderPhases => {
174 plan.build_execution_phases()?;
175 }
176 _ => {
177 let _strategies = self.query_optimizer.optimize_plan(&mut plan)?;
179 }
180 }
181
182 let execution_result = self.plan_executor.execute_plan(&plan, connection).await?;
184 let processed_data = self.process_execution_results(execution_result.results, &root_ids)?;
185
186 let stats = EagerLoadStats {
187 execution_time_ms: 0, query_count: execution_result.stats.query_count,
189 records_loaded: execution_result.stats.rows_fetched,
190 depth_loaded: plan.max_depth,
191 cache_hit_ratio: self.calculate_cache_hit_ratio().await,
192 };
193
194 Ok(EagerLoadResult {
195 data: processed_data,
196 stats,
197 optimizations: vec![strategy],
198 })
199 }
200
201 fn build_query_plan(
203 &self,
204 root_table: &str,
205 root_ids: &[JsonValue],
206 relationships: &str,
207 ) -> OrmResult<QueryPlan> {
208 let mut plan = QueryPlan::new();
209 let mut node_counter = 0;
210
211 let root_node_id = format!("root_{}", node_counter);
213 node_counter += 1;
214
215 let mut root_node = QueryNode::root(root_node_id.clone(), root_table.to_string());
216 root_node.set_estimated_rows(root_ids.len());
217 plan.add_node(root_node);
218
219 if !relationships.is_empty() {
221 self.build_relationship_nodes(
222 &mut plan,
223 &root_node_id,
224 relationships,
225 1, &mut node_counter,
227 )?;
228 }
229
230 plan.build_execution_phases()?;
232
233 Ok(plan)
234 }
235
236 fn build_relationship_nodes(
238 &self,
239 plan: &mut QueryPlan,
240 parent_node_id: &str,
241 relationships: &str,
242 depth: usize,
243 node_counter: &mut usize,
244 ) -> OrmResult<()> {
245 if depth > self.config.max_depth {
246 return Ok(()); }
248
249 let parts: Vec<&str> = relationships.split(',').collect();
251
252 for part in parts {
253 let relation_chain: Vec<&str> = part.split('.').collect();
254 self.build_relation_chain(plan, parent_node_id, &relation_chain, depth, node_counter)?;
255 }
256
257 Ok(())
258 }
259
260 fn build_relation_chain(
262 &self,
263 plan: &mut QueryPlan,
264 parent_node_id: &str,
265 chain: &[&str],
266 depth: usize,
267 node_counter: &mut usize,
268 ) -> OrmResult<()> {
269 if chain.is_empty() || depth > self.config.max_depth {
270 return Ok(());
271 }
272
273 let relation_name = chain[0];
274 let node_id = format!("{}_{}", relation_name, *node_counter);
275 *node_counter += 1;
276
277 let (table_name, relationship_type, foreign_key) =
279 self.get_relationship_info(relation_name)?;
280
281 let mut node = QueryNode::child(
283 node_id.clone(),
284 table_name,
285 parent_node_id.to_string(),
286 relationship_type,
287 foreign_key,
288 );
289 node.set_depth(depth);
290 node.set_estimated_rows(std::cmp::min(1000, self.config.max_batch_size)); plan.add_node(node);
293
294 if chain.len() > 1 {
296 self.build_relation_chain(plan, &node_id, &chain[1..], depth + 1, node_counter)?;
297 }
298
299 Ok(())
300 }
301
302 fn get_relationship_info(
304 &self,
305 relation: &str,
306 ) -> OrmResult<(String, RelationshipType, String)> {
307 match relation {
310 "posts" => Ok((
311 "posts".to_string(),
312 RelationshipType::HasMany,
313 "user_id".to_string(),
314 )),
315 "comments" => Ok((
316 "comments".to_string(),
317 RelationshipType::HasMany,
318 "post_id".to_string(),
319 )),
320 "user" => Ok((
321 "users".to_string(),
322 RelationshipType::BelongsTo,
323 "user_id".to_string(),
324 )),
325 "profile" => Ok((
326 "profiles".to_string(),
327 RelationshipType::HasOne,
328 "user_id".to_string(),
329 )),
330 _ => {
331 Ok((
333 format!("{}s", relation),
334 RelationshipType::HasMany,
335 format!("{}_id", relation),
336 ))
337 }
338 }
339 }
340
341 fn process_execution_results(
343 &self,
344 results: HashMap<String, Vec<JsonValue>>,
345 root_ids: &[JsonValue],
346 ) -> OrmResult<HashMap<JsonValue, JsonValue>> {
347 let mut processed = HashMap::new();
348
349 for root_id in root_ids.iter() {
352 let mut entity_data = serde_json::json!({
353 "id": root_id,
354 "relationships": {}
355 });
356
357 for (node_id, node_results) in &results {
359 if node_id.starts_with("root_") {
360 continue; }
362
363 if let Some(obj) = entity_data.as_object_mut() {
365 if let Some(relationships) =
366 obj.get_mut("relationships").and_then(|r| r.as_object_mut())
367 {
368 relationships.insert(node_id.clone(), serde_json::json!(node_results));
369 }
370 }
371 }
372
373 processed.insert(root_id.clone(), entity_data);
374 }
375
376 Ok(processed)
377 }
378
379 fn apply_parallel_optimization(&self, plan: &mut QueryPlan) -> OrmResult<()> {
381 for node in plan.nodes.values_mut() {
382 if node.constraints.is_empty() {
383 node.set_parallel_safe(true);
384 }
385 }
386 plan.build_execution_phases()?;
387 Ok(())
388 }
389
390 fn apply_batch_size_optimization(&self, plan: &mut QueryPlan) -> OrmResult<()> {
392 for node in plan.nodes.values_mut() {
394 if node.estimated_rows > 5000 {
395 node.set_estimated_rows(node.estimated_rows / 2);
396 }
397 }
398 Ok(())
399 }
400
401 async fn calculate_cache_hit_ratio(&self) -> f64 {
403 let stats = self.batch_loader.cache_stats().await;
404 if stats.total_cached_records > 0 {
405 0.75 } else {
407 0.0
408 }
409 }
410
411 pub fn config(&self) -> &EagerLoadConfig {
413 &self.config
414 }
415
416 pub fn update_config(&mut self, config: EagerLoadConfig) {
418 self.config = config;
419 }
420
421 pub async fn clear_caches(&self) {
423 self.batch_loader.clear_cache().await;
424 }
425}
426
427impl Default for OptimizedEagerLoader {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use serde_json::json;
437
438 #[test]
439 fn test_eager_load_config_default() {
440 let config = EagerLoadConfig::default();
441 assert_eq!(config.max_batch_size, 100);
442 assert!(config.deduplicate_queries);
443 assert_eq!(config.max_depth, 10);
444 assert!(config.enable_parallelism);
445 }
446
447 #[test]
448 fn test_build_query_plan() {
449 let loader = OptimizedEagerLoader::new();
450 let root_ids = vec![json!(1), json!(2)];
451
452 let plan = loader
453 .build_query_plan("users", &root_ids, "posts.comments")
454 .unwrap();
455
456 assert_eq!(plan.roots.len(), 1);
457 assert!(plan.nodes.len() >= 1); assert_eq!(plan.max_depth, 2); }
460
461 #[test]
462 fn test_relationship_info_mapping() {
463 let loader = OptimizedEagerLoader::new();
464
465 let (table, rel_type, fk) = loader.get_relationship_info("posts").unwrap();
466 assert_eq!(table, "posts");
467 assert_eq!(rel_type, RelationshipType::HasMany);
468 assert_eq!(fk, "user_id");
469 }
470}