Skip to main content

laminar_sql/translator/
having_translator.rs

1//! HAVING clause translator
2//!
3//! Translates a parsed HAVING expression into a configuration that can
4//! filter aggregated `RecordBatch` results. The filter is applied after
5//! window aggregation emission, before downstream consumption.
6
7use std::sync::Arc;
8
9use crate::datafusion::create_session_context;
10use arrow_array::RecordBatch;
11use arrow_schema::SchemaRef;
12use datafusion::datasource::MemTable;
13
14/// Configuration for a post-aggregation HAVING filter.
15#[derive(Debug, Clone)]
16pub struct HavingFilterConfig {
17    /// The HAVING predicate as a SQL expression string.
18    predicate: String,
19}
20
21impl HavingFilterConfig {
22    /// Creates a new HAVING filter configuration.
23    #[must_use]
24    pub fn new(predicate: String) -> Self {
25        Self { predicate }
26    }
27
28    /// Returns the predicate SQL string.
29    #[must_use]
30    pub fn predicate(&self) -> &str {
31        &self.predicate
32    }
33
34    /// Filters a `RecordBatch` by evaluating the HAVING predicate.
35    ///
36    /// Registers the batch as a temporary table in a DataFusion context,
37    /// applies the predicate as a WHERE clause, and returns matching rows.
38    ///
39    /// # Errors
40    ///
41    /// Returns an error if the predicate cannot be parsed, the schema
42    /// doesn't match, or evaluation fails.
43    pub async fn filter_batch(
44        &self,
45        batch: &RecordBatch,
46    ) -> Result<RecordBatch, HavingFilterError> {
47        if batch.num_rows() == 0 {
48            return Ok(batch.clone());
49        }
50
51        let schema = batch.schema();
52
53        // Register the batch as a temporary MemTable so column names resolve
54        let ctx = create_session_context();
55        let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch.clone()]])
56            .map_err(|e| HavingFilterError::SchemaError(e.to_string()))?;
57        ctx.register_table("_having_input", Arc::new(mem_table))
58            .map_err(|e| HavingFilterError::SchemaError(e.to_string()))?;
59
60        // Execute: SELECT * FROM _having_input WHERE <predicate>
61        let sql = format!("SELECT * FROM _having_input WHERE {}", self.predicate);
62        let df = ctx
63            .sql(&sql)
64            .await
65            .map_err(|e| HavingFilterError::ParseError(e.to_string()))?;
66
67        let batches = df
68            .collect()
69            .await
70            .map_err(|e| HavingFilterError::EvaluationError(e.to_string()))?;
71
72        if batches.is_empty() {
73            return Ok(RecordBatch::new_empty(schema));
74        }
75
76        arrow::compute::concat_batches(&schema, &batches)
77            .map_err(|e| HavingFilterError::EvaluationError(e.to_string()))
78    }
79
80    /// Returns the output schema (same as input — HAVING only filters rows).
81    #[must_use]
82    pub fn output_schema(&self, input_schema: &SchemaRef) -> SchemaRef {
83        Arc::clone(input_schema)
84    }
85}
86
87/// Errors from HAVING filter operations.
88#[derive(Debug, thiserror::Error)]
89pub enum HavingFilterError {
90    /// Failed to parse the HAVING expression
91    #[error("HAVING parse error: {0}")]
92    ParseError(String),
93
94    /// Schema mismatch or missing columns
95    #[error("HAVING schema error: {0}")]
96    SchemaError(String),
97
98    /// Evaluation failed
99    #[error("HAVING evaluation error: {0}")]
100    EvaluationError(String),
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use arrow_array::{Float64Array, Int64Array, StringArray};
107    use arrow_schema::{DataType, Field, Schema};
108
109    fn test_batch() -> RecordBatch {
110        let schema = Arc::new(Schema::new(vec![
111            Field::new("symbol", DataType::Utf8, false),
112            Field::new("count_star", DataType::Int64, false),
113            Field::new("sum_volume", DataType::Float64, false),
114        ]));
115
116        RecordBatch::try_new(
117            schema,
118            vec![
119                Arc::new(StringArray::from(vec!["AAPL", "GOOG", "MSFT", "TSLA"])),
120                Arc::new(Int64Array::from(vec![100, 5, 50, 200])),
121                Arc::new(Float64Array::from(vec![50000.0, 1000.0, 25000.0, 80000.0])),
122            ],
123        )
124        .unwrap()
125    }
126
127    #[tokio::test]
128    async fn test_filter_count_greater_than() {
129        let config = HavingFilterConfig::new("count_star > 10".to_string());
130        let batch = test_batch();
131        let result = config.filter_batch(&batch).await.unwrap();
132
133        assert_eq!(result.num_rows(), 3); // AAPL(100), MSFT(50), TSLA(200)
134        let symbols: Vec<&str> = result
135            .column(0)
136            .as_any()
137            .downcast_ref::<StringArray>()
138            .unwrap()
139            .iter()
140            .map(|v| v.unwrap())
141            .collect();
142        assert_eq!(symbols, vec!["AAPL", "MSFT", "TSLA"]);
143    }
144
145    #[tokio::test]
146    async fn test_filter_all_pass() {
147        let config = HavingFilterConfig::new("count_star > 0".to_string());
148        let batch = test_batch();
149        let result = config.filter_batch(&batch).await.unwrap();
150
151        assert_eq!(result.num_rows(), 4);
152    }
153
154    #[tokio::test]
155    async fn test_filter_none_pass() {
156        let config = HavingFilterConfig::new("count_star > 1000".to_string());
157        let batch = test_batch();
158        let result = config.filter_batch(&batch).await.unwrap();
159
160        assert_eq!(result.num_rows(), 0);
161    }
162
163    #[tokio::test]
164    async fn test_filter_compound_predicate() {
165        let config = HavingFilterConfig::new("count_star >= 50 AND sum_volume > 30000".to_string());
166        let batch = test_batch();
167        let result = config.filter_batch(&batch).await.unwrap();
168
169        // AAPL: 100 >= 50 AND 50000 > 30000 → yes
170        // GOOG: 5 >= 50 → no
171        // MSFT: 50 >= 50 AND 25000 > 30000 → no
172        // TSLA: 200 >= 50 AND 80000 > 30000 → yes
173        assert_eq!(result.num_rows(), 2);
174    }
175
176    #[tokio::test]
177    async fn test_filter_or_predicate() {
178        let config = HavingFilterConfig::new("count_star > 150 OR sum_volume < 2000".to_string());
179        let batch = test_batch();
180        let result = config.filter_batch(&batch).await.unwrap();
181
182        // AAPL: 100 > 150 OR 50000 < 2000 → no
183        // GOOG: 5 > 150 OR 1000 < 2000 → yes
184        // MSFT: 50 > 150 OR 25000 < 2000 → no
185        // TSLA: 200 > 150 OR 80000 < 2000 → yes
186        assert_eq!(result.num_rows(), 2);
187    }
188
189    #[tokio::test]
190    async fn test_filter_empty_batch() {
191        let config = HavingFilterConfig::new("count_star > 10".to_string());
192        let schema = Arc::new(Schema::new(vec![Field::new(
193            "count_star",
194            DataType::Int64,
195            false,
196        )]));
197        let batch = RecordBatch::new_empty(schema);
198        let result = config.filter_batch(&batch).await.unwrap();
199
200        assert_eq!(result.num_rows(), 0);
201    }
202
203    #[tokio::test]
204    async fn test_filter_invalid_column() {
205        let config = HavingFilterConfig::new("nonexistent_col > 10".to_string());
206        let batch = test_batch();
207        let result = config.filter_batch(&batch).await;
208
209        assert!(result.is_err());
210    }
211
212    #[test]
213    fn test_predicate_accessor() {
214        let config = HavingFilterConfig::new("SUM(volume) > 1000".to_string());
215        assert_eq!(config.predicate(), "SUM(volume) > 1000");
216    }
217
218    #[test]
219    fn test_output_schema_unchanged() {
220        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
221        let config = HavingFilterConfig::new("a > 0".to_string());
222        let output = config.output_schema(&schema);
223        assert_eq!(*output, *schema);
224    }
225}