1use crate::cache::{QueryCache, QueryCacheKey};
7use crate::cube::ElastiCube;
8use crate::error::{Error, Result};
9use crate::optimization::OptimizationConfig;
10use arrow::record_batch::RecordBatch;
11use datafusion::datasource::MemTable;
12use datafusion::prelude::*;
13use std::sync::Arc;
14
15pub struct QueryBuilder {
40 cube: Arc<ElastiCube>,
42
43 ctx: SessionContext,
45
46 #[allow(dead_code)] config: OptimizationConfig,
49
50 cache: Option<Arc<QueryCache>>,
52
53 sql_query: Option<String>,
55
56 select_exprs: Vec<String>,
58
59 filter_expr: Option<String>,
61
62 group_by_exprs: Vec<String>,
64
65 order_by_exprs: Vec<String>,
67
68 limit_count: Option<usize>,
70
71 offset_count: Option<usize>,
73}
74
75impl QueryBuilder {
76 pub(crate) fn new(cube: Arc<ElastiCube>) -> Result<Self> {
78 Self::with_config(cube, OptimizationConfig::default())
79 }
80
81 pub(crate) fn with_config(cube: Arc<ElastiCube>, config: OptimizationConfig) -> Result<Self> {
83 let session_config = config.to_session_config();
85 let runtime_env = config.to_runtime_env();
86 let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
87
88 let cache = if config.enable_query_cache {
90 Some(Arc::new(QueryCache::new(config.max_cache_entries)))
91 } else {
92 None
93 };
94
95 Ok(Self {
96 cube,
97 ctx,
98 config,
99 cache,
100 sql_query: None,
101 select_exprs: Vec::new(),
102 filter_expr: None,
103 group_by_exprs: Vec::new(),
104 order_by_exprs: Vec::new(),
105 limit_count: None,
106 offset_count: None,
107 })
108 }
109
110 pub fn sql(mut self, query: impl Into<String>) -> Self {
123 self.sql_query = Some(query.into());
124 self
125 }
126
127 pub fn select(mut self, columns: &[impl AsRef<str>]) -> Self {
137 self.select_exprs = columns.iter().map(|c| c.as_ref().to_string()).collect();
138 self
139 }
140
141 pub fn filter(mut self, condition: impl Into<String>) -> Self {
151 self.filter_expr = Some(condition.into());
152 self
153 }
154
155 pub fn where_clause(self, condition: impl Into<String>) -> Self {
157 self.filter(condition)
158 }
159
160 pub fn group_by(mut self, columns: &[impl AsRef<str>]) -> Self {
170 self.group_by_exprs = columns.iter().map(|c| c.as_ref().to_string()).collect();
171 self
172 }
173
174 pub fn order_by(mut self, columns: &[impl AsRef<str>]) -> Self {
184 self.order_by_exprs = columns.iter().map(|c| c.as_ref().to_string()).collect();
185 self
186 }
187
188 pub fn limit(mut self, count: usize) -> Self {
195 self.limit_count = Some(count);
196 self
197 }
198
199 pub fn offset(mut self, count: usize) -> Self {
206 self.offset_count = Some(count);
207 self
208 }
209
210 pub fn slice(self, dimension: impl AsRef<str>, value: impl AsRef<str>) -> Self {
217 let condition = format!("{} = '{}'", dimension.as_ref(), value.as_ref());
218 self.filter(condition)
219 }
220
221 pub fn dice(self, filters: &[(impl AsRef<str>, impl AsRef<str>)]) -> Self {
228 let conditions: Vec<String> = filters
229 .iter()
230 .map(|(dim, val)| format!("{} = '{}'", dim.as_ref(), val.as_ref()))
231 .collect();
232 let combined = conditions.join(" AND ");
233 self.filter(combined)
234 }
235
236 pub fn drill_down(
246 mut self,
247 _parent_level: impl AsRef<str>,
248 child_levels: &[impl AsRef<str>],
249 ) -> Self {
250 self.group_by_exprs
252 .extend(child_levels.iter().map(|c| c.as_ref().to_string()));
253 self
254 }
255
256 pub fn roll_up(mut self, dimensions_to_remove: &[impl AsRef<str>]) -> Self {
265 let to_remove: Vec<String> = dimensions_to_remove
266 .iter()
267 .map(|d| d.as_ref().to_string())
268 .collect();
269
270 self.group_by_exprs
271 .retain(|col| !to_remove.contains(col));
272 self
273 }
274
275 pub async fn execute(mut self) -> Result<QueryResult> {
280 let query_sql = if let Some(sql) = &self.sql_query {
282 sql.clone()
283 } else {
284 self.build_sql_query()
285 };
286
287 if let Some(cache) = &self.cache {
289 let cache_key = QueryCacheKey::new(&query_sql);
290 if let Some(cached_result) = cache.get(&cache_key) {
291 return Ok(cached_result);
292 }
293 }
294
295 self.register_cube_data().await?;
297
298 let dataframe = if let Some(sql) = &self.sql_query {
300 self.execute_sql(sql).await?
302 } else {
303 self.execute_fluent_query().await?
305 };
306
307 let batches = dataframe
309 .collect()
310 .await
311 .map_err(|e| Error::query(format!("Failed to collect query results: {}", e)))?;
312
313 let row_count = batches.iter().map(|b| b.num_rows()).sum();
314
315 let result = QueryResult {
316 batches,
317 row_count,
318 };
319
320 if let Some(cache) = &self.cache {
322 let cache_key = QueryCacheKey::new(&query_sql);
323 cache.put(cache_key, result.clone());
324 }
325
326 Ok(result)
327 }
328
329 async fn register_cube_data(&mut self) -> Result<()> {
331 let schema = self.cube.arrow_schema().clone();
332 let data = self.cube.data().to_vec();
333
334 let partitions = vec![data];
337
338 let mem_table = MemTable::try_new(schema, partitions)
339 .map_err(|e| Error::query(format!("Failed to create MemTable: {}", e)))?;
340
341 self.ctx
342 .register_table("cube", Arc::new(mem_table))
343 .map_err(|e| Error::query(format!("Failed to register table: {}", e)))?;
344
345 Ok(())
346 }
347
348 async fn execute_sql(&self, query: &str) -> Result<DataFrame> {
350 self.ctx
351 .sql(query)
352 .await
353 .map_err(|e| Error::query(format!("SQL execution failed: {}", e)))
354 }
355
356 fn expand_calculated_fields(&self, expr: &str) -> String {
362 let mut expanded = expr.to_string();
363 let schema = self.cube.schema();
364
365 const MAX_ITERATIONS: usize = 10;
368 for _ in 0..MAX_ITERATIONS {
369 let before = expanded.clone();
370
371 for vdim in schema.virtual_dimensions() {
373 let pattern = vdim.name();
374 let regex_pattern = format!(r"\b{}\b", regex::escape(pattern));
377 if let Ok(re) = regex::Regex::new(®ex_pattern) {
378 let replacement = format!("({})", vdim.expression());
379 expanded = re.replace_all(&expanded, replacement.as_str()).to_string();
380 }
381 }
382
383 for calc_measure in schema.calculated_measures() {
385 let pattern = calc_measure.name();
386 let regex_pattern = format!(r"\b{}\b", regex::escape(pattern));
387 if let Ok(re) = regex::Regex::new(®ex_pattern) {
388 let replacement = format!("({})", calc_measure.expression());
389 expanded = re.replace_all(&expanded, replacement.as_str()).to_string();
390 }
391 }
392
393 if expanded == before {
395 break;
396 }
397 }
398
399 expanded
400 }
401
402 fn build_sql_query(&self) -> String {
404 let mut query_str = String::from("SELECT ");
405
406 if self.select_exprs.is_empty() {
408 query_str.push('*');
409 } else {
410 let expanded_selects: Vec<String> = self
411 .select_exprs
412 .iter()
413 .map(|expr| self.expand_calculated_fields(expr))
414 .collect();
415 query_str.push_str(&expanded_selects.join(", "));
416 }
417
418 query_str.push_str(" FROM cube");
419
420 if let Some(filter) = &self.filter_expr {
422 query_str.push_str(" WHERE ");
423 let expanded_filter = self.expand_calculated_fields(filter);
424 query_str.push_str(&expanded_filter);
425 }
426
427 if !self.group_by_exprs.is_empty() {
429 query_str.push_str(" GROUP BY ");
430 let expanded_groups: Vec<String> = self
431 .group_by_exprs
432 .iter()
433 .map(|expr| self.expand_calculated_fields(expr))
434 .collect();
435 query_str.push_str(&expanded_groups.join(", "));
436 }
437
438 if !self.order_by_exprs.is_empty() {
440 query_str.push_str(" ORDER BY ");
441 let expanded_orders: Vec<String> = self
442 .order_by_exprs
443 .iter()
444 .map(|expr| self.expand_calculated_fields(expr))
445 .collect();
446 query_str.push_str(&expanded_orders.join(", "));
447 }
448
449 if let Some(limit) = self.limit_count {
451 query_str.push_str(&format!(" LIMIT {}", limit));
452 }
453
454 if let Some(offset) = self.offset_count {
456 query_str.push_str(&format!(" OFFSET {}", offset));
457 }
458
459 query_str
460 }
461
462 async fn execute_fluent_query(&self) -> Result<DataFrame> {
464 let query_str = self.build_sql_query();
465 self.execute_sql(&query_str).await
466 }
467}
468
469#[derive(Debug, Clone)]
471pub struct QueryResult {
472 batches: Vec<RecordBatch>,
474
475 row_count: usize,
477}
478
479impl QueryResult {
480 #[cfg(test)]
482 pub(crate) fn new_for_testing(batches: Vec<RecordBatch>, row_count: usize) -> Self {
483 Self {
484 batches,
485 row_count,
486 }
487 }
488
489 pub fn batches(&self) -> &[RecordBatch] {
491 &self.batches
492 }
493
494 pub fn row_count(&self) -> usize {
496 self.row_count
497 }
498
499 pub fn is_empty(&self) -> bool {
501 self.row_count == 0
502 }
503
504 pub fn pretty_print(&self) -> Result<String> {
508 use arrow::util::pretty::pretty_format_batches;
509
510 pretty_format_batches(&self.batches)
511 .map(|display| display.to_string())
512 .map_err(|e| Error::query(format!("Failed to format results: {}", e)))
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::builder::ElastiCubeBuilder;
520 use crate::cube::AggFunc;
521 use arrow::array::{Float64Array, Int32Array, StringArray};
522 use arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
523
524 fn create_test_cube() -> Result<ElastiCube> {
525 let schema = Arc::new(ArrowSchema::new(vec![
527 Field::new("region", DataType::Utf8, false),
528 Field::new("product", DataType::Utf8, false),
529 Field::new("sales", DataType::Float64, false),
530 Field::new("quantity", DataType::Int32, false),
531 ]));
532
533 let batch = RecordBatch::try_new(
534 schema.clone(),
535 vec![
536 Arc::new(StringArray::from(vec![
537 "North", "South", "North", "East", "South",
538 ])),
539 Arc::new(StringArray::from(vec![
540 "Widget", "Widget", "Gadget", "Widget", "Gadget",
541 ])),
542 Arc::new(Float64Array::from(vec![100.0, 200.0, 150.0, 175.0, 225.0])),
543 Arc::new(Int32Array::from(vec![10, 20, 15, 17, 22])),
544 ],
545 )
546 .unwrap();
547
548 ElastiCubeBuilder::new("test_cube")
549 .add_dimension("region", DataType::Utf8)?
550 .add_dimension("product", DataType::Utf8)?
551 .add_measure("sales", DataType::Float64, AggFunc::Sum)?
552 .add_measure("quantity", DataType::Int32, AggFunc::Sum)?
553 .load_record_batches(schema, vec![batch])?
554 .build()
555 }
556
557 #[tokio::test]
558 async fn test_query_select_all() {
559 let cube = create_test_cube().unwrap();
560 let arc_cube = Arc::new(cube);
561
562 let result = arc_cube.query().unwrap().execute().await.unwrap();
563
564 assert_eq!(result.row_count(), 5);
565 assert_eq!(result.batches().len(), 1);
566 }
567
568 #[tokio::test]
569 async fn test_query_select_columns() {
570 let cube = create_test_cube().unwrap();
571 let arc_cube = Arc::new(cube);
572
573 let result = arc_cube
574 .query()
575 .unwrap()
576 .select(&["region", "sales"])
577 .execute()
578 .await
579 .unwrap();
580
581 assert_eq!(result.row_count(), 5);
582 assert_eq!(result.batches()[0].num_columns(), 2);
584 }
585
586 #[tokio::test]
587 async fn test_query_filter() {
588 let cube = create_test_cube().unwrap();
589 let arc_cube = Arc::new(cube);
590
591 let result = arc_cube
592 .query()
593 .unwrap()
594 .filter("sales > 150")
595 .execute()
596 .await
597 .unwrap();
598
599 assert_eq!(result.row_count(), 3); }
601
602 #[tokio::test]
603 async fn test_query_group_by() {
604 let cube = create_test_cube().unwrap();
605 let arc_cube = Arc::new(cube);
606
607 let result = arc_cube
608 .query()
609 .unwrap()
610 .select(&["region", "SUM(sales) as total_sales"])
611 .group_by(&["region"])
612 .execute()
613 .await
614 .unwrap();
615
616 assert_eq!(result.row_count(), 3); }
618
619 #[tokio::test]
620 async fn test_query_order_by() {
621 let cube = create_test_cube().unwrap();
622 let arc_cube = Arc::new(cube);
623
624 let result = arc_cube
625 .query()
626 .unwrap()
627 .select(&["region", "sales"])
628 .order_by(&["sales DESC"])
629 .execute()
630 .await
631 .unwrap();
632
633 assert_eq!(result.row_count(), 5);
634 }
636
637 #[tokio::test]
638 async fn test_query_limit() {
639 let cube = create_test_cube().unwrap();
640 let arc_cube = Arc::new(cube);
641
642 let result = arc_cube
643 .query()
644 .unwrap()
645 .limit(3)
646 .execute()
647 .await
648 .unwrap();
649
650 assert_eq!(result.row_count(), 3);
651 }
652
653 #[tokio::test]
654 async fn test_query_sql() {
655 let cube = create_test_cube().unwrap();
656 let arc_cube = Arc::new(cube);
657
658 let result = arc_cube
659 .query()
660 .unwrap()
661 .sql("SELECT region, SUM(sales) as total FROM cube GROUP BY region ORDER BY total DESC")
662 .execute()
663 .await
664 .unwrap();
665
666 assert_eq!(result.row_count(), 3);
667 }
668
669 #[tokio::test]
670 async fn test_olap_slice() {
671 let cube = create_test_cube().unwrap();
672 let arc_cube = Arc::new(cube);
673
674 let result = arc_cube
675 .query()
676 .unwrap()
677 .slice("region", "North")
678 .execute()
679 .await
680 .unwrap();
681
682 assert_eq!(result.row_count(), 2); }
684
685 #[tokio::test]
686 async fn test_olap_dice() {
687 let cube = create_test_cube().unwrap();
688 let arc_cube = Arc::new(cube);
689
690 let result = arc_cube
691 .query()
692 .unwrap()
693 .dice(&[("region", "North"), ("product", "Widget")])
694 .execute()
695 .await
696 .unwrap();
697
698 assert_eq!(result.row_count(), 1); }
700
701 #[tokio::test]
702 async fn test_complex_query() {
703 let cube = create_test_cube().unwrap();
704 let arc_cube = Arc::new(cube);
705
706 let result = arc_cube
707 .query()
708 .unwrap()
709 .select(&["region", "product", "SUM(sales) as total_sales", "AVG(quantity) as avg_qty"])
710 .filter("sales > 100")
711 .group_by(&["region", "product"])
712 .order_by(&["total_sales DESC"])
713 .limit(5)
714 .execute()
715 .await
716 .unwrap();
717
718 assert!(result.row_count() > 0);
719 }
720}