Skip to main content

apiary_query/
lib.rs

1//! DataFusion-based SQL query engine for Apiary.
2//!
3//! [`ApiaryQueryContext`] wraps a DataFusion `SessionContext` and resolves
4//! Apiary frame references (hive.box.frame) to in-memory tables built from
5//! the frame's active Parquet cells.  Custom SQL commands (USE, SHOW,
6//! DESCRIBE) are intercepted before they reach DataFusion.
7
8pub mod distributed;
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use arrow::array::StringArray;
14use arrow::datatypes::{DataType, Field, Schema};
15use arrow::record_batch::RecordBatch;
16use datafusion::prelude::*;
17use tracing::{info, warn};
18
19use apiary_core::error::ApiaryError;
20use apiary_core::registry_manager::RegistryManager;
21use apiary_core::storage::StorageBackend;
22use apiary_core::types::NodeId;
23use apiary_core::Result;
24use apiary_storage::cell_reader::CellReader;
25use apiary_storage::ledger::Ledger;
26
27/// The Apiary query context — wraps DataFusion with Apiary namespace resolution.
28pub struct ApiaryQueryContext {
29    storage: Arc<dyn StorageBackend>,
30    registry: Arc<RegistryManager>,
31    current_hive: Option<String>,
32    current_box: Option<String>,
33    #[allow(dead_code)] // Will be used for distributed execution
34    node_id: NodeId,
35}
36
37impl ApiaryQueryContext {
38    /// Create a new query context.
39    pub fn new(storage: Arc<dyn StorageBackend>, registry: Arc<RegistryManager>) -> Self {
40        Self::with_node_id(storage, registry, NodeId::from("local"))
41    }
42
43    /// Create a new query context with a specific node ID.
44    pub fn with_node_id(
45        storage: Arc<dyn StorageBackend>,
46        registry: Arc<RegistryManager>,
47        node_id: NodeId,
48    ) -> Self {
49        Self {
50            storage,
51            registry,
52            current_hive: None,
53            current_box: None,
54            node_id,
55        }
56    }
57
58    /// Execute a SQL query and return results as RecordBatches.
59    pub async fn sql(&mut self, query: &str) -> Result<Vec<RecordBatch>> {
60        let trimmed = query.trim();
61
62        // Detect and block unsupported DML
63        if let Some(err) = check_unsupported_dml(trimmed) {
64            return Err(err);
65        }
66
67        // Handle custom commands
68        if let Some(result) = self.handle_custom_command(trimmed).await? {
69            return Ok(result);
70        }
71
72        // Standard SQL: resolve frame references, register tables, execute
73        self.execute_standard_sql(trimmed).await
74    }
75
76    /// Handle custom SQL commands (USE, SHOW, DESCRIBE).
77    async fn handle_custom_command(&mut self, sql: &str) -> Result<Option<Vec<RecordBatch>>> {
78        let upper = sql.to_uppercase();
79        let upper = upper.trim_end_matches(';').trim();
80
81        // USE HIVE <name>
82        if let Some(name) = upper.strip_prefix("USE HIVE ") {
83            let name = name.trim().to_lowercase();
84            // Verify hive exists
85            let hives = self.registry.list_hives().await?;
86            if !hives.iter().any(|h| h.to_lowercase() == name) {
87                return Err(ApiaryError::EntityNotFound {
88                    entity_type: "Hive".into(),
89                    name: name.clone(),
90                });
91            }
92            self.current_hive = Some(name.clone());
93            let batch = single_message_batch(&format!("Current hive set to '{name}'"));
94            return Ok(Some(vec![batch]));
95        }
96
97        // USE BOX <name>
98        if let Some(name) = upper.strip_prefix("USE BOX ") {
99            let name = name.trim().to_lowercase();
100            let hive = self
101                .current_hive
102                .as_ref()
103                .ok_or_else(|| ApiaryError::Config {
104                    message: "No hive selected. Run USE HIVE <name> first.".into(),
105                })?;
106            // Verify box exists
107            let boxes = self.registry.list_boxes(hive).await?;
108            if !boxes.iter().any(|b| b.to_lowercase() == name) {
109                return Err(ApiaryError::EntityNotFound {
110                    entity_type: "Box".into(),
111                    name: name.clone(),
112                });
113            }
114            self.current_box = Some(name.clone());
115            let batch = single_message_batch(&format!("Current box set to '{name}'"));
116            return Ok(Some(vec![batch]));
117        }
118
119        // SHOW HIVES
120        if upper == "SHOW HIVES" {
121            let hives = self.registry.list_hives().await?;
122            let batch = string_list_batch("hive", &hives);
123            return Ok(Some(vec![batch]));
124        }
125
126        // SHOW BOXES IN <hive>
127        if let Some(rest) = upper.strip_prefix("SHOW BOXES IN ") {
128            let hive = rest.trim().to_lowercase();
129            let boxes = self.registry.list_boxes(&hive).await?;
130            let batch = string_list_batch("box", &boxes);
131            return Ok(Some(vec![batch]));
132        }
133
134        // SHOW BOXES (using current hive context)
135        if upper == "SHOW BOXES" {
136            let hive = self
137                .current_hive
138                .as_ref()
139                .ok_or_else(|| ApiaryError::Config {
140                    message:
141                        "No hive selected. Run USE HIVE <name> first, or use SHOW BOXES IN <hive>."
142                            .into(),
143                })?;
144            let boxes = self.registry.list_boxes(hive).await?;
145            let batch = string_list_batch("box", &boxes);
146            return Ok(Some(vec![batch]));
147        }
148
149        // SHOW FRAMES IN <hive>.<box>
150        if let Some(rest) = upper.strip_prefix("SHOW FRAMES IN ") {
151            let parts: Vec<&str> = rest.trim().split('.').collect();
152            if parts.len() != 2 {
153                return Err(ApiaryError::Config {
154                    message: "SHOW FRAMES IN requires hive.box format".into(),
155                });
156            }
157            let hive = parts[0].trim().to_lowercase();
158            let box_name = parts[1].trim().to_lowercase();
159            let frames = self.registry.list_frames(&hive, &box_name).await?;
160            let batch = string_list_batch("frame", &frames);
161            return Ok(Some(vec![batch]));
162        }
163
164        // SHOW FRAMES (using current hive and box context)
165        if upper == "SHOW FRAMES" {
166            let hive = self.current_hive.as_ref().ok_or_else(|| ApiaryError::Config {
167                message: "No hive selected. Run USE HIVE <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
168            })?;
169            let box_name = self.current_box.as_ref().ok_or_else(|| ApiaryError::Config {
170                message: "No box selected. Run USE BOX <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
171            })?;
172            let frames = self.registry.list_frames(hive, box_name).await?;
173            let batch = string_list_batch("frame", &frames);
174            return Ok(Some(vec![batch]));
175        }
176
177        // DESCRIBE <hive>.<box>.<frame>
178        if let Some(rest) = upper.strip_prefix("DESCRIBE ") {
179            let raw_rest = sql.trim_end_matches(';').trim();
180            let raw_rest = &raw_rest[raw_rest.len() - rest.len()..];
181            let parts: Vec<&str> = raw_rest.trim().split('.').collect();
182            if parts.len() != 3 {
183                return Err(ApiaryError::Config {
184                    message: "DESCRIBE requires hive.box.frame format".into(),
185                });
186            }
187            let hive = parts[0].trim();
188            let box_name = parts[1].trim();
189            let frame_name = parts[2].trim();
190            return Ok(Some(vec![
191                self.describe_frame(hive, box_name, frame_name).await?,
192            ]));
193        }
194
195        Ok(None)
196    }
197
198    /// Produce a DESCRIBE result for a frame.
199    async fn describe_frame(
200        &self,
201        hive: &str,
202        box_name: &str,
203        frame_name: &str,
204    ) -> Result<RecordBatch> {
205        let frame = self.registry.get_frame(hive, box_name, frame_name).await?;
206        let frame_path = format!("{hive}/{box_name}/{frame_name}");
207
208        // Get cell count and total size from ledger
209        let (cell_count, total_rows, total_bytes) =
210            match Ledger::open(Arc::clone(&self.storage), &frame_path).await {
211                Ok(ledger) => {
212                    let cells = ledger.active_cells();
213                    let rows: u64 = cells.iter().map(|c| c.rows).sum();
214                    let bytes: u64 = cells.iter().map(|c| c.bytes).sum();
215                    (cells.len() as u64, rows, bytes)
216                }
217                Err(_) => (0, 0, 0),
218            };
219
220        let schema_json = serde_json::to_string(&frame.schema).unwrap_or_else(|_| "{}".into());
221
222        let schema = Arc::new(Schema::new(vec![
223            Field::new("property", DataType::Utf8, false),
224            Field::new("value", DataType::Utf8, false),
225        ]));
226
227        let partition_str = if frame.partition_by.is_empty() {
228            "(none)".to_string()
229        } else {
230            frame.partition_by.join(", ")
231        };
232
233        let properties = vec![
234            "schema",
235            "partition_by",
236            "cells",
237            "total_rows",
238            "total_bytes",
239        ];
240        let values = vec![
241            schema_json,
242            partition_str,
243            cell_count.to_string(),
244            total_rows.to_string(),
245            total_bytes.to_string(),
246        ];
247
248        RecordBatch::try_new(
249            schema,
250            vec![
251                Arc::new(StringArray::from(
252                    properties
253                        .into_iter()
254                        .map(|s| s.to_string())
255                        .collect::<Vec<_>>(),
256                )),
257                Arc::new(StringArray::from(values)),
258            ],
259        )
260        .map_err(|e| ApiaryError::Internal {
261            message: format!("Failed to create DESCRIBE result: {e}"),
262        })
263    }
264
265    /// Execute standard SQL by resolving frame references and delegating to DataFusion.
266    async fn execute_standard_sql(&self, sql: &str) -> Result<Vec<RecordBatch>> {
267        // Extract table references from SQL
268        let table_refs = extract_table_references(sql);
269
270        if table_refs.is_empty() {
271            return Err(ApiaryError::Config {
272                message: "No table references found in query".into(),
273            });
274        }
275
276        // Extract simple WHERE predicates for pruning
277        let predicates = extract_where_predicates(sql);
278
279        // Create a fresh session for this query (avoids stale table registrations)
280        let session = SessionContext::new();
281
282        // Resolve and register each table
283        for table_ref in &table_refs {
284            let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
285
286            // Open ledger and prune cells
287            let frame_path = format!("{hive}/{box_name}/{frame_name}");
288            let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
289
290            // Build partition and stat filters from predicates
291            let partition_by: Vec<String> = ledger.partition_by().to_vec();
292            let (partition_filters, stat_filters) = build_filters(&predicates, &partition_by);
293
294            let cells = if partition_filters.is_empty() && stat_filters.is_empty() {
295                ledger.active_cells().iter().collect::<Vec<_>>()
296            } else {
297                ledger.prune_cells(&partition_filters, &stat_filters)
298            };
299
300            info!(
301                frame = %frame_path,
302                total_cells = ledger.active_cells().len(),
303                surviving_cells = cells.len(),
304                "Cell pruning complete"
305            );
306
307            if cells.is_empty() {
308                // Register an empty table with the correct schema
309                let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
310                let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
311                let mem_table = datafusion::datasource::MemTable::try_new(
312                    empty_batch.schema(),
313                    vec![vec![empty_batch]],
314                )
315                .map_err(|e| ApiaryError::Internal {
316                    message: format!("Failed to create empty MemTable: {e}"),
317                })?;
318                session
319                    .register_table(&register_name, Arc::new(mem_table))
320                    .map_err(|e| ApiaryError::Internal {
321                        message: format!("Failed to register table: {e}"),
322                    })?;
323                continue;
324            }
325
326            // Read surviving cells
327            let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
328            let merged = reader.read_cells_merged(&cells, None).await?;
329
330            let batches = match merged {
331                Some(batch) => vec![vec![batch]],
332                None => {
333                    let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
334                    vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
335                }
336            };
337
338            let schema = batches[0][0].schema();
339            let mem_table =
340                datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
341                    ApiaryError::Internal {
342                        message: format!("Failed to create MemTable: {e}"),
343                    }
344                })?;
345
346            session
347                .register_table(&register_name, Arc::new(mem_table))
348                .map_err(|e| ApiaryError::Internal {
349                    message: format!("Failed to register table '{register_name}': {e}"),
350                })?;
351        }
352
353        // Rewrite the SQL to use the registered table names
354        let rewritten =
355            rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
356
357        // Execute via DataFusion
358        let df = session
359            .sql(&rewritten)
360            .await
361            .map_err(|e| ApiaryError::Internal {
362                message: format!("DataFusion query error: {e}"),
363            })?;
364
365        df.collect().await.map_err(|e| ApiaryError::Internal {
366            message: format!("DataFusion execution error: {e}"),
367        })
368    }
369
370    /// Resolve a table reference to (hive, box, frame, register_name).
371    fn resolve_table_ref(&self, table_ref: &str) -> Result<(String, String, String, String)> {
372        let parts: Vec<&str> = table_ref.split('.').collect();
373
374        match parts.len() {
375            3 => {
376                let hive = parts[0].to_string();
377                let box_name = parts[1].to_string();
378                let frame_name = parts[2].to_string();
379                // Register with just the frame name to simplify SQL rewriting
380                let register_name = frame_name.clone();
381                Ok((hive, box_name, frame_name, register_name))
382            }
383            2 => {
384                let hive = self.current_hive.as_ref().ok_or_else(|| {
385                    ApiaryError::Resolution {
386                        path: table_ref.into(),
387                        reason: "No hive selected. Use 3-part name (hive.box.frame) or run USE HIVE first.".into(),
388                    }
389                })?;
390                let box_name = parts[0].to_string();
391                let frame_name = parts[1].to_string();
392                let register_name = frame_name.clone();
393                Ok((hive.clone(), box_name, frame_name, register_name))
394            }
395            1 => {
396                let hive = self
397                    .current_hive
398                    .as_ref()
399                    .ok_or_else(|| ApiaryError::Resolution {
400                        path: table_ref.into(),
401                        reason: "No hive selected. Use 3-part name or run USE HIVE first.".into(),
402                    })?;
403                let box_name =
404                    self.current_box
405                        .as_ref()
406                        .ok_or_else(|| ApiaryError::Resolution {
407                            path: table_ref.into(),
408                            reason: "No box selected. Use 3-part name or run USE BOX first.".into(),
409                        })?;
410                let frame_name = parts[0].to_string();
411                let register_name = frame_name.clone();
412                Ok((hive.clone(), box_name.clone(), frame_name, register_name))
413            }
414            _ => Err(ApiaryError::Resolution {
415                path: table_ref.into(),
416                reason: "Invalid table reference. Use hive.box.frame format.".into(),
417            }),
418        }
419    }
420
421    /// Execute a query using distributed execution (stub for Step 7).
422    ///
423    /// For v1, this is a simplified implementation that:
424    /// - Creates tasks from cell assignments
425    /// - Generates SQL fragments (simple pass-through)
426    /// - Writes the query manifest
427    /// - Executes local tasks
428    /// - Polls for partial results from other nodes
429    /// - Merges results (simple concatenation)
430    /// - Cleans up query files
431    pub async fn execute_distributed(
432        &self,
433        sql: &str,
434        assignments: HashMap<NodeId, Vec<distributed::CellInfo>>,
435    ) -> Result<Vec<RecordBatch>> {
436        use distributed::*;
437
438        // 1. Create tasks from assignments
439        let mut tasks = Vec::new();
440        for (node_id, cells) in &assignments {
441            let task_id = format!("{}_{}", node_id.as_str(), uuid::Uuid::new_v4());
442            let cell_keys: Vec<String> = cells.iter().map(|c| c.storage_key.clone()).collect();
443
444            tasks.push(PlannedTask {
445                task_id,
446                node_id: node_id.clone(),
447                cells: cell_keys,
448                sql_fragment: sql.to_string(), // For v1, use original SQL
449            });
450        }
451
452        // 2. Create and write manifest
453        let manifest = create_manifest(sql, tasks.clone(), None, 60);
454        write_manifest(&self.storage, &manifest).await?;
455
456        info!(
457            query_id = %manifest.query_id,
458            tasks = tasks.len(),
459            "Distributed query manifest written"
460        );
461
462        // 3. Execute local tasks
463        let local_tasks: Vec<_> = tasks.iter().filter(|t| t.node_id == self.node_id).collect();
464
465        let mut local_results = Vec::new();
466        for task in local_tasks {
467            match self.execute_task(&task.sql_fragment, &task.cells).await {
468                Ok(batches) => {
469                    if !batches.is_empty() {
470                        local_results.extend(batches);
471                    }
472                }
473                Err(e) => {
474                    warn!(task_id = %task.task_id, error = %e, "Local task failed");
475                }
476            }
477        }
478
479        // Write local partial result
480        if !local_results.is_empty() {
481            write_partial_result(
482                &self.storage,
483                &manifest.query_id,
484                &self.node_id,
485                &local_results,
486            )
487            .await?;
488        }
489
490        // 4. Poll for partial results from other nodes
491        let remote_nodes: Vec<_> = tasks
492            .iter()
493            .filter(|t| t.node_id != self.node_id)
494            .map(|t| t.node_id.clone())
495            .collect();
496
497        let timeout = std::time::Duration::from_secs(manifest.timeout_secs);
498        let start = std::time::Instant::now();
499        let mut collected_results = local_results;
500
501        for remote_node in &remote_nodes {
502            let deadline = timeout.saturating_sub(start.elapsed());
503            if start.elapsed() >= timeout {
504                warn!(query_id = %manifest.query_id, "Query timeout reached");
505                break;
506            }
507
508            // Poll for partial result
509            let poll_interval = std::time::Duration::from_millis(500);
510            let mut attempts = 0;
511            let max_attempts = (deadline.as_millis() / poll_interval.as_millis()) as usize;
512
513            while attempts < max_attempts {
514                match read_partial_result(&self.storage, &manifest.query_id, remote_node).await {
515                    Ok(batches) => {
516                        info!(
517                            query_id = %manifest.query_id,
518                            node_id = %remote_node,
519                            "Partial result received"
520                        );
521                        collected_results.extend(batches);
522                        break;
523                    }
524                    Err(_) => {
525                        tokio::time::sleep(poll_interval).await;
526                        attempts += 1;
527                    }
528                }
529            }
530        }
531
532        // 5. Cleanup
533        if let Err(e) = cleanup_query(&self.storage, &manifest.query_id).await {
534            warn!(query_id = %manifest.query_id, error = %e, "Failed to cleanup query");
535        }
536
537        Ok(collected_results)
538    }
539
540    /// Execute a task on a specific set of cells.
541    ///
542    /// # Arguments
543    /// * `sql` - The SQL query to execute
544    /// * `cell_keys` - Storage keys of cells to scan. Only these cells are registered
545    ///   as the table for the query, enabling true work partitioning across nodes.
546    pub async fn execute_task(&self, sql: &str, cell_keys: &[String]) -> Result<Vec<RecordBatch>> {
547        if cell_keys.is_empty() {
548            // No cells assigned — fall back to standard execution
549            return self.execute_standard_sql(sql).await;
550        }
551
552        // Extract table references from SQL
553        let table_refs = extract_table_references(sql);
554
555        if table_refs.is_empty() {
556            return Err(ApiaryError::Config {
557                message: "No table references found in query".into(),
558            });
559        }
560
561        // Create a fresh session for this query
562        let session = SessionContext::new();
563
564        // Build a set for O(1) cell key lookups
565        let cell_key_set: std::collections::HashSet<&String> = cell_keys.iter().collect();
566
567        // Resolve and register each table, filtering to only the assigned cells
568        for table_ref in &table_refs {
569            let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
570
571            let frame_path = format!("{hive}/{box_name}/{frame_name}");
572            let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
573
574            // Filter active cells to only those in our assigned cell_keys
575            let cells: Vec<_> = ledger
576                .active_cells()
577                .iter()
578                .filter(|cell| {
579                    let cell_storage_key = format!("{}/{}", frame_path, cell.path);
580                    cell_key_set.contains(&cell_storage_key)
581                })
582                .collect();
583
584            info!(
585                frame = %frame_path,
586                total_cells = ledger.active_cells().len(),
587                assigned_cells = cells.len(),
588                "Cell filtering for distributed task"
589            );
590
591            if cells.is_empty() {
592                let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
593                let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
594                let mem_table = datafusion::datasource::MemTable::try_new(
595                    empty_batch.schema(),
596                    vec![vec![empty_batch]],
597                )
598                .map_err(|e| ApiaryError::Internal {
599                    message: format!("Failed to create empty MemTable: {e}"),
600                })?;
601                session
602                    .register_table(&register_name, Arc::new(mem_table))
603                    .map_err(|e| ApiaryError::Internal {
604                        message: format!("Failed to register table: {e}"),
605                    })?;
606                continue;
607            }
608
609            // Read only the assigned cells
610            let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
611            let merged = reader.read_cells_merged(&cells, None).await?;
612
613            let batches = match merged {
614                Some(batch) => vec![vec![batch]],
615                None => {
616                    let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
617                    vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
618                }
619            };
620
621            let schema = batches[0][0].schema();
622            let mem_table =
623                datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
624                    ApiaryError::Internal {
625                        message: format!("Failed to create MemTable: {e}"),
626                    }
627                })?;
628
629            session
630                .register_table(&register_name, Arc::new(mem_table))
631                .map_err(|e| ApiaryError::Internal {
632                    message: format!("Failed to register table '{register_name}': {e}"),
633                })?;
634        }
635
636        // Rewrite the SQL to use the registered table names
637        let rewritten =
638            rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
639
640        // Execute via DataFusion
641        let df = session
642            .sql(&rewritten)
643            .await
644            .map_err(|e| ApiaryError::Internal {
645                message: format!("DataFusion query error: {e}"),
646            })?;
647
648        df.collect().await.map_err(|e| ApiaryError::Internal {
649            message: format!("DataFusion execution error: {e}"),
650        })
651    }
652}
653
654// ---------------------------------------------------------------------------
655// Helper functions
656// ---------------------------------------------------------------------------
657
658/// Check for unsupported DML and return an error if detected.
659fn check_unsupported_dml(sql: &str) -> Option<ApiaryError> {
660    let upper = sql.to_uppercase();
661    let first_word = upper.split_whitespace().next().unwrap_or("");
662
663    match first_word {
664        "DELETE" => Some(ApiaryError::Unsupported {
665            message: "DELETE is not supported. Apiary uses append-only writes. Use overwrite_frame() to replace all data in a frame.".into(),
666        }),
667        "UPDATE" => Some(ApiaryError::Unsupported {
668            message: "UPDATE is not supported. Apiary uses append-only writes. Rewrite the frame with corrected data using overwrite_frame().".into(),
669        }),
670        "INSERT" => Some(ApiaryError::Unsupported {
671            message: "INSERT is not supported via SQL. Use write_to_frame() to add data.".into(),
672        }),
673        "DROP" => Some(ApiaryError::Unsupported {
674            message: "DROP is not supported via SQL. Use the registry API for DDL operations.".into(),
675        }),
676        "CREATE" => Some(ApiaryError::Unsupported {
677            message: "CREATE is not supported via SQL. Use create_frame() for DDL operations.".into(),
678        }),
679        "ALTER" => Some(ApiaryError::Unsupported {
680            message: "ALTER is not supported via SQL. Use the registry API for DDL operations.".into(),
681        }),
682        _ => None,
683    }
684}
685
686/// Extract table references from SQL.
687///
688/// Finds patterns like `FROM table` and `JOIN table`, where table can be
689/// `hive.box.frame`, `box.frame`, or `frame`.
690fn extract_table_references(sql: &str) -> Vec<String> {
691    let mut refs = Vec::new();
692    let tokens: Vec<&str> = sql.split_whitespace().collect();
693
694    for i in 0..tokens.len() {
695        let upper = tokens[i].to_uppercase();
696        if (upper == "FROM" || upper == "JOIN") && i + 1 < tokens.len() {
697            let table_name = tokens[i + 1]
698                .trim_end_matches(',')
699                .trim_end_matches(')')
700                .trim_end_matches(';');
701            // Skip subqueries
702            if table_name.starts_with('(') || table_name.is_empty() {
703                continue;
704            }
705            // Skip SQL keywords that follow FROM (e.g., FROM (SELECT ...))
706            let table_upper = table_name.to_uppercase();
707            if matches!(
708                table_upper.as_str(),
709                "SELECT" | "WHERE" | "GROUP" | "ORDER" | "LIMIT" | "HAVING"
710            ) {
711                continue;
712            }
713            if !refs.contains(&table_name.to_string()) {
714                refs.push(table_name.to_string());
715            }
716        }
717    }
718
719    refs
720}
721
722/// A simple WHERE predicate extracted from SQL.
723#[derive(Debug, Clone)]
724struct Predicate {
725    column: String,
726    op: PredicateOp,
727    value: String,
728}
729
730#[derive(Debug, Clone)]
731enum PredicateOp {
732    Eq,
733    Gt,
734    Lt,
735    Gte,
736    Lte,
737}
738
739/// Extract simple WHERE predicates from SQL for pruning.
740///
741/// Handles patterns like:
742/// - `column = 'value'` or `column = value`
743/// - `column > N`
744/// - `column < N`
745/// - `column >= N`
746/// - `column <= N`
747fn extract_where_predicates(sql: &str) -> Vec<Predicate> {
748    let mut predicates = Vec::new();
749
750    // Find WHERE clause
751    let upper = sql.to_uppercase();
752    let where_pos = match upper.find(" WHERE ") {
753        Some(pos) => pos + 7,
754        None => return predicates,
755    };
756
757    let where_clause = &sql[where_pos..];
758    // Truncate at GROUP BY, ORDER BY, LIMIT, HAVING
759    let end_keywords = [" GROUP ", " ORDER ", " LIMIT ", " HAVING ", ";"];
760    let end_pos = end_keywords
761        .iter()
762        .filter_map(|kw| where_clause.to_uppercase().find(kw))
763        .min()
764        .unwrap_or(where_clause.len());
765    let where_clause = &where_clause[..end_pos];
766
767    // Split by AND (simple approach — doesn't handle OR or nested parens)
768    let parts: Vec<&str> = split_on_and(where_clause);
769
770    for part in parts {
771        let part = part.trim();
772        if let Some(pred) = parse_predicate(part) {
773            predicates.push(pred);
774        }
775    }
776
777    predicates
778}
779
780/// Split a WHERE clause on AND keywords (case-insensitive).
781fn split_on_and(clause: &str) -> Vec<&str> {
782    let mut parts = Vec::new();
783    let upper = clause.to_uppercase();
784    let mut last = 0;
785
786    for (i, _) in upper.match_indices(" AND ") {
787        parts.push(&clause[last..i]);
788        last = i + 5; // " AND ".len()
789    }
790    parts.push(&clause[last..]);
791    parts
792}
793
794/// Parse a single predicate condition.
795fn parse_predicate(condition: &str) -> Option<Predicate> {
796    let condition = condition.trim();
797
798    // Try >=, <=, >, <, = operators
799    let ops = [
800        (">=", PredicateOp::Gte),
801        ("<=", PredicateOp::Lte),
802        (">", PredicateOp::Gt),
803        ("<", PredicateOp::Lt),
804        ("=", PredicateOp::Eq),
805    ];
806
807    for (op_str, op) in &ops {
808        if let Some(pos) = condition.find(op_str) {
809            let col = condition[..pos].trim();
810            let val = condition[pos + op_str.len()..].trim();
811
812            // Clean the value: strip quotes
813            let val = val.trim_matches('\'').trim_matches('"').to_string();
814
815            if !col.is_empty() && !val.is_empty() {
816                return Some(Predicate {
817                    column: col.to_string(),
818                    op: op.clone(),
819                    value: val,
820                });
821            }
822        }
823    }
824
825    None
826}
827
828/// Stat filter: maps column name → (min bound, max bound).
829type StatFilters = HashMap<String, (Option<serde_json::Value>, Option<serde_json::Value>)>;
830
831/// Build partition and stat filters from predicates.
832fn build_filters(
833    predicates: &[Predicate],
834    partition_columns: &[String],
835) -> (HashMap<String, String>, StatFilters) {
836    let mut partition_filters = HashMap::new();
837    let mut stat_filters: StatFilters = HashMap::new();
838
839    for pred in predicates {
840        if partition_columns.contains(&pred.column) {
841            // Partition filter: only support equality
842            if matches!(pred.op, PredicateOp::Eq) {
843                partition_filters.insert(pred.column.clone(), pred.value.clone());
844            }
845        }
846
847        // Stat filter: convert numeric predicates
848        if let Ok(num) = pred.value.parse::<f64>() {
849            let json_val = serde_json::json!(num);
850            let entry = stat_filters
851                .entry(pred.column.clone())
852                .or_insert((None, None));
853            match pred.op {
854                PredicateOp::Gt | PredicateOp::Gte => {
855                    // min_filter: skip cells where max < this value
856                    entry.0 = Some(json_val);
857                }
858                PredicateOp::Lt | PredicateOp::Lte => {
859                    // max_filter: skip cells where min > this value
860                    entry.1 = Some(json_val);
861                }
862                PredicateOp::Eq => {
863                    // Both bounds
864                    entry.0 = Some(json_val.clone());
865                    entry.1 = Some(json_val);
866                }
867            }
868        }
869    }
870
871    (partition_filters, stat_filters)
872}
873
874/// Convert a FrameSchema to an Arrow Schema.
875fn frame_schema_to_arrow(schema: &apiary_core::FrameSchema) -> Result<Schema> {
876    let fields: Vec<Field> = schema
877        .fields
878        .iter()
879        .map(|f| {
880            let dt = match f.data_type.as_str() {
881                "int8" => DataType::Int8,
882                "int16" => DataType::Int16,
883                "int32" => DataType::Int32,
884                "int64" => DataType::Int64,
885                "uint8" => DataType::UInt8,
886                "uint16" => DataType::UInt16,
887                "uint32" => DataType::UInt32,
888                "uint64" => DataType::UInt64,
889                "float32" | "float" => DataType::Float32,
890                "float64" | "double" => DataType::Float64,
891                "string" | "utf8" => DataType::Utf8,
892                "boolean" | "bool" => DataType::Boolean,
893                "datetime" | "timestamp" => {
894                    DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
895                }
896                _ => DataType::Utf8,
897            };
898            Field::new(&f.name, dt, f.nullable)
899        })
900        .collect();
901
902    Ok(Schema::new(fields))
903}
904
905/// Rewrite SQL to replace 3-part or 2-part table references with the registered names.
906fn rewrite_sql_table_refs(
907    sql: &str,
908    table_refs: &[String],
909    _current_hive: &Option<String>,
910    _current_box: &Option<String>,
911) -> String {
912    let mut result = sql.to_string();
913
914    for table_ref in table_refs {
915        let parts: Vec<&str> = table_ref.split('.').collect();
916        let register_name = parts.last().unwrap_or(&table_ref.as_str()).to_string();
917        if parts.len() > 1 {
918            // Replace full reference with just the frame name
919            result = result.replace(table_ref, &register_name);
920        }
921    }
922
923    result
924}
925
926/// Create a single-row batch with a message.
927fn single_message_batch(message: &str) -> RecordBatch {
928    let schema = Arc::new(Schema::new(vec![Field::new(
929        "message",
930        DataType::Utf8,
931        false,
932    )]));
933    RecordBatch::try_new(
934        schema,
935        vec![Arc::new(StringArray::from(vec![message.to_string()]))],
936    )
937    .unwrap()
938}
939
940/// Create a batch from a list of strings.
941fn string_list_batch(column_name: &str, values: &[String]) -> RecordBatch {
942    let schema = Arc::new(Schema::new(vec![Field::new(
943        column_name,
944        DataType::Utf8,
945        false,
946    )]));
947    RecordBatch::try_new(
948        schema,
949        vec![Arc::new(StringArray::from(
950            values.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
951        ))],
952    )
953    .unwrap()
954}
955
956#[cfg(test)]
957mod tests {
958    use super::*;
959    use apiary_core::{CellSizingPolicy, FieldDef, FrameSchema, NodeId};
960    use apiary_storage::cell_writer::CellWriter;
961    use apiary_storage::ledger::Ledger;
962    use apiary_storage::local::LocalBackend;
963    use arrow::array::{Float64Array, Int64Array};
964
965    async fn make_test_env() -> (
966        Arc<dyn StorageBackend>,
967        Arc<RegistryManager>,
968        tempfile::TempDir,
969    ) {
970        let dir = tempfile::tempdir().unwrap();
971        let backend = LocalBackend::new(dir.path().to_path_buf()).await.unwrap();
972        let storage: Arc<dyn StorageBackend> = Arc::new(backend);
973        let registry = Arc::new(RegistryManager::new(Arc::clone(&storage)));
974        let _ = registry.load_or_create().await.unwrap();
975        (storage, registry, dir)
976    }
977
978    fn test_schema() -> serde_json::Value {
979        serde_json::json!({
980            "region": "string",
981            "temp": "float64",
982            "humidity": "int64"
983        })
984    }
985
986    async fn setup_frame(storage: &Arc<dyn StorageBackend>, registry: &Arc<RegistryManager>) {
987        registry.create_hive("test_hive").await.unwrap();
988        registry.create_box("test_hive", "test_box").await.unwrap();
989        registry
990            .create_frame(
991                "test_hive",
992                "test_box",
993                "sensors",
994                test_schema(),
995                vec!["region".into()],
996            )
997            .await
998            .unwrap();
999
1000        // Create ledger and write data
1001        let frame_schema = FrameSchema {
1002            fields: vec![
1003                FieldDef {
1004                    name: "region".into(),
1005                    data_type: "string".into(),
1006                    nullable: false,
1007                },
1008                FieldDef {
1009                    name: "temp".into(),
1010                    data_type: "float64".into(),
1011                    nullable: true,
1012                },
1013                FieldDef {
1014                    name: "humidity".into(),
1015                    data_type: "int64".into(),
1016                    nullable: true,
1017                },
1018            ],
1019        };
1020
1021        let node_id = NodeId::new("test_node");
1022        let frame_path = "test_hive/test_box/sensors";
1023
1024        let mut ledger = Ledger::create(
1025            Arc::clone(storage),
1026            frame_path,
1027            frame_schema.clone(),
1028            vec!["region".into()],
1029            &node_id,
1030        )
1031        .await
1032        .unwrap();
1033
1034        let sizing = CellSizingPolicy::new(256 * 1024 * 1024, 512 * 1024 * 1024, 16 * 1024 * 1024);
1035
1036        let writer = CellWriter::new(
1037            Arc::clone(storage),
1038            frame_path.into(),
1039            frame_schema,
1040            vec!["region".into()],
1041            sizing,
1042        );
1043
1044        // Write test data
1045        let schema = Arc::new(Schema::new(vec![
1046            Field::new("region", DataType::Utf8, false),
1047            Field::new("temp", DataType::Float64, true),
1048            Field::new("humidity", DataType::Int64, true),
1049        ]));
1050
1051        let batch = RecordBatch::try_new(
1052            schema,
1053            vec![
1054                Arc::new(StringArray::from(vec!["north", "north", "south", "south"])),
1055                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1056                Arc::new(Int64Array::from(vec![50, 60, 70, 80])),
1057            ],
1058        )
1059        .unwrap();
1060
1061        let cells = writer.write(&batch).await.unwrap();
1062        ledger
1063            .commit(apiary_core::LedgerAction::AddCells { cells }, &node_id)
1064            .await
1065            .unwrap();
1066    }
1067
1068    #[tokio::test]
1069    async fn test_select_all() {
1070        let (storage, registry, _dir) = make_test_env().await;
1071        setup_frame(&storage, &registry).await;
1072
1073        let mut ctx = ApiaryQueryContext::new(storage, registry);
1074        let results = ctx
1075            .sql("SELECT * FROM test_hive.test_box.sensors")
1076            .await
1077            .unwrap();
1078
1079        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1080        assert_eq!(total_rows, 4);
1081    }
1082
1083    #[tokio::test]
1084    async fn test_aggregation() {
1085        let (storage, registry, _dir) = make_test_env().await;
1086        setup_frame(&storage, &registry).await;
1087
1088        let mut ctx = ApiaryQueryContext::new(storage, registry);
1089        let results = ctx
1090            .sql("SELECT region, AVG(temp) as avg_temp FROM test_hive.test_box.sensors GROUP BY region ORDER BY region")
1091            .await
1092            .unwrap();
1093
1094        assert!(!results.is_empty());
1095        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1096        assert_eq!(total_rows, 2); // north, south
1097    }
1098
1099    #[tokio::test]
1100    async fn test_use_hive_and_box() {
1101        let (storage, registry, _dir) = make_test_env().await;
1102        setup_frame(&storage, &registry).await;
1103
1104        let mut ctx = ApiaryQueryContext::new(storage, registry);
1105        ctx.sql("USE HIVE test_hive").await.unwrap();
1106        ctx.sql("USE BOX test_box").await.unwrap();
1107        let results = ctx.sql("SELECT * FROM sensors").await.unwrap();
1108
1109        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1110        assert_eq!(total_rows, 4);
1111    }
1112
1113    #[tokio::test]
1114    async fn test_show_hives() {
1115        let (storage, registry, _dir) = make_test_env().await;
1116        setup_frame(&storage, &registry).await;
1117
1118        let mut ctx = ApiaryQueryContext::new(storage, registry);
1119        let results = ctx.sql("SHOW HIVES").await.unwrap();
1120
1121        assert_eq!(results.len(), 1);
1122        assert_eq!(results[0].num_rows(), 1);
1123    }
1124
1125    #[tokio::test]
1126    async fn test_show_frames() {
1127        let (storage, registry, _dir) = make_test_env().await;
1128        setup_frame(&storage, &registry).await;
1129
1130        let mut ctx = ApiaryQueryContext::new(storage, registry);
1131        let results = ctx.sql("SHOW FRAMES IN test_hive.test_box").await.unwrap();
1132
1133        assert_eq!(results.len(), 1);
1134        assert!(results[0].num_rows() >= 1);
1135    }
1136
1137    #[tokio::test]
1138    async fn test_describe() {
1139        let (storage, registry, _dir) = make_test_env().await;
1140        setup_frame(&storage, &registry).await;
1141
1142        let mut ctx = ApiaryQueryContext::new(storage, registry);
1143        let results = ctx
1144            .sql("DESCRIBE test_hive.test_box.sensors")
1145            .await
1146            .unwrap();
1147
1148        assert_eq!(results.len(), 1);
1149        assert!(results[0].num_rows() >= 3);
1150    }
1151
1152    #[tokio::test]
1153    async fn test_delete_blocked() {
1154        let (storage, registry, _dir) = make_test_env().await;
1155        let mut ctx = ApiaryQueryContext::new(storage, registry);
1156
1157        let result = ctx.sql("DELETE FROM test_hive.test_box.sensors").await;
1158        assert!(result.is_err());
1159        let err = result.unwrap_err().to_string();
1160        assert!(err.contains("not supported"));
1161    }
1162
1163    #[tokio::test]
1164    async fn test_update_blocked() {
1165        let (storage, registry, _dir) = make_test_env().await;
1166        let mut ctx = ApiaryQueryContext::new(storage, registry);
1167
1168        let result = ctx
1169            .sql("UPDATE test_hive.test_box.sensors SET temp = 0")
1170            .await;
1171        assert!(result.is_err());
1172        let err = result.unwrap_err().to_string();
1173        assert!(err.contains("not supported"));
1174    }
1175
1176    #[tokio::test]
1177    async fn test_where_filter() {
1178        let (storage, registry, _dir) = make_test_env().await;
1179        setup_frame(&storage, &registry).await;
1180
1181        let mut ctx = ApiaryQueryContext::new(storage, registry);
1182        let results = ctx
1183            .sql("SELECT * FROM test_hive.test_box.sensors WHERE region = 'north'")
1184            .await
1185            .unwrap();
1186
1187        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1188        assert_eq!(total_rows, 2);
1189    }
1190
1191    #[tokio::test]
1192    async fn test_projection() {
1193        let (storage, registry, _dir) = make_test_env().await;
1194        setup_frame(&storage, &registry).await;
1195
1196        let mut ctx = ApiaryQueryContext::new(storage, registry);
1197        let results = ctx
1198            .sql("SELECT temp FROM test_hive.test_box.sensors")
1199            .await
1200            .unwrap();
1201
1202        assert!(!results.is_empty());
1203        assert_eq!(results[0].num_columns(), 1);
1204        assert_eq!(results[0].schema().field(0).name(), "temp");
1205    }
1206
1207    #[test]
1208    fn test_extract_table_references() {
1209        let refs = extract_table_references("SELECT * FROM hive.box.frame WHERE x = 1");
1210        assert_eq!(refs, vec!["hive.box.frame"]);
1211
1212        let refs = extract_table_references("SELECT * FROM frame1 JOIN frame2 ON x = y");
1213        assert_eq!(refs, vec!["frame1", "frame2"]);
1214    }
1215
1216    #[test]
1217    fn test_extract_predicates() {
1218        let preds =
1219            extract_where_predicates("SELECT * FROM t WHERE region = 'north' AND temp > 25");
1220        assert_eq!(preds.len(), 2);
1221        assert_eq!(preds[0].column, "region");
1222        assert_eq!(preds[0].value, "north");
1223        assert_eq!(preds[1].column, "temp");
1224        assert_eq!(preds[1].value, "25");
1225    }
1226
1227    #[test]
1228    fn test_check_unsupported_dml() {
1229        assert!(check_unsupported_dml("DELETE FROM t").is_some());
1230        assert!(check_unsupported_dml("UPDATE t SET x = 1").is_some());
1231        assert!(check_unsupported_dml("SELECT * FROM t").is_none());
1232    }
1233
1234    #[tokio::test]
1235    async fn test_show_boxes_without_qualifier() {
1236        let (storage, registry, _dir) = make_test_env().await;
1237        setup_frame(&storage, &registry).await;
1238
1239        let mut ctx = ApiaryQueryContext::new(storage, registry);
1240        ctx.sql("USE HIVE test_hive").await.unwrap();
1241
1242        let results = ctx.sql("SHOW BOXES").await.unwrap();
1243        assert_eq!(results.len(), 1);
1244        assert!(results[0].num_rows() >= 1);
1245        assert_eq!(results[0].schema().field(0).name(), "box");
1246    }
1247
1248    #[tokio::test]
1249    async fn test_show_frames_without_qualifier() {
1250        let (storage, registry, _dir) = make_test_env().await;
1251        setup_frame(&storage, &registry).await;
1252
1253        let mut ctx = ApiaryQueryContext::new(storage, registry);
1254        ctx.sql("USE HIVE test_hive").await.unwrap();
1255        ctx.sql("USE BOX test_box").await.unwrap();
1256
1257        let results = ctx.sql("SHOW FRAMES").await.unwrap();
1258        assert_eq!(results.len(), 1);
1259        assert!(results[0].num_rows() >= 1);
1260        assert_eq!(results[0].schema().field(0).name(), "frame");
1261    }
1262}