elif_orm/loading/optimizer/
executor.rs1use super::plan::{QueryNode, QueryPlan};
2use crate::{
3 error::{OrmError, OrmResult},
4 loading::batch_loader::BatchLoader,
5};
6use serde_json::Value as JsonValue;
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9use tokio::task::JoinHandle;
10
11#[derive(Debug)]
13pub struct ExecutionResult {
14 pub results: HashMap<String, Vec<JsonValue>>,
16 pub stats: ExecutionStats,
18 pub errors: Vec<OrmError>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ExecutionStats {
25 pub total_duration: Duration,
27 pub phase_durations: Vec<Duration>,
29 pub query_count: usize,
31 pub rows_fetched: usize,
33 pub parallel_phases: usize,
35 pub avg_query_time: Duration,
37 pub peak_memory_mb: Option<f64>,
39}
40
41impl ExecutionStats {
42 pub fn new() -> Self {
43 Self {
44 total_duration: Duration::from_secs(0),
45 phase_durations: Vec::new(),
46 query_count: 0,
47 rows_fetched: 0,
48 parallel_phases: 0,
49 avg_query_time: Duration::from_secs(0),
50 peak_memory_mb: None,
51 }
52 }
53
54 pub fn calculate_averages(&mut self) {
56 if self.query_count > 0 {
57 self.avg_query_time = self.total_duration / self.query_count as u32;
58 }
59 }
60
61 pub fn add_phase_duration(&mut self, duration: Duration) {
63 self.phase_durations.push(duration);
64 self.total_duration += duration;
65 }
66}
67
68impl Default for ExecutionStats {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74pub struct PlanExecutor {
76 _batch_loader: BatchLoader,
78 max_parallel_tasks: usize,
80 query_timeout: Duration,
82}
83
84impl PlanExecutor {
85 pub fn new(batch_loader: BatchLoader) -> Self {
87 Self {
88 _batch_loader: batch_loader,
89 max_parallel_tasks: 10, query_timeout: Duration::from_secs(30),
91 }
92 }
93
94 pub fn with_config(
96 batch_loader: BatchLoader,
97 max_parallel_tasks: usize,
98 query_timeout: Duration,
99 ) -> Self {
100 Self {
101 _batch_loader: batch_loader,
102 max_parallel_tasks,
103 query_timeout,
104 }
105 }
106
107 pub async fn execute_plan(
109 &self,
110 plan: &QueryPlan,
111 connection: &sqlx::PgPool,
112 ) -> OrmResult<ExecutionResult> {
113 let start_time = Instant::now();
114 let mut results: HashMap<String, Vec<JsonValue>> = HashMap::new();
115 let mut stats = ExecutionStats::new();
116 let mut errors = Vec::new();
117
118 for phase in plan.execution_phases.iter() {
120 let phase_start = Instant::now();
121
122 if phase.len() == 1 {
123 let node_id = &phase[0];
125 if let Some(node) = plan.nodes.get(node_id) {
126 match self.execute_node_query(node, connection).await {
127 Ok(node_results) => {
128 stats.query_count += 1;
129 stats.rows_fetched += node_results.len();
130 results.insert(node_id.clone(), node_results);
131 }
132 Err(e) => errors.push(e),
133 }
134 }
135 } else {
136 stats.parallel_phases += 1;
138 let parallel_results = self.execute_phase_parallel(phase, plan, connection).await;
139
140 for (node_id, result) in parallel_results {
141 match result {
142 Ok(node_results) => {
143 stats.query_count += 1;
144 stats.rows_fetched += node_results.len();
145 results.insert(node_id, node_results);
146 }
147 Err(e) => errors.push(e),
148 }
149 }
150 }
151
152 let phase_duration = phase_start.elapsed();
153 stats.add_phase_duration(phase_duration);
154 }
155
156 stats.total_duration = start_time.elapsed();
157 stats.calculate_averages();
158
159 Ok(ExecutionResult {
160 results,
161 stats,
162 errors,
163 })
164 }
165
166 async fn execute_phase_parallel(
168 &self,
169 phase: &[String],
170 plan: &QueryPlan,
171 connection: &sqlx::PgPool,
172 ) -> HashMap<String, OrmResult<Vec<JsonValue>>> {
173 let _handles: Vec<JoinHandle<(String, OrmResult<Vec<JsonValue>>)>> = Vec::new();
174 let mut results = HashMap::new();
175
176 let chunks: Vec<_> = phase.chunks(self.max_parallel_tasks).collect();
178
179 for chunk in chunks {
180 let mut chunk_handles = Vec::new();
181
182 for node_id in chunk {
183 if let Some(node) = plan.nodes.get(node_id) {
184 let node_clone = node.clone();
185 let node_id_clone = node_id.clone();
186 let connection_clone = connection.clone();
187
188 let handle = tokio::spawn(async move {
189 let result =
190 Self::execute_node_query_static(&node_clone, &connection_clone).await;
191 (node_id_clone, result)
192 });
193
194 chunk_handles.push(handle);
195 }
196 }
197
198 for handle in chunk_handles {
200 match handle.await {
201 Ok((node_id, result)) => {
202 results.insert(node_id, result);
203 }
204 Err(e) => {
205 eprintln!("Task join error: {}", e);
206 }
207 }
208 }
209 }
210
211 results
212 }
213
214 async fn execute_node_query(
216 &self,
217 node: &QueryNode,
218 connection: &sqlx::PgPool,
219 ) -> OrmResult<Vec<JsonValue>> {
220 let query_future = self.execute_node_query_impl(node, connection);
222
223 match tokio::time::timeout(self.query_timeout, query_future).await {
224 Ok(result) => result,
225 Err(_) => Err(OrmError::Query(format!(
226 "Query timeout for node '{}' after {:?}",
227 node.id, self.query_timeout
228 ))),
229 }
230 }
231
232 async fn execute_node_query_impl(
234 &self,
235 node: &QueryNode,
236 connection: &sqlx::PgPool,
237 ) -> OrmResult<Vec<JsonValue>> {
238 if node.is_root() {
242 self.execute_root_query(node, connection).await
244 } else {
245 self.execute_relationship_query(node, connection).await
247 }
248 }
249
250 async fn execute_root_query(
252 &self,
253 node: &QueryNode,
254 connection: &sqlx::PgPool,
255 ) -> OrmResult<Vec<JsonValue>> {
256 use crate::query::QueryBuilder;
257
258 let mut query = QueryBuilder::<()>::new().from(&node.table);
260
261 for constraint in &node.constraints {
263 query = query.where_raw(constraint);
264 }
265
266 let limit = std::cmp::min(node.estimated_rows, 1000);
268 query = query.limit(limit as i64);
269
270 let (sql, _params) = query.to_sql_with_params();
272 let db_query = sqlx::query(&sql);
273
274 let rows = db_query
275 .fetch_all(connection)
276 .await
277 .map_err(|e| OrmError::Database(e.to_string()))?;
278
279 let results: Result<Vec<JsonValue>, OrmError> = rows
281 .into_iter()
282 .map(|row| {
283 crate::loading::batch_loader::row_conversion::convert_row_to_json(&row)
284 .map_err(|e| OrmError::Serialization(e.to_string()))
285 })
286 .collect();
287
288 results
289 }
290
291 async fn execute_relationship_query(
293 &self,
294 _node: &QueryNode,
295 _connection: &sqlx::PgPool,
296 ) -> OrmResult<Vec<JsonValue>> {
297 Ok(Vec::new())
306 }
307
308 async fn execute_node_query_static(
310 node: &QueryNode,
311 connection: &sqlx::PgPool,
312 ) -> OrmResult<Vec<JsonValue>> {
313 if node.is_root() {
315 Self::execute_root_query_static(node, connection).await
316 } else {
317 Self::execute_relationship_query_static(node, connection).await
318 }
319 }
320
321 async fn execute_root_query_static(
323 node: &QueryNode,
324 connection: &sqlx::PgPool,
325 ) -> OrmResult<Vec<JsonValue>> {
326 use crate::query::QueryBuilder;
327
328 let mut query = QueryBuilder::<()>::new().from(&node.table);
330
331 for constraint in &node.constraints {
333 query = query.where_raw(constraint);
334 }
335
336 let limit = std::cmp::min(node.estimated_rows, 1000);
338 query = query.limit(limit as i64);
339
340 let (sql, _params) = query.to_sql_with_params();
342 let db_query = sqlx::query(&sql);
343
344 let rows = db_query
345 .fetch_all(connection)
346 .await
347 .map_err(|e| OrmError::Database(e.to_string()))?;
348
349 let results: Result<Vec<JsonValue>, OrmError> = rows
351 .into_iter()
352 .map(|row| {
353 crate::loading::batch_loader::row_conversion::convert_row_to_json(&row)
354 .map_err(|e| OrmError::Serialization(e.to_string()))
355 })
356 .collect();
357
358 results
359 }
360
361 async fn execute_relationship_query_static(
363 _node: &QueryNode,
364 _connection: &sqlx::PgPool,
365 ) -> OrmResult<Vec<JsonValue>> {
366 Ok(Vec::new())
369 }
370
371 pub fn get_stats(&self) -> ExecutorStats {
373 ExecutorStats {
374 max_parallel_tasks: self.max_parallel_tasks,
375 query_timeout: self.query_timeout,
376 }
377 }
378
379 pub fn set_max_parallel_tasks(&mut self, max_tasks: usize) {
381 self.max_parallel_tasks = max_tasks;
382 }
383
384 pub fn set_query_timeout(&mut self, timeout: Duration) {
385 self.query_timeout = timeout;
386 }
387}
388
389#[derive(Debug, Clone)]
391pub struct ExecutorStats {
392 pub max_parallel_tasks: usize,
393 pub query_timeout: Duration,
394}