1pub mod distributed;
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use arrow::array::StringArray;
14use arrow::datatypes::{DataType, Field, Schema};
15use arrow::record_batch::RecordBatch;
16use datafusion::prelude::*;
17use tracing::{info, warn};
18
19use apiary_core::error::ApiaryError;
20use apiary_core::registry_manager::RegistryManager;
21use apiary_core::storage::StorageBackend;
22use apiary_core::types::NodeId;
23use apiary_core::Result;
24use apiary_storage::cell_reader::CellReader;
25use apiary_storage::ledger::Ledger;
26
27pub struct ApiaryQueryContext {
29 storage: Arc<dyn StorageBackend>,
30 registry: Arc<RegistryManager>,
31 current_hive: Option<String>,
32 current_box: Option<String>,
33 #[allow(dead_code)] node_id: NodeId,
35}
36
37impl ApiaryQueryContext {
38 pub fn new(storage: Arc<dyn StorageBackend>, registry: Arc<RegistryManager>) -> Self {
40 Self::with_node_id(storage, registry, NodeId::from("local"))
41 }
42
43 pub fn with_node_id(
45 storage: Arc<dyn StorageBackend>,
46 registry: Arc<RegistryManager>,
47 node_id: NodeId,
48 ) -> Self {
49 Self {
50 storage,
51 registry,
52 current_hive: None,
53 current_box: None,
54 node_id,
55 }
56 }
57
58 pub async fn sql(&mut self, query: &str) -> Result<Vec<RecordBatch>> {
60 let trimmed = query.trim();
61
62 if let Some(err) = check_unsupported_dml(trimmed) {
64 return Err(err);
65 }
66
67 if let Some(result) = self.handle_custom_command(trimmed).await? {
69 return Ok(result);
70 }
71
72 self.execute_standard_sql(trimmed).await
74 }
75
76 async fn handle_custom_command(&mut self, sql: &str) -> Result<Option<Vec<RecordBatch>>> {
78 let upper = sql.to_uppercase();
79 let upper = upper.trim_end_matches(';').trim();
80
81 if let Some(name) = upper.strip_prefix("USE HIVE ") {
83 let name = name.trim().to_lowercase();
84 let hives = self.registry.list_hives().await?;
86 if !hives.iter().any(|h| h.to_lowercase() == name) {
87 return Err(ApiaryError::EntityNotFound {
88 entity_type: "Hive".into(),
89 name: name.clone(),
90 });
91 }
92 self.current_hive = Some(name.clone());
93 let batch = single_message_batch(&format!("Current hive set to '{name}'"));
94 return Ok(Some(vec![batch]));
95 }
96
97 if let Some(name) = upper.strip_prefix("USE BOX ") {
99 let name = name.trim().to_lowercase();
100 let hive = self
101 .current_hive
102 .as_ref()
103 .ok_or_else(|| ApiaryError::Config {
104 message: "No hive selected. Run USE HIVE <name> first.".into(),
105 })?;
106 let boxes = self.registry.list_boxes(hive).await?;
108 if !boxes.iter().any(|b| b.to_lowercase() == name) {
109 return Err(ApiaryError::EntityNotFound {
110 entity_type: "Box".into(),
111 name: name.clone(),
112 });
113 }
114 self.current_box = Some(name.clone());
115 let batch = single_message_batch(&format!("Current box set to '{name}'"));
116 return Ok(Some(vec![batch]));
117 }
118
119 if upper == "SHOW HIVES" {
121 let hives = self.registry.list_hives().await?;
122 let batch = string_list_batch("hive", &hives);
123 return Ok(Some(vec![batch]));
124 }
125
126 if let Some(rest) = upper.strip_prefix("SHOW BOXES IN ") {
128 let hive = rest.trim().to_lowercase();
129 let boxes = self.registry.list_boxes(&hive).await?;
130 let batch = string_list_batch("box", &boxes);
131 return Ok(Some(vec![batch]));
132 }
133
134 if upper == "SHOW BOXES" {
136 let hive = self
137 .current_hive
138 .as_ref()
139 .ok_or_else(|| ApiaryError::Config {
140 message:
141 "No hive selected. Run USE HIVE <name> first, or use SHOW BOXES IN <hive>."
142 .into(),
143 })?;
144 let boxes = self.registry.list_boxes(hive).await?;
145 let batch = string_list_batch("box", &boxes);
146 return Ok(Some(vec![batch]));
147 }
148
149 if let Some(rest) = upper.strip_prefix("SHOW FRAMES IN ") {
151 let parts: Vec<&str> = rest.trim().split('.').collect();
152 if parts.len() != 2 {
153 return Err(ApiaryError::Config {
154 message: "SHOW FRAMES IN requires hive.box format".into(),
155 });
156 }
157 let hive = parts[0].trim().to_lowercase();
158 let box_name = parts[1].trim().to_lowercase();
159 let frames = self.registry.list_frames(&hive, &box_name).await?;
160 let batch = string_list_batch("frame", &frames);
161 return Ok(Some(vec![batch]));
162 }
163
164 if upper == "SHOW FRAMES" {
166 let hive = self.current_hive.as_ref().ok_or_else(|| ApiaryError::Config {
167 message: "No hive selected. Run USE HIVE <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
168 })?;
169 let box_name = self.current_box.as_ref().ok_or_else(|| ApiaryError::Config {
170 message: "No box selected. Run USE BOX <name> first, or use SHOW FRAMES IN <hive>.<box>.".into(),
171 })?;
172 let frames = self.registry.list_frames(hive, box_name).await?;
173 let batch = string_list_batch("frame", &frames);
174 return Ok(Some(vec![batch]));
175 }
176
177 if let Some(rest) = upper.strip_prefix("DESCRIBE ") {
179 let raw_rest = sql.trim_end_matches(';').trim();
180 let raw_rest = &raw_rest[raw_rest.len() - rest.len()..];
181 let parts: Vec<&str> = raw_rest.trim().split('.').collect();
182 if parts.len() != 3 {
183 return Err(ApiaryError::Config {
184 message: "DESCRIBE requires hive.box.frame format".into(),
185 });
186 }
187 let hive = parts[0].trim();
188 let box_name = parts[1].trim();
189 let frame_name = parts[2].trim();
190 return Ok(Some(vec![
191 self.describe_frame(hive, box_name, frame_name).await?,
192 ]));
193 }
194
195 Ok(None)
196 }
197
198 async fn describe_frame(
200 &self,
201 hive: &str,
202 box_name: &str,
203 frame_name: &str,
204 ) -> Result<RecordBatch> {
205 let frame = self.registry.get_frame(hive, box_name, frame_name).await?;
206 let frame_path = format!("{hive}/{box_name}/{frame_name}");
207
208 let (cell_count, total_rows, total_bytes) =
210 match Ledger::open(Arc::clone(&self.storage), &frame_path).await {
211 Ok(ledger) => {
212 let cells = ledger.active_cells();
213 let rows: u64 = cells.iter().map(|c| c.rows).sum();
214 let bytes: u64 = cells.iter().map(|c| c.bytes).sum();
215 (cells.len() as u64, rows, bytes)
216 }
217 Err(_) => (0, 0, 0),
218 };
219
220 let schema_json = serde_json::to_string(&frame.schema).unwrap_or_else(|_| "{}".into());
221
222 let schema = Arc::new(Schema::new(vec![
223 Field::new("property", DataType::Utf8, false),
224 Field::new("value", DataType::Utf8, false),
225 ]));
226
227 let partition_str = if frame.partition_by.is_empty() {
228 "(none)".to_string()
229 } else {
230 frame.partition_by.join(", ")
231 };
232
233 let properties = vec![
234 "schema",
235 "partition_by",
236 "cells",
237 "total_rows",
238 "total_bytes",
239 ];
240 let values = vec![
241 schema_json,
242 partition_str,
243 cell_count.to_string(),
244 total_rows.to_string(),
245 total_bytes.to_string(),
246 ];
247
248 RecordBatch::try_new(
249 schema,
250 vec![
251 Arc::new(StringArray::from(
252 properties
253 .into_iter()
254 .map(|s| s.to_string())
255 .collect::<Vec<_>>(),
256 )),
257 Arc::new(StringArray::from(values)),
258 ],
259 )
260 .map_err(|e| ApiaryError::Internal {
261 message: format!("Failed to create DESCRIBE result: {e}"),
262 })
263 }
264
265 async fn execute_standard_sql(&self, sql: &str) -> Result<Vec<RecordBatch>> {
267 let table_refs = extract_table_references(sql);
269
270 if table_refs.is_empty() {
271 return Err(ApiaryError::Config {
272 message: "No table references found in query".into(),
273 });
274 }
275
276 let predicates = extract_where_predicates(sql);
278
279 let session = SessionContext::new();
281
282 for table_ref in &table_refs {
284 let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
285
286 let frame_path = format!("{hive}/{box_name}/{frame_name}");
288 let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
289
290 let partition_by: Vec<String> = ledger.partition_by().to_vec();
292 let (partition_filters, stat_filters) = build_filters(&predicates, &partition_by);
293
294 let cells = if partition_filters.is_empty() && stat_filters.is_empty() {
295 ledger.active_cells().iter().collect::<Vec<_>>()
296 } else {
297 ledger.prune_cells(&partition_filters, &stat_filters)
298 };
299
300 info!(
301 frame = %frame_path,
302 total_cells = ledger.active_cells().len(),
303 surviving_cells = cells.len(),
304 "Cell pruning complete"
305 );
306
307 if cells.is_empty() {
308 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
310 let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
311 let mem_table = datafusion::datasource::MemTable::try_new(
312 empty_batch.schema(),
313 vec![vec![empty_batch]],
314 )
315 .map_err(|e| ApiaryError::Internal {
316 message: format!("Failed to create empty MemTable: {e}"),
317 })?;
318 session
319 .register_table(®ister_name, Arc::new(mem_table))
320 .map_err(|e| ApiaryError::Internal {
321 message: format!("Failed to register table: {e}"),
322 })?;
323 continue;
324 }
325
326 let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
328 let merged = reader.read_cells_merged(&cells, None).await?;
329
330 let batches = match merged {
331 Some(batch) => vec![vec![batch]],
332 None => {
333 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
334 vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
335 }
336 };
337
338 let schema = batches[0][0].schema();
339 let mem_table =
340 datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
341 ApiaryError::Internal {
342 message: format!("Failed to create MemTable: {e}"),
343 }
344 })?;
345
346 session
347 .register_table(®ister_name, Arc::new(mem_table))
348 .map_err(|e| ApiaryError::Internal {
349 message: format!("Failed to register table '{register_name}': {e}"),
350 })?;
351 }
352
353 let rewritten =
355 rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
356
357 let df = session
359 .sql(&rewritten)
360 .await
361 .map_err(|e| ApiaryError::Internal {
362 message: format!("DataFusion query error: {e}"),
363 })?;
364
365 df.collect().await.map_err(|e| ApiaryError::Internal {
366 message: format!("DataFusion execution error: {e}"),
367 })
368 }
369
370 fn resolve_table_ref(&self, table_ref: &str) -> Result<(String, String, String, String)> {
372 let parts: Vec<&str> = table_ref.split('.').collect();
373
374 match parts.len() {
375 3 => {
376 let hive = parts[0].to_string();
377 let box_name = parts[1].to_string();
378 let frame_name = parts[2].to_string();
379 let register_name = frame_name.clone();
381 Ok((hive, box_name, frame_name, register_name))
382 }
383 2 => {
384 let hive = self.current_hive.as_ref().ok_or_else(|| {
385 ApiaryError::Resolution {
386 path: table_ref.into(),
387 reason: "No hive selected. Use 3-part name (hive.box.frame) or run USE HIVE first.".into(),
388 }
389 })?;
390 let box_name = parts[0].to_string();
391 let frame_name = parts[1].to_string();
392 let register_name = frame_name.clone();
393 Ok((hive.clone(), box_name, frame_name, register_name))
394 }
395 1 => {
396 let hive = self
397 .current_hive
398 .as_ref()
399 .ok_or_else(|| ApiaryError::Resolution {
400 path: table_ref.into(),
401 reason: "No hive selected. Use 3-part name or run USE HIVE first.".into(),
402 })?;
403 let box_name =
404 self.current_box
405 .as_ref()
406 .ok_or_else(|| ApiaryError::Resolution {
407 path: table_ref.into(),
408 reason: "No box selected. Use 3-part name or run USE BOX first.".into(),
409 })?;
410 let frame_name = parts[0].to_string();
411 let register_name = frame_name.clone();
412 Ok((hive.clone(), box_name.clone(), frame_name, register_name))
413 }
414 _ => Err(ApiaryError::Resolution {
415 path: table_ref.into(),
416 reason: "Invalid table reference. Use hive.box.frame format.".into(),
417 }),
418 }
419 }
420
421 pub async fn execute_distributed(
432 &self,
433 sql: &str,
434 assignments: HashMap<NodeId, Vec<distributed::CellInfo>>,
435 ) -> Result<Vec<RecordBatch>> {
436 use distributed::*;
437
438 let mut tasks = Vec::new();
440 for (node_id, cells) in &assignments {
441 let task_id = format!("{}_{}", node_id.as_str(), uuid::Uuid::new_v4());
442 let cell_keys: Vec<String> = cells.iter().map(|c| c.storage_key.clone()).collect();
443
444 tasks.push(PlannedTask {
445 task_id,
446 node_id: node_id.clone(),
447 cells: cell_keys,
448 sql_fragment: sql.to_string(), });
450 }
451
452 let manifest = create_manifest(sql, tasks.clone(), None, 60);
454 write_manifest(&self.storage, &manifest).await?;
455
456 info!(
457 query_id = %manifest.query_id,
458 tasks = tasks.len(),
459 "Distributed query manifest written"
460 );
461
462 let local_tasks: Vec<_> = tasks.iter().filter(|t| t.node_id == self.node_id).collect();
464
465 let mut local_results = Vec::new();
466 for task in local_tasks {
467 match self.execute_task(&task.sql_fragment, &task.cells).await {
468 Ok(batches) => {
469 if !batches.is_empty() {
470 local_results.extend(batches);
471 }
472 }
473 Err(e) => {
474 warn!(task_id = %task.task_id, error = %e, "Local task failed");
475 }
476 }
477 }
478
479 if !local_results.is_empty() {
481 write_partial_result(
482 &self.storage,
483 &manifest.query_id,
484 &self.node_id,
485 &local_results,
486 )
487 .await?;
488 }
489
490 let remote_nodes: Vec<_> = tasks
492 .iter()
493 .filter(|t| t.node_id != self.node_id)
494 .map(|t| t.node_id.clone())
495 .collect();
496
497 let timeout = std::time::Duration::from_secs(manifest.timeout_secs);
498 let start = std::time::Instant::now();
499 let mut collected_results = local_results;
500
501 for remote_node in &remote_nodes {
502 let deadline = timeout.saturating_sub(start.elapsed());
503 if start.elapsed() >= timeout {
504 warn!(query_id = %manifest.query_id, "Query timeout reached");
505 break;
506 }
507
508 let poll_interval = std::time::Duration::from_millis(500);
510 let mut attempts = 0;
511 let max_attempts = (deadline.as_millis() / poll_interval.as_millis()) as usize;
512
513 while attempts < max_attempts {
514 match read_partial_result(&self.storage, &manifest.query_id, remote_node).await {
515 Ok(batches) => {
516 info!(
517 query_id = %manifest.query_id,
518 node_id = %remote_node,
519 "Partial result received"
520 );
521 collected_results.extend(batches);
522 break;
523 }
524 Err(_) => {
525 tokio::time::sleep(poll_interval).await;
526 attempts += 1;
527 }
528 }
529 }
530 }
531
532 if let Err(e) = cleanup_query(&self.storage, &manifest.query_id).await {
534 warn!(query_id = %manifest.query_id, error = %e, "Failed to cleanup query");
535 }
536
537 Ok(collected_results)
538 }
539
540 pub async fn execute_task(&self, sql: &str, cell_keys: &[String]) -> Result<Vec<RecordBatch>> {
547 if cell_keys.is_empty() {
548 return self.execute_standard_sql(sql).await;
550 }
551
552 let table_refs = extract_table_references(sql);
554
555 if table_refs.is_empty() {
556 return Err(ApiaryError::Config {
557 message: "No table references found in query".into(),
558 });
559 }
560
561 let session = SessionContext::new();
563
564 let cell_key_set: std::collections::HashSet<&String> = cell_keys.iter().collect();
566
567 for table_ref in &table_refs {
569 let (hive, box_name, frame_name, register_name) = self.resolve_table_ref(table_ref)?;
570
571 let frame_path = format!("{hive}/{box_name}/{frame_name}");
572 let ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
573
574 let cells: Vec<_> = ledger
576 .active_cells()
577 .iter()
578 .filter(|cell| {
579 let cell_storage_key = format!("{}/{}", frame_path, cell.path);
580 cell_key_set.contains(&cell_storage_key)
581 })
582 .collect();
583
584 info!(
585 frame = %frame_path,
586 total_cells = ledger.active_cells().len(),
587 assigned_cells = cells.len(),
588 "Cell filtering for distributed task"
589 );
590
591 if cells.is_empty() {
592 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
593 let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
594 let mem_table = datafusion::datasource::MemTable::try_new(
595 empty_batch.schema(),
596 vec![vec![empty_batch]],
597 )
598 .map_err(|e| ApiaryError::Internal {
599 message: format!("Failed to create empty MemTable: {e}"),
600 })?;
601 session
602 .register_table(®ister_name, Arc::new(mem_table))
603 .map_err(|e| ApiaryError::Internal {
604 message: format!("Failed to register table: {e}"),
605 })?;
606 continue;
607 }
608
609 let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
611 let merged = reader.read_cells_merged(&cells, None).await?;
612
613 let batches = match merged {
614 Some(batch) => vec![vec![batch]],
615 None => {
616 let arrow_schema = frame_schema_to_arrow(ledger.schema())?;
617 vec![vec![RecordBatch::new_empty(Arc::new(arrow_schema))]]
618 }
619 };
620
621 let schema = batches[0][0].schema();
622 let mem_table =
623 datafusion::datasource::MemTable::try_new(schema, batches).map_err(|e| {
624 ApiaryError::Internal {
625 message: format!("Failed to create MemTable: {e}"),
626 }
627 })?;
628
629 session
630 .register_table(®ister_name, Arc::new(mem_table))
631 .map_err(|e| ApiaryError::Internal {
632 message: format!("Failed to register table '{register_name}': {e}"),
633 })?;
634 }
635
636 let rewritten =
638 rewrite_sql_table_refs(sql, &table_refs, &self.current_hive, &self.current_box);
639
640 let df = session
642 .sql(&rewritten)
643 .await
644 .map_err(|e| ApiaryError::Internal {
645 message: format!("DataFusion query error: {e}"),
646 })?;
647
648 df.collect().await.map_err(|e| ApiaryError::Internal {
649 message: format!("DataFusion execution error: {e}"),
650 })
651 }
652}
653
654fn check_unsupported_dml(sql: &str) -> Option<ApiaryError> {
660 let upper = sql.to_uppercase();
661 let first_word = upper.split_whitespace().next().unwrap_or("");
662
663 match first_word {
664 "DELETE" => Some(ApiaryError::Unsupported {
665 message: "DELETE is not supported. Apiary uses append-only writes. Use overwrite_frame() to replace all data in a frame.".into(),
666 }),
667 "UPDATE" => Some(ApiaryError::Unsupported {
668 message: "UPDATE is not supported. Apiary uses append-only writes. Rewrite the frame with corrected data using overwrite_frame().".into(),
669 }),
670 "INSERT" => Some(ApiaryError::Unsupported {
671 message: "INSERT is not supported via SQL. Use write_to_frame() to add data.".into(),
672 }),
673 "DROP" => Some(ApiaryError::Unsupported {
674 message: "DROP is not supported via SQL. Use the registry API for DDL operations.".into(),
675 }),
676 "CREATE" => Some(ApiaryError::Unsupported {
677 message: "CREATE is not supported via SQL. Use create_frame() for DDL operations.".into(),
678 }),
679 "ALTER" => Some(ApiaryError::Unsupported {
680 message: "ALTER is not supported via SQL. Use the registry API for DDL operations.".into(),
681 }),
682 _ => None,
683 }
684}
685
686fn extract_table_references(sql: &str) -> Vec<String> {
691 let mut refs = Vec::new();
692 let tokens: Vec<&str> = sql.split_whitespace().collect();
693
694 for i in 0..tokens.len() {
695 let upper = tokens[i].to_uppercase();
696 if (upper == "FROM" || upper == "JOIN") && i + 1 < tokens.len() {
697 let table_name = tokens[i + 1]
698 .trim_end_matches(',')
699 .trim_end_matches(')')
700 .trim_end_matches(';');
701 if table_name.starts_with('(') || table_name.is_empty() {
703 continue;
704 }
705 let table_upper = table_name.to_uppercase();
707 if matches!(
708 table_upper.as_str(),
709 "SELECT" | "WHERE" | "GROUP" | "ORDER" | "LIMIT" | "HAVING"
710 ) {
711 continue;
712 }
713 if !refs.contains(&table_name.to_string()) {
714 refs.push(table_name.to_string());
715 }
716 }
717 }
718
719 refs
720}
721
722#[derive(Debug, Clone)]
724struct Predicate {
725 column: String,
726 op: PredicateOp,
727 value: String,
728}
729
730#[derive(Debug, Clone)]
731enum PredicateOp {
732 Eq,
733 Gt,
734 Lt,
735 Gte,
736 Lte,
737}
738
739fn extract_where_predicates(sql: &str) -> Vec<Predicate> {
748 let mut predicates = Vec::new();
749
750 let upper = sql.to_uppercase();
752 let where_pos = match upper.find(" WHERE ") {
753 Some(pos) => pos + 7,
754 None => return predicates,
755 };
756
757 let where_clause = &sql[where_pos..];
758 let end_keywords = [" GROUP ", " ORDER ", " LIMIT ", " HAVING ", ";"];
760 let end_pos = end_keywords
761 .iter()
762 .filter_map(|kw| where_clause.to_uppercase().find(kw))
763 .min()
764 .unwrap_or(where_clause.len());
765 let where_clause = &where_clause[..end_pos];
766
767 let parts: Vec<&str> = split_on_and(where_clause);
769
770 for part in parts {
771 let part = part.trim();
772 if let Some(pred) = parse_predicate(part) {
773 predicates.push(pred);
774 }
775 }
776
777 predicates
778}
779
780fn split_on_and(clause: &str) -> Vec<&str> {
782 let mut parts = Vec::new();
783 let upper = clause.to_uppercase();
784 let mut last = 0;
785
786 for (i, _) in upper.match_indices(" AND ") {
787 parts.push(&clause[last..i]);
788 last = i + 5; }
790 parts.push(&clause[last..]);
791 parts
792}
793
794fn parse_predicate(condition: &str) -> Option<Predicate> {
796 let condition = condition.trim();
797
798 let ops = [
800 (">=", PredicateOp::Gte),
801 ("<=", PredicateOp::Lte),
802 (">", PredicateOp::Gt),
803 ("<", PredicateOp::Lt),
804 ("=", PredicateOp::Eq),
805 ];
806
807 for (op_str, op) in &ops {
808 if let Some(pos) = condition.find(op_str) {
809 let col = condition[..pos].trim();
810 let val = condition[pos + op_str.len()..].trim();
811
812 let val = val.trim_matches('\'').trim_matches('"').to_string();
814
815 if !col.is_empty() && !val.is_empty() {
816 return Some(Predicate {
817 column: col.to_string(),
818 op: op.clone(),
819 value: val,
820 });
821 }
822 }
823 }
824
825 None
826}
827
828type StatFilters = HashMap<String, (Option<serde_json::Value>, Option<serde_json::Value>)>;
830
831fn build_filters(
833 predicates: &[Predicate],
834 partition_columns: &[String],
835) -> (HashMap<String, String>, StatFilters) {
836 let mut partition_filters = HashMap::new();
837 let mut stat_filters: StatFilters = HashMap::new();
838
839 for pred in predicates {
840 if partition_columns.contains(&pred.column) {
841 if matches!(pred.op, PredicateOp::Eq) {
843 partition_filters.insert(pred.column.clone(), pred.value.clone());
844 }
845 }
846
847 if let Ok(num) = pred.value.parse::<f64>() {
849 let json_val = serde_json::json!(num);
850 let entry = stat_filters
851 .entry(pred.column.clone())
852 .or_insert((None, None));
853 match pred.op {
854 PredicateOp::Gt | PredicateOp::Gte => {
855 entry.0 = Some(json_val);
857 }
858 PredicateOp::Lt | PredicateOp::Lte => {
859 entry.1 = Some(json_val);
861 }
862 PredicateOp::Eq => {
863 entry.0 = Some(json_val.clone());
865 entry.1 = Some(json_val);
866 }
867 }
868 }
869 }
870
871 (partition_filters, stat_filters)
872}
873
874fn frame_schema_to_arrow(schema: &apiary_core::FrameSchema) -> Result<Schema> {
876 let fields: Vec<Field> = schema
877 .fields
878 .iter()
879 .map(|f| {
880 let dt = match f.data_type.as_str() {
881 "int8" => DataType::Int8,
882 "int16" => DataType::Int16,
883 "int32" => DataType::Int32,
884 "int64" => DataType::Int64,
885 "uint8" => DataType::UInt8,
886 "uint16" => DataType::UInt16,
887 "uint32" => DataType::UInt32,
888 "uint64" => DataType::UInt64,
889 "float32" | "float" => DataType::Float32,
890 "float64" | "double" => DataType::Float64,
891 "string" | "utf8" => DataType::Utf8,
892 "boolean" | "bool" => DataType::Boolean,
893 "datetime" | "timestamp" => {
894 DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
895 }
896 _ => DataType::Utf8,
897 };
898 Field::new(&f.name, dt, f.nullable)
899 })
900 .collect();
901
902 Ok(Schema::new(fields))
903}
904
905fn rewrite_sql_table_refs(
907 sql: &str,
908 table_refs: &[String],
909 _current_hive: &Option<String>,
910 _current_box: &Option<String>,
911) -> String {
912 let mut result = sql.to_string();
913
914 for table_ref in table_refs {
915 let parts: Vec<&str> = table_ref.split('.').collect();
916 let register_name = parts.last().unwrap_or(&table_ref.as_str()).to_string();
917 if parts.len() > 1 {
918 result = result.replace(table_ref, ®ister_name);
920 }
921 }
922
923 result
924}
925
926fn single_message_batch(message: &str) -> RecordBatch {
928 let schema = Arc::new(Schema::new(vec![Field::new(
929 "message",
930 DataType::Utf8,
931 false,
932 )]));
933 RecordBatch::try_new(
934 schema,
935 vec![Arc::new(StringArray::from(vec![message.to_string()]))],
936 )
937 .unwrap()
938}
939
940fn string_list_batch(column_name: &str, values: &[String]) -> RecordBatch {
942 let schema = Arc::new(Schema::new(vec![Field::new(
943 column_name,
944 DataType::Utf8,
945 false,
946 )]));
947 RecordBatch::try_new(
948 schema,
949 vec![Arc::new(StringArray::from(
950 values.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
951 ))],
952 )
953 .unwrap()
954}
955
956#[cfg(test)]
957mod tests {
958 use super::*;
959 use apiary_core::{CellSizingPolicy, FieldDef, FrameSchema, NodeId};
960 use apiary_storage::cell_writer::CellWriter;
961 use apiary_storage::ledger::Ledger;
962 use apiary_storage::local::LocalBackend;
963 use arrow::array::{Float64Array, Int64Array};
964
965 async fn make_test_env() -> (
966 Arc<dyn StorageBackend>,
967 Arc<RegistryManager>,
968 tempfile::TempDir,
969 ) {
970 let dir = tempfile::tempdir().unwrap();
971 let backend = LocalBackend::new(dir.path().to_path_buf()).await.unwrap();
972 let storage: Arc<dyn StorageBackend> = Arc::new(backend);
973 let registry = Arc::new(RegistryManager::new(Arc::clone(&storage)));
974 let _ = registry.load_or_create().await.unwrap();
975 (storage, registry, dir)
976 }
977
978 fn test_schema() -> serde_json::Value {
979 serde_json::json!({
980 "region": "string",
981 "temp": "float64",
982 "humidity": "int64"
983 })
984 }
985
986 async fn setup_frame(storage: &Arc<dyn StorageBackend>, registry: &Arc<RegistryManager>) {
987 registry.create_hive("test_hive").await.unwrap();
988 registry.create_box("test_hive", "test_box").await.unwrap();
989 registry
990 .create_frame(
991 "test_hive",
992 "test_box",
993 "sensors",
994 test_schema(),
995 vec!["region".into()],
996 )
997 .await
998 .unwrap();
999
1000 let frame_schema = FrameSchema {
1002 fields: vec![
1003 FieldDef {
1004 name: "region".into(),
1005 data_type: "string".into(),
1006 nullable: false,
1007 },
1008 FieldDef {
1009 name: "temp".into(),
1010 data_type: "float64".into(),
1011 nullable: true,
1012 },
1013 FieldDef {
1014 name: "humidity".into(),
1015 data_type: "int64".into(),
1016 nullable: true,
1017 },
1018 ],
1019 };
1020
1021 let node_id = NodeId::new("test_node");
1022 let frame_path = "test_hive/test_box/sensors";
1023
1024 let mut ledger = Ledger::create(
1025 Arc::clone(storage),
1026 frame_path,
1027 frame_schema.clone(),
1028 vec!["region".into()],
1029 &node_id,
1030 )
1031 .await
1032 .unwrap();
1033
1034 let sizing = CellSizingPolicy::new(256 * 1024 * 1024, 512 * 1024 * 1024, 16 * 1024 * 1024);
1035
1036 let writer = CellWriter::new(
1037 Arc::clone(storage),
1038 frame_path.into(),
1039 frame_schema,
1040 vec!["region".into()],
1041 sizing,
1042 );
1043
1044 let schema = Arc::new(Schema::new(vec![
1046 Field::new("region", DataType::Utf8, false),
1047 Field::new("temp", DataType::Float64, true),
1048 Field::new("humidity", DataType::Int64, true),
1049 ]));
1050
1051 let batch = RecordBatch::try_new(
1052 schema,
1053 vec![
1054 Arc::new(StringArray::from(vec!["north", "north", "south", "south"])),
1055 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1056 Arc::new(Int64Array::from(vec![50, 60, 70, 80])),
1057 ],
1058 )
1059 .unwrap();
1060
1061 let cells = writer.write(&batch).await.unwrap();
1062 ledger
1063 .commit(apiary_core::LedgerAction::AddCells { cells }, &node_id)
1064 .await
1065 .unwrap();
1066 }
1067
1068 #[tokio::test]
1069 async fn test_select_all() {
1070 let (storage, registry, _dir) = make_test_env().await;
1071 setup_frame(&storage, ®istry).await;
1072
1073 let mut ctx = ApiaryQueryContext::new(storage, registry);
1074 let results = ctx
1075 .sql("SELECT * FROM test_hive.test_box.sensors")
1076 .await
1077 .unwrap();
1078
1079 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1080 assert_eq!(total_rows, 4);
1081 }
1082
1083 #[tokio::test]
1084 async fn test_aggregation() {
1085 let (storage, registry, _dir) = make_test_env().await;
1086 setup_frame(&storage, ®istry).await;
1087
1088 let mut ctx = ApiaryQueryContext::new(storage, registry);
1089 let results = ctx
1090 .sql("SELECT region, AVG(temp) as avg_temp FROM test_hive.test_box.sensors GROUP BY region ORDER BY region")
1091 .await
1092 .unwrap();
1093
1094 assert!(!results.is_empty());
1095 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1096 assert_eq!(total_rows, 2); }
1098
1099 #[tokio::test]
1100 async fn test_use_hive_and_box() {
1101 let (storage, registry, _dir) = make_test_env().await;
1102 setup_frame(&storage, ®istry).await;
1103
1104 let mut ctx = ApiaryQueryContext::new(storage, registry);
1105 ctx.sql("USE HIVE test_hive").await.unwrap();
1106 ctx.sql("USE BOX test_box").await.unwrap();
1107 let results = ctx.sql("SELECT * FROM sensors").await.unwrap();
1108
1109 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1110 assert_eq!(total_rows, 4);
1111 }
1112
1113 #[tokio::test]
1114 async fn test_show_hives() {
1115 let (storage, registry, _dir) = make_test_env().await;
1116 setup_frame(&storage, ®istry).await;
1117
1118 let mut ctx = ApiaryQueryContext::new(storage, registry);
1119 let results = ctx.sql("SHOW HIVES").await.unwrap();
1120
1121 assert_eq!(results.len(), 1);
1122 assert_eq!(results[0].num_rows(), 1);
1123 }
1124
1125 #[tokio::test]
1126 async fn test_show_frames() {
1127 let (storage, registry, _dir) = make_test_env().await;
1128 setup_frame(&storage, ®istry).await;
1129
1130 let mut ctx = ApiaryQueryContext::new(storage, registry);
1131 let results = ctx.sql("SHOW FRAMES IN test_hive.test_box").await.unwrap();
1132
1133 assert_eq!(results.len(), 1);
1134 assert!(results[0].num_rows() >= 1);
1135 }
1136
1137 #[tokio::test]
1138 async fn test_describe() {
1139 let (storage, registry, _dir) = make_test_env().await;
1140 setup_frame(&storage, ®istry).await;
1141
1142 let mut ctx = ApiaryQueryContext::new(storage, registry);
1143 let results = ctx
1144 .sql("DESCRIBE test_hive.test_box.sensors")
1145 .await
1146 .unwrap();
1147
1148 assert_eq!(results.len(), 1);
1149 assert!(results[0].num_rows() >= 3);
1150 }
1151
1152 #[tokio::test]
1153 async fn test_delete_blocked() {
1154 let (storage, registry, _dir) = make_test_env().await;
1155 let mut ctx = ApiaryQueryContext::new(storage, registry);
1156
1157 let result = ctx.sql("DELETE FROM test_hive.test_box.sensors").await;
1158 assert!(result.is_err());
1159 let err = result.unwrap_err().to_string();
1160 assert!(err.contains("not supported"));
1161 }
1162
1163 #[tokio::test]
1164 async fn test_update_blocked() {
1165 let (storage, registry, _dir) = make_test_env().await;
1166 let mut ctx = ApiaryQueryContext::new(storage, registry);
1167
1168 let result = ctx
1169 .sql("UPDATE test_hive.test_box.sensors SET temp = 0")
1170 .await;
1171 assert!(result.is_err());
1172 let err = result.unwrap_err().to_string();
1173 assert!(err.contains("not supported"));
1174 }
1175
1176 #[tokio::test]
1177 async fn test_where_filter() {
1178 let (storage, registry, _dir) = make_test_env().await;
1179 setup_frame(&storage, ®istry).await;
1180
1181 let mut ctx = ApiaryQueryContext::new(storage, registry);
1182 let results = ctx
1183 .sql("SELECT * FROM test_hive.test_box.sensors WHERE region = 'north'")
1184 .await
1185 .unwrap();
1186
1187 let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
1188 assert_eq!(total_rows, 2);
1189 }
1190
1191 #[tokio::test]
1192 async fn test_projection() {
1193 let (storage, registry, _dir) = make_test_env().await;
1194 setup_frame(&storage, ®istry).await;
1195
1196 let mut ctx = ApiaryQueryContext::new(storage, registry);
1197 let results = ctx
1198 .sql("SELECT temp FROM test_hive.test_box.sensors")
1199 .await
1200 .unwrap();
1201
1202 assert!(!results.is_empty());
1203 assert_eq!(results[0].num_columns(), 1);
1204 assert_eq!(results[0].schema().field(0).name(), "temp");
1205 }
1206
1207 #[test]
1208 fn test_extract_table_references() {
1209 let refs = extract_table_references("SELECT * FROM hive.box.frame WHERE x = 1");
1210 assert_eq!(refs, vec!["hive.box.frame"]);
1211
1212 let refs = extract_table_references("SELECT * FROM frame1 JOIN frame2 ON x = y");
1213 assert_eq!(refs, vec!["frame1", "frame2"]);
1214 }
1215
1216 #[test]
1217 fn test_extract_predicates() {
1218 let preds =
1219 extract_where_predicates("SELECT * FROM t WHERE region = 'north' AND temp > 25");
1220 assert_eq!(preds.len(), 2);
1221 assert_eq!(preds[0].column, "region");
1222 assert_eq!(preds[0].value, "north");
1223 assert_eq!(preds[1].column, "temp");
1224 assert_eq!(preds[1].value, "25");
1225 }
1226
1227 #[test]
1228 fn test_check_unsupported_dml() {
1229 assert!(check_unsupported_dml("DELETE FROM t").is_some());
1230 assert!(check_unsupported_dml("UPDATE t SET x = 1").is_some());
1231 assert!(check_unsupported_dml("SELECT * FROM t").is_none());
1232 }
1233
1234 #[tokio::test]
1235 async fn test_show_boxes_without_qualifier() {
1236 let (storage, registry, _dir) = make_test_env().await;
1237 setup_frame(&storage, ®istry).await;
1238
1239 let mut ctx = ApiaryQueryContext::new(storage, registry);
1240 ctx.sql("USE HIVE test_hive").await.unwrap();
1241
1242 let results = ctx.sql("SHOW BOXES").await.unwrap();
1243 assert_eq!(results.len(), 1);
1244 assert!(results[0].num_rows() >= 1);
1245 assert_eq!(results[0].schema().field(0).name(), "box");
1246 }
1247
1248 #[tokio::test]
1249 async fn test_show_frames_without_qualifier() {
1250 let (storage, registry, _dir) = make_test_env().await;
1251 setup_frame(&storage, ®istry).await;
1252
1253 let mut ctx = ApiaryQueryContext::new(storage, registry);
1254 ctx.sql("USE HIVE test_hive").await.unwrap();
1255 ctx.sql("USE BOX test_box").await.unwrap();
1256
1257 let results = ctx.sql("SHOW FRAMES").await.unwrap();
1258 assert_eq!(results.len(), 1);
1259 assert!(results[0].num_rows() >= 1);
1260 assert_eq!(results[0].schema().field(0).name(), "frame");
1261 }
1262}