elasticube_core/
query.rs

1//! Query API for ElastiCube
2//!
3//! Provides a fluent API for building and executing analytical queries
4//! against ElastiCube data using Apache DataFusion.
5
6use crate::cache::{QueryCache, QueryCacheKey};
7use crate::cube::ElastiCube;
8use crate::error::{Error, Result};
9use crate::optimization::OptimizationConfig;
10use arrow::record_batch::RecordBatch;
11use datafusion::datasource::MemTable;
12use datafusion::prelude::*;
13use std::sync::Arc;
14
15/// Query builder for ElastiCube queries
16///
17/// Provides a fluent API for building and executing queries against a cube.
18/// Supports both SQL queries and a DataFrame-style fluent API.
19///
20/// # Examples
21///
22/// ```rust,ignore
23/// // SQL query
24/// let results = cube.query()
25///     .sql("SELECT region, SUM(sales) FROM cube GROUP BY region")
26///     .execute()
27///     .await?;
28///
29/// // Fluent API query
30/// let results = cube.query()
31///     .select(&["region", "SUM(sales) as total_sales"])
32///     .filter("sales > 1000")
33///     .group_by(&["region"])
34///     .order_by(&["total_sales DESC"])
35///     .limit(10)
36///     .execute()
37///     .await?;
38/// ```
39pub struct QueryBuilder {
40    /// Reference to the parent cube
41    cube: Arc<ElastiCube>,
42
43    /// DataFusion SessionContext for query execution
44    ctx: SessionContext,
45
46    /// Optimization configuration
47    #[allow(dead_code)] // Used for creating SessionContext, may be used in future features
48    config: OptimizationConfig,
49
50    /// Optional query cache
51    cache: Option<Arc<QueryCache>>,
52
53    /// Optional SQL query string (takes precedence over fluent API)
54    sql_query: Option<String>,
55
56    /// SELECT columns/expressions
57    select_exprs: Vec<String>,
58
59    /// WHERE filter condition
60    filter_expr: Option<String>,
61
62    /// GROUP BY columns
63    group_by_exprs: Vec<String>,
64
65    /// ORDER BY expressions
66    order_by_exprs: Vec<String>,
67
68    /// LIMIT clause
69    limit_count: Option<usize>,
70
71    /// OFFSET clause
72    offset_count: Option<usize>,
73}
74
75impl QueryBuilder {
76    /// Create a new query builder for the given cube
77    pub(crate) fn new(cube: Arc<ElastiCube>) -> Result<Self> {
78        Self::with_config(cube, OptimizationConfig::default())
79    }
80
81    /// Create a new query builder with custom optimization configuration
82    pub(crate) fn with_config(cube: Arc<ElastiCube>, config: OptimizationConfig) -> Result<Self> {
83        // Create SessionContext with optimization settings
84        let session_config = config.to_session_config();
85        let runtime_env = config.to_runtime_env();
86        let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
87
88        // Create query cache if enabled
89        let cache = if config.enable_query_cache {
90            Some(Arc::new(QueryCache::new(config.max_cache_entries)))
91        } else {
92            None
93        };
94
95        Ok(Self {
96            cube,
97            ctx,
98            config,
99            cache,
100            sql_query: None,
101            select_exprs: Vec::new(),
102            filter_expr: None,
103            group_by_exprs: Vec::new(),
104            order_by_exprs: Vec::new(),
105            limit_count: None,
106            offset_count: None,
107        })
108    }
109
110    /// Execute a raw SQL query
111    ///
112    /// # Arguments
113    /// * `query` - SQL query string (can reference the cube as "cube")
114    ///
115    /// # Example
116    /// ```rust,ignore
117    /// let results = cube.query()
118    ///     .sql("SELECT region, SUM(sales) as total FROM cube GROUP BY region")
119    ///     .execute()
120    ///     .await?;
121    /// ```
122    pub fn sql(mut self, query: impl Into<String>) -> Self {
123        self.sql_query = Some(query.into());
124        self
125    }
126
127    /// Select specific columns or expressions
128    ///
129    /// # Arguments
130    /// * `columns` - Column names or SQL expressions
131    ///
132    /// # Example
133    /// ```rust,ignore
134    /// .select(&["region", "product", "SUM(sales) as total_sales"])
135    /// ```
136    pub fn select(mut self, columns: &[impl AsRef<str>]) -> Self {
137        self.select_exprs = columns.iter().map(|c| c.as_ref().to_string()).collect();
138        self
139    }
140
141    /// Add a WHERE filter condition
142    ///
143    /// # Arguments
144    /// * `condition` - SQL filter expression
145    ///
146    /// # Example
147    /// ```rust,ignore
148    /// .filter("sales > 1000 AND region = 'North'")
149    /// ```
150    pub fn filter(mut self, condition: impl Into<String>) -> Self {
151        self.filter_expr = Some(condition.into());
152        self
153    }
154
155    /// Add WHERE filter (alias for filter)
156    pub fn where_clause(self, condition: impl Into<String>) -> Self {
157        self.filter(condition)
158    }
159
160    /// Group by columns
161    ///
162    /// # Arguments
163    /// * `columns` - Column names to group by
164    ///
165    /// # Example
166    /// ```rust,ignore
167    /// .group_by(&["region", "product"])
168    /// ```
169    pub fn group_by(mut self, columns: &[impl AsRef<str>]) -> Self {
170        self.group_by_exprs = columns.iter().map(|c| c.as_ref().to_string()).collect();
171        self
172    }
173
174    /// Order results by columns
175    ///
176    /// # Arguments
177    /// * `columns` - Column names with optional ASC/DESC
178    ///
179    /// # Example
180    /// ```rust,ignore
181    /// .order_by(&["total_sales DESC", "region ASC"])
182    /// ```
183    pub fn order_by(mut self, columns: &[impl AsRef<str>]) -> Self {
184        self.order_by_exprs = columns.iter().map(|c| c.as_ref().to_string()).collect();
185        self
186    }
187
188    /// Limit the number of results
189    ///
190    /// # Example
191    /// ```rust,ignore
192    /// .limit(100)
193    /// ```
194    pub fn limit(mut self, count: usize) -> Self {
195        self.limit_count = Some(count);
196        self
197    }
198
199    /// Skip a number of results
200    ///
201    /// # Example
202    /// ```rust,ignore
203    /// .offset(50)
204    /// ```
205    pub fn offset(mut self, count: usize) -> Self {
206        self.offset_count = Some(count);
207        self
208    }
209
210    /// OLAP Operation: Slice - filter on a single dimension
211    ///
212    /// # Example
213    /// ```rust,ignore
214    /// .slice("region", "North")
215    /// ```
216    pub fn slice(self, dimension: impl AsRef<str>, value: impl AsRef<str>) -> Self {
217        let condition = format!("{} = '{}'", dimension.as_ref(), value.as_ref());
218        self.filter(condition)
219    }
220
221    /// OLAP Operation: Dice - filter on multiple dimensions
222    ///
223    /// # Example
224    /// ```rust,ignore
225    /// .dice(&[("region", "North"), ("product", "Widget")])
226    /// ```
227    pub fn dice(self, filters: &[(impl AsRef<str>, impl AsRef<str>)]) -> Self {
228        let conditions: Vec<String> = filters
229            .iter()
230            .map(|(dim, val)| format!("{} = '{}'", dim.as_ref(), val.as_ref()))
231            .collect();
232        let combined = conditions.join(" AND ");
233        self.filter(combined)
234    }
235
236    /// OLAP Operation: Drill-down - navigate down a hierarchy
237    ///
238    /// This selects data at a more granular level by including a lower-level dimension.
239    ///
240    /// # Example
241    /// ```rust,ignore
242    /// // Drill down from year to month
243    /// .drill_down("year", &["year", "month"])
244    /// ```
245    pub fn drill_down(
246        mut self,
247        _parent_level: impl AsRef<str>,
248        child_levels: &[impl AsRef<str>],
249    ) -> Self {
250        // Add child levels to GROUP BY
251        self.group_by_exprs
252            .extend(child_levels.iter().map(|c| c.as_ref().to_string()));
253        self
254    }
255
256    /// OLAP Operation: Roll-up - aggregate across dimensions
257    ///
258    /// This aggregates data by removing one or more dimensions from grouping.
259    ///
260    /// # Example
261    /// ```rust,ignore
262    /// .roll_up(&["region"]) // Aggregate across all regions
263    /// ```
264    pub fn roll_up(mut self, dimensions_to_remove: &[impl AsRef<str>]) -> Self {
265        let to_remove: Vec<String> = dimensions_to_remove
266            .iter()
267            .map(|d| d.as_ref().to_string())
268            .collect();
269
270        self.group_by_exprs
271            .retain(|col| !to_remove.contains(col));
272        self
273    }
274
275    /// Execute the query and return results
276    ///
277    /// # Returns
278    /// A QueryResult containing the data and metadata
279    pub async fn execute(mut self) -> Result<QueryResult> {
280        // Build the query SQL string for caching
281        let query_sql = if let Some(sql) = &self.sql_query {
282            sql.clone()
283        } else {
284            self.build_sql_query()
285        };
286
287        // Check cache if enabled
288        if let Some(cache) = &self.cache {
289            let cache_key = QueryCacheKey::new(&query_sql);
290            if let Some(cached_result) = cache.get(&cache_key) {
291                return Ok(cached_result);
292            }
293        }
294
295        // Register the cube data as a MemTable
296        self.register_cube_data().await?;
297
298        // Execute the query
299        let dataframe = if let Some(sql) = &self.sql_query {
300            // Execute raw SQL query
301            self.execute_sql(sql).await?
302        } else {
303            // Build and execute fluent API query
304            self.execute_fluent_query().await?
305        };
306
307        // Collect results
308        let batches = dataframe
309            .collect()
310            .await
311            .map_err(|e| Error::query(format!("Failed to collect query results: {}", e)))?;
312
313        let row_count = batches.iter().map(|b| b.num_rows()).sum();
314
315        let result = QueryResult {
316            batches,
317            row_count,
318        };
319
320        // Cache the result if caching is enabled
321        if let Some(cache) = &self.cache {
322            let cache_key = QueryCacheKey::new(&query_sql);
323            cache.put(cache_key, result.clone());
324        }
325
326        Ok(result)
327    }
328
329    /// Register cube data as a DataFusion MemTable
330    async fn register_cube_data(&mut self) -> Result<()> {
331        let schema = self.cube.arrow_schema().clone();
332        let data = self.cube.data().to_vec();
333
334        // MemTable expects Vec<Vec<RecordBatch>> (partitions)
335        // We'll use a single partition with all our batches
336        let partitions = vec![data];
337
338        let mem_table = MemTable::try_new(schema, partitions)
339            .map_err(|e| Error::query(format!("Failed to create MemTable: {}", e)))?;
340
341        self.ctx
342            .register_table("cube", Arc::new(mem_table))
343            .map_err(|e| Error::query(format!("Failed to register table: {}", e)))?;
344
345        Ok(())
346    }
347
348    /// Execute a raw SQL query
349    async fn execute_sql(&self, query: &str) -> Result<DataFrame> {
350        self.ctx
351            .sql(query)
352            .await
353            .map_err(|e| Error::query(format!("SQL execution failed: {}", e)))
354    }
355
356    /// Expand calculated fields in an expression
357    ///
358    /// Replaces references to calculated measures and virtual dimensions
359    /// with their underlying expressions. Performs recursive expansion
360    /// to handle nested calculated fields.
361    fn expand_calculated_fields(&self, expr: &str) -> String {
362        let mut expanded = expr.to_string();
363        let schema = self.cube.schema();
364
365        // Keep expanding until no more changes occur (handles nested calculated fields)
366        // Use a maximum iteration count to prevent infinite loops
367        const MAX_ITERATIONS: usize = 10;
368        for _ in 0..MAX_ITERATIONS {
369            let before = expanded.clone();
370
371            // Expand virtual dimensions first (they can be used in calculated measures)
372            for vdim in schema.virtual_dimensions() {
373                let pattern = vdim.name();
374                // Use word boundaries to avoid partial matches
375                // e.g., don't replace "year" in "yearly_sales"
376                let regex_pattern = format!(r"\b{}\b", regex::escape(pattern));
377                if let Ok(re) = regex::Regex::new(&regex_pattern) {
378                    let replacement = format!("({})", vdim.expression());
379                    expanded = re.replace_all(&expanded, replacement.as_str()).to_string();
380                }
381            }
382
383            // Expand calculated measures
384            for calc_measure in schema.calculated_measures() {
385                let pattern = calc_measure.name();
386                let regex_pattern = format!(r"\b{}\b", regex::escape(pattern));
387                if let Ok(re) = regex::Regex::new(&regex_pattern) {
388                    let replacement = format!("({})", calc_measure.expression());
389                    expanded = re.replace_all(&expanded, replacement.as_str()).to_string();
390                }
391            }
392
393            // If no changes were made, we're done
394            if expanded == before {
395                break;
396            }
397        }
398
399        expanded
400    }
401
402    /// Build SQL query string from fluent API parameters
403    fn build_sql_query(&self) -> String {
404        let mut query_str = String::from("SELECT ");
405
406        // SELECT clause - expand calculated fields
407        if self.select_exprs.is_empty() {
408            query_str.push('*');
409        } else {
410            let expanded_selects: Vec<String> = self
411                .select_exprs
412                .iter()
413                .map(|expr| self.expand_calculated_fields(expr))
414                .collect();
415            query_str.push_str(&expanded_selects.join(", "));
416        }
417
418        query_str.push_str(" FROM cube");
419
420        // WHERE clause - expand calculated fields
421        if let Some(filter) = &self.filter_expr {
422            query_str.push_str(" WHERE ");
423            let expanded_filter = self.expand_calculated_fields(filter);
424            query_str.push_str(&expanded_filter);
425        }
426
427        // GROUP BY clause - expand calculated fields
428        if !self.group_by_exprs.is_empty() {
429            query_str.push_str(" GROUP BY ");
430            let expanded_groups: Vec<String> = self
431                .group_by_exprs
432                .iter()
433                .map(|expr| self.expand_calculated_fields(expr))
434                .collect();
435            query_str.push_str(&expanded_groups.join(", "));
436        }
437
438        // ORDER BY clause - expand calculated fields
439        if !self.order_by_exprs.is_empty() {
440            query_str.push_str(" ORDER BY ");
441            let expanded_orders: Vec<String> = self
442                .order_by_exprs
443                .iter()
444                .map(|expr| self.expand_calculated_fields(expr))
445                .collect();
446            query_str.push_str(&expanded_orders.join(", "));
447        }
448
449        // LIMIT clause
450        if let Some(limit) = self.limit_count {
451            query_str.push_str(&format!(" LIMIT {}", limit));
452        }
453
454        // OFFSET clause
455        if let Some(offset) = self.offset_count {
456            query_str.push_str(&format!(" OFFSET {}", offset));
457        }
458
459        query_str
460    }
461
462    /// Build and execute a fluent API query
463    async fn execute_fluent_query(&self) -> Result<DataFrame> {
464        let query_str = self.build_sql_query();
465        self.execute_sql(&query_str).await
466    }
467}
468
469/// Query result containing the executed query data
470#[derive(Debug, Clone)]
471pub struct QueryResult {
472    /// Result data as Arrow RecordBatches
473    batches: Vec<RecordBatch>,
474
475    /// Total number of rows in the result
476    row_count: usize,
477}
478
479impl QueryResult {
480    /// Create a new QueryResult (for testing purposes)
481    #[cfg(test)]
482    pub(crate) fn new_for_testing(batches: Vec<RecordBatch>, row_count: usize) -> Self {
483        Self {
484            batches,
485            row_count,
486        }
487    }
488
489    /// Get the result batches
490    pub fn batches(&self) -> &[RecordBatch] {
491        &self.batches
492    }
493
494    /// Get the total number of rows
495    pub fn row_count(&self) -> usize {
496        self.row_count
497    }
498
499    /// Check if the result is empty
500    pub fn is_empty(&self) -> bool {
501        self.row_count == 0
502    }
503
504    /// Get a pretty-printed string representation of the results
505    ///
506    /// Useful for debugging and testing
507    pub fn pretty_print(&self) -> Result<String> {
508        use arrow::util::pretty::pretty_format_batches;
509
510        pretty_format_batches(&self.batches)
511            .map(|display| display.to_string())
512            .map_err(|e| Error::query(format!("Failed to format results: {}", e)))
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use crate::builder::ElastiCubeBuilder;
520    use crate::cube::AggFunc;
521    use arrow::array::{Float64Array, Int32Array, StringArray};
522    use arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
523
524    fn create_test_cube() -> Result<ElastiCube> {
525        // Create test data
526        let schema = Arc::new(ArrowSchema::new(vec![
527            Field::new("region", DataType::Utf8, false),
528            Field::new("product", DataType::Utf8, false),
529            Field::new("sales", DataType::Float64, false),
530            Field::new("quantity", DataType::Int32, false),
531        ]));
532
533        let batch = RecordBatch::try_new(
534            schema.clone(),
535            vec![
536                Arc::new(StringArray::from(vec![
537                    "North", "South", "North", "East", "South",
538                ])),
539                Arc::new(StringArray::from(vec![
540                    "Widget", "Widget", "Gadget", "Widget", "Gadget",
541                ])),
542                Arc::new(Float64Array::from(vec![100.0, 200.0, 150.0, 175.0, 225.0])),
543                Arc::new(Int32Array::from(vec![10, 20, 15, 17, 22])),
544            ],
545        )
546        .unwrap();
547
548        ElastiCubeBuilder::new("test_cube")
549            .add_dimension("region", DataType::Utf8)?
550            .add_dimension("product", DataType::Utf8)?
551            .add_measure("sales", DataType::Float64, AggFunc::Sum)?
552            .add_measure("quantity", DataType::Int32, AggFunc::Sum)?
553            .load_record_batches(schema, vec![batch])?
554            .build()
555    }
556
557    #[tokio::test]
558    async fn test_query_select_all() {
559        let cube = create_test_cube().unwrap();
560        let arc_cube = Arc::new(cube);
561
562        let result = arc_cube.query().unwrap().execute().await.unwrap();
563
564        assert_eq!(result.row_count(), 5);
565        assert_eq!(result.batches().len(), 1);
566    }
567
568    #[tokio::test]
569    async fn test_query_select_columns() {
570        let cube = create_test_cube().unwrap();
571        let arc_cube = Arc::new(cube);
572
573        let result = arc_cube
574            .query()
575            .unwrap()
576            .select(&["region", "sales"])
577            .execute()
578            .await
579            .unwrap();
580
581        assert_eq!(result.row_count(), 5);
582        // Check that we only got 2 columns
583        assert_eq!(result.batches()[0].num_columns(), 2);
584    }
585
586    #[tokio::test]
587    async fn test_query_filter() {
588        let cube = create_test_cube().unwrap();
589        let arc_cube = Arc::new(cube);
590
591        let result = arc_cube
592            .query()
593            .unwrap()
594            .filter("sales > 150")
595            .execute()
596            .await
597            .unwrap();
598
599        assert_eq!(result.row_count(), 3); // 200, 175, 225
600    }
601
602    #[tokio::test]
603    async fn test_query_group_by() {
604        let cube = create_test_cube().unwrap();
605        let arc_cube = Arc::new(cube);
606
607        let result = arc_cube
608            .query()
609            .unwrap()
610            .select(&["region", "SUM(sales) as total_sales"])
611            .group_by(&["region"])
612            .execute()
613            .await
614            .unwrap();
615
616        assert_eq!(result.row_count(), 3); // North, South, East
617    }
618
619    #[tokio::test]
620    async fn test_query_order_by() {
621        let cube = create_test_cube().unwrap();
622        let arc_cube = Arc::new(cube);
623
624        let result = arc_cube
625            .query()
626            .unwrap()
627            .select(&["region", "sales"])
628            .order_by(&["sales DESC"])
629            .execute()
630            .await
631            .unwrap();
632
633        assert_eq!(result.row_count(), 5);
634        // First row should have highest sales (225)
635    }
636
637    #[tokio::test]
638    async fn test_query_limit() {
639        let cube = create_test_cube().unwrap();
640        let arc_cube = Arc::new(cube);
641
642        let result = arc_cube
643            .query()
644            .unwrap()
645            .limit(3)
646            .execute()
647            .await
648            .unwrap();
649
650        assert_eq!(result.row_count(), 3);
651    }
652
653    #[tokio::test]
654    async fn test_query_sql() {
655        let cube = create_test_cube().unwrap();
656        let arc_cube = Arc::new(cube);
657
658        let result = arc_cube
659            .query()
660            .unwrap()
661            .sql("SELECT region, SUM(sales) as total FROM cube GROUP BY region ORDER BY total DESC")
662            .execute()
663            .await
664            .unwrap();
665
666        assert_eq!(result.row_count(), 3);
667    }
668
669    #[tokio::test]
670    async fn test_olap_slice() {
671        let cube = create_test_cube().unwrap();
672        let arc_cube = Arc::new(cube);
673
674        let result = arc_cube
675            .query()
676            .unwrap()
677            .slice("region", "North")
678            .execute()
679            .await
680            .unwrap();
681
682        assert_eq!(result.row_count(), 2); // 2 North entries
683    }
684
685    #[tokio::test]
686    async fn test_olap_dice() {
687        let cube = create_test_cube().unwrap();
688        let arc_cube = Arc::new(cube);
689
690        let result = arc_cube
691            .query()
692            .unwrap()
693            .dice(&[("region", "North"), ("product", "Widget")])
694            .execute()
695            .await
696            .unwrap();
697
698        assert_eq!(result.row_count(), 1); // 1 North Widget
699    }
700
701    #[tokio::test]
702    async fn test_complex_query() {
703        let cube = create_test_cube().unwrap();
704        let arc_cube = Arc::new(cube);
705
706        let result = arc_cube
707            .query()
708            .unwrap()
709            .select(&["region", "product", "SUM(sales) as total_sales", "AVG(quantity) as avg_qty"])
710            .filter("sales > 100")
711            .group_by(&["region", "product"])
712            .order_by(&["total_sales DESC"])
713            .limit(5)
714            .execute()
715            .await
716            .unwrap();
717
718        assert!(result.row_count() > 0);
719    }
720}