Skip to main content

activecube_rs/sql/
clickhouse.rs

1use crate::compiler::ir::*;
2use crate::sql::dialect::SqlDialect;
3
4pub struct ClickHouseDialect;
5
6impl ClickHouseDialect {
7    pub fn new() -> Self {
8        Self
9    }
10}
11
12impl Default for ClickHouseDialect {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl SqlDialect for ClickHouseDialect {
19    fn compile(&self, ir: &QueryIR) -> (String, Vec<SqlValue>) {
20        let mut bindings = Vec::new();
21        let mut sql = String::new();
22
23        sql.push_str("SELECT ");
24        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
25            SelectExpr::Column { column, alias } => {
26                let col = if column.contains('(') { column.clone() } else { format!("`{column}`") };
27                match alias {
28                    Some(a) => format!("{col} AS `{a}`"),
29                    None => col,
30                }
31            },
32            SelectExpr::Aggregate { function, column, alias, condition } => {
33                let func = function.to_uppercase();
34                match (func.as_str(), column.as_str(), condition) {
35                    ("COUNT", "*", None) => format!("count() AS `{alias}`"),
36                    ("COUNT", "*", Some(cond)) => format!("countIf({cond}) AS `{alias}`"),
37                    ("UNIQ", col, None) => format!("uniq(`{col}`) AS `{alias}`"),
38                    ("UNIQ", col, Some(cond)) => format!("uniqIf(`{col}`, {cond}) AS `{alias}`"),
39                    (_, col, None) => format!("{f}(`{col}`) AS `{alias}`", f = func.to_lowercase()),
40                    (_, col, Some(cond)) => format!("{f}If(`{col}`, {cond}) AS `{alias}`", f = func.to_lowercase()),
41                }
42            }
43        }).collect();
44        sql.push_str(&select_parts.join(", "));
45
46        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
47        if ir.use_final {
48            sql.push_str(" FINAL");
49        }
50
51        let where_clause = compile_filter(&ir.filters, &mut bindings);
52        if !where_clause.is_empty() {
53            sql.push_str(" WHERE ");
54            sql.push_str(&where_clause);
55        }
56
57        // Auto-detect AggregatingMergeTree queries: if any SELECT column contains
58        // a -Merge combinator (e.g. argMaxMerge, sumMerge), auto-add GROUP BY
59        // for all non-aggregate columns.
60        let effective_group_by = if !ir.group_by.is_empty() {
61            ir.group_by.clone()
62        } else {
63            let has_merge_cols = ir.selects.iter().any(|s| match s {
64                SelectExpr::Column { column, .. } => column.contains("Merge("),
65                SelectExpr::Aggregate { .. } => true,
66            });
67            if has_merge_cols {
68                ir.selects.iter().filter_map(|s| match s {
69                    SelectExpr::Column { column, .. } if !column.contains("Merge(") && !column.contains('(') => {
70                        Some(column.clone())
71                    }
72                    _ => None,
73                }).collect()
74            } else {
75                vec![]
76            }
77        };
78
79        if !effective_group_by.is_empty() {
80            sql.push_str(" GROUP BY ");
81            let cols: Vec<String> = effective_group_by.iter().map(|c| format!("`{c}`")).collect();
82            sql.push_str(&cols.join(", "));
83        }
84
85        let having_clause = compile_filter(&ir.having, &mut bindings);
86        if !having_clause.is_empty() {
87            sql.push_str(" HAVING ");
88            sql.push_str(&having_clause);
89        }
90
91        if !ir.order_by.is_empty() {
92            sql.push_str(" ORDER BY ");
93            let parts: Vec<String> = ir.order_by.iter().map(|o| {
94                let dir = if o.descending { "DESC" } else { "ASC" };
95                format!("`{}` {dir}", o.column)
96            }).collect();
97            sql.push_str(&parts.join(", "));
98        }
99
100        sql.push_str(&format!(" LIMIT {}", ir.limit));
101        if ir.offset > 0 {
102            sql.push_str(&format!(" OFFSET {}", ir.offset));
103        }
104
105        (sql, bindings)
106    }
107
108    fn quote_identifier(&self, name: &str) -> String {
109        format!("`{name}`")
110    }
111
112    fn name(&self) -> &str {
113        "ClickHouse"
114    }
115}
116
117fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
118    match node {
119        FilterNode::Empty => String::new(),
120        FilterNode::Condition { column, op, value } => {
121            compile_condition(column, op, value, bindings)
122        }
123        FilterNode::And(children) => {
124            let parts: Vec<String> = children.iter()
125                .map(|c| compile_filter(c, bindings))
126                .filter(|s| !s.is_empty())
127                .collect();
128            match parts.len() {
129                0 => String::new(),
130                1 => parts.into_iter().next().unwrap(),
131                _ => format!("({})", parts.join(" AND ")),
132            }
133        }
134        FilterNode::Or(children) => {
135            let parts: Vec<String> = children.iter()
136                .map(|c| compile_filter(c, bindings))
137                .filter(|s| !s.is_empty())
138                .collect();
139            match parts.len() {
140                0 => String::new(),
141                1 => parts.into_iter().next().unwrap(),
142                _ => format!("({})", parts.join(" OR ")),
143            }
144        }
145    }
146}
147
148fn quote_col(column: &str) -> String {
149    if column.contains('(') {
150        column.to_string()
151    } else {
152        format!("`{column}`")
153    }
154}
155
156fn compile_condition(
157    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
158) -> String {
159    let col = quote_col(column);
160    match op {
161        CompareOp::In | CompareOp::NotIn => {
162            if let SqlValue::String(csv) = value {
163                let items: Vec<&str> = csv.split(',').collect();
164                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
165                for item in &items {
166                    bindings.push(SqlValue::String(item.trim().to_string()));
167                }
168                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
169            } else {
170                bindings.push(value.clone());
171                format!("{col} {} (?)", op.sql_op())
172            }
173        }
174        CompareOp::Includes => {
175            if let SqlValue::String(s) = value {
176                bindings.push(SqlValue::String(format!("%{s}%")));
177            } else {
178                bindings.push(value.clone());
179            }
180            format!("{col} LIKE ?")
181        }
182        CompareOp::IsNull | CompareOp::IsNotNull => {
183            format!("{col} {}", op.sql_op())
184        }
185        _ => {
186            bindings.push(value.clone());
187            format!("{col} {} ?", op.sql_op())
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    fn ch() -> ClickHouseDialect { ClickHouseDialect::new() }
197
198    #[test]
199    fn test_simple_select() {
200        let ir = QueryIR {
201            cube: "DEXTrades".into(), schema: "default".into(),
202            table: "dwd_dex_trades".into(),
203            selects: vec![
204                SelectExpr::Column { column: "tx_hash".into(), alias: None },
205                SelectExpr::Column { column: "token_a_amount".into(), alias: None },
206            ],
207            filters: FilterNode::Empty, having: FilterNode::Empty,
208            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
209            use_final: false,
210        };
211        let (sql, bindings) = ch().compile(&ir);
212        assert_eq!(sql, "SELECT `tx_hash`, `token_a_amount` FROM `default`.`dwd_dex_trades` LIMIT 10");
213        assert!(bindings.is_empty());
214    }
215
216    #[test]
217    fn test_final_keyword() {
218        let ir = QueryIR {
219            cube: "T".into(), schema: "db".into(), table: "tokens".into(),
220            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
221            filters: FilterNode::Empty, having: FilterNode::Empty,
222            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
223            use_final: true,
224        };
225        let (sql, _) = ch().compile(&ir);
226        assert!(sql.contains("FROM `db`.`tokens` FINAL"), "FINAL should be appended, got: {sql}");
227    }
228
229    #[test]
230    fn test_uniq_uses_native_function() {
231        let ir = QueryIR {
232            cube: "T".into(), schema: "db".into(), table: "t".into(),
233            selects: vec![
234                SelectExpr::Aggregate { function: "UNIQ".into(), column: "wallet".into(), alias: "__uniq".into(), condition: None },
235            ],
236            filters: FilterNode::Empty, having: FilterNode::Empty,
237            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
238            use_final: false,
239        };
240        let (sql, _) = ch().compile(&ir);
241        assert!(sql.contains("uniq(`wallet`) AS `__uniq`"), "ClickHouse should use native uniq(), got: {sql}");
242    }
243
244    #[test]
245    fn test_count_star() {
246        let ir = QueryIR {
247            cube: "T".into(), schema: "db".into(), table: "t".into(),
248            selects: vec![
249                SelectExpr::Aggregate { function: "COUNT".into(), column: "*".into(), alias: "__count".into(), condition: None },
250            ],
251            filters: FilterNode::Empty, having: FilterNode::Empty,
252            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
253            use_final: false,
254        };
255        let (sql, _) = ch().compile(&ir);
256        assert!(sql.contains("count() AS `__count`"), "ClickHouse should use count() not COUNT(*), got: {sql}");
257    }
258
259    #[test]
260    fn test_aggregate_lowercase() {
261        let ir = QueryIR {
262            cube: "T".into(), schema: "db".into(), table: "t".into(),
263            selects: vec![
264                SelectExpr::Aggregate { function: "SUM".into(), column: "amount".into(), alias: "__sum".into(), condition: None },
265                SelectExpr::Aggregate { function: "AVG".into(), column: "price".into(), alias: "__avg".into(), condition: None },
266            ],
267            filters: FilterNode::Empty, having: FilterNode::Empty,
268            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
269            use_final: false,
270        };
271        let (sql, _) = ch().compile(&ir);
272        assert!(sql.contains("sum(`amount`) AS `__sum`"), "ClickHouse functions should be lowercase, got: {sql}");
273        assert!(sql.contains("avg(`price`) AS `__avg`"), "got: {sql}");
274    }
275
276    #[test]
277    fn test_where_and_order() {
278        let ir = QueryIR {
279            cube: "T".into(), schema: "db".into(), table: "t".into(),
280            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
281            filters: FilterNode::And(vec![
282                FilterNode::Condition { column: "chain_id".into(), op: CompareOp::Eq, value: SqlValue::Int(1) },
283                FilterNode::Condition { column: "amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
284            ]),
285            having: FilterNode::Empty, group_by: vec![],
286            order_by: vec![OrderExpr { column: "block_timestamp".into(), descending: true }],
287            limit: 25, offset: 0,
288            use_final: false,
289        };
290        let (sql, bindings) = ch().compile(&ir);
291        assert!(sql.contains("WHERE (`chain_id` = ? AND `amount_usd` > ?)"));
292        assert!(sql.contains("ORDER BY `block_timestamp` DESC"));
293        assert_eq!(bindings.len(), 2);
294    }
295
296    #[test]
297    fn test_having_with_aggregate_expr() {
298        let ir = QueryIR {
299            cube: "T".into(), schema: "db".into(), table: "t".into(),
300            selects: vec![
301                SelectExpr::Column { column: "token_address".into(), alias: None },
302                SelectExpr::Aggregate { function: "SUM".into(), column: "amount_usd".into(), alias: "__sum".into(), condition: None },
303            ],
304            filters: FilterNode::Empty,
305            having: FilterNode::Condition {
306                column: "sum(`amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
307            },
308            group_by: vec!["token_address".into()], order_by: vec![], limit: 25, offset: 0,
309            use_final: false,
310        };
311        let (sql, bindings) = ch().compile(&ir);
312        assert!(sql.contains("GROUP BY `token_address`"));
313        assert!(sql.contains("HAVING sum(`amount_usd`) > ?"), "got: {sql}");
314        assert_eq!(bindings.len(), 1);
315    }
316}