Skip to main content

activecube_rs/sql/
clickhouse.rs

1use crate::compiler::ir::*;
2use crate::sql::dialect::SqlDialect;
3
4pub struct ClickHouseDialect;
5
6impl ClickHouseDialect {
7    pub fn new() -> Self {
8        Self
9    }
10}
11
12impl Default for ClickHouseDialect {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl SqlDialect for ClickHouseDialect {
19    fn compile(&self, ir: &QueryIR) -> (String, Vec<SqlValue>) {
20        let mut bindings = Vec::new();
21        let mut sql = String::new();
22
23        sql.push_str("SELECT ");
24        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
25            SelectExpr::Column { column, alias } => {
26                let col = if column.contains('(') { column.clone() } else { format!("`{column}`") };
27                match alias {
28                    Some(a) => format!("{col} AS `{a}`"),
29                    None => col,
30                }
31            },
32            SelectExpr::Aggregate { function, column, alias, condition } => {
33                let func = function.to_uppercase();
34                match (func.as_str(), column.as_str(), condition) {
35                    ("COUNT", "*", None) => format!("count() AS `{alias}`"),
36                    ("COUNT", "*", Some(cond)) => format!("countIf({cond}) AS `{alias}`"),
37                    ("UNIQ", col, None) => format!("uniq(`{col}`) AS `{alias}`"),
38                    ("UNIQ", col, Some(cond)) => format!("uniqIf(`{col}`, {cond}) AS `{alias}`"),
39                    (_, col, None) => format!("{f}(`{col}`) AS `{alias}`", f = func.to_lowercase()),
40                    (_, col, Some(cond)) => format!("{f}If(`{col}`, {cond}) AS `{alias}`", f = func.to_lowercase()),
41                }
42            }
43        }).collect();
44        sql.push_str(&select_parts.join(", "));
45
46        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
47        if ir.use_final {
48            sql.push_str(" FINAL");
49        }
50
51        let where_clause = compile_filter(&ir.filters, &mut bindings);
52        if !where_clause.is_empty() {
53            sql.push_str(" WHERE ");
54            sql.push_str(&where_clause);
55        }
56
57        // Auto-detect AggregatingMergeTree queries: if any SELECT column contains
58        // a -Merge combinator (e.g. argMaxMerge, sumMerge), auto-add GROUP BY
59        // for all non-aggregate columns.
60        let effective_group_by = if !ir.group_by.is_empty() {
61            ir.group_by.clone()
62        } else {
63            let has_merge_cols = ir.selects.iter().any(|s| match s {
64                SelectExpr::Column { column, .. } => column.contains("Merge("),
65                SelectExpr::Aggregate { .. } => true,
66            });
67            if has_merge_cols {
68                ir.selects.iter().filter_map(|s| match s {
69                    SelectExpr::Column { column, .. } if !column.contains("Merge(") && !column.contains('(') => {
70                        Some(column.clone())
71                    }
72                    _ => None,
73                }).collect()
74            } else {
75                vec![]
76            }
77        };
78
79        if !effective_group_by.is_empty() {
80            sql.push_str(" GROUP BY ");
81            let cols: Vec<String> = effective_group_by.iter().map(|c| format!("`{c}`")).collect();
82            sql.push_str(&cols.join(", "));
83        }
84
85        let having_clause = compile_filter(&ir.having, &mut bindings);
86        if !having_clause.is_empty() {
87            sql.push_str(" HAVING ");
88            sql.push_str(&having_clause);
89        }
90
91        if !ir.order_by.is_empty() {
92            sql.push_str(" ORDER BY ");
93            let parts: Vec<String> = ir.order_by.iter().map(|o| {
94                let dir = if o.descending { "DESC" } else { "ASC" };
95                format!("`{}` {dir}", o.column)
96            }).collect();
97            sql.push_str(&parts.join(", "));
98        }
99
100        if let Some(ref lb) = ir.limit_by {
101            let by_cols: Vec<String> = lb.columns.iter().map(|c| format!("`{c}`")).collect();
102            sql.push_str(&format!(" LIMIT {} BY {}", lb.count, by_cols.join(", ")));
103            if lb.offset > 0 {
104                sql.push_str(&format!(" OFFSET {}", lb.offset));
105            }
106        }
107
108        sql.push_str(&format!(" LIMIT {}", ir.limit));
109        if ir.offset > 0 {
110            sql.push_str(&format!(" OFFSET {}", ir.offset));
111        }
112
113        (sql, bindings)
114    }
115
116    fn quote_identifier(&self, name: &str) -> String {
117        format!("`{name}`")
118    }
119
120    fn name(&self) -> &str {
121        "ClickHouse"
122    }
123}
124
125fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
126    match node {
127        FilterNode::Empty => String::new(),
128        FilterNode::Condition { column, op, value } => {
129            compile_condition(column, op, value, bindings)
130        }
131        FilterNode::And(children) => {
132            let parts: Vec<String> = children.iter()
133                .map(|c| compile_filter(c, bindings))
134                .filter(|s| !s.is_empty())
135                .collect();
136            match parts.len() {
137                0 => String::new(),
138                1 => parts.into_iter().next().unwrap(),
139                _ => format!("({})", parts.join(" AND ")),
140            }
141        }
142        FilterNode::Or(children) => {
143            let parts: Vec<String> = children.iter()
144                .map(|c| compile_filter(c, bindings))
145                .filter(|s| !s.is_empty())
146                .collect();
147            match parts.len() {
148                0 => String::new(),
149                1 => parts.into_iter().next().unwrap(),
150                _ => format!("({})", parts.join(" OR ")),
151            }
152        }
153    }
154}
155
156fn quote_col(column: &str) -> String {
157    if column.contains('(') {
158        column.to_string()
159    } else {
160        format!("`{column}`")
161    }
162}
163
164fn compile_condition(
165    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
166) -> String {
167    let col = quote_col(column);
168    match op {
169        CompareOp::In | CompareOp::NotIn => {
170            if let SqlValue::String(csv) = value {
171                let items: Vec<&str> = csv.split(',').collect();
172                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
173                for item in &items {
174                    bindings.push(SqlValue::String(item.trim().to_string()));
175                }
176                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
177            } else {
178                bindings.push(value.clone());
179                format!("{col} {} (?)", op.sql_op())
180            }
181        }
182        CompareOp::Includes => {
183            if let SqlValue::String(s) = value {
184                bindings.push(SqlValue::String(format!("%{s}%")));
185            } else {
186                bindings.push(value.clone());
187            }
188            format!("{col} LIKE ?")
189        }
190        CompareOp::IsNull | CompareOp::IsNotNull => {
191            format!("{col} {}", op.sql_op())
192        }
193        _ => {
194            bindings.push(value.clone());
195            format!("{col} {} ?", op.sql_op())
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    fn ch() -> ClickHouseDialect { ClickHouseDialect::new() }
205
206    #[test]
207    fn test_simple_select() {
208        let ir = QueryIR {
209            cube: "DEXTrades".into(), schema: "default".into(),
210            table: "dwd_dex_trades".into(),
211            selects: vec![
212                SelectExpr::Column { column: "tx_hash".into(), alias: None },
213                SelectExpr::Column { column: "token_a_amount".into(), alias: None },
214            ],
215            filters: FilterNode::Empty, having: FilterNode::Empty,
216            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
217            limit_by: None,
218            use_final: false,
219        };
220        let (sql, bindings) = ch().compile(&ir);
221        assert_eq!(sql, "SELECT `tx_hash`, `token_a_amount` FROM `default`.`dwd_dex_trades` LIMIT 10");
222        assert!(bindings.is_empty());
223    }
224
225    #[test]
226    fn test_final_keyword() {
227        let ir = QueryIR {
228            cube: "T".into(), schema: "db".into(), table: "tokens".into(),
229            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
230            filters: FilterNode::Empty, having: FilterNode::Empty,
231            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
232            limit_by: None,
233            use_final: true,
234        };
235        let (sql, _) = ch().compile(&ir);
236        assert!(sql.contains("FROM `db`.`tokens` FINAL"), "FINAL should be appended, got: {sql}");
237    }
238
239    #[test]
240    fn test_uniq_uses_native_function() {
241        let ir = QueryIR {
242            cube: "T".into(), schema: "db".into(), table: "t".into(),
243            selects: vec![
244                SelectExpr::Aggregate { function: "UNIQ".into(), column: "wallet".into(), alias: "__uniq".into(), condition: None },
245            ],
246            filters: FilterNode::Empty, having: FilterNode::Empty,
247            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
248            limit_by: None,
249            use_final: false,
250        };
251        let (sql, _) = ch().compile(&ir);
252        assert!(sql.contains("uniq(`wallet`) AS `__uniq`"), "ClickHouse should use native uniq(), got: {sql}");
253    }
254
255    #[test]
256    fn test_count_star() {
257        let ir = QueryIR {
258            cube: "T".into(), schema: "db".into(), table: "t".into(),
259            selects: vec![
260                SelectExpr::Aggregate { function: "COUNT".into(), column: "*".into(), alias: "__count".into(), condition: None },
261            ],
262            filters: FilterNode::Empty, having: FilterNode::Empty,
263            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
264            limit_by: None,
265            use_final: false,
266        };
267        let (sql, _) = ch().compile(&ir);
268        assert!(sql.contains("count() AS `__count`"), "ClickHouse should use count() not COUNT(*), got: {sql}");
269    }
270
271    #[test]
272    fn test_aggregate_lowercase() {
273        let ir = QueryIR {
274            cube: "T".into(), schema: "db".into(), table: "t".into(),
275            selects: vec![
276                SelectExpr::Aggregate { function: "SUM".into(), column: "amount".into(), alias: "__sum".into(), condition: None },
277                SelectExpr::Aggregate { function: "AVG".into(), column: "price".into(), alias: "__avg".into(), condition: None },
278            ],
279            filters: FilterNode::Empty, having: FilterNode::Empty,
280            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
281            limit_by: None,
282            use_final: false,
283        };
284        let (sql, _) = ch().compile(&ir);
285        assert!(sql.contains("sum(`amount`) AS `__sum`"), "ClickHouse functions should be lowercase, got: {sql}");
286        assert!(sql.contains("avg(`price`) AS `__avg`"), "got: {sql}");
287    }
288
289    #[test]
290    fn test_where_and_order() {
291        let ir = QueryIR {
292            cube: "T".into(), schema: "db".into(), table: "t".into(),
293            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
294            filters: FilterNode::And(vec![
295                FilterNode::Condition { column: "chain_id".into(), op: CompareOp::Eq, value: SqlValue::Int(1) },
296                FilterNode::Condition { column: "amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
297            ]),
298            having: FilterNode::Empty, group_by: vec![],
299            order_by: vec![OrderExpr { column: "block_timestamp".into(), descending: true }],
300            limit: 25, offset: 0,
301            limit_by: None,
302            use_final: false,
303        };
304        let (sql, bindings) = ch().compile(&ir);
305        assert!(sql.contains("WHERE (`chain_id` = ? AND `amount_usd` > ?)"));
306        assert!(sql.contains("ORDER BY `block_timestamp` DESC"));
307        assert_eq!(bindings.len(), 2);
308    }
309
310    #[test]
311    fn test_having_with_aggregate_expr() {
312        let ir = QueryIR {
313            cube: "T".into(), schema: "db".into(), table: "t".into(),
314            selects: vec![
315                SelectExpr::Column { column: "token_address".into(), alias: None },
316                SelectExpr::Aggregate { function: "SUM".into(), column: "amount_usd".into(), alias: "__sum".into(), condition: None },
317            ],
318            filters: FilterNode::Empty,
319            having: FilterNode::Condition {
320                column: "sum(`amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
321            },
322            group_by: vec!["token_address".into()], order_by: vec![], limit: 25, offset: 0,
323            limit_by: None,
324            use_final: false,
325        };
326        let (sql, bindings) = ch().compile(&ir);
327        assert!(sql.contains("GROUP BY `token_address`"));
328        assert!(sql.contains("HAVING sum(`amount_usd`) > ?"), "got: {sql}");
329        assert_eq!(bindings.len(), 1);
330    }
331
332    #[test]
333    fn test_limit_by() {
334        let ir = QueryIR {
335            cube: "T".into(), schema: "db".into(), table: "t".into(),
336            selects: vec![
337                SelectExpr::Column { column: "owner".into(), alias: None },
338                SelectExpr::Column { column: "amount".into(), alias: None },
339            ],
340            filters: FilterNode::Empty, having: FilterNode::Empty,
341            group_by: vec![], 
342            order_by: vec![OrderExpr { column: "amount".into(), descending: true }],
343            limit: 100, offset: 0,
344            limit_by: Some(LimitByExpr { count: 3, offset: 0, columns: vec!["owner".into()] }),
345            use_final: false,
346        };
347        let (sql, _) = ch().compile(&ir);
348        assert!(sql.contains("LIMIT 3 BY `owner`"), "LIMIT BY should be present, got: {sql}");
349        assert!(sql.contains("ORDER BY `amount` DESC"), "ORDER BY should be present, got: {sql}");
350        assert!(sql.contains("LIMIT 100"), "outer LIMIT should be present, got: {sql}");
351        let order_by_pos = sql.find("ORDER BY").unwrap();
352        let limit_by_pos = sql.find("LIMIT 3 BY").unwrap();
353        let limit_pos = sql.rfind("LIMIT 100").unwrap();
354        assert!(order_by_pos < limit_by_pos, "ORDER BY should come before LIMIT BY in ClickHouse");
355        assert!(limit_by_pos < limit_pos, "LIMIT BY should come before outer LIMIT");
356    }
357
358    #[test]
359    fn test_limit_by_with_offset() {
360        let ir = QueryIR {
361            cube: "T".into(), schema: "db".into(), table: "t".into(),
362            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
363            filters: FilterNode::Empty, having: FilterNode::Empty,
364            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
365            limit_by: Some(LimitByExpr { count: 5, offset: 2, columns: vec!["token".into(), "wallet".into()] }),
366            use_final: false,
367        };
368        let (sql, _) = ch().compile(&ir);
369        assert!(sql.contains("LIMIT 5 BY `token`, `wallet` OFFSET 2"), "multi-column LIMIT BY with offset, got: {sql}");
370    }
371}