use crate::{
error::OrmResult,
loading::{
batch_loader::BatchLoader,
optimizer::{OptimizationStrategy, PlanExecutor, QueryNode, QueryOptimizer, QueryPlan},
query_deduplicator::QueryDeduplicator,
},
relationships::RelationshipType,
};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct EagerLoadConfig {
pub max_batch_size: usize,
pub deduplicate_queries: bool,
pub max_depth: usize,
pub enable_parallelism: bool,
pub query_timeout_ms: u64,
}
impl Default for EagerLoadConfig {
fn default() -> Self {
Self {
max_batch_size: 100,
deduplicate_queries: true,
max_depth: 10,
enable_parallelism: true,
query_timeout_ms: 30000,
}
}
}
#[derive(Debug)]
pub struct EagerLoadResult {
pub data: HashMap<JsonValue, JsonValue>,
pub stats: EagerLoadStats,
pub optimizations: Vec<OptimizationStrategy>,
}
#[derive(Debug, Clone)]
pub struct EagerLoadStats {
pub execution_time_ms: u64,
pub query_count: usize,
pub records_loaded: usize,
pub depth_loaded: usize,
pub cache_hit_ratio: f64,
}
impl Default for EagerLoadStats {
fn default() -> Self {
Self {
execution_time_ms: 0,
query_count: 0,
records_loaded: 0,
depth_loaded: 0,
cache_hit_ratio: 0.0,
}
}
}
pub struct OptimizedEagerLoader {
batch_loader: BatchLoader,
query_optimizer: QueryOptimizer,
plan_executor: PlanExecutor,
_query_deduplicator: QueryDeduplicator,
config: EagerLoadConfig,
}
impl OptimizedEagerLoader {
pub fn new() -> Self {
let config = EagerLoadConfig::default();
let batch_loader = BatchLoader::new();
Self::with_config(config, batch_loader)
}
pub fn with_config(config: EagerLoadConfig, batch_loader: BatchLoader) -> Self {
let query_optimizer = QueryOptimizer::new();
let plan_executor = PlanExecutor::with_config(
batch_loader.clone(),
if config.enable_parallelism { 10 } else { 1 },
std::time::Duration::from_millis(config.query_timeout_ms),
);
let query_deduplicator = QueryDeduplicator::new();
Self {
batch_loader,
query_optimizer,
plan_executor,
_query_deduplicator: query_deduplicator,
config,
}
}
pub async fn load_with_relationships(
&mut self,
root_table: &str,
root_ids: Vec<JsonValue>,
relationships: &str,
connection: &sqlx::PgPool,
) -> OrmResult<EagerLoadResult> {
let start_time = std::time::Instant::now();
let mut plan = self.build_query_plan(root_table, &root_ids, relationships)?;
let optimization_strategies = self.query_optimizer.optimize_plan(&mut plan)?;
let execution_result = self.plan_executor.execute_plan(&plan, connection).await?;
let processed_data = self.process_execution_results(execution_result.results, &root_ids)?;
let execution_time = start_time.elapsed();
let stats = EagerLoadStats {
execution_time_ms: execution_time.as_millis() as u64,
query_count: execution_result.stats.query_count,
records_loaded: execution_result.stats.rows_fetched,
depth_loaded: plan.max_depth,
cache_hit_ratio: self.calculate_cache_hit_ratio().await,
};
Ok(EagerLoadResult {
data: processed_data,
stats,
optimizations: optimization_strategies,
})
}
pub async fn load_with_strategy(
&mut self,
root_table: &str,
root_ids: Vec<JsonValue>,
relationships: &str,
strategy: OptimizationStrategy,
connection: &sqlx::PgPool,
) -> OrmResult<EagerLoadResult> {
let mut plan = self.build_query_plan(root_table, &root_ids, relationships)?;
match strategy {
OptimizationStrategy::IncreaseParallelism => {
self.apply_parallel_optimization(&mut plan)?;
}
OptimizationStrategy::ReduceBatchSize => {
self.apply_batch_size_optimization(&mut plan)?;
}
OptimizationStrategy::ReorderPhases => {
plan.build_execution_phases()?;
}
_ => {
let _strategies = self.query_optimizer.optimize_plan(&mut plan)?;
}
}
let execution_result = self.plan_executor.execute_plan(&plan, connection).await?;
let processed_data = self.process_execution_results(execution_result.results, &root_ids)?;
let stats = EagerLoadStats {
execution_time_ms: 0, query_count: execution_result.stats.query_count,
records_loaded: execution_result.stats.rows_fetched,
depth_loaded: plan.max_depth,
cache_hit_ratio: self.calculate_cache_hit_ratio().await,
};
Ok(EagerLoadResult {
data: processed_data,
stats,
optimizations: vec![strategy],
})
}
fn build_query_plan(
&self,
root_table: &str,
root_ids: &[JsonValue],
relationships: &str,
) -> OrmResult<QueryPlan> {
let mut plan = QueryPlan::new();
let mut node_counter = 0;
let root_node_id = format!("root_{}", node_counter);
node_counter += 1;
let mut root_node = QueryNode::root(root_node_id.clone(), root_table.to_string());
root_node.set_estimated_rows(root_ids.len());
plan.add_node(root_node);
if !relationships.is_empty() {
self.build_relationship_nodes(
&mut plan,
&root_node_id,
relationships,
1, &mut node_counter,
)?;
}
plan.build_execution_phases()?;
Ok(plan)
}
fn build_relationship_nodes(
&self,
plan: &mut QueryPlan,
parent_node_id: &str,
relationships: &str,
depth: usize,
node_counter: &mut usize,
) -> OrmResult<()> {
if depth > self.config.max_depth {
return Ok(()); }
let parts: Vec<&str> = relationships.split(',').collect();
for part in parts {
let relation_chain: Vec<&str> = part.split('.').collect();
self.build_relation_chain(plan, parent_node_id, &relation_chain, depth, node_counter)?;
}
Ok(())
}
fn build_relation_chain(
&self,
plan: &mut QueryPlan,
parent_node_id: &str,
chain: &[&str],
depth: usize,
node_counter: &mut usize,
) -> OrmResult<()> {
if chain.is_empty() || depth > self.config.max_depth {
return Ok(());
}
let relation_name = chain[0];
let node_id = format!("{}_{}", relation_name, *node_counter);
*node_counter += 1;
let (table_name, relationship_type, foreign_key) =
self.get_relationship_info(relation_name)?;
let mut node = QueryNode::child(
node_id.clone(),
table_name,
parent_node_id.to_string(),
relationship_type,
foreign_key,
);
node.set_depth(depth);
node.set_estimated_rows(std::cmp::min(1000, self.config.max_batch_size));
plan.add_node(node);
if chain.len() > 1 {
self.build_relation_chain(plan, &node_id, &chain[1..], depth + 1, node_counter)?;
}
Ok(())
}
fn get_relationship_info(
&self,
relation: &str,
) -> OrmResult<(String, RelationshipType, String)> {
match relation {
"posts" => Ok((
"posts".to_string(),
RelationshipType::HasMany,
"user_id".to_string(),
)),
"comments" => Ok((
"comments".to_string(),
RelationshipType::HasMany,
"post_id".to_string(),
)),
"user" => Ok((
"users".to_string(),
RelationshipType::BelongsTo,
"user_id".to_string(),
)),
"profile" => Ok((
"profiles".to_string(),
RelationshipType::HasOne,
"user_id".to_string(),
)),
_ => {
Ok((
format!("{}s", relation),
RelationshipType::HasMany,
format!("{}_id", relation),
))
}
}
}
fn process_execution_results(
&self,
results: HashMap<String, Vec<JsonValue>>,
root_ids: &[JsonValue],
) -> OrmResult<HashMap<JsonValue, JsonValue>> {
let mut processed = HashMap::new();
for root_id in root_ids.iter() {
let mut entity_data = serde_json::json!({
"id": root_id,
"relationships": {}
});
for (node_id, node_results) in &results {
if node_id.starts_with("root_") {
continue; }
if let Some(obj) = entity_data.as_object_mut() {
if let Some(relationships) =
obj.get_mut("relationships").and_then(|r| r.as_object_mut())
{
relationships.insert(node_id.clone(), serde_json::json!(node_results));
}
}
}
processed.insert(root_id.clone(), entity_data);
}
Ok(processed)
}
fn apply_parallel_optimization(&self, plan: &mut QueryPlan) -> OrmResult<()> {
for node in plan.nodes.values_mut() {
if node.constraints.is_empty() {
node.set_parallel_safe(true);
}
}
plan.build_execution_phases()?;
Ok(())
}
fn apply_batch_size_optimization(&self, plan: &mut QueryPlan) -> OrmResult<()> {
for node in plan.nodes.values_mut() {
if node.estimated_rows > 5000 {
node.set_estimated_rows(node.estimated_rows / 2);
}
}
Ok(())
}
async fn calculate_cache_hit_ratio(&self) -> f64 {
let stats = self.batch_loader.cache_stats().await;
if stats.total_cached_records > 0 {
0.75 } else {
0.0
}
}
pub fn config(&self) -> &EagerLoadConfig {
&self.config
}
pub fn update_config(&mut self, config: EagerLoadConfig) {
self.config = config;
}
pub async fn clear_caches(&self) {
self.batch_loader.clear_cache().await;
}
}
impl Default for OptimizedEagerLoader {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_eager_load_config_default() {
let config = EagerLoadConfig::default();
assert_eq!(config.max_batch_size, 100);
assert!(config.deduplicate_queries);
assert_eq!(config.max_depth, 10);
assert!(config.enable_parallelism);
}
#[test]
fn test_build_query_plan() {
let loader = OptimizedEagerLoader::new();
let root_ids = vec![json!(1), json!(2)];
let plan = loader
.build_query_plan("users", &root_ids, "posts.comments")
.unwrap();
assert_eq!(plan.roots.len(), 1);
assert!(plan.nodes.len() >= 1); assert_eq!(plan.max_depth, 2); }
#[test]
fn test_relationship_info_mapping() {
let loader = OptimizedEagerLoader::new();
let (table, rel_type, fk) = loader.get_relationship_info("posts").unwrap();
assert_eq!(table, "posts");
assert_eq!(rel_type, RelationshipType::HasMany);
assert_eq!(fk, "user_id");
}
}