elif_orm/loading/optimizer/
executor.rs1use crate::{
2 error::{OrmError, OrmResult},
3 loading::batch_loader::BatchLoader,
4};
5use super::plan::{QueryPlan, QueryNode};
6use serde_json::Value as JsonValue;
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9use tokio::task::JoinHandle;
10
11#[derive(Debug)]
13pub struct ExecutionResult {
14 pub results: HashMap<String, Vec<JsonValue>>,
16 pub stats: ExecutionStats,
18 pub errors: Vec<OrmError>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ExecutionStats {
25 pub total_duration: Duration,
27 pub phase_durations: Vec<Duration>,
29 pub query_count: usize,
31 pub rows_fetched: usize,
33 pub parallel_phases: usize,
35 pub avg_query_time: Duration,
37 pub peak_memory_mb: Option<f64>,
39}
40
41impl ExecutionStats {
42 pub fn new() -> Self {
43 Self {
44 total_duration: Duration::from_secs(0),
45 phase_durations: Vec::new(),
46 query_count: 0,
47 rows_fetched: 0,
48 parallel_phases: 0,
49 avg_query_time: Duration::from_secs(0),
50 peak_memory_mb: None,
51 }
52 }
53
54 pub fn calculate_averages(&mut self) {
56 if self.query_count > 0 {
57 self.avg_query_time = self.total_duration / self.query_count as u32;
58 }
59 }
60
61 pub fn add_phase_duration(&mut self, duration: Duration) {
63 self.phase_durations.push(duration);
64 self.total_duration += duration;
65 }
66}
67
68impl Default for ExecutionStats {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74pub struct PlanExecutor {
76 batch_loader: BatchLoader,
78 max_parallel_tasks: usize,
80 query_timeout: Duration,
82}
83
84impl PlanExecutor {
85 pub fn new(batch_loader: BatchLoader) -> Self {
87 Self {
88 batch_loader,
89 max_parallel_tasks: 10, query_timeout: Duration::from_secs(30),
91 }
92 }
93
94 pub fn with_config(
96 batch_loader: BatchLoader,
97 max_parallel_tasks: usize,
98 query_timeout: Duration,
99 ) -> Self {
100 Self {
101 batch_loader,
102 max_parallel_tasks,
103 query_timeout,
104 }
105 }
106
107 pub async fn execute_plan(
109 &self,
110 plan: &QueryPlan,
111 connection: &sqlx::PgPool,
112 ) -> OrmResult<ExecutionResult> {
113 let start_time = Instant::now();
114 let mut results: HashMap<String, Vec<JsonValue>> = HashMap::new();
115 let mut stats = ExecutionStats::new();
116 let mut errors = Vec::new();
117
118 for (phase_index, phase) in plan.execution_phases.iter().enumerate() {
120 let phase_start = Instant::now();
121
122 if phase.len() == 1 {
123 let node_id = &phase[0];
125 if let Some(node) = plan.nodes.get(node_id) {
126 match self.execute_node_query(node, connection).await {
127 Ok(node_results) => {
128 stats.query_count += 1;
129 stats.rows_fetched += node_results.len();
130 results.insert(node_id.clone(), node_results);
131 }
132 Err(e) => errors.push(e),
133 }
134 }
135 } else {
136 stats.parallel_phases += 1;
138 let parallel_results = self.execute_phase_parallel(phase, plan, connection).await;
139
140 for (node_id, result) in parallel_results {
141 match result {
142 Ok(node_results) => {
143 stats.query_count += 1;
144 stats.rows_fetched += node_results.len();
145 results.insert(node_id, node_results);
146 }
147 Err(e) => errors.push(e),
148 }
149 }
150 }
151
152 let phase_duration = phase_start.elapsed();
153 stats.add_phase_duration(phase_duration);
154 }
155
156 stats.total_duration = start_time.elapsed();
157 stats.calculate_averages();
158
159 Ok(ExecutionResult {
160 results,
161 stats,
162 errors,
163 })
164 }
165
166 async fn execute_phase_parallel(
168 &self,
169 phase: &[String],
170 plan: &QueryPlan,
171 connection: &sqlx::PgPool,
172 ) -> HashMap<String, OrmResult<Vec<JsonValue>>> {
173 let mut handles: Vec<JoinHandle<(String, OrmResult<Vec<JsonValue>>)>> = Vec::new();
174 let mut results = HashMap::new();
175
176 let chunks: Vec<_> = phase.chunks(self.max_parallel_tasks).collect();
178
179 for chunk in chunks {
180 let mut chunk_handles = Vec::new();
181
182 for node_id in chunk {
183 if let Some(node) = plan.nodes.get(node_id) {
184 let node_clone = node.clone();
185 let node_id_clone = node_id.clone();
186 let connection_clone = connection.clone();
187
188 let handle = tokio::spawn(async move {
189 let result = Self::execute_node_query_static(&node_clone, &connection_clone).await;
190 (node_id_clone, result)
191 });
192
193 chunk_handles.push(handle);
194 }
195 }
196
197 for handle in chunk_handles {
199 match handle.await {
200 Ok((node_id, result)) => {
201 results.insert(node_id, result);
202 }
203 Err(e) => {
204 eprintln!("Task join error: {}", e);
205 }
206 }
207 }
208 }
209
210 results
211 }
212
213 async fn execute_node_query(
215 &self,
216 node: &QueryNode,
217 connection: &sqlx::PgPool,
218 ) -> OrmResult<Vec<JsonValue>> {
219 let query_future = self.execute_node_query_impl(node, connection);
221
222 match tokio::time::timeout(self.query_timeout, query_future).await {
223 Ok(result) => result,
224 Err(_) => Err(OrmError::Query(format!(
225 "Query timeout for node '{}' after {:?}",
226 node.id, self.query_timeout
227 ))),
228 }
229 }
230
231 async fn execute_node_query_impl(
233 &self,
234 node: &QueryNode,
235 connection: &sqlx::PgPool,
236 ) -> OrmResult<Vec<JsonValue>> {
237 if node.is_root() {
241 self.execute_root_query(node, connection).await
243 } else {
244 self.execute_relationship_query(node, connection).await
246 }
247 }
248
249 async fn execute_root_query(
251 &self,
252 node: &QueryNode,
253 connection: &sqlx::PgPool,
254 ) -> OrmResult<Vec<JsonValue>> {
255 use crate::query::QueryBuilder;
256
257 let mut query = QueryBuilder::<()>::new().from(&node.table);
259
260 for constraint in &node.constraints {
262 query = query.where_raw(constraint);
263 }
264
265 let limit = std::cmp::min(node.estimated_rows, 1000);
267 query = query.limit(limit as i64);
268
269 let (sql, _params) = query.to_sql_with_params();
271 let db_query = sqlx::query(&sql);
272
273 let rows = db_query.fetch_all(connection).await
274 .map_err(|e| OrmError::Database(e.to_string()))?;
275
276 let results: Result<Vec<JsonValue>, OrmError> = rows.into_iter()
278 .map(|row| {
279 crate::loading::batch_loader::row_conversion::convert_row_to_json(&row)
280 .map_err(|e| OrmError::Serialization(e.to_string()))
281 })
282 .collect();
283
284 results
285 }
286
287 async fn execute_relationship_query(
289 &self,
290 node: &QueryNode,
291 _connection: &sqlx::PgPool,
292 ) -> OrmResult<Vec<JsonValue>> {
293 Ok(Vec::new())
302 }
303
304 async fn execute_node_query_static(
306 node: &QueryNode,
307 connection: &sqlx::PgPool,
308 ) -> OrmResult<Vec<JsonValue>> {
309 if node.is_root() {
311 Self::execute_root_query_static(node, connection).await
312 } else {
313 Self::execute_relationship_query_static(node, connection).await
314 }
315 }
316
317 async fn execute_root_query_static(
319 node: &QueryNode,
320 connection: &sqlx::PgPool,
321 ) -> OrmResult<Vec<JsonValue>> {
322 use crate::query::QueryBuilder;
323
324 let mut query = QueryBuilder::<()>::new().from(&node.table);
326
327 for constraint in &node.constraints {
329 query = query.where_raw(constraint);
330 }
331
332 let limit = std::cmp::min(node.estimated_rows, 1000);
334 query = query.limit(limit as i64);
335
336 let (sql, _params) = query.to_sql_with_params();
338 let db_query = sqlx::query(&sql);
339
340 let rows = db_query.fetch_all(connection).await
341 .map_err(|e| OrmError::Database(e.to_string()))?;
342
343 let results: Result<Vec<JsonValue>, OrmError> = rows.into_iter()
345 .map(|row| {
346 crate::loading::batch_loader::row_conversion::convert_row_to_json(&row)
347 .map_err(|e| OrmError::Serialization(e.to_string()))
348 })
349 .collect();
350
351 results
352 }
353
354 async fn execute_relationship_query_static(
356 node: &QueryNode,
357 _connection: &sqlx::PgPool,
358 ) -> OrmResult<Vec<JsonValue>> {
359 Ok(Vec::new())
362 }
363
364 pub fn get_stats(&self) -> ExecutorStats {
366 ExecutorStats {
367 max_parallel_tasks: self.max_parallel_tasks,
368 query_timeout: self.query_timeout,
369 }
370 }
371
372 pub fn set_max_parallel_tasks(&mut self, max_tasks: usize) {
374 self.max_parallel_tasks = max_tasks;
375 }
376
377 pub fn set_query_timeout(&mut self, timeout: Duration) {
378 self.query_timeout = timeout;
379 }
380}
381
382#[derive(Debug, Clone)]
384pub struct ExecutorStats {
385 pub max_parallel_tasks: usize,
386 pub query_timeout: Duration,
387}