Skip to main content

activecube_rs/sql/
clickhouse.rs

1use std::collections::HashMap;
2
3use crate::compiler::ir::*;
4use crate::compiler::ir::CompileResult;
5use crate::sql::dialect::SqlDialect;
6
7pub struct ClickHouseDialect;
8
9impl ClickHouseDialect {
10    pub fn new() -> Self {
11        Self
12    }
13}
14
15impl Default for ClickHouseDialect {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl SqlDialect for ClickHouseDialect {
22    fn compile(&self, ir: &QueryIR) -> CompileResult {
23        let mut bindings = Vec::new();
24        let mut sql = String::new();
25
26        // Build augmented selects: aggregate expressions used in HAVING need
27        // aliases in SELECT so ClickHouse can resolve them (especially -Merge combinators).
28        let mut augmented_selects = ir.selects.clone();
29        let mut agg_alias_map: HashMap<String, String> = HashMap::new();
30        let mut alias_remap: Vec<(String, String)> = Vec::new();
31        let mut alias_counter = 0u32;
32
33        let having_cols: std::collections::HashSet<String> =
34            collect_filter_columns(&ir.having).into_iter().collect();
35        let has_having_agg = having_cols.iter().any(|c| c.contains('('));
36
37        if has_having_agg {
38            // Only alias columns that are BOTH aggregate AND referenced in HAVING.
39            for sel in &mut augmented_selects {
40                if let SelectExpr::Column { column, alias } = sel {
41                    if column.contains('(') && having_cols.contains(column.as_str()) {
42                        if alias.is_none() {
43                            let a = format!("__f_{alias_counter}");
44                            alias_counter += 1;
45                            alias_remap.push((a.clone(), column.clone()));
46                            agg_alias_map.insert(column.clone(), a.clone());
47                            *alias = Some(a);
48                        } else if let Some(existing) = alias {
49                            agg_alias_map.insert(column.clone(), existing.clone());
50                        }
51                    }
52                }
53            }
54            // Add missing HAVING columns not in SELECT
55            for col in &having_cols {
56                if col.contains('(') && !agg_alias_map.contains_key(col.as_str()) {
57                    let a = format!("__f_{alias_counter}");
58                    alias_counter += 1;
59                    agg_alias_map.insert(col.clone(), a.clone());
60                    augmented_selects.push(SelectExpr::Column {
61                        column: col.clone(),
62                        alias: Some(a),
63                    });
64                    // No remap needed: this column wasn't requested by user
65                }
66            }
67        }
68
69        sql.push_str("SELECT ");
70        let select_parts: Vec<String> = augmented_selects.iter().map(|s| match s {
71            SelectExpr::Column { column, alias } => {
72                let col = if column.contains('(') { column.clone() } else { format!("`{column}`") };
73                match alias {
74                    Some(a) => format!("{col} AS `{a}`"),
75                    None => col,
76                }
77            },
78            SelectExpr::Aggregate { function, column, alias, condition } => {
79                let func = function.to_uppercase();
80                match (func.as_str(), column.as_str(), condition) {
81                    ("COUNT", "*", None) => format!("count() AS `{alias}`"),
82                    ("COUNT", "*", Some(cond)) => format!("countIf({cond}) AS `{alias}`"),
83                    ("UNIQ", col, None) => format!("uniq(`{col}`) AS `{alias}`"),
84                    ("UNIQ", col, Some(cond)) => format!("uniqIf(`{col}`, {cond}) AS `{alias}`"),
85                    (_, col, None) => format!("{f}(`{col}`) AS `{alias}`", f = func.to_lowercase()),
86                    (_, col, Some(cond)) => format!("{f}If(`{col}`, {cond}) AS `{alias}`", f = func.to_lowercase()),
87                }
88            }
89        }).collect();
90        sql.push_str(&select_parts.join(", "));
91
92        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
93        if ir.use_final {
94            sql.push_str(" FINAL");
95        }
96
97        let where_clause = compile_filter(&ir.filters, &mut bindings);
98        if !where_clause.is_empty() {
99            sql.push_str(" WHERE ");
100            sql.push_str(&where_clause);
101        }
102
103        // Auto-detect AggregatingMergeTree queries: if any SELECT column contains
104        // a -Merge combinator (e.g. argMaxMerge, sumMerge), auto-add GROUP BY
105        // for all non-aggregate columns.
106        let effective_group_by = if !ir.group_by.is_empty() {
107            ir.group_by.clone()
108        } else {
109            let has_merge_cols = augmented_selects.iter().any(|s| match s {
110                SelectExpr::Column { column, .. } => column.contains("Merge("),
111                SelectExpr::Aggregate { .. } => true,
112            });
113            if has_merge_cols {
114                augmented_selects.iter().filter_map(|s| match s {
115                    SelectExpr::Column { column, .. } if !column.contains("Merge(") && !column.contains('(') => {
116                        Some(column.clone())
117                    }
118                    _ => None,
119                }).collect()
120            } else {
121                vec![]
122            }
123        };
124
125        if !effective_group_by.is_empty() {
126            sql.push_str(" GROUP BY ");
127            let cols: Vec<String> = effective_group_by.iter().map(|c| format!("`{c}`")).collect();
128            sql.push_str(&cols.join(", "));
129        }
130
131        if has_having_agg {
132            let having_clause = compile_filter_with_aliases(&ir.having, &mut bindings, &agg_alias_map);
133            if !having_clause.is_empty() {
134                sql.push_str(" HAVING ");
135                sql.push_str(&having_clause);
136            }
137        } else {
138            let having_clause = compile_filter(&ir.having, &mut bindings);
139            if !having_clause.is_empty() {
140                sql.push_str(" HAVING ");
141                sql.push_str(&having_clause);
142            }
143        }
144
145        if !ir.order_by.is_empty() {
146            sql.push_str(" ORDER BY ");
147            let parts: Vec<String> = ir.order_by.iter().map(|o| {
148                let col = if o.column.contains('(') {
149                    agg_alias_map.get(&o.column)
150                        .map(|a| format!("`{a}`"))
151                        .unwrap_or_else(|| o.column.clone())
152                } else {
153                    format!("`{}`", o.column)
154                };
155                let dir = if o.descending { "DESC" } else { "ASC" };
156                format!("{col} {dir}")
157            }).collect();
158            sql.push_str(&parts.join(", "));
159        }
160
161        if let Some(ref lb) = ir.limit_by {
162            let by_cols: Vec<String> = lb.columns.iter().map(|c| format!("`{c}`")).collect();
163            sql.push_str(&format!(" LIMIT {} BY {}", lb.count, by_cols.join(", ")));
164            if lb.offset > 0 {
165                sql.push_str(&format!(" OFFSET {}", lb.offset));
166            }
167        }
168
169        sql.push_str(&format!(" LIMIT {}", ir.limit));
170        if ir.offset > 0 {
171            sql.push_str(&format!(" OFFSET {}", ir.offset));
172        }
173
174        CompileResult { sql, bindings, alias_remap }
175    }
176
177    fn quote_identifier(&self, name: &str) -> String {
178        format!("`{name}`")
179    }
180
181    fn name(&self) -> &str {
182        "ClickHouse"
183    }
184}
185
186/// Collect all column names referenced in a filter tree.
187fn collect_filter_columns(node: &FilterNode) -> Vec<String> {
188    match node {
189        FilterNode::Empty => vec![],
190        FilterNode::Condition { column, .. } => vec![column.clone()],
191        FilterNode::And(children) | FilterNode::Or(children) => {
192            children.iter().flat_map(collect_filter_columns).collect()
193        }
194    }
195}
196
197/// Like `compile_filter` but replaces aggregate expression columns with their
198/// SELECT aliases so ClickHouse can resolve them in HAVING scope.
199fn compile_filter_with_aliases(
200    node: &FilterNode,
201    bindings: &mut Vec<SqlValue>,
202    aliases: &HashMap<String, String>,
203) -> String {
204    match node {
205        FilterNode::Empty => String::new(),
206        FilterNode::Condition { column, op, value } => {
207            let effective_col = aliases.get(column)
208                .map(|a| a.as_str())
209                .unwrap_or(column.as_str());
210            compile_condition(effective_col, op, value, bindings)
211        }
212        FilterNode::And(children) => {
213            let parts: Vec<String> = children.iter()
214                .map(|c| compile_filter_with_aliases(c, bindings, aliases))
215                .filter(|s| !s.is_empty())
216                .collect();
217            match parts.len() {
218                0 => String::new(),
219                1 => parts.into_iter().next().unwrap(),
220                _ => format!("({})", parts.join(" AND ")),
221            }
222        }
223        FilterNode::Or(children) => {
224            let parts: Vec<String> = children.iter()
225                .map(|c| compile_filter_with_aliases(c, bindings, aliases))
226                .filter(|s| !s.is_empty())
227                .collect();
228            match parts.len() {
229                0 => String::new(),
230                1 => parts.into_iter().next().unwrap(),
231                _ => format!("({})", parts.join(" OR ")),
232            }
233        }
234    }
235}
236
237fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
238    match node {
239        FilterNode::Empty => String::new(),
240        FilterNode::Condition { column, op, value } => {
241            compile_condition(column, op, value, bindings)
242        }
243        FilterNode::And(children) => {
244            let parts: Vec<String> = children.iter()
245                .map(|c| compile_filter(c, bindings))
246                .filter(|s| !s.is_empty())
247                .collect();
248            match parts.len() {
249                0 => String::new(),
250                1 => parts.into_iter().next().unwrap(),
251                _ => format!("({})", parts.join(" AND ")),
252            }
253        }
254        FilterNode::Or(children) => {
255            let parts: Vec<String> = children.iter()
256                .map(|c| compile_filter(c, bindings))
257                .filter(|s| !s.is_empty())
258                .collect();
259            match parts.len() {
260                0 => String::new(),
261                1 => parts.into_iter().next().unwrap(),
262                _ => format!("({})", parts.join(" OR ")),
263            }
264        }
265    }
266}
267
268fn quote_col(column: &str) -> String {
269    if column.contains('(') {
270        column.to_string()
271    } else {
272        format!("`{column}`")
273    }
274}
275
276fn compile_condition(
277    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
278) -> String {
279    let col = quote_col(column);
280    match op {
281        CompareOp::In | CompareOp::NotIn => {
282            if let SqlValue::String(csv) = value {
283                let items: Vec<&str> = csv.split(',').collect();
284                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
285                for item in &items {
286                    bindings.push(SqlValue::String(item.trim().to_string()));
287                }
288                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
289            } else {
290                bindings.push(value.clone());
291                format!("{col} {} (?)", op.sql_op())
292            }
293        }
294        CompareOp::Includes => {
295            if let SqlValue::String(s) = value {
296                bindings.push(SqlValue::String(format!("%{s}%")));
297            } else {
298                bindings.push(value.clone());
299            }
300            format!("{col} LIKE ?")
301        }
302        CompareOp::IsNull | CompareOp::IsNotNull => {
303            format!("{col} {}", op.sql_op())
304        }
305        _ => {
306            bindings.push(value.clone());
307            format!("{col} {} ?", op.sql_op())
308        }
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    fn ch() -> ClickHouseDialect { ClickHouseDialect::new() }
317
318    #[test]
319    fn test_simple_select() {
320        let ir = QueryIR {
321            cube: "DEXTrades".into(), schema: "default".into(),
322            table: "dwd_dex_trades".into(),
323            selects: vec![
324                SelectExpr::Column { column: "tx_hash".into(), alias: None },
325                SelectExpr::Column { column: "token_a_amount".into(), alias: None },
326            ],
327            filters: FilterNode::Empty, having: FilterNode::Empty,
328            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
329            limit_by: None,
330            use_final: false,
331        };
332        let r = ch().compile(&ir);
333        assert_eq!(r.sql, "SELECT `tx_hash`, `token_a_amount` FROM `default`.`dwd_dex_trades` LIMIT 10");
334        assert!(r.bindings.is_empty());
335    }
336
337    #[test]
338    fn test_final_keyword() {
339        let ir = QueryIR {
340            cube: "T".into(), schema: "db".into(), table: "tokens".into(),
341            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
342            filters: FilterNode::Empty, having: FilterNode::Empty,
343            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
344            limit_by: None,
345            use_final: true,
346        };
347        let r = ch().compile(&ir);
348        assert!(r.sql.contains("FROM `db`.`tokens` FINAL"), "FINAL should be appended, got: {}", r.sql);
349    }
350
351    #[test]
352    fn test_uniq_uses_native_function() {
353        let ir = QueryIR {
354            cube: "T".into(), schema: "db".into(), table: "t".into(),
355            selects: vec![
356                SelectExpr::Aggregate { function: "UNIQ".into(), column: "wallet".into(), alias: "__uniq".into(), condition: None },
357            ],
358            filters: FilterNode::Empty, having: FilterNode::Empty,
359            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
360            limit_by: None,
361            use_final: false,
362        };
363        let r = ch().compile(&ir);
364        assert!(r.sql.contains("uniq(`wallet`) AS `__uniq`"), "ClickHouse should use native uniq(), got: {}", r.sql);
365    }
366
367    #[test]
368    fn test_count_star() {
369        let ir = QueryIR {
370            cube: "T".into(), schema: "db".into(), table: "t".into(),
371            selects: vec![
372                SelectExpr::Aggregate { function: "COUNT".into(), column: "*".into(), alias: "__count".into(), condition: None },
373            ],
374            filters: FilterNode::Empty, having: FilterNode::Empty,
375            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
376            limit_by: None,
377            use_final: false,
378        };
379        let r = ch().compile(&ir);
380        assert!(r.sql.contains("count() AS `__count`"), "ClickHouse should use count() not COUNT(*), got: {}", r.sql);
381    }
382
383    #[test]
384    fn test_aggregate_lowercase() {
385        let ir = QueryIR {
386            cube: "T".into(), schema: "db".into(), table: "t".into(),
387            selects: vec![
388                SelectExpr::Aggregate { function: "SUM".into(), column: "amount".into(), alias: "__sum".into(), condition: None },
389                SelectExpr::Aggregate { function: "AVG".into(), column: "price".into(), alias: "__avg".into(), condition: None },
390            ],
391            filters: FilterNode::Empty, having: FilterNode::Empty,
392            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
393            limit_by: None,
394            use_final: false,
395        };
396        let r = ch().compile(&ir);
397        assert!(r.sql.contains("sum(`amount`) AS `__sum`"), "ClickHouse functions should be lowercase, got: {}", r.sql);
398        assert!(r.sql.contains("avg(`price`) AS `__avg`"), "got: {}", r.sql);
399    }
400
401    #[test]
402    fn test_where_and_order() {
403        let ir = QueryIR {
404            cube: "T".into(), schema: "db".into(), table: "t".into(),
405            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
406            filters: FilterNode::And(vec![
407                FilterNode::Condition { column: "chain_id".into(), op: CompareOp::Eq, value: SqlValue::Int(1) },
408                FilterNode::Condition { column: "amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
409            ]),
410            having: FilterNode::Empty, group_by: vec![],
411            order_by: vec![OrderExpr { column: "block_timestamp".into(), descending: true }],
412            limit: 25, offset: 0,
413            limit_by: None,
414            use_final: false,
415        };
416        let r = ch().compile(&ir);
417        assert!(r.sql.contains("WHERE (`chain_id` = ? AND `amount_usd` > ?)"));
418        assert!(r.sql.contains("ORDER BY `block_timestamp` DESC"));
419        assert_eq!(r.bindings.len(), 2);
420    }
421
422    #[test]
423    fn test_having_with_aggregate_expr() {
424        let ir = QueryIR {
425            cube: "T".into(), schema: "db".into(), table: "t".into(),
426            selects: vec![
427                SelectExpr::Column { column: "token_address".into(), alias: None },
428                SelectExpr::Aggregate { function: "SUM".into(), column: "amount_usd".into(), alias: "__sum".into(), condition: None },
429            ],
430            filters: FilterNode::Empty,
431            having: FilterNode::Condition {
432                column: "sum(`amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
433            },
434            group_by: vec!["token_address".into()], order_by: vec![], limit: 25, offset: 0,
435            limit_by: None,
436            use_final: false,
437        };
438        let r = ch().compile(&ir);
439        assert!(r.sql.contains("GROUP BY `token_address`"));
440        // HAVING references sum(`amount_usd`) which is not a SelectExpr::Column,
441        // so it gets added as a new aliased column for HAVING reference
442        assert!(r.sql.contains("HAVING `__f_0` > ?"), "expected alias in HAVING, got: {}", r.sql);
443        assert!(r.sql.contains("sum(`amount_usd`) AS `__f_0`"), "expected alias in SELECT, got: {}", r.sql);
444        assert_eq!(r.bindings.len(), 1);
445    }
446
447    #[test]
448    fn test_having_appends_missing_agg_column() {
449        // HAVING references an aggregate column NOT in SELECT — it should be
450        // auto-appended to SELECT with an alias.
451        let ir = QueryIR {
452            cube: "T".into(), schema: "db".into(), table: "t".into(),
453            selects: vec![
454                SelectExpr::Column { column: "pool_address".into(), alias: None },
455                SelectExpr::Column { column: "argMaxMerge(latest_liquidity_usd_state)".into(), alias: None },
456            ],
457            filters: FilterNode::Empty,
458            having: FilterNode::And(vec![
459                FilterNode::Condition {
460                    column: "argMaxMerge(latest_liquidity_usd_state)".into(),
461                    op: CompareOp::Gt, value: SqlValue::Float(2.0),
462                },
463                FilterNode::Condition {
464                    column: "argMaxMerge(latest_token_a_amount_state)".into(),
465                    op: CompareOp::Gt, value: SqlValue::Float(3.0),
466                },
467            ]),
468            group_by: vec!["pool_address".into()], order_by: vec![], limit: 25, offset: 0,
469            limit_by: None,
470            use_final: false,
471        };
472        let r = ch().compile(&ir);
473        // The existing column referenced in HAVING should be aliased
474        assert!(r.sql.contains("argMaxMerge(latest_liquidity_usd_state) AS `__f_0`"),
475            "existing HAVING col should be aliased, got: {}", r.sql);
476        // The missing column should be appended to SELECT
477        assert!(r.sql.contains("argMaxMerge(latest_token_a_amount_state) AS `__f_1`"),
478            "missing agg col should be appended, got: {}", r.sql);
479        // HAVING should reference aliases
480        assert!(r.sql.contains("HAVING (`__f_0` > ? AND `__f_1` > ?)"),
481            "HAVING should use aliases, got: {}", r.sql);
482        assert_eq!(r.bindings.len(), 2);
483        // alias_remap should map the aliased column back to original
484        assert_eq!(r.alias_remap.len(), 1);
485        assert_eq!(r.alias_remap[0], ("__f_0".to_string(), "argMaxMerge(latest_liquidity_usd_state)".to_string()));
486    }
487
488    #[test]
489    fn test_limit_by() {
490        let ir = QueryIR {
491            cube: "T".into(), schema: "db".into(), table: "t".into(),
492            selects: vec![
493                SelectExpr::Column { column: "owner".into(), alias: None },
494                SelectExpr::Column { column: "amount".into(), alias: None },
495            ],
496            filters: FilterNode::Empty, having: FilterNode::Empty,
497            group_by: vec![], 
498            order_by: vec![OrderExpr { column: "amount".into(), descending: true }],
499            limit: 100, offset: 0,
500            limit_by: Some(LimitByExpr { count: 3, offset: 0, columns: vec!["owner".into()] }),
501            use_final: false,
502        };
503        let r = ch().compile(&ir);
504        let sql = &r.sql;
505        assert!(sql.contains("LIMIT 3 BY `owner`"), "LIMIT BY should be present, got: {sql}");
506        assert!(sql.contains("ORDER BY `amount` DESC"), "ORDER BY should be present, got: {sql}");
507        assert!(sql.contains("LIMIT 100"), "outer LIMIT should be present, got: {sql}");
508        let order_by_pos = sql.find("ORDER BY").unwrap();
509        let limit_by_pos = sql.find("LIMIT 3 BY").unwrap();
510        let limit_pos = sql.rfind("LIMIT 100").unwrap();
511        assert!(order_by_pos < limit_by_pos, "ORDER BY should come before LIMIT BY in ClickHouse");
512        assert!(limit_by_pos < limit_pos, "LIMIT BY should come before outer LIMIT");
513    }
514
515    #[test]
516    fn test_limit_by_with_offset() {
517        let ir = QueryIR {
518            cube: "T".into(), schema: "db".into(), table: "t".into(),
519            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
520            filters: FilterNode::Empty, having: FilterNode::Empty,
521            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
522            limit_by: Some(LimitByExpr { count: 5, offset: 2, columns: vec!["token".into(), "wallet".into()] }),
523            use_final: false,
524        };
525        let r = ch().compile(&ir);
526        assert!(r.sql.contains("LIMIT 5 BY `token`, `wallet` OFFSET 2"), "multi-column LIMIT BY with offset, got: {}", r.sql);
527    }
528}