1use crate::{
2 error::{OrmError, OrmResult},
3 loading::{
4 batch_loader::BatchLoader,
5 optimizer::{QueryOptimizer, QueryPlan, QueryNode, PlanExecutor, OptimizationStrategy},
6 query_deduplicator::QueryDeduplicator,
7 },
8 relationships::RelationshipType,
9};
10use serde_json::Value as JsonValue;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct EagerLoadConfig {
16 pub max_batch_size: usize,
18 pub deduplicate_queries: bool,
20 pub max_depth: usize,
22 pub enable_parallelism: bool,
24 pub query_timeout_ms: u64,
26}
27
28impl Default for EagerLoadConfig {
29 fn default() -> Self {
30 Self {
31 max_batch_size: 100,
32 deduplicate_queries: true,
33 max_depth: 10,
34 enable_parallelism: true,
35 query_timeout_ms: 30000,
36 }
37 }
38}
39
40#[derive(Debug)]
42pub struct EagerLoadResult {
43 pub data: HashMap<JsonValue, JsonValue>,
45 pub stats: EagerLoadStats,
47 pub optimizations: Vec<OptimizationStrategy>,
49}
50
51#[derive(Debug, Clone)]
53pub struct EagerLoadStats {
54 pub execution_time_ms: u64,
56 pub query_count: usize,
58 pub records_loaded: usize,
60 pub depth_loaded: usize,
62 pub cache_hit_ratio: f64,
64}
65
66impl Default for EagerLoadStats {
67 fn default() -> Self {
68 Self {
69 execution_time_ms: 0,
70 query_count: 0,
71 records_loaded: 0,
72 depth_loaded: 0,
73 cache_hit_ratio: 0.0,
74 }
75 }
76}
77
78pub struct OptimizedEagerLoader {
80 batch_loader: BatchLoader,
81 query_optimizer: QueryOptimizer,
82 plan_executor: PlanExecutor,
83 query_deduplicator: QueryDeduplicator,
84 config: EagerLoadConfig,
85}
86
87impl OptimizedEagerLoader {
88 pub fn new() -> Self {
90 let config = EagerLoadConfig::default();
91 let batch_loader = BatchLoader::new();
92 Self::with_config(config, batch_loader)
93 }
94
95 pub fn with_config(config: EagerLoadConfig, batch_loader: BatchLoader) -> Self {
97 let query_optimizer = QueryOptimizer::new();
98 let plan_executor = PlanExecutor::with_config(
99 batch_loader.clone(),
100 if config.enable_parallelism { 10 } else { 1 },
101 std::time::Duration::from_millis(config.query_timeout_ms),
102 );
103 let query_deduplicator = QueryDeduplicator::new();
104
105 Self {
106 batch_loader,
107 query_optimizer,
108 plan_executor,
109 query_deduplicator,
110 config,
111 }
112 }
113
114 pub async fn load_with_relationships(
116 &mut self,
117 root_table: &str,
118 root_ids: Vec<JsonValue>,
119 relationships: &str,
120 connection: &sqlx::PgPool,
121 ) -> OrmResult<EagerLoadResult> {
122 let start_time = std::time::Instant::now();
123
124 let mut plan = self.build_query_plan(root_table, &root_ids, relationships)?;
126
127 let optimization_strategies = self.query_optimizer.optimize_plan(&mut plan)?;
129
130 let execution_result = self.plan_executor.execute_plan(&plan, connection).await?;
132
133 let processed_data = self.process_execution_results(execution_result.results, &root_ids)?;
135
136 let execution_time = start_time.elapsed();
138 let stats = EagerLoadStats {
139 execution_time_ms: execution_time.as_millis() as u64,
140 query_count: execution_result.stats.query_count,
141 records_loaded: execution_result.stats.rows_fetched,
142 depth_loaded: plan.max_depth,
143 cache_hit_ratio: self.calculate_cache_hit_ratio().await,
144 };
145
146 Ok(EagerLoadResult {
147 data: processed_data,
148 stats,
149 optimizations: optimization_strategies,
150 })
151 }
152
153 pub async fn load_with_strategy(
155 &mut self,
156 root_table: &str,
157 root_ids: Vec<JsonValue>,
158 relationships: &str,
159 strategy: OptimizationStrategy,
160 connection: &sqlx::PgPool,
161 ) -> OrmResult<EagerLoadResult> {
162 let mut plan = self.build_query_plan(root_table, &root_ids, relationships)?;
164
165 match strategy {
167 OptimizationStrategy::IncreaseParallelism => {
168 self.apply_parallel_optimization(&mut plan)?;
169 }
170 OptimizationStrategy::ReduceBatchSize => {
171 self.apply_batch_size_optimization(&mut plan)?;
172 }
173 OptimizationStrategy::ReorderPhases => {
174 plan.build_execution_phases()?;
175 }
176 _ => {
177 let _strategies = self.query_optimizer.optimize_plan(&mut plan)?;
179 }
180 }
181
182 let execution_result = self.plan_executor.execute_plan(&plan, connection).await?;
184 let processed_data = self.process_execution_results(execution_result.results, &root_ids)?;
185
186 let stats = EagerLoadStats {
187 execution_time_ms: 0, query_count: execution_result.stats.query_count,
189 records_loaded: execution_result.stats.rows_fetched,
190 depth_loaded: plan.max_depth,
191 cache_hit_ratio: self.calculate_cache_hit_ratio().await,
192 };
193
194 Ok(EagerLoadResult {
195 data: processed_data,
196 stats,
197 optimizations: vec![strategy],
198 })
199 }
200
201 fn build_query_plan(
203 &self,
204 root_table: &str,
205 root_ids: &[JsonValue],
206 relationships: &str,
207 ) -> OrmResult<QueryPlan> {
208 let mut plan = QueryPlan::new();
209 let mut node_counter = 0;
210
211 let root_node_id = format!("root_{}", node_counter);
213 node_counter += 1;
214
215 let mut root_node = QueryNode::root(root_node_id.clone(), root_table.to_string());
216 root_node.set_estimated_rows(root_ids.len());
217 plan.add_node(root_node);
218
219 if !relationships.is_empty() {
221 self.build_relationship_nodes(
222 &mut plan,
223 &root_node_id,
224 relationships,
225 1, &mut node_counter,
227 )?;
228 }
229
230 plan.build_execution_phases()?;
232
233 Ok(plan)
234 }
235
236 fn build_relationship_nodes(
238 &self,
239 plan: &mut QueryPlan,
240 parent_node_id: &str,
241 relationships: &str,
242 depth: usize,
243 node_counter: &mut usize,
244 ) -> OrmResult<()> {
245 if depth > self.config.max_depth {
246 return Ok(()); }
248
249 let parts: Vec<&str> = relationships.split(',').collect();
251
252 for part in parts {
253 let relation_chain: Vec<&str> = part.split('.').collect();
254 self.build_relation_chain(
255 plan,
256 parent_node_id,
257 &relation_chain,
258 depth,
259 node_counter,
260 )?;
261 }
262
263 Ok(())
264 }
265
266 fn build_relation_chain(
268 &self,
269 plan: &mut QueryPlan,
270 parent_node_id: &str,
271 chain: &[&str],
272 depth: usize,
273 node_counter: &mut usize,
274 ) -> OrmResult<()> {
275 if chain.is_empty() || depth > self.config.max_depth {
276 return Ok(());
277 }
278
279 let relation_name = chain[0];
280 let node_id = format!("{}_{}", relation_name, *node_counter);
281 *node_counter += 1;
282
283 let (table_name, relationship_type, foreign_key) = self.get_relationship_info(relation_name)?;
285
286 let mut node = QueryNode::child(
288 node_id.clone(),
289 table_name,
290 parent_node_id.to_string(),
291 relationship_type,
292 foreign_key,
293 );
294 node.set_depth(depth);
295 node.set_estimated_rows(std::cmp::min(1000, self.config.max_batch_size)); plan.add_node(node);
298
299 if chain.len() > 1 {
301 self.build_relation_chain(
302 plan,
303 &node_id,
304 &chain[1..],
305 depth + 1,
306 node_counter,
307 )?;
308 }
309
310 Ok(())
311 }
312
313 fn get_relationship_info(&self, relation: &str) -> OrmResult<(String, RelationshipType, String)> {
315 match relation {
318 "posts" => Ok(("posts".to_string(), RelationshipType::HasMany, "user_id".to_string())),
319 "comments" => Ok(("comments".to_string(), RelationshipType::HasMany, "post_id".to_string())),
320 "user" => Ok(("users".to_string(), RelationshipType::BelongsTo, "user_id".to_string())),
321 "profile" => Ok(("profiles".to_string(), RelationshipType::HasOne, "user_id".to_string())),
322 _ => {
323 Ok((
325 format!("{}s", relation),
326 RelationshipType::HasMany,
327 format!("{}_id", relation),
328 ))
329 }
330 }
331 }
332
333 fn process_execution_results(
335 &self,
336 results: HashMap<String, Vec<JsonValue>>,
337 root_ids: &[JsonValue],
338 ) -> OrmResult<HashMap<JsonValue, JsonValue>> {
339 let mut processed = HashMap::new();
340
341 for (i, root_id) in root_ids.iter().enumerate() {
344 let mut entity_data = serde_json::json!({
345 "id": root_id,
346 "relationships": {}
347 });
348
349 for (node_id, node_results) in &results {
351 if node_id.starts_with("root_") {
352 continue; }
354
355 if let Some(obj) = entity_data.as_object_mut() {
357 if let Some(relationships) = obj.get_mut("relationships").and_then(|r| r.as_object_mut()) {
358 relationships.insert(node_id.clone(), serde_json::json!(node_results));
359 }
360 }
361 }
362
363 processed.insert(root_id.clone(), entity_data);
364 }
365
366 Ok(processed)
367 }
368
369 fn apply_parallel_optimization(&self, plan: &mut QueryPlan) -> OrmResult<()> {
371 for node in plan.nodes.values_mut() {
372 if node.constraints.is_empty() {
373 node.set_parallel_safe(true);
374 }
375 }
376 plan.build_execution_phases()?;
377 Ok(())
378 }
379
380 fn apply_batch_size_optimization(&self, plan: &mut QueryPlan) -> OrmResult<()> {
382 for node in plan.nodes.values_mut() {
384 if node.estimated_rows > 5000 {
385 node.set_estimated_rows(node.estimated_rows / 2);
386 }
387 }
388 Ok(())
389 }
390
391 async fn calculate_cache_hit_ratio(&self) -> f64 {
393 let stats = self.batch_loader.cache_stats().await;
394 if stats.total_cached_records > 0 {
395 0.75 } else {
397 0.0
398 }
399 }
400
401 pub fn config(&self) -> &EagerLoadConfig {
403 &self.config
404 }
405
406 pub fn update_config(&mut self, config: EagerLoadConfig) {
408 self.config = config;
409 }
410
411 pub async fn clear_caches(&self) {
413 self.batch_loader.clear_cache().await;
414 }
415}
416
417impl Default for OptimizedEagerLoader {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use serde_json::json;
427
428 #[test]
429 fn test_eager_load_config_default() {
430 let config = EagerLoadConfig::default();
431 assert_eq!(config.max_batch_size, 100);
432 assert!(config.deduplicate_queries);
433 assert_eq!(config.max_depth, 10);
434 assert!(config.enable_parallelism);
435 }
436
437 #[test]
438 fn test_build_query_plan() {
439 let loader = OptimizedEagerLoader::new();
440 let root_ids = vec![json!(1), json!(2)];
441
442 let plan = loader.build_query_plan("users", &root_ids, "posts.comments").unwrap();
443
444 assert_eq!(plan.roots.len(), 1);
445 assert!(plan.nodes.len() >= 1); assert_eq!(plan.max_depth, 2); }
448
449 #[test]
450 fn test_relationship_info_mapping() {
451 let loader = OptimizedEagerLoader::new();
452
453 let (table, rel_type, fk) = loader.get_relationship_info("posts").unwrap();
454 assert_eq!(table, "posts");
455 assert_eq!(rel_type, RelationshipType::HasMany);
456 assert_eq!(fk, "user_id");
457 }
458}