laminar_sql/translator/
having_translator.rs1use std::sync::Arc;
8
9use crate::datafusion::create_session_context;
10use arrow_array::RecordBatch;
11use arrow_schema::SchemaRef;
12use datafusion::datasource::MemTable;
13
14#[derive(Debug, Clone)]
16pub struct HavingFilterConfig {
17 predicate: String,
19}
20
21impl HavingFilterConfig {
22 #[must_use]
24 pub fn new(predicate: String) -> Self {
25 Self { predicate }
26 }
27
28 #[must_use]
30 pub fn predicate(&self) -> &str {
31 &self.predicate
32 }
33
34 pub async fn filter_batch(
44 &self,
45 batch: &RecordBatch,
46 ) -> Result<RecordBatch, HavingFilterError> {
47 if batch.num_rows() == 0 {
48 return Ok(batch.clone());
49 }
50
51 let schema = batch.schema();
52
53 let ctx = create_session_context();
55 let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch.clone()]])
56 .map_err(|e| HavingFilterError::SchemaError(e.to_string()))?;
57 ctx.register_table("_having_input", Arc::new(mem_table))
58 .map_err(|e| HavingFilterError::SchemaError(e.to_string()))?;
59
60 let sql = format!("SELECT * FROM _having_input WHERE {}", self.predicate);
62 let df = ctx
63 .sql(&sql)
64 .await
65 .map_err(|e| HavingFilterError::ParseError(e.to_string()))?;
66
67 let batches = df
68 .collect()
69 .await
70 .map_err(|e| HavingFilterError::EvaluationError(e.to_string()))?;
71
72 if batches.is_empty() {
73 return Ok(RecordBatch::new_empty(schema));
74 }
75
76 arrow::compute::concat_batches(&schema, &batches)
77 .map_err(|e| HavingFilterError::EvaluationError(e.to_string()))
78 }
79
80 #[must_use]
82 pub fn output_schema(&self, input_schema: &SchemaRef) -> SchemaRef {
83 Arc::clone(input_schema)
84 }
85}
86
87#[derive(Debug, thiserror::Error)]
89pub enum HavingFilterError {
90 #[error("HAVING parse error: {0}")]
92 ParseError(String),
93
94 #[error("HAVING schema error: {0}")]
96 SchemaError(String),
97
98 #[error("HAVING evaluation error: {0}")]
100 EvaluationError(String),
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use arrow_array::{Float64Array, Int64Array, StringArray};
107 use arrow_schema::{DataType, Field, Schema};
108
109 fn test_batch() -> RecordBatch {
110 let schema = Arc::new(Schema::new(vec![
111 Field::new("symbol", DataType::Utf8, false),
112 Field::new("count_star", DataType::Int64, false),
113 Field::new("sum_volume", DataType::Float64, false),
114 ]));
115
116 RecordBatch::try_new(
117 schema,
118 vec![
119 Arc::new(StringArray::from(vec!["AAPL", "GOOG", "MSFT", "TSLA"])),
120 Arc::new(Int64Array::from(vec![100, 5, 50, 200])),
121 Arc::new(Float64Array::from(vec![50000.0, 1000.0, 25000.0, 80000.0])),
122 ],
123 )
124 .unwrap()
125 }
126
127 #[tokio::test]
128 async fn test_filter_count_greater_than() {
129 let config = HavingFilterConfig::new("count_star > 10".to_string());
130 let batch = test_batch();
131 let result = config.filter_batch(&batch).await.unwrap();
132
133 assert_eq!(result.num_rows(), 3); let symbols: Vec<&str> = result
135 .column(0)
136 .as_any()
137 .downcast_ref::<StringArray>()
138 .unwrap()
139 .iter()
140 .map(|v| v.unwrap())
141 .collect();
142 assert_eq!(symbols, vec!["AAPL", "MSFT", "TSLA"]);
143 }
144
145 #[tokio::test]
146 async fn test_filter_all_pass() {
147 let config = HavingFilterConfig::new("count_star > 0".to_string());
148 let batch = test_batch();
149 let result = config.filter_batch(&batch).await.unwrap();
150
151 assert_eq!(result.num_rows(), 4);
152 }
153
154 #[tokio::test]
155 async fn test_filter_none_pass() {
156 let config = HavingFilterConfig::new("count_star > 1000".to_string());
157 let batch = test_batch();
158 let result = config.filter_batch(&batch).await.unwrap();
159
160 assert_eq!(result.num_rows(), 0);
161 }
162
163 #[tokio::test]
164 async fn test_filter_compound_predicate() {
165 let config = HavingFilterConfig::new("count_star >= 50 AND sum_volume > 30000".to_string());
166 let batch = test_batch();
167 let result = config.filter_batch(&batch).await.unwrap();
168
169 assert_eq!(result.num_rows(), 2);
174 }
175
176 #[tokio::test]
177 async fn test_filter_or_predicate() {
178 let config = HavingFilterConfig::new("count_star > 150 OR sum_volume < 2000".to_string());
179 let batch = test_batch();
180 let result = config.filter_batch(&batch).await.unwrap();
181
182 assert_eq!(result.num_rows(), 2);
187 }
188
189 #[tokio::test]
190 async fn test_filter_empty_batch() {
191 let config = HavingFilterConfig::new("count_star > 10".to_string());
192 let schema = Arc::new(Schema::new(vec![Field::new(
193 "count_star",
194 DataType::Int64,
195 false,
196 )]));
197 let batch = RecordBatch::new_empty(schema);
198 let result = config.filter_batch(&batch).await.unwrap();
199
200 assert_eq!(result.num_rows(), 0);
201 }
202
203 #[tokio::test]
204 async fn test_filter_invalid_column() {
205 let config = HavingFilterConfig::new("nonexistent_col > 10".to_string());
206 let batch = test_batch();
207 let result = config.filter_batch(&batch).await;
208
209 assert!(result.is_err());
210 }
211
212 #[test]
213 fn test_predicate_accessor() {
214 let config = HavingFilterConfig::new("SUM(volume) > 1000".to_string());
215 assert_eq!(config.predicate(), "SUM(volume) > 1000");
216 }
217
218 #[test]
219 fn test_output_schema_unchanged() {
220 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
221 let config = HavingFilterConfig::new("a > 0".to_string());
222 let output = config.output_schema(&schema);
223 assert_eq!(*output, *schema);
224 }
225}