Skip to main content

apiary_query/
lib.rs

1//! DataFusion-based SQL query engine for Apiary.
2//!
3//! [`ApiaryQueryContext`] wraps a DataFusion `SessionContext` and resolves
4//! Apiary frame references (hive.box.frame) to in-memory tables built from
5//! the frame's active Parquet cells.  Custom SQL commands (USE, SHOW,
6//! DESCRIBE) are intercepted before they reach DataFusion.
7
8pub mod distributed;
9pub mod timing;
10
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use arrow::array::StringArray;
15use arrow::datatypes::{DataType, Field, Schema};
16use arrow::record_batch::RecordBatch;
17use datafusion::prelude::*;
18use tracing::{info, warn};
19
20use apiary_core::error::ApiaryError;
21use apiary_core::registry_manager::RegistryManager;
22use apiary_core::storage::StorageBackend;
23use apiary_core::types::NodeId;
24use apiary_core::Result;
25use apiary_storage::cell_reader::CellReader;
26use apiary_storage::ledger::Ledger;
27
28/// The Apiary query context — wraps DataFusion with Apiary namespace resolution.
29pub struct ApiaryQueryContext {
30    storage: Arc<dyn StorageBackend>,
31    registry: Arc<RegistryManager>,
32    current_hive: Option<String>,
33    current_box: Option<String>,
34    #[allow(dead_code)] // Will be used for distributed execution
35    node_id: NodeId,
36}
37
38impl ApiaryQueryContext {
39    /// Create a new query context.
40    pub fn new(storage: Arc<dyn StorageBackend>, registry: Arc<RegistryManager>) -> Self {
41        Self::with_node_id(storage, registry, NodeId::from("local"))
42    }
43
44    /// Create a new query context with a specific node ID.
45    pub fn with_node_id(
46        storage: Arc<dyn StorageBackend>,
47        registry: Arc<RegistryManager>,
48        node_id: NodeId,
49    ) -> Self {
50        Self {
51            storage,
52            registry,
53            current_hive: None,
54            current_box: None,
55            node_id,
56        }
57    }
58
59    /// Execute a SQL query and return results as RecordBatches.
60    pub async fn sql(&mut self, query: &str) -> Result<Vec<RecordBatch>> {
61        let trimmed = query.trim();
62
63        // Detect and block unsupported DML
64        if let Some(err) = check_unsupported_dml(trimmed) {
65            return Err(err);
66        }
67
68        // Handle custom commands
69        if let Some(result) = self.handle_custom_command(trimmed).await? {
70            return Ok(result);
71        }
72
73        // Standard SQL: resolve frame references, register tables, execute
74        self.execute_standard_sql(trimmed).await
75    }
76
77    /// Handle custom SQL commands (USE, SHOW, DESCRIBE).
78    async fn handle_custom_command(&mut self, sql: &str) -> Result<Option<Vec<RecordBatch>>> {
79        let upper = sql.to_uppercase();
80        let upper = upper.trim_end_matches(';').trim();
81
82        // USE HIVE <name>
83        if let Some(name) = upper.strip_prefix("USE HIVE ") {
84            let name = name.trim().to_lowercase();
85            // Verify hive exists
86            let hives = self.registry.list_hives().await?;
87            if !hives.iter().any(|h| h.to_lowercase() == name) {
88                return Err(ApiaryError::EntityNotFound {
89                    entity_type: "Hive".into(),
90                    name: name.clone(),
91                });
92            }
93            self.current_hive = Some(name.clone());
94            let batch = single_message_batch(&format!("Current hive set to '{name}'"));
95            return Ok(Some(vec![batch]));
96        }
97
98        // USE BOX <name>
99        if let Some(name) = upper.strip_prefix("USE BOX ") {
100            let name = name.trim().to_lowercase();
101            let hive = self
102                .current_hive
103                .as_ref()
104                .ok_or_else(|| ApiaryError::Config {
105                    message: "No hive selected. Run USE HIVE <name> first.".into(),
106                })?;
107            // Verify box exists
108            let boxes = self.registry.list_boxes(hive).await?;
109            if !boxes.iter().any(|b| b.to_lowercase() == name) {
110                return Err(ApiaryError::EntityNotFound {
111                    entity_type: "Box".into(),
112                    name: name.clone(),
113                });
114            }
115            self.current_box = Some(name.clone());
116            let batch = single_message_batch(&format!("Current box set to '{name}'"));
117            return Ok(Some(vec![batch]));
118        }
119
120        // SHOW HIVES
121        if upper == "SHOW HIVES" {
122            let hives = self.registry.list_hives().await?;
123            let batch = string_list_batch("hive", &hives);
124            return Ok(Some(vec![batch]));
125        }
126
127        // SHOW BOXES IN <hive>
128        if let Some(rest) = upper.strip_prefix("SHOW BOXES IN ") {
129            let hive = rest.trim().to_lowercase();
130            let boxes = self.registry.list_boxes(&hive).await?;
131            let batch = string_list_batch("box", &boxes);
132            return Ok(Some(vec![batch]));
133        }
134
135        // SHOW BOXES (using current hive context)
136        if upper == "SHOW BOXES" {
137            let hive = self
138                .current_hive
139                .as_ref()
140                .ok_or_else(|| ApiaryError::Config {
141                    message:
142                        "No hive selected. Run USE HIVE <name> first, or use SHOW BOXES IN <hive>."
143                            .into(),
144                })?;
145            let boxes = self.registry.list_boxes(hive).await?;
146            let batch = string_list_batch("box", &boxes);
147            return Ok(Some(vec![batch]));
148        }
149
150        // SHOW FRAMES IN <hive>.<box>
151        if let Some(rest) = upper.strip_prefix("SHOW FRAMES IN ") {
152            let parts: Vec<&str> = rest.trim().split('.').collect();
153            if parts.len() != 2 {
154                return Err(ApiaryError::Config {
155                    message: "SHOW FRAMES IN requires hive.box format".into(),
156                });
157            }
158            let hive = parts[0].trim().to_lowercase();
159            let box_name = parts[1].trim().to_lowercase();
160            let frames = self.registry.list_frames(&hive, &box_name).await?;
161            let batch = string_list_batch("frame", &frames);
162            return Ok(Some(vec![batch]));
163        }
164
165        // SHOW FRAMES (using current hive and box context)
166        if upper == "SHOW FRAMES" {
167            let hive = self.current_hive.as_ref().ok_or_else(|| ApiaryError::Config {
168                message: "No hive selected. Run USE HIVE <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
169            })?;
170            let box_name = self.current_box.as_ref().ok_or_else(|| ApiaryError::Config {
171                message: "No box selected. Run USE BOX <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
172            })?;
173            let frames = self.registry.list_frames(hive, box_name).await?;
174            let batch = string_list_batch("frame", &frames);
175            return Ok(Some(vec![batch]));
176        }
177
178        // DESCRIBE <hive>.<box>.<frame>
179        if let Some(rest) = upper.strip_prefix("DESCRIBE ") {
180            let raw_rest = sql.trim_end_matches(';').trim();
181            let raw_rest = &raw_rest[raw_rest.len() - rest.len()..];
182            let parts: Vec<&str> = raw_rest.trim().split('.').collect();
183            if parts.len() != 3 {
184                return Err(ApiaryError::Config {
185                    message: "DESCRIBE requires hive.box.frame format".into(),
186                });
187            }
188            let hive = parts[0].trim();
189            let box_name = parts[1].trim();
190            let frame_name = parts[2].trim();
191            return Ok(Some(vec![
192                self.describe_frame(hive, box_name, frame_name).await?,
193            ]));
194        }
195
196        Ok(None)
197    }
198
199    /// Produce a DESCRIBE result for a frame.
200    async fn describe_frame(
201        &self,
202        hive: &str,
203        box_name: &str,
204        frame_name: &str,
205    ) -> Result<RecordBatch> {
206        let frame = self.registry.get_frame(hive, box_name, frame_name).await?;
207        let frame_path = format!("{hive}/{box_name}/{frame_name}");
208
209        // Get cell count and total size from ledger
210        let (cell_count, total_rows, total_bytes) =
211            match Ledger::open(Arc::clone(&self.storage), &frame_path).await {
212                Ok(ledger) => {
213                    let cells = ledger.active_cells();
214                    let rows: u64 = cells.iter().map(|c| c.rows).sum();
215                    let bytes: u64 = cells.iter().map(|c| c.bytes).sum();
216                    (cells.len() as u64, rows, bytes)
217                }
218                Err(_) => (0, 0, 0),
219            };
220
221        let schema_json = serde_json::to_string(&frame.schema).unwrap_or_else(|_| "{}".into());
222
223        let schema = Arc::new(Schema::new(vec![
224            Field::new("property", DataType::Utf8, false),
225            Field::new("value", DataType::Utf8, false),
226        ]));
227
228        let partition_str = if frame.partition_by.is_empty() {
229            "(none)".to_string()
230        } else {
231            frame.partition_by.join(", ")
232        };
233
234        let properties = vec![
235            "schema",
236            "partition_by",
237            "cells",
238            "total_rows",
239            "total_bytes",
240        ];
241        let values = vec![
242            schema_json,
243            partition_str,
244            cell_count.to_string(),
245            total_rows.to_string(),
246            total_bytes.to_string(),
247        ];
248
249        RecordBatch::try_new(
250            schema,
251            vec![
252                Arc::new(StringArray::from(
253                    properties
254                        .into_iter()
255                        .map(|s| s.to_string())
256                        .collect::<Vec<_>>(),
257                )),
258                Arc::new(StringArray::from(values)),
259            ],
260        )
261        .map_err(|e| ApiaryError::Internal {
262            message: format!("Failed to create DESCRIBE result: {e}"),
263        })
264    }
265
266    /// Execute standard SQL by resolving frame references and delegating to DataFusion.
267    async fn execute_standard_sql(&self, sql: &str) -> Result<Vec<RecordBatch>> {
268        let mut timings = timing::QueryTimings::begin_from_sql(sql);
269
270        // --- query_parse phase ---
271        let parse_start = timings.as_ref().map(|t| t.start_phase());
272
273        // Extract table references from SQL
274        let table_refs = extract_table_references(sql);
275
276        if table_refs.is_empty() {
277            return Err(ApiaryError::Config {
278                message: "No table references found in query".into(),
279            });
280        }
281
282        // Extract simple WHERE predicates for pruning
283        let predicates = extract_where_predicates(sql);
284
285        if let (Some(t), Some(s)) = (timings.as_mut(), parse_start) {
286            t.end_phase("parse", s);
287        }
288
289        // --- query_plan phase ---
290        let plan_start = timings.as_ref().map(|t| t.start_phase());
291
292        // Create a fresh session for this query (avoids stale table registrations)
293        let session = SessionContext::new();
294
295        if let (Some(t), Some(s)) = (timings.as_mut(), plan_start) {
296            t.end_phase("plan", s);
297        }
298
299        // Resolve and register each table
300        let mut file_discovery_total = std::time::Duration::ZERO;
301        let mut metadata_read_total = std::time::Duration::ZERO;
302        let mut data_read_total = std::time::Duration::ZERO;
303
304        for table_ref in &table_refs {
305            let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
306
307            // --- file_discovery phase (ledger open + cell listing) ---
308            let fd_start = timings.as_ref().map(|t| t.start_phase());
309
310            let frame_path = format!("{hive}/{box_name}/{frame_name}");
311            let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
312
313            if let Some(s) = fd_start {
314                file_discovery_total += s.elapsed();
315            }
316
317            // --- metadata_read phase (pruning with partition/stat filters) ---
318            let mr_start = timings.as_ref().map(|t| t.start_phase());
319
320            // Build partition and stat filters from predicates
321            let partition_by: Vec<String> = ledger.partition_by().to_vec();
322            let (partition_filters, stat_filters) = build_filters(&predicates, &partition_by);
323
324            let cells = if partition_filters.is_empty() && stat_filters.is_empty() {
325                ledger.active_cells().iter().collect::<Vec<_>>()
326            } else {
327                ledger.prune_cells(&partition_filters, &stat_filters)
328            };
329
330            if let Some(s) = mr_start {
331                metadata_read_total += s.elapsed();
332            }
333
334            info!(
335                frame = %frame_path,
336                total_cells = ledger.active_cells().len(),
337                surviving_cells = cells.len(),
338                "Cell pruning complete"
339            );
340
341            if cells.is_empty() {
342                // Register an empty table with the correct schema
343                let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
344                let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
345                let mem_table = datafusion::datasource::MemTable::try_new(
346                    empty_batch.schema(),
347                    vec![vec![empty_batch]],
348                )
349                .map_err(|e| ApiaryError::Internal {
350                    message: format!("Failed to create empty MemTable: {e}"),
351                })?;
352                session
353                    .register_table(&register_name, Arc::new(mem_table))
354                    .map_err(|e| ApiaryError::Internal {
355                        message: format!("Failed to register table: {e}"),
356                    })?;
357                continue;
358            }
359
360            // --- data_read phase (reading Parquet cells from storage) ---
361            let dr_start = timings.as_ref().map(|t| t.start_phase());
362
363            let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
364            let merged = reader.read_cells_merged(&cells, None).await?;
365
366            if let Some(s) = dr_start {
367                data_read_total += s.elapsed();
368            }
369
370            let batches = match merged {
371                Some(batch) => vec![vec![batch]],
372                None => {
373                    let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
374                    vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
375                }
376            };
377
378            let schema = batches[0][0].schema();
379            let mem_table =
380                datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
381                    ApiaryError::Internal {
382                        message: format!("Failed to create MemTable: {e}"),
383                    }
384                })?;
385
386            session
387                .register_table(&register_name, Arc::new(mem_table))
388                .map_err(|e| ApiaryError::Internal {
389                    message: format!("Failed to register table '{register_name}': {e}"),
390                })?;
391        }
392
393        // Record accumulated I/O phase timings
394        if let Some(t) = timings.as_mut() {
395            t.add_accumulated_phase("file_discovery", file_discovery_total);
396            t.add_accumulated_phase("metadata_read", metadata_read_total);
397            t.add_accumulated_phase("data_read", data_read_total);
398        }
399
400        // Rewrite the SQL to use the registered table names
401        let rewritten =
402            rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
403
404        // --- query_execute phase (DataFusion planning + execution) ---
405        let exec_start = timings.as_ref().map(|t| t.start_phase());
406
407        let df = session
408            .sql(&rewritten)
409            .await
410            .map_err(|e| ApiaryError::Internal {
411                message: format!("DataFusion query error: {e}"),
412            })?;
413
414        let results = df.collect().await.map_err(|e| ApiaryError::Internal {
415            message: format!("DataFusion execution error: {e}"),
416        })?;
417
418        if let (Some(t), Some(s)) = (timings.as_mut(), exec_start) {
419            t.end_phase("execute", s);
420        }
421
422        if let Some(t) = timings {
423            t.finish();
424        }
425
426        Ok(results)
427    }
428
429    /// Resolve a table reference to (hive, box, frame, register_name).
430    fn resolve_table_ref(&self, table_ref: &str) -> Result<(String, String, String, String)> {
431        let parts: Vec<&str> = table_ref.split('.').collect();
432
433        match parts.len() {
434            3 => {
435                let hive = parts[0].to_string();
436                let box_name = parts[1].to_string();
437                let frame_name = parts[2].to_string();
438                // Register with just the frame name to simplify SQL rewriting
439                let register_name = frame_name.clone();
440                Ok((hive, box_name, frame_name, register_name))
441            }
442            2 => {
443                let hive = self.current_hive.as_ref().ok_or_else(|| {
444                    ApiaryError::Resolution {
445                        path: table_ref.into(),
446                        reason: "No hive selected. Use 3-part name (hive.box.frame) or run USE HIVE first.".into(),
447                    }
448                })?;
449                let box_name = parts[0].to_string();
450                let frame_name = parts[1].to_string();
451                let register_name = frame_name.clone();
452                Ok((hive.clone(), box_name, frame_name, register_name))
453            }
454            1 => {
455                let hive = self
456                    .current_hive
457                    .as_ref()
458                    .ok_or_else(|| ApiaryError::Resolution {
459                        path: table_ref.into(),
460                        reason: "No hive selected. Use 3-part name or run USE HIVE first.".into(),
461                    })?;
462                let box_name =
463                    self.current_box
464                        .as_ref()
465                        .ok_or_else(|| ApiaryError::Resolution {
466                            path: table_ref.into(),
467                            reason: "No box selected. Use 3-part name or run USE BOX first.".into(),
468                        })?;
469                let frame_name = parts[0].to_string();
470                let register_name = frame_name.clone();
471                Ok((hive.clone(), box_name.clone(), frame_name, register_name))
472            }
473            _ => Err(ApiaryError::Resolution {
474                path: table_ref.into(),
475                reason: "Invalid table reference. Use hive.box.frame format.".into(),
476            }),
477        }
478    }
479
480    /// Execute a query using distributed execution (stub for Step 7).
481    ///
482    /// For v1, this is a simplified implementation that:
483    /// - Creates tasks from cell assignments
484    /// - Generates SQL fragments (simple pass-through)
485    /// - Writes the query manifest
486    /// - Executes local tasks
487    /// - Polls for partial results from other nodes
488    /// - Merges results (simple concatenation)
489    /// - Cleans up query files
490    pub async fn execute_distributed(
491        &self,
492        sql: &str,
493        assignments: HashMap<NodeId, Vec<distributed::CellInfo>>,
494    ) -> Result<Vec<RecordBatch>> {
495        use distributed::*;
496
497        let mut timings = timing::QueryTimings::begin_from_sql(&format!("distributed:{sql}"));
498
499        // --- coordination_overhead phase ---
500        let coord_start = timings.as_ref().map(|t| t.start_phase());
501
502        // 1. Create tasks from assignments
503        let mut tasks = Vec::new();
504        for (node_id, cells) in &assignments {
505            let task_id = format!("{}_{}", node_id.as_str(), uuid::Uuid::new_v4());
506            let cell_keys: Vec<String> = cells.iter().map(|c| c.storage_key.clone()).collect();
507
508            tasks.push(PlannedTask {
509                task_id,
510                node_id: node_id.clone(),
511                cells: cell_keys,
512                sql_fragment: sql.to_string(), // For v1, use original SQL
513            });
514        }
515
516        // 2. Create and write manifest
517        let manifest = create_manifest(sql, tasks.clone(), None, 60);
518        write_manifest(&self.storage, &manifest).await?;
519
520        info!(
521            query_id = %manifest.query_id,
522            tasks = tasks.len(),
523            "Distributed query manifest written"
524        );
525
526        if let (Some(t), Some(s)) = (timings.as_mut(), coord_start) {
527            t.end_phase("coordination", s);
528        }
529
530        // --- query_execute phase (local tasks) ---
531        let exec_start = timings.as_ref().map(|t| t.start_phase());
532
533        // 3. Execute local tasks
534        let local_tasks: Vec<_> = tasks.iter().filter(|t| t.node_id == self.node_id).collect();
535
536        let mut local_results = Vec::new();
537        for task in local_tasks {
538            match self.execute_task(&task.sql_fragment, &task.cells).await {
539                Ok(batches) => {
540                    if !batches.is_empty() {
541                        local_results.extend(batches);
542                    }
543                }
544                Err(e) => {
545                    warn!(task_id = %task.task_id, error = %e, "Local task failed");
546                }
547            }
548        }
549
550        // Write local partial result
551        if !local_results.is_empty() {
552            write_partial_result(
553                &self.storage,
554                &manifest.query_id,
555                &self.node_id,
556                &local_results,
557            )
558            .await?;
559        }
560
561        if let (Some(t), Some(s)) = (timings.as_mut(), exec_start) {
562            t.end_phase("execute", s);
563        }
564
565        // --- result_collect phase ---
566        let collect_start = timings.as_ref().map(|t| t.start_phase());
567
568        // 4. Poll for partial results from other nodes
569        let remote_nodes: Vec<_> = tasks
570            .iter()
571            .filter(|t| t.node_id != self.node_id)
572            .map(|t| t.node_id.clone())
573            .collect();
574
575        let timeout = std::time::Duration::from_secs(manifest.timeout_secs);
576        let start = std::time::Instant::now();
577        let mut collected_results = local_results;
578
579        for remote_node in &remote_nodes {
580            let deadline = timeout.saturating_sub(start.elapsed());
581            if start.elapsed() >= timeout {
582                warn!(query_id = %manifest.query_id, "Query timeout reached");
583                break;
584            }
585
586            // Poll for partial result
587            let poll_interval = std::time::Duration::from_millis(500);
588            let mut attempts = 0;
589            let max_attempts = (deadline.as_millis() / poll_interval.as_millis()) as usize;
590
591            while attempts < max_attempts {
592                match read_partial_result(&self.storage, &manifest.query_id, remote_node).await {
593                    Ok(batches) => {
594                        info!(
595                            query_id = %manifest.query_id,
596                            node_id = %remote_node,
597                            "Partial result received"
598                        );
599                        collected_results.extend(batches);
600                        break;
601                    }
602                    Err(_) => {
603                        tokio::time::sleep(poll_interval).await;
604                        attempts += 1;
605                    }
606                }
607            }
608        }
609
610        // 5. Cleanup
611        if let Err(e) = cleanup_query(&self.storage, &manifest.query_id).await {
612            warn!(query_id = %manifest.query_id, error = %e, "Failed to cleanup query");
613        }
614
615        if let (Some(t), Some(s)) = (timings.as_mut(), collect_start) {
616            t.end_phase("result_collect", s);
617        }
618
619        if let Some(t) = timings {
620            t.finish();
621        }
622
623        Ok(collected_results)
624    }
625
626    /// Execute a task on a specific set of cells.
627    ///
628    /// # Arguments
629    /// * `sql` - The SQL query to execute
630    /// * `cell_keys` - Storage keys of cells to scan. Only these cells are registered
631    ///   as the table for the query, enabling true work partitioning across nodes.
632    pub async fn execute_task(&self, sql: &str, cell_keys: &[String]) -> Result<Vec<RecordBatch>> {
633        if cell_keys.is_empty() {
634            // No cells assigned — fall back to standard execution
635            return self.execute_standard_sql(sql).await;
636        }
637
638        // Extract table references from SQL
639        let table_refs = extract_table_references(sql);
640
641        if table_refs.is_empty() {
642            return Err(ApiaryError::Config {
643                message: "No table references found in query".into(),
644            });
645        }
646
647        // Create a fresh session for this query
648        let session = SessionContext::new();
649
650        // Build a set for O(1) cell key lookups
651        let cell_key_set: std::collections::HashSet<&String> = cell_keys.iter().collect();
652
653        // Resolve and register each table, filtering to only the assigned cells
654        for table_ref in &table_refs {
655            let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
656
657            let frame_path = format!("{hive}/{box_name}/{frame_name}");
658            let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
659
660            // Filter active cells to only those in our assigned cell_keys
661            let cells: Vec<_> = ledger
662                .active_cells()
663                .iter()
664                .filter(|cell| {
665                    let cell_storage_key = format!("{}/{}", frame_path, cell.path);
666                    cell_key_set.contains(&cell_storage_key)
667                })
668                .collect();
669
670            info!(
671                frame = %frame_path,
672                total_cells = ledger.active_cells().len(),
673                assigned_cells = cells.len(),
674                "Cell filtering for distributed task"
675            );
676
677            if cells.is_empty() {
678                let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
679                let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
680                let mem_table = datafusion::datasource::MemTable::try_new(
681                    empty_batch.schema(),
682                    vec![vec![empty_batch]],
683                )
684                .map_err(|e| ApiaryError::Internal {
685                    message: format!("Failed to create empty MemTable: {e}"),
686                })?;
687                session
688                    .register_table(&register_name, Arc::new(mem_table))
689                    .map_err(|e| ApiaryError::Internal {
690                        message: format!("Failed to register table: {e}"),
691                    })?;
692                continue;
693            }
694
695            // Read only the assigned cells
696            let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
697            let merged = reader.read_cells_merged(&cells, None).await?;
698
699            let batches = match merged {
700                Some(batch) => vec![vec![batch]],
701                None => {
702                    let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
703                    vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
704                }
705            };
706
707            let schema = batches[0][0].schema();
708            let mem_table =
709                datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
710                    ApiaryError::Internal {
711                        message: format!("Failed to create MemTable: {e}"),
712                    }
713                })?;
714
715            session
716                .register_table(&register_name, Arc::new(mem_table))
717                .map_err(|e| ApiaryError::Internal {
718                    message: format!("Failed to register table '{register_name}': {e}"),
719                })?;
720        }
721
722        // Rewrite the SQL to use the registered table names
723        let rewritten =
724            rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
725
726        // Execute via DataFusion
727        let df = session
728            .sql(&rewritten)
729            .await
730            .map_err(|e| ApiaryError::Internal {
731                message: format!("DataFusion query error: {e}"),
732            })?;
733
734        df.collect().await.map_err(|e| ApiaryError::Internal {
735            message: format!("DataFusion execution error: {e}"),
736        })
737    }
738}
739
740// ---------------------------------------------------------------------------
741// Helper functions
742// ---------------------------------------------------------------------------
743
744/// Check for unsupported DML and return an error if detected.
745fn check_unsupported_dml(sql: &str) -> Option<ApiaryError> {
746    let upper = sql.to_uppercase();
747    let first_word = upper.split_whitespace().next().unwrap_or("");
748
749    match first_word {
750        "DELETE" => Some(ApiaryError::Unsupported {
751            message: "DELETE is not supported. Apiary uses append-only writes. Use overwrite_frame() to replace all data in a frame.".into(),
752        }),
753        "UPDATE" => Some(ApiaryError::Unsupported {
754            message: "UPDATE is not supported. Apiary uses append-only writes. Rewrite the frame with corrected data using overwrite_frame().".into(),
755        }),
756        "INSERT" => Some(ApiaryError::Unsupported {
757            message: "INSERT is not supported via SQL. Use write_to_frame() to add data.".into(),
758        }),
759        "DROP" => Some(ApiaryError::Unsupported {
760            message: "DROP is not supported via SQL. Use the registry API for DDL operations.".into(),
761        }),
762        "CREATE" => Some(ApiaryError::Unsupported {
763            message: "CREATE is not supported via SQL. Use create_frame() for DDL operations.".into(),
764        }),
765        "ALTER" => Some(ApiaryError::Unsupported {
766            message: "ALTER is not supported via SQL. Use the registry API for DDL operations.".into(),
767        }),
768        _ => None,
769    }
770}
771
772/// Extract table references from SQL.
773///
774/// Finds patterns like `FROM table` and `JOIN table`, where table can be
775/// `hive.box.frame`, `box.frame`, or `frame`.
776fn extract_table_references(sql: &str) -> Vec<String> {
777    let mut refs = Vec::new();
778    let tokens: Vec<&str> = sql.split_whitespace().collect();
779
780    for i in 0..tokens.len() {
781        let upper = tokens[i].to_uppercase();
782        if (upper == "FROM" || upper == "JOIN") && i + 1 < tokens.len() {
783            let table_name = tokens[i + 1]
784                .trim_end_matches(',')
785                .trim_end_matches(')')
786                .trim_end_matches(';');
787            // Skip subqueries
788            if table_name.starts_with('(') || table_name.is_empty() {
789                continue;
790            }
791            // Skip SQL keywords that follow FROM (e.g., FROM (SELECT ...))
792            let table_upper = table_name.to_uppercase();
793            if matches!(
794                table_upper.as_str(),
795                "SELECT" | "WHERE" | "GROUP" | "ORDER" | "LIMIT" | "HAVING"
796            ) {
797                continue;
798            }
799            if !refs.contains(&table_name.to_string()) {
800                refs.push(table_name.to_string());
801            }
802        }
803    }
804
805    refs
806}
807
808/// A simple WHERE predicate extracted from SQL.
809#[derive(Debug, Clone)]
810struct Predicate {
811    column: String,
812    op: PredicateOp,
813    value: String,
814}
815
816#[derive(Debug, Clone)]
817enum PredicateOp {
818    Eq,
819    Gt,
820    Lt,
821    Gte,
822    Lte,
823}
824
825/// Extract simple WHERE predicates from SQL for pruning.
826///
827/// Handles patterns like:
828/// - `column = 'value'` or `column = value`
829/// - `column > N`
830/// - `column < N`
831/// - `column >= N`
832/// - `column <= N`
833fn extract_where_predicates(sql: &str) -> Vec<Predicate> {
834    let mut predicates = Vec::new();
835
836    // Find WHERE clause
837    let upper = sql.to_uppercase();
838    let where_pos = match upper.find(" WHERE ") {
839        Some(pos) => pos + 7,
840        None => return predicates,
841    };
842
843    let where_clause = &sql[where_pos..];
844    // Truncate at GROUP BY, ORDER BY, LIMIT, HAVING
845    let end_keywords = [" GROUP ", " ORDER ", " LIMIT ", " HAVING ", ";"];
846    let end_pos = end_keywords
847        .iter()
848        .filter_map(|kw| where_clause.to_uppercase().find(kw))
849        .min()
850        .unwrap_or(where_clause.len());
851    let where_clause = &where_clause[..end_pos];
852
853    // Split by AND (simple approach — doesn't handle OR or nested parens)
854    let parts: Vec<&str> = split_on_and(where_clause);
855
856    for part in parts {
857        let part = part.trim();
858        if let Some(pred) = parse_predicate(part) {
859            predicates.push(pred);
860        }
861    }
862
863    predicates
864}
865
866/// Split a WHERE clause on AND keywords (case-insensitive).
867fn split_on_and(clause: &str) -> Vec<&str> {
868    let mut parts = Vec::new();
869    let upper = clause.to_uppercase();
870    let mut last = 0;
871
872    for (i, _) in upper.match_indices(" AND ") {
873        parts.push(&clause[last..i]);
874        last = i + 5; // " AND ".len()
875    }
876    parts.push(&clause[last..]);
877    parts
878}
879
880/// Parse a single predicate condition.
881fn parse_predicate(condition: &str) -> Option<Predicate> {
882    let condition = condition.trim();
883
884    // Try >=, <=, >, <, = operators
885    let ops = [
886        (">=", PredicateOp::Gte),
887        ("<=", PredicateOp::Lte),
888        (">", PredicateOp::Gt),
889        ("<", PredicateOp::Lt),
890        ("=", PredicateOp::Eq),
891    ];
892
893    for (op_str, op) in &ops {
894        if let Some(pos) = condition.find(op_str) {
895            let col = condition[..pos].trim();
896            let val = condition[pos + op_str.len()..].trim();
897
898            // Clean the value: strip quotes
899            let val = val.trim_matches('\'').trim_matches('"').to_string();
900
901            if !col.is_empty() && !val.is_empty() {
902                return Some(Predicate {
903                    column: col.to_string(),
904                    op: op.clone(),
905                    value: val,
906                });
907            }
908        }
909    }
910
911    None
912}
913
914/// Stat filter: maps column name → (min bound, max bound).
915type StatFilters = HashMap<String, (Option<serde_json::Value>, Option<serde_json::Value>)>;
916
917/// Build partition and stat filters from predicates.
918fn build_filters(
919    predicates: &[Predicate],
920    partition_columns: &[String],
921) -> (HashMap<String, String>, StatFilters) {
922    let mut partition_filters = HashMap::new();
923    let mut stat_filters: StatFilters = HashMap::new();
924
925    for pred in predicates {
926        if partition_columns.contains(&pred.column) {
927            // Partition filter: only support equality
928            if matches!(pred.op, PredicateOp::Eq) {
929                partition_filters.insert(pred.column.clone(), pred.value.clone());
930            }
931        }
932
933        // Stat filter: convert numeric predicates
934        if let Ok(num) = pred.value.parse::<f64>() {
935            let json_val = serde_json::json!(num);
936            let entry = stat_filters
937                .entry(pred.column.clone())
938                .or_insert((None, None));
939            match pred.op {
940                PredicateOp::Gt | PredicateOp::Gte => {
941                    // min_filter: skip cells where max < this value
942                    entry.0 = Some(json_val);
943                }
944                PredicateOp::Lt | PredicateOp::Lte => {
945                    // max_filter: skip cells where min > this value
946                    entry.1 = Some(json_val);
947                }
948                PredicateOp::Eq => {
949                    // Both bounds
950                    entry.0 = Some(json_val.clone());
951                    entry.1 = Some(json_val);
952                }
953            }
954        }
955    }
956
957    (partition_filters, stat_filters)
958}
959
960/// Convert a FrameSchema to an Arrow Schema.
961fn frame_schema_to_arrow(schema: &apiary_core::FrameSchema) -> Result<Schema> {
962    let fields: Vec<Field> = schema
963        .fields
964        .iter()
965        .map(|f| {
966            let dt = match f.data_type.as_str() {
967                "int8" => DataType::Int8,
968                "int16" => DataType::Int16,
969                "int32" => DataType::Int32,
970                "int64" => DataType::Int64,
971                "uint8" => DataType::UInt8,
972                "uint16" => DataType::UInt16,
973                "uint32" => DataType::UInt32,
974                "uint64" => DataType::UInt64,
975                "float32" | "float" => DataType::Float32,
976                "float64" | "double" => DataType::Float64,
977                "string" | "utf8" => DataType::Utf8,
978                "boolean" | "bool" => DataType::Boolean,
979                "datetime" | "timestamp" => {
980                    DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
981                }
982                _ => DataType::Utf8,
983            };
984            Field::new(&f.name, dt, f.nullable)
985        })
986        .collect();
987
988    Ok(Schema::new(fields))
989}
990
991/// Rewrite SQL to replace 3-part or 2-part table references with the registered names.
992fn rewrite_sql_table_refs(
993    sql: &str,
994    table_refs: &[String],
995    _current_hive: &Option<String>,
996    _current_box: &Option<String>,
997) -> String {
998    let mut result = sql.to_string();
999
1000    for table_ref in table_refs {
1001        let parts: Vec<&str> = table_ref.split('.').collect();
1002        let register_name = parts.last().unwrap_or(&table_ref.as_str()).to_string();
1003        if parts.len() > 1 {
1004            // Replace full reference with just the frame name
1005            result = result.replace(table_ref, &register_name);
1006        }
1007    }
1008
1009    result
1010}
1011
1012/// Create a single-row batch with a message.
1013fn single_message_batch(message: &str) -> RecordBatch {
1014    let schema = Arc::new(Schema::new(vec![Field::new(
1015        "message",
1016        DataType::Utf8,
1017        false,
1018    )]));
1019    RecordBatch::try_new(
1020        schema,
1021        vec![Arc::new(StringArray::from(vec![message.to_string()]))],
1022    )
1023    .unwrap()
1024}
1025
1026/// Create a batch from a list of strings.
1027fn string_list_batch(column_name: &str, values: &[String]) -> RecordBatch {
1028    let schema = Arc::new(Schema::new(vec![Field::new(
1029        column_name,
1030        DataType::Utf8,
1031        false,
1032    )]));
1033    RecordBatch::try_new(
1034        schema,
1035        vec![Arc::new(StringArray::from(
1036            values.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
1037        ))],
1038    )
1039    .unwrap()
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045    use apiary_core::{CellSizingPolicy, FieldDef, FrameSchema, NodeId};
1046    use apiary_storage::cell_writer::CellWriter;
1047    use apiary_storage::ledger::Ledger;
1048    use apiary_storage::local::LocalBackend;
1049    use arrow::array::{Float64Array, Int64Array};
1050
1051    async fn make_test_env() -> (
1052        Arc<dyn StorageBackend>,
1053        Arc<RegistryManager>,
1054        tempfile::TempDir,
1055    ) {
1056        let dir = tempfile::tempdir().unwrap();
1057        let backend = LocalBackend::new(dir.path().to_path_buf()).await.unwrap();
1058        let storage: Arc<dyn StorageBackend> = Arc::new(backend);
1059        let registry = Arc::new(RegistryManager::new(Arc::clone(&storage)));
1060        let _ = registry.load_or_create().await.unwrap();
1061        (storage, registry, dir)
1062    }
1063
1064    fn test_schema() -> serde_json::Value {
1065        serde_json::json!({
1066            "region": "string",
1067            "temp": "float64",
1068            "humidity": "int64"
1069        })
1070    }
1071
1072    async fn setup_frame(storage: &Arc<dyn StorageBackend>, registry: &Arc<RegistryManager>) {
1073        registry.create_hive("test_hive").await.unwrap();
1074        registry.create_box("test_hive", "test_box").await.unwrap();
1075        registry
1076            .create_frame(
1077                "test_hive",
1078                "test_box",
1079                "sensors",
1080                test_schema(),
1081                vec!["region".into()],
1082            )
1083            .await
1084            .unwrap();
1085
1086        // Create ledger and write data
1087        let frame_schema = FrameSchema {
1088            fields: vec![
1089                FieldDef {
1090                    name: "region".into(),
1091                    data_type: "string".into(),
1092                    nullable: false,
1093                },
1094                FieldDef {
1095                    name: "temp".into(),
1096                    data_type: "float64".into(),
1097                    nullable: true,
1098                },
1099                FieldDef {
1100                    name: "humidity".into(),
1101                    data_type: "int64".into(),
1102                    nullable: true,
1103                },
1104            ],
1105        };
1106
1107        let node_id = NodeId::new("test_node");
1108        let frame_path = "test_hive/test_box/sensors";
1109
1110        let mut ledger = Ledger::create(
1111            Arc::clone(storage),
1112            frame_path,
1113            frame_schema.clone(),
1114            vec!["region".into()],
1115            &node_id,
1116        )
1117        .await
1118        .unwrap();
1119
1120        let sizing = CellSizingPolicy::new(256 * 1024 * 1024, 512 * 1024 * 1024, 16 * 1024 * 1024);
1121
1122        let writer = CellWriter::new(
1123            Arc::clone(storage),
1124            frame_path.into(),
1125            frame_schema,
1126            vec!["region".into()],
1127            sizing,
1128        );
1129
1130        // Write test data
1131        let schema = Arc::new(Schema::new(vec![
1132            Field::new("region", DataType::Utf8, false),
1133            Field::new("temp", DataType::Float64, true),
1134            Field::new("humidity", DataType::Int64, true),
1135        ]));
1136
1137        let batch = RecordBatch::try_new(
1138            schema,
1139            vec![
1140                Arc::new(StringArray::from(vec!["north", "north", "south", "south"])),
1141                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1142                Arc::new(Int64Array::from(vec![50, 60, 70, 80])),
1143            ],
1144        )
1145        .unwrap();
1146
1147        let cells = writer.write(&batch).await.unwrap();
1148        ledger
1149            .commit(apiary_core::LedgerAction::AddCells { cells }, &node_id)
1150            .await
1151            .unwrap();
1152    }
1153
1154    #[tokio::test]
1155    async fn test_select_all() {
1156        let (storage, registry, _dir) = make_test_env().await;
1157        setup_frame(&storage, &registry).await;
1158
1159        let mut ctx = ApiaryQueryContext::new(storage, registry);
1160        let results = ctx
1161            .sql("SELECT * FROM test_hive.test_box.sensors")
1162            .await
1163            .unwrap();
1164
1165        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1166        assert_eq!(total_rows, 4);
1167    }
1168
1169    #[tokio::test]
1170    async fn test_aggregation() {
1171        let (storage, registry, _dir) = make_test_env().await;
1172        setup_frame(&storage, &registry).await;
1173
1174        let mut ctx = ApiaryQueryContext::new(storage, registry);
1175        let results = ctx
1176            .sql("SELECT region, AVG(temp) as avg_temp FROM test_hive.test_box.sensors GROUP BY region ORDER BY region")
1177            .await
1178            .unwrap();
1179
1180        assert!(!results.is_empty());
1181        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1182        assert_eq!(total_rows, 2); // north, south
1183    }
1184
1185    #[tokio::test]
1186    async fn test_use_hive_and_box() {
1187        let (storage, registry, _dir) = make_test_env().await;
1188        setup_frame(&storage, &registry).await;
1189
1190        let mut ctx = ApiaryQueryContext::new(storage, registry);
1191        ctx.sql("USE HIVE test_hive").await.unwrap();
1192        ctx.sql("USE BOX test_box").await.unwrap();
1193        let results = ctx.sql("SELECT * FROM sensors").await.unwrap();
1194
1195        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1196        assert_eq!(total_rows, 4);
1197    }
1198
1199    #[tokio::test]
1200    async fn test_show_hives() {
1201        let (storage, registry, _dir) = make_test_env().await;
1202        setup_frame(&storage, &registry).await;
1203
1204        let mut ctx = ApiaryQueryContext::new(storage, registry);
1205        let results = ctx.sql("SHOW HIVES").await.unwrap();
1206
1207        assert_eq!(results.len(), 1);
1208        assert_eq!(results[0].num_rows(), 1);
1209    }
1210
1211    #[tokio::test]
1212    async fn test_show_frames() {
1213        let (storage, registry, _dir) = make_test_env().await;
1214        setup_frame(&storage, &registry).await;
1215
1216        let mut ctx = ApiaryQueryContext::new(storage, registry);
1217        let results = ctx.sql("SHOW FRAMES IN test_hive.test_box").await.unwrap();
1218
1219        assert_eq!(results.len(), 1);
1220        assert!(results[0].num_rows() >= 1);
1221    }
1222
1223    #[tokio::test]
1224    async fn test_describe() {
1225        let (storage, registry, _dir) = make_test_env().await;
1226        setup_frame(&storage, &registry).await;
1227
1228        let mut ctx = ApiaryQueryContext::new(storage, registry);
1229        let results = ctx
1230            .sql("DESCRIBE test_hive.test_box.sensors")
1231            .await
1232            .unwrap();
1233
1234        assert_eq!(results.len(), 1);
1235        assert!(results[0].num_rows() >= 3);
1236    }
1237
1238    #[tokio::test]
1239    async fn test_delete_blocked() {
1240        let (storage, registry, _dir) = make_test_env().await;
1241        let mut ctx = ApiaryQueryContext::new(storage, registry);
1242
1243        let result = ctx.sql("DELETE FROM test_hive.test_box.sensors").await;
1244        assert!(result.is_err());
1245        let err = result.unwrap_err().to_string();
1246        assert!(err.contains("not supported"));
1247    }
1248
1249    #[tokio::test]
1250    async fn test_update_blocked() {
1251        let (storage, registry, _dir) = make_test_env().await;
1252        let mut ctx = ApiaryQueryContext::new(storage, registry);
1253
1254        let result = ctx
1255            .sql("UPDATE test_hive.test_box.sensors SET temp = 0")
1256            .await;
1257        assert!(result.is_err());
1258        let err = result.unwrap_err().to_string();
1259        assert!(err.contains("not supported"));
1260    }
1261
1262    #[tokio::test]
1263    async fn test_where_filter() {
1264        let (storage, registry, _dir) = make_test_env().await;
1265        setup_frame(&storage, &registry).await;
1266
1267        let mut ctx = ApiaryQueryContext::new(storage, registry);
1268        let results = ctx
1269            .sql("SELECT * FROM test_hive.test_box.sensors WHERE region = 'north'")
1270            .await
1271            .unwrap();
1272
1273        let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1274        assert_eq!(total_rows, 2);
1275    }
1276
1277    #[tokio::test]
1278    async fn test_projection() {
1279        let (storage, registry, _dir) = make_test_env().await;
1280        setup_frame(&storage, &registry).await;
1281
1282        let mut ctx = ApiaryQueryContext::new(storage, registry);
1283        let results = ctx
1284            .sql("SELECT temp FROM test_hive.test_box.sensors")
1285            .await
1286            .unwrap();
1287
1288        assert!(!results.is_empty());
1289        assert_eq!(results[0].num_columns(), 1);
1290        assert_eq!(results[0].schema().field(0).name(), "temp");
1291    }
1292
1293    #[test]
1294    fn test_extract_table_references() {
1295        let refs = extract_table_references("SELECT * FROM hive.box.frame WHERE x = 1");
1296        assert_eq!(refs, vec!["hive.box.frame"]);
1297
1298        let refs = extract_table_references("SELECT * FROM frame1 JOIN frame2 ON x = y");
1299        assert_eq!(refs, vec!["frame1", "frame2"]);
1300    }
1301
1302    #[test]
1303    fn test_extract_predicates() {
1304        let preds =
1305            extract_where_predicates("SELECT * FROM t WHERE region = 'north' AND temp > 25");
1306        assert_eq!(preds.len(), 2);
1307        assert_eq!(preds[0].column, "region");
1308        assert_eq!(preds[0].value, "north");
1309        assert_eq!(preds[1].column, "temp");
1310        assert_eq!(preds[1].value, "25");
1311    }
1312
1313    #[test]
1314    fn test_check_unsupported_dml() {
1315        assert!(check_unsupported_dml("DELETE FROM t").is_some());
1316        assert!(check_unsupported_dml("UPDATE t SET x = 1").is_some());
1317        assert!(check_unsupported_dml("SELECT * FROM t").is_none());
1318    }
1319
1320    #[tokio::test]
1321    async fn test_show_boxes_without_qualifier() {
1322        let (storage, registry, _dir) = make_test_env().await;
1323        setup_frame(&storage, &registry).await;
1324
1325        let mut ctx = ApiaryQueryContext::new(storage, registry);
1326        ctx.sql("USE HIVE test_hive").await.unwrap();
1327
1328        let results = ctx.sql("SHOW BOXES").await.unwrap();
1329        assert_eq!(results.len(), 1);
1330        assert!(results[0].num_rows() >= 1);
1331        assert_eq!(results[0].schema().field(0).name(), "box");
1332    }
1333
1334    #[tokio::test]
1335    async fn test_show_frames_without_qualifier() {
1336        let (storage, registry, _dir) = make_test_env().await;
1337        setup_frame(&storage, &registry).await;
1338
1339        let mut ctx = ApiaryQueryContext::new(storage, registry);
1340        ctx.sql("USE HIVE test_hive").await.unwrap();
1341        ctx.sql("USE BOX test_box").await.unwrap();
1342
1343        let results = ctx.sql("SHOW FRAMES").await.unwrap();
1344        assert_eq!(results.len(), 1);
1345        assert!(results[0].num_rows() >= 1);
1346        assert_eq!(results[0].schema().field(0).name(), "frame");
1347    }
1348}