1pub mod distributed;
9pub mod timing;
10
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use arrow::array::StringArray;
15use arrow::datatypes::{DataType, Field, Schema};
16use arrow::record_batch::RecordBatch;
17use datafusion::prelude::*;
18use tracing::{info, warn};
19
20use apiary_core::error::ApiaryError;
21use apiary_core::registry_manager::RegistryManager;
22use apiary_core::storage::StorageBackend;
23use apiary_core::types::NodeId;
24use apiary_core::Result;
25use apiary_storage::cell_reader::CellReader;
26use apiary_storage::ledger::Ledger;
27
28pub struct ApiaryQueryContext {
30 storage: Arc<dyn StorageBackend>,
31 registry: Arc<RegistryManager>,
32 current_hive: Option<String>,
33 current_box: Option<String>,
34 #[allow(dead_code)] node_id: NodeId,
36}
37
38impl ApiaryQueryContext {
39 pub fn new(storage: Arc<dyn StorageBackend>, registry: Arc<RegistryManager>) -> Self {
41 Self::with_node_id(storage, registry, NodeId::from("local"))
42 }
43
44 pub fn with_node_id(
46 storage: Arc<dyn StorageBackend>,
47 registry: Arc<RegistryManager>,
48 node_id: NodeId,
49 ) -> Self {
50 Self {
51 storage,
52 registry,
53 current_hive: None,
54 current_box: None,
55 node_id,
56 }
57 }
58
59 pub async fn sql(&mut self, query: &str) -> Result<Vec<RecordBatch>> {
61 let trimmed = query.trim();
62
63 if let Some(err) = check_unsupported_dml(trimmed) {
65 return Err(err);
66 }
67
68 if let Some(result) = self.handle_custom_command(trimmed).await? {
70 return Ok(result);
71 }
72
73 self.execute_standard_sql(trimmed).await
75 }
76
77 async fn handle_custom_command(&mut self, sql: &str) -> Result<Option<Vec<RecordBatch>>> {
79 let upper = sql.to_uppercase();
80 let upper = upper.trim_end_matches(';').trim();
81
82 if let Some(name) = upper.strip_prefix("USE HIVE ") {
84 let name = name.trim().to_lowercase();
85 let hives = self.registry.list_hives().await?;
87 if !hives.iter().any(|h| h.to_lowercase() == name) {
88 return Err(ApiaryError::EntityNotFound {
89 entity_type: "Hive".into(),
90 name: name.clone(),
91 });
92 }
93 self.current_hive = Some(name.clone());
94 let batch = single_message_batch(&format!("Current hive set to '{name}'"));
95 return Ok(Some(vec![batch]));
96 }
97
98 if let Some(name) = upper.strip_prefix("USE BOX ") {
100 let name = name.trim().to_lowercase();
101 let hive = self
102 .current_hive
103 .as_ref()
104 .ok_or_else(|| ApiaryError::Config {
105 message: "No hive selected. Run USE HIVE <name> first.".into(),
106 })?;
107 let boxes = self.registry.list_boxes(hive).await?;
109 if !boxes.iter().any(|b| b.to_lowercase() == name) {
110 return Err(ApiaryError::EntityNotFound {
111 entity_type: "Box".into(),
112 name: name.clone(),
113 });
114 }
115 self.current_box = Some(name.clone());
116 let batch = single_message_batch(&format!("Current box set to '{name}'"));
117 return Ok(Some(vec![batch]));
118 }
119
120 if upper == "SHOW HIVES" {
122 let hives = self.registry.list_hives().await?;
123 let batch = string_list_batch("hive", &hives);
124 return Ok(Some(vec![batch]));
125 }
126
127 if let Some(rest) = upper.strip_prefix("SHOW BOXES IN ") {
129 let hive = rest.trim().to_lowercase();
130 let boxes = self.registry.list_boxes(&hive).await?;
131 let batch = string_list_batch("box", &boxes);
132 return Ok(Some(vec![batch]));
133 }
134
135 if upper == "SHOW BOXES" {
137 let hive = self
138 .current_hive
139 .as_ref()
140 .ok_or_else(|| ApiaryError::Config {
141 message:
142 "No hive selected. Run USE HIVE <name> first, or use SHOW BOXES IN <hive>."
143 .into(),
144 })?;
145 let boxes = self.registry.list_boxes(hive).await?;
146 let batch = string_list_batch("box", &boxes);
147 return Ok(Some(vec![batch]));
148 }
149
150 if let Some(rest) = upper.strip_prefix("SHOW FRAMES IN ") {
152 let parts: Vec<&str> = rest.trim().split('.').collect();
153 if parts.len() != 2 {
154 return Err(ApiaryError::Config {
155 message: "SHOW FRAMES IN requires hive.box format".into(),
156 });
157 }
158 let hive = parts[0].trim().to_lowercase();
159 let box_name = parts[1].trim().to_lowercase();
160 let frames = self.registry.list_frames(&hive, &box_name).await?;
161 let batch = string_list_batch("frame", &frames);
162 return Ok(Some(vec![batch]));
163 }
164
165 if upper == "SHOW FRAMES" {
167 let hive = self.current_hive.as_ref().ok_or_else(|| ApiaryError::Config {
168 message: "No hive selected. Run USE HIVE <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
169 })?;
170 let box_name = self.current_box.as_ref().ok_or_else(|| ApiaryError::Config {
171 message: "No box selected. Run USE BOX <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
172 })?;
173 let frames = self.registry.list_frames(hive, box_name).await?;
174 let batch = string_list_batch("frame", &frames);
175 return Ok(Some(vec![batch]));
176 }
177
178 if let Some(rest) = upper.strip_prefix("DESCRIBE ") {
180 let raw_rest = sql.trim_end_matches(';').trim();
181 let raw_rest = &raw_rest[raw_rest.len() - rest.len()..];
182 let parts: Vec<&str> = raw_rest.trim().split('.').collect();
183 if parts.len() != 3 {
184 return Err(ApiaryError::Config {
185 message: "DESCRIBE requires hive.box.frame format".into(),
186 });
187 }
188 let hive = parts[0].trim();
189 let box_name = parts[1].trim();
190 let frame_name = parts[2].trim();
191 return Ok(Some(vec![
192 self.describe_frame(hive, box_name, frame_name).await?,
193 ]));
194 }
195
196 Ok(None)
197 }
198
199 async fn describe_frame(
201 &self,
202 hive: &str,
203 box_name: &str,
204 frame_name: &str,
205 ) -> Result<RecordBatch> {
206 let frame = self.registry.get_frame(hive, box_name, frame_name).await?;
207 let frame_path = format!("{hive}/{box_name}/{frame_name}");
208
209 let (cell_count, total_rows, total_bytes) =
211 match Ledger::open(Arc::clone(&self.storage), &frame_path).await {
212 Ok(ledger) => {
213 let cells = ledger.active_cells();
214 let rows: u64 = cells.iter().map(|c| c.rows).sum();
215 let bytes: u64 = cells.iter().map(|c| c.bytes).sum();
216 (cells.len() as u64, rows, bytes)
217 }
218 Err(_) => (0, 0, 0),
219 };
220
221 let schema_json = serde_json::to_string(&frame.schema).unwrap_or_else(|_| "{}".into());
222
223 let schema = Arc::new(Schema::new(vec![
224 Field::new("property", DataType::Utf8, false),
225 Field::new("value", DataType::Utf8, false),
226 ]));
227
228 let partition_str = if frame.partition_by.is_empty() {
229 "(none)".to_string()
230 } else {
231 frame.partition_by.join(", ")
232 };
233
234 let properties = vec![
235 "schema",
236 "partition_by",
237 "cells",
238 "total_rows",
239 "total_bytes",
240 ];
241 let values = vec![
242 schema_json,
243 partition_str,
244 cell_count.to_string(),
245 total_rows.to_string(),
246 total_bytes.to_string(),
247 ];
248
249 RecordBatch::try_new(
250 schema,
251 vec![
252 Arc::new(StringArray::from(
253 properties
254 .into_iter()
255 .map(|s| s.to_string())
256 .collect::<Vec<_>>(),
257 )),
258 Arc::new(StringArray::from(values)),
259 ],
260 )
261 .map_err(|e| ApiaryError::Internal {
262 message: format!("Failed to create DESCRIBE result: {e}"),
263 })
264 }
265
266 async fn execute_standard_sql(&self, sql: &str) -> Result<Vec<RecordBatch>> {
268 let mut timings = timing::QueryTimings::begin_from_sql(sql);
269
270 let parse_start = timings.as_ref().map(|t| t.start_phase());
272
273 let table_refs = extract_table_references(sql);
275
276 if table_refs.is_empty() {
277 return Err(ApiaryError::Config {
278 message: "No table references found in query".into(),
279 });
280 }
281
282 let predicates = extract_where_predicates(sql);
284
285 if let (Some(t), Some(s)) = (timings.as_mut(), parse_start) {
286 t.end_phase("parse", s);
287 }
288
289 let plan_start = timings.as_ref().map(|t| t.start_phase());
291
292 let session = SessionContext::new();
294
295 if let (Some(t), Some(s)) = (timings.as_mut(), plan_start) {
296 t.end_phase("plan", s);
297 }
298
299 let mut file_discovery_total = std::time::Duration::ZERO;
301 let mut metadata_read_total = std::time::Duration::ZERO;
302 let mut data_read_total = std::time::Duration::ZERO;
303
304 for table_ref in &table_refs {
305 let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
306
307 let fd_start = timings.as_ref().map(|t| t.start_phase());
309
310 let frame_path = format!("{hive}/{box_name}/{frame_name}");
311 let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
312
313 if let Some(s) = fd_start {
314 file_discovery_total += s.elapsed();
315 }
316
317 let mr_start = timings.as_ref().map(|t| t.start_phase());
319
320 let partition_by: Vec<String> = ledger.partition_by().to_vec();
322 let (partition_filters, stat_filters) = build_filters(&predicates, &partition_by);
323
324 let cells = if partition_filters.is_empty() && stat_filters.is_empty() {
325 ledger.active_cells().iter().collect::<Vec<_>>()
326 } else {
327 ledger.prune_cells(&partition_filters, &stat_filters)
328 };
329
330 if let Some(s) = mr_start {
331 metadata_read_total += s.elapsed();
332 }
333
334 info!(
335 frame = %frame_path,
336 total_cells = ledger.active_cells().len(),
337 surviving_cells = cells.len(),
338 "Cell pruning complete"
339 );
340
341 if cells.is_empty() {
342 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
344 let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
345 let mem_table = datafusion::datasource::MemTable::try_new(
346 empty_batch.schema(),
347 vec![vec![empty_batch]],
348 )
349 .map_err(|e| ApiaryError::Internal {
350 message: format!("Failed to create empty MemTable: {e}"),
351 })?;
352 session
353 .register_table(®ister_name, Arc::new(mem_table))
354 .map_err(|e| ApiaryError::Internal {
355 message: format!("Failed to register table: {e}"),
356 })?;
357 continue;
358 }
359
360 let dr_start = timings.as_ref().map(|t| t.start_phase());
362
363 let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
364 let merged = reader.read_cells_merged(&cells, None).await?;
365
366 if let Some(s) = dr_start {
367 data_read_total += s.elapsed();
368 }
369
370 let batches = match merged {
371 Some(batch) => vec![vec![batch]],
372 None => {
373 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
374 vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
375 }
376 };
377
378 let schema = batches[0][0].schema();
379 let mem_table =
380 datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
381 ApiaryError::Internal {
382 message: format!("Failed to create MemTable: {e}"),
383 }
384 })?;
385
386 session
387 .register_table(®ister_name, Arc::new(mem_table))
388 .map_err(|e| ApiaryError::Internal {
389 message: format!("Failed to register table '{register_name}': {e}"),
390 })?;
391 }
392
393 if let Some(t) = timings.as_mut() {
395 t.add_accumulated_phase("file_discovery", file_discovery_total);
396 t.add_accumulated_phase("metadata_read", metadata_read_total);
397 t.add_accumulated_phase("data_read", data_read_total);
398 }
399
400 let rewritten =
402 rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
403
404 let exec_start = timings.as_ref().map(|t| t.start_phase());
406
407 let df = session
408 .sql(&rewritten)
409 .await
410 .map_err(|e| ApiaryError::Internal {
411 message: format!("DataFusion query error: {e}"),
412 })?;
413
414 let results = df.collect().await.map_err(|e| ApiaryError::Internal {
415 message: format!("DataFusion execution error: {e}"),
416 })?;
417
418 if let (Some(t), Some(s)) = (timings.as_mut(), exec_start) {
419 t.end_phase("execute", s);
420 }
421
422 if let Some(t) = timings {
423 t.finish();
424 }
425
426 Ok(results)
427 }
428
429 fn resolve_table_ref(&self, table_ref: &str) -> Result<(String, String, String, String)> {
431 let parts: Vec<&str> = table_ref.split('.').collect();
432
433 match parts.len() {
434 3 => {
435 let hive = parts[0].to_string();
436 let box_name = parts[1].to_string();
437 let frame_name = parts[2].to_string();
438 let register_name = frame_name.clone();
440 Ok((hive, box_name, frame_name, register_name))
441 }
442 2 => {
443 let hive = self.current_hive.as_ref().ok_or_else(|| {
444 ApiaryError::Resolution {
445 path: table_ref.into(),
446 reason: "No hive selected. Use 3-part name (hive.box.frame) or run USE HIVE first.".into(),
447 }
448 })?;
449 let box_name = parts[0].to_string();
450 let frame_name = parts[1].to_string();
451 let register_name = frame_name.clone();
452 Ok((hive.clone(), box_name, frame_name, register_name))
453 }
454 1 => {
455 let hive = self
456 .current_hive
457 .as_ref()
458 .ok_or_else(|| ApiaryError::Resolution {
459 path: table_ref.into(),
460 reason: "No hive selected. Use 3-part name or run USE HIVE first.".into(),
461 })?;
462 let box_name =
463 self.current_box
464 .as_ref()
465 .ok_or_else(|| ApiaryError::Resolution {
466 path: table_ref.into(),
467 reason: "No box selected. Use 3-part name or run USE BOX first.".into(),
468 })?;
469 let frame_name = parts[0].to_string();
470 let register_name = frame_name.clone();
471 Ok((hive.clone(), box_name.clone(), frame_name, register_name))
472 }
473 _ => Err(ApiaryError::Resolution {
474 path: table_ref.into(),
475 reason: "Invalid table reference. Use hive.box.frame format.".into(),
476 }),
477 }
478 }
479
480 pub async fn execute_distributed(
491 &self,
492 sql: &str,
493 assignments: HashMap<NodeId, Vec<distributed::CellInfo>>,
494 ) -> Result<Vec<RecordBatch>> {
495 use distributed::*;
496
497 let mut timings = timing::QueryTimings::begin_from_sql(&format!("distributed:{sql}"));
498
499 let coord_start = timings.as_ref().map(|t| t.start_phase());
501
502 let mut tasks = Vec::new();
504 for (node_id, cells) in &assignments {
505 let task_id = format!("{}_{}", node_id.as_str(), uuid::Uuid::new_v4());
506 let cell_keys: Vec<String> = cells.iter().map(|c| c.storage_key.clone()).collect();
507
508 tasks.push(PlannedTask {
509 task_id,
510 node_id: node_id.clone(),
511 cells: cell_keys,
512 sql_fragment: sql.to_string(), });
514 }
515
516 let manifest = create_manifest(sql, tasks.clone(), None, 60);
518 write_manifest(&self.storage, &manifest).await?;
519
520 info!(
521 query_id = %manifest.query_id,
522 tasks = tasks.len(),
523 "Distributed query manifest written"
524 );
525
526 if let (Some(t), Some(s)) = (timings.as_mut(), coord_start) {
527 t.end_phase("coordination", s);
528 }
529
530 let exec_start = timings.as_ref().map(|t| t.start_phase());
532
533 let local_tasks: Vec<_> = tasks.iter().filter(|t| t.node_id == self.node_id).collect();
535
536 let mut local_results = Vec::new();
537 for task in local_tasks {
538 match self.execute_task(&task.sql_fragment, &task.cells).await {
539 Ok(batches) => {
540 if !batches.is_empty() {
541 local_results.extend(batches);
542 }
543 }
544 Err(e) => {
545 warn!(task_id = %task.task_id, error = %e, "Local task failed");
546 }
547 }
548 }
549
550 if !local_results.is_empty() {
552 write_partial_result(
553 &self.storage,
554 &manifest.query_id,
555 &self.node_id,
556 &local_results,
557 )
558 .await?;
559 }
560
561 if let (Some(t), Some(s)) = (timings.as_mut(), exec_start) {
562 t.end_phase("execute", s);
563 }
564
565 let collect_start = timings.as_ref().map(|t| t.start_phase());
567
568 let remote_nodes: Vec<_> = tasks
570 .iter()
571 .filter(|t| t.node_id != self.node_id)
572 .map(|t| t.node_id.clone())
573 .collect();
574
575 let timeout = std::time::Duration::from_secs(manifest.timeout_secs);
576 let start = std::time::Instant::now();
577 let mut collected_results = local_results;
578
579 for remote_node in &remote_nodes {
580 let deadline = timeout.saturating_sub(start.elapsed());
581 if start.elapsed() >= timeout {
582 warn!(query_id = %manifest.query_id, "Query timeout reached");
583 break;
584 }
585
586 let poll_interval = std::time::Duration::from_millis(500);
588 let mut attempts = 0;
589 let max_attempts = (deadline.as_millis() / poll_interval.as_millis()) as usize;
590
591 while attempts < max_attempts {
592 match read_partial_result(&self.storage, &manifest.query_id, remote_node).await {
593 Ok(batches) => {
594 info!(
595 query_id = %manifest.query_id,
596 node_id = %remote_node,
597 "Partial result received"
598 );
599 collected_results.extend(batches);
600 break;
601 }
602 Err(_) => {
603 tokio::time::sleep(poll_interval).await;
604 attempts += 1;
605 }
606 }
607 }
608 }
609
610 if let Err(e) = cleanup_query(&self.storage, &manifest.query_id).await {
612 warn!(query_id = %manifest.query_id, error = %e, "Failed to cleanup query");
613 }
614
615 if let (Some(t), Some(s)) = (timings.as_mut(), collect_start) {
616 t.end_phase("result_collect", s);
617 }
618
619 if let Some(t) = timings {
620 t.finish();
621 }
622
623 Ok(collected_results)
624 }
625
626 pub async fn execute_task(&self, sql: &str, cell_keys: &[String]) -> Result<Vec<RecordBatch>> {
633 if cell_keys.is_empty() {
634 return self.execute_standard_sql(sql).await;
636 }
637
638 let table_refs = extract_table_references(sql);
640
641 if table_refs.is_empty() {
642 return Err(ApiaryError::Config {
643 message: "No table references found in query".into(),
644 });
645 }
646
647 let session = SessionContext::new();
649
650 let cell_key_set: std::collections::HashSet<&String> = cell_keys.iter().collect();
652
653 for table_ref in &table_refs {
655 let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
656
657 let frame_path = format!("{hive}/{box_name}/{frame_name}");
658 let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
659
660 let cells: Vec<_> = ledger
662 .active_cells()
663 .iter()
664 .filter(|cell| {
665 let cell_storage_key = format!("{}/{}", frame_path, cell.path);
666 cell_key_set.contains(&cell_storage_key)
667 })
668 .collect();
669
670 info!(
671 frame = %frame_path,
672 total_cells = ledger.active_cells().len(),
673 assigned_cells = cells.len(),
674 "Cell filtering for distributed task"
675 );
676
677 if cells.is_empty() {
678 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
679 let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
680 let mem_table = datafusion::datasource::MemTable::try_new(
681 empty_batch.schema(),
682 vec![vec![empty_batch]],
683 )
684 .map_err(|e| ApiaryError::Internal {
685 message: format!("Failed to create empty MemTable: {e}"),
686 })?;
687 session
688 .register_table(®ister_name, Arc::new(mem_table))
689 .map_err(|e| ApiaryError::Internal {
690 message: format!("Failed to register table: {e}"),
691 })?;
692 continue;
693 }
694
695 let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
697 let merged = reader.read_cells_merged(&cells, None).await?;
698
699 let batches = match merged {
700 Some(batch) => vec![vec![batch]],
701 None => {
702 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
703 vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
704 }
705 };
706
707 let schema = batches[0][0].schema();
708 let mem_table =
709 datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
710 ApiaryError::Internal {
711 message: format!("Failed to create MemTable: {e}"),
712 }
713 })?;
714
715 session
716 .register_table(®ister_name, Arc::new(mem_table))
717 .map_err(|e| ApiaryError::Internal {
718 message: format!("Failed to register table '{register_name}': {e}"),
719 })?;
720 }
721
722 let rewritten =
724 rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
725
726 let df = session
728 .sql(&rewritten)
729 .await
730 .map_err(|e| ApiaryError::Internal {
731 message: format!("DataFusion query error: {e}"),
732 })?;
733
734 df.collect().await.map_err(|e| ApiaryError::Internal {
735 message: format!("DataFusion execution error: {e}"),
736 })
737 }
738}
739
740fn check_unsupported_dml(sql: &str) -> Option<ApiaryError> {
746 let upper = sql.to_uppercase();
747 let first_word = upper.split_whitespace().next().unwrap_or("");
748
749 match first_word {
750 "DELETE" => Some(ApiaryError::Unsupported {
751 message: "DELETE is not supported. Apiary uses append-only writes. Use overwrite_frame() to replace all data in a frame.".into(),
752 }),
753 "UPDATE" => Some(ApiaryError::Unsupported {
754 message: "UPDATE is not supported. Apiary uses append-only writes. Rewrite the frame with corrected data using overwrite_frame().".into(),
755 }),
756 "INSERT" => Some(ApiaryError::Unsupported {
757 message: "INSERT is not supported via SQL. Use write_to_frame() to add data.".into(),
758 }),
759 "DROP" => Some(ApiaryError::Unsupported {
760 message: "DROP is not supported via SQL. Use the registry API for DDL operations.".into(),
761 }),
762 "CREATE" => Some(ApiaryError::Unsupported {
763 message: "CREATE is not supported via SQL. Use create_frame() for DDL operations.".into(),
764 }),
765 "ALTER" => Some(ApiaryError::Unsupported {
766 message: "ALTER is not supported via SQL. Use the registry API for DDL operations.".into(),
767 }),
768 _ => None,
769 }
770}
771
772fn extract_table_references(sql: &str) -> Vec<String> {
777 let mut refs = Vec::new();
778 let tokens: Vec<&str> = sql.split_whitespace().collect();
779
780 for i in 0..tokens.len() {
781 let upper = tokens[i].to_uppercase();
782 if (upper == "FROM" || upper == "JOIN") && i + 1 < tokens.len() {
783 let table_name = tokens[i + 1]
784 .trim_end_matches(',')
785 .trim_end_matches(')')
786 .trim_end_matches(';');
787 if table_name.starts_with('(') || table_name.is_empty() {
789 continue;
790 }
791 let table_upper = table_name.to_uppercase();
793 if matches!(
794 table_upper.as_str(),
795 "SELECT" | "WHERE" | "GROUP" | "ORDER" | "LIMIT" | "HAVING"
796 ) {
797 continue;
798 }
799 if !refs.contains(&table_name.to_string()) {
800 refs.push(table_name.to_string());
801 }
802 }
803 }
804
805 refs
806}
807
808#[derive(Debug, Clone)]
810struct Predicate {
811 column: String,
812 op: PredicateOp,
813 value: String,
814}
815
816#[derive(Debug, Clone)]
817enum PredicateOp {
818 Eq,
819 Gt,
820 Lt,
821 Gte,
822 Lte,
823}
824
825fn extract_where_predicates(sql: &str) -> Vec<Predicate> {
834 let mut predicates = Vec::new();
835
836 let upper = sql.to_uppercase();
838 let where_pos = match upper.find(" WHERE ") {
839 Some(pos) => pos + 7,
840 None => return predicates,
841 };
842
843 let where_clause = &sql[where_pos..];
844 let end_keywords = [" GROUP ", " ORDER ", " LIMIT ", " HAVING ", ";"];
846 let end_pos = end_keywords
847 .iter()
848 .filter_map(|kw| where_clause.to_uppercase().find(kw))
849 .min()
850 .unwrap_or(where_clause.len());
851 let where_clause = &where_clause[..end_pos];
852
853 let parts: Vec<&str> = split_on_and(where_clause);
855
856 for part in parts {
857 let part = part.trim();
858 if let Some(pred) = parse_predicate(part) {
859 predicates.push(pred);
860 }
861 }
862
863 predicates
864}
865
866fn split_on_and(clause: &str) -> Vec<&str> {
868 let mut parts = Vec::new();
869 let upper = clause.to_uppercase();
870 let mut last = 0;
871
872 for (i, _) in upper.match_indices(" AND ") {
873 parts.push(&clause[last..i]);
874 last = i + 5; }
876 parts.push(&clause[last..]);
877 parts
878}
879
880fn parse_predicate(condition: &str) -> Option<Predicate> {
882 let condition = condition.trim();
883
884 let ops = [
886 (">=", PredicateOp::Gte),
887 ("<=", PredicateOp::Lte),
888 (">", PredicateOp::Gt),
889 ("<", PredicateOp::Lt),
890 ("=", PredicateOp::Eq),
891 ];
892
893 for (op_str, op) in &ops {
894 if let Some(pos) = condition.find(op_str) {
895 let col = condition[..pos].trim();
896 let val = condition[pos + op_str.len()..].trim();
897
898 let val = val.trim_matches('\'').trim_matches('"').to_string();
900
901 if !col.is_empty() && !val.is_empty() {
902 return Some(Predicate {
903 column: col.to_string(),
904 op: op.clone(),
905 value: val,
906 });
907 }
908 }
909 }
910
911 None
912}
913
914type StatFilters = HashMap<String, (Option<serde_json::Value>, Option<serde_json::Value>)>;
916
917fn build_filters(
919 predicates: &[Predicate],
920 partition_columns: &[String],
921) -> (HashMap<String, String>, StatFilters) {
922 let mut partition_filters = HashMap::new();
923 let mut stat_filters: StatFilters = HashMap::new();
924
925 for pred in predicates {
926 if partition_columns.contains(&pred.column) {
927 if matches!(pred.op, PredicateOp::Eq) {
929 partition_filters.insert(pred.column.clone(), pred.value.clone());
930 }
931 }
932
933 if let Ok(num) = pred.value.parse::<f64>() {
935 let json_val = serde_json::json!(num);
936 let entry = stat_filters
937 .entry(pred.column.clone())
938 .or_insert((None, None));
939 match pred.op {
940 PredicateOp::Gt | PredicateOp::Gte => {
941 entry.0 = Some(json_val);
943 }
944 PredicateOp::Lt | PredicateOp::Lte => {
945 entry.1 = Some(json_val);
947 }
948 PredicateOp::Eq => {
949 entry.0 = Some(json_val.clone());
951 entry.1 = Some(json_val);
952 }
953 }
954 }
955 }
956
957 (partition_filters, stat_filters)
958}
959
960fn frame_schema_to_arrow(schema: &apiary_core::FrameSchema) -> Result<Schema> {
962 let fields: Vec<Field> = schema
963 .fields
964 .iter()
965 .map(|f| {
966 let dt = match f.data_type.as_str() {
967 "int8" => DataType::Int8,
968 "int16" => DataType::Int16,
969 "int32" => DataType::Int32,
970 "int64" => DataType::Int64,
971 "uint8" => DataType::UInt8,
972 "uint16" => DataType::UInt16,
973 "uint32" => DataType::UInt32,
974 "uint64" => DataType::UInt64,
975 "float32" | "float" => DataType::Float32,
976 "float64" | "double" => DataType::Float64,
977 "string" | "utf8" => DataType::Utf8,
978 "boolean" | "bool" => DataType::Boolean,
979 "datetime" | "timestamp" => {
980 DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
981 }
982 _ => DataType::Utf8,
983 };
984 Field::new(&f.name, dt, f.nullable)
985 })
986 .collect();
987
988 Ok(Schema::new(fields))
989}
990
991fn rewrite_sql_table_refs(
993 sql: &str,
994 table_refs: &[String],
995 _current_hive: &Option<String>,
996 _current_box: &Option<String>,
997) -> String {
998 let mut result = sql.to_string();
999
1000 for table_ref in table_refs {
1001 let parts: Vec<&str> = table_ref.split('.').collect();
1002 let register_name = parts.last().unwrap_or(&table_ref.as_str()).to_string();
1003 if parts.len() > 1 {
1004 result = result.replace(table_ref, ®ister_name);
1006 }
1007 }
1008
1009 result
1010}
1011
1012fn single_message_batch(message: &str) -> RecordBatch {
1014 let schema = Arc::new(Schema::new(vec![Field::new(
1015 "message",
1016 DataType::Utf8,
1017 false,
1018 )]));
1019 RecordBatch::try_new(
1020 schema,
1021 vec![Arc::new(StringArray::from(vec![message.to_string()]))],
1022 )
1023 .unwrap()
1024}
1025
1026fn string_list_batch(column_name: &str, values: &[String]) -> RecordBatch {
1028 let schema = Arc::new(Schema::new(vec![Field::new(
1029 column_name,
1030 DataType::Utf8,
1031 false,
1032 )]));
1033 RecordBatch::try_new(
1034 schema,
1035 vec![Arc::new(StringArray::from(
1036 values.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
1037 ))],
1038 )
1039 .unwrap()
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044 use super::*;
1045 use apiary_core::{CellSizingPolicy, FieldDef, FrameSchema, NodeId};
1046 use apiary_storage::cell_writer::CellWriter;
1047 use apiary_storage::ledger::Ledger;
1048 use apiary_storage::local::LocalBackend;
1049 use arrow::array::{Float64Array, Int64Array};
1050
1051 async fn make_test_env() -> (
1052 Arc<dyn StorageBackend>,
1053 Arc<RegistryManager>,
1054 tempfile::TempDir,
1055 ) {
1056 let dir = tempfile::tempdir().unwrap();
1057 let backend = LocalBackend::new(dir.path().to_path_buf()).await.unwrap();
1058 let storage: Arc<dyn StorageBackend> = Arc::new(backend);
1059 let registry = Arc::new(RegistryManager::new(Arc::clone(&storage)));
1060 let _ = registry.load_or_create().await.unwrap();
1061 (storage, registry, dir)
1062 }
1063
1064 fn test_schema() -> serde_json::Value {
1065 serde_json::json!({
1066 "region": "string",
1067 "temp": "float64",
1068 "humidity": "int64"
1069 })
1070 }
1071
1072 async fn setup_frame(storage: &Arc<dyn StorageBackend>, registry: &Arc<RegistryManager>) {
1073 registry.create_hive("test_hive").await.unwrap();
1074 registry.create_box("test_hive", "test_box").await.unwrap();
1075 registry
1076 .create_frame(
1077 "test_hive",
1078 "test_box",
1079 "sensors",
1080 test_schema(),
1081 vec!["region".into()],
1082 )
1083 .await
1084 .unwrap();
1085
1086 let frame_schema = FrameSchema {
1088 fields: vec![
1089 FieldDef {
1090 name: "region".into(),
1091 data_type: "string".into(),
1092 nullable: false,
1093 },
1094 FieldDef {
1095 name: "temp".into(),
1096 data_type: "float64".into(),
1097 nullable: true,
1098 },
1099 FieldDef {
1100 name: "humidity".into(),
1101 data_type: "int64".into(),
1102 nullable: true,
1103 },
1104 ],
1105 };
1106
1107 let node_id = NodeId::new("test_node");
1108 let frame_path = "test_hive/test_box/sensors";
1109
1110 let mut ledger = Ledger::create(
1111 Arc::clone(storage),
1112 frame_path,
1113 frame_schema.clone(),
1114 vec!["region".into()],
1115 &node_id,
1116 )
1117 .await
1118 .unwrap();
1119
1120 let sizing = CellSizingPolicy::new(256 * 1024 * 1024, 512 * 1024 * 1024, 16 * 1024 * 1024);
1121
1122 let writer = CellWriter::new(
1123 Arc::clone(storage),
1124 frame_path.into(),
1125 frame_schema,
1126 vec!["region".into()],
1127 sizing,
1128 );
1129
1130 let schema = Arc::new(Schema::new(vec![
1132 Field::new("region", DataType::Utf8, false),
1133 Field::new("temp", DataType::Float64, true),
1134 Field::new("humidity", DataType::Int64, true),
1135 ]));
1136
1137 let batch = RecordBatch::try_new(
1138 schema,
1139 vec![
1140 Arc::new(StringArray::from(vec!["north", "north", "south", "south"])),
1141 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1142 Arc::new(Int64Array::from(vec![50, 60, 70, 80])),
1143 ],
1144 )
1145 .unwrap();
1146
1147 let cells = writer.write(&batch).await.unwrap();
1148 ledger
1149 .commit(apiary_core::LedgerAction::AddCells { cells }, &node_id)
1150 .await
1151 .unwrap();
1152 }
1153
1154 #[tokio::test]
1155 async fn test_select_all() {
1156 let (storage, registry, _dir) = make_test_env().await;
1157 setup_frame(&storage, ®istry).await;
1158
1159 let mut ctx = ApiaryQueryContext::new(storage, registry);
1160 let results = ctx
1161 .sql("SELECT * FROM test_hive.test_box.sensors")
1162 .await
1163 .unwrap();
1164
1165 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1166 assert_eq!(total_rows, 4);
1167 }
1168
1169 #[tokio::test]
1170 async fn test_aggregation() {
1171 let (storage, registry, _dir) = make_test_env().await;
1172 setup_frame(&storage, ®istry).await;
1173
1174 let mut ctx = ApiaryQueryContext::new(storage, registry);
1175 let results = ctx
1176 .sql("SELECT region, AVG(temp) as avg_temp FROM test_hive.test_box.sensors GROUP BY region ORDER BY region")
1177 .await
1178 .unwrap();
1179
1180 assert!(!results.is_empty());
1181 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1182 assert_eq!(total_rows, 2); }
1184
1185 #[tokio::test]
1186 async fn test_use_hive_and_box() {
1187 let (storage, registry, _dir) = make_test_env().await;
1188 setup_frame(&storage, ®istry).await;
1189
1190 let mut ctx = ApiaryQueryContext::new(storage, registry);
1191 ctx.sql("USE HIVE test_hive").await.unwrap();
1192 ctx.sql("USE BOX test_box").await.unwrap();
1193 let results = ctx.sql("SELECT * FROM sensors").await.unwrap();
1194
1195 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1196 assert_eq!(total_rows, 4);
1197 }
1198
1199 #[tokio::test]
1200 async fn test_show_hives() {
1201 let (storage, registry, _dir) = make_test_env().await;
1202 setup_frame(&storage, ®istry).await;
1203
1204 let mut ctx = ApiaryQueryContext::new(storage, registry);
1205 let results = ctx.sql("SHOW HIVES").await.unwrap();
1206
1207 assert_eq!(results.len(), 1);
1208 assert_eq!(results[0].num_rows(), 1);
1209 }
1210
1211 #[tokio::test]
1212 async fn test_show_frames() {
1213 let (storage, registry, _dir) = make_test_env().await;
1214 setup_frame(&storage, ®istry).await;
1215
1216 let mut ctx = ApiaryQueryContext::new(storage, registry);
1217 let results = ctx.sql("SHOW FRAMES IN test_hive.test_box").await.unwrap();
1218
1219 assert_eq!(results.len(), 1);
1220 assert!(results[0].num_rows() >= 1);
1221 }
1222
1223 #[tokio::test]
1224 async fn test_describe() {
1225 let (storage, registry, _dir) = make_test_env().await;
1226 setup_frame(&storage, ®istry).await;
1227
1228 let mut ctx = ApiaryQueryContext::new(storage, registry);
1229 let results = ctx
1230 .sql("DESCRIBE test_hive.test_box.sensors")
1231 .await
1232 .unwrap();
1233
1234 assert_eq!(results.len(), 1);
1235 assert!(results[0].num_rows() >= 3);
1236 }
1237
1238 #[tokio::test]
1239 async fn test_delete_blocked() {
1240 let (storage, registry, _dir) = make_test_env().await;
1241 let mut ctx = ApiaryQueryContext::new(storage, registry);
1242
1243 let result = ctx.sql("DELETE FROM test_hive.test_box.sensors").await;
1244 assert!(result.is_err());
1245 let err = result.unwrap_err().to_string();
1246 assert!(err.contains("not supported"));
1247 }
1248
1249 #[tokio::test]
1250 async fn test_update_blocked() {
1251 let (storage, registry, _dir) = make_test_env().await;
1252 let mut ctx = ApiaryQueryContext::new(storage, registry);
1253
1254 let result = ctx
1255 .sql("UPDATE test_hive.test_box.sensors SET temp = 0")
1256 .await;
1257 assert!(result.is_err());
1258 let err = result.unwrap_err().to_string();
1259 assert!(err.contains("not supported"));
1260 }
1261
1262 #[tokio::test]
1263 async fn test_where_filter() {
1264 let (storage, registry, _dir) = make_test_env().await;
1265 setup_frame(&storage, ®istry).await;
1266
1267 let mut ctx = ApiaryQueryContext::new(storage, registry);
1268 let results = ctx
1269 .sql("SELECT * FROM test_hive.test_box.sensors WHERE region = 'north'")
1270 .await
1271 .unwrap();
1272
1273 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1274 assert_eq!(total_rows, 2);
1275 }
1276
1277 #[tokio::test]
1278 async fn test_projection() {
1279 let (storage, registry, _dir) = make_test_env().await;
1280 setup_frame(&storage, ®istry).await;
1281
1282 let mut ctx = ApiaryQueryContext::new(storage, registry);
1283 let results = ctx
1284 .sql("SELECT temp FROM test_hive.test_box.sensors")
1285 .await
1286 .unwrap();
1287
1288 assert!(!results.is_empty());
1289 assert_eq!(results[0].num_columns(), 1);
1290 assert_eq!(results[0].schema().field(0).name(), "temp");
1291 }
1292
1293 #[test]
1294 fn test_extract_table_references() {
1295 let refs = extract_table_references("SELECT * FROM hive.box.frame WHERE x = 1");
1296 assert_eq!(refs, vec!["hive.box.frame"]);
1297
1298 let refs = extract_table_references("SELECT * FROM frame1 JOIN frame2 ON x = y");
1299 assert_eq!(refs, vec!["frame1", "frame2"]);
1300 }
1301
1302 #[test]
1303 fn test_extract_predicates() {
1304 let preds =
1305 extract_where_predicates("SELECT * FROM t WHERE region = 'north' AND temp > 25");
1306 assert_eq!(preds.len(), 2);
1307 assert_eq!(preds[0].column, "region");
1308 assert_eq!(preds[0].value, "north");
1309 assert_eq!(preds[1].column, "temp");
1310 assert_eq!(preds[1].value, "25");
1311 }
1312
1313 #[test]
1314 fn test_check_unsupported_dml() {
1315 assert!(check_unsupported_dml("DELETE FROM t").is_some());
1316 assert!(check_unsupported_dml("UPDATE t SET x = 1").is_some());
1317 assert!(check_unsupported_dml("SELECT * FROM t").is_none());
1318 }
1319
1320 #[tokio::test]
1321 async fn test_show_boxes_without_qualifier() {
1322 let (storage, registry, _dir) = make_test_env().await;
1323 setup_frame(&storage, ®istry).await;
1324
1325 let mut ctx = ApiaryQueryContext::new(storage, registry);
1326 ctx.sql("USE HIVE test_hive").await.unwrap();
1327
1328 let results = ctx.sql("SHOW BOXES").await.unwrap();
1329 assert_eq!(results.len(), 1);
1330 assert!(results[0].num_rows() >= 1);
1331 assert_eq!(results[0].schema().field(0).name(), "box");
1332 }
1333
1334 #[tokio::test]
1335 async fn test_show_frames_without_qualifier() {
1336 let (storage, registry, _dir) = make_test_env().await;
1337 setup_frame(&storage, ®istry).await;
1338
1339 let mut ctx = ApiaryQueryContext::new(storage, registry);
1340 ctx.sql("USE HIVE test_hive").await.unwrap();
1341 ctx.sql("USE BOX test_box").await.unwrap();
1342
1343 let results = ctx.sql("SHOW FRAMES").await.unwrap();
1344 assert_eq!(results.len(), 1);
1345 assert!(results[0].num_rows() >= 1);
1346 assert_eq!(results[0].schema().field(0).name(), "frame");
1347 }
1348}