Skip to main content

activecube_rs/sql/
clickhouse.rs

1use crate::compiler::ir::*;
2use crate::sql::dialect::SqlDialect;
3
4pub struct ClickHouseDialect;
5
6impl ClickHouseDialect {
7    pub fn new() -> Self {
8        Self
9    }
10}
11
12impl Default for ClickHouseDialect {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl SqlDialect for ClickHouseDialect {
19    fn compile(&self, ir: &QueryIR) -> (String, Vec<SqlValue>) {
20        let mut bindings = Vec::new();
21        let mut sql = String::new();
22
23        sql.push_str("SELECT ");
24        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
25            SelectExpr::Column { column, alias } => {
26                let col = if column.contains('(') { column.clone() } else { format!("`{column}`") };
27                match alias {
28                    Some(a) => format!("{col} AS `{a}`"),
29                    None => col,
30                }
31            },
32            SelectExpr::Aggregate { function, column, alias } => {
33                let func = function.to_uppercase();
34                if func == "COUNT" && column == "*" {
35                    format!("count() AS `{alias}`")
36                } else if func == "UNIQ" {
37                    format!("uniq(`{column}`) AS `{alias}`")
38                } else {
39                    format!("{func}(`{column}`) AS `{alias}`", func = func.to_lowercase())
40                }
41            }
42        }).collect();
43        sql.push_str(&select_parts.join(", "));
44
45        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
46
47        let where_clause = compile_filter(&ir.filters, &mut bindings);
48        if !where_clause.is_empty() {
49            sql.push_str(" WHERE ");
50            sql.push_str(&where_clause);
51        }
52
53        // Auto-detect AggregatingMergeTree queries: if any SELECT column contains
54        // a -Merge combinator (e.g. argMaxMerge, sumMerge), auto-add GROUP BY
55        // for all non-aggregate columns.
56        let effective_group_by = if !ir.group_by.is_empty() {
57            ir.group_by.clone()
58        } else {
59            let has_merge_cols = ir.selects.iter().any(|s| match s {
60                SelectExpr::Column { column, .. } => column.contains("Merge("),
61                SelectExpr::Aggregate { .. } => true,
62            });
63            if has_merge_cols {
64                ir.selects.iter().filter_map(|s| match s {
65                    SelectExpr::Column { column, .. } if !column.contains("Merge(") && !column.contains('(') => {
66                        Some(column.clone())
67                    }
68                    _ => None,
69                }).collect()
70            } else {
71                vec![]
72            }
73        };
74
75        if !effective_group_by.is_empty() {
76            sql.push_str(" GROUP BY ");
77            let cols: Vec<String> = effective_group_by.iter().map(|c| format!("`{c}`")).collect();
78            sql.push_str(&cols.join(", "));
79        }
80
81        let having_clause = compile_filter(&ir.having, &mut bindings);
82        if !having_clause.is_empty() {
83            sql.push_str(" HAVING ");
84            sql.push_str(&having_clause);
85        }
86
87        if !ir.order_by.is_empty() {
88            sql.push_str(" ORDER BY ");
89            let parts: Vec<String> = ir.order_by.iter().map(|o| {
90                let dir = if o.descending { "DESC" } else { "ASC" };
91                format!("`{}` {dir}", o.column)
92            }).collect();
93            sql.push_str(&parts.join(", "));
94        }
95
96        sql.push_str(&format!(" LIMIT {}", ir.limit));
97        if ir.offset > 0 {
98            sql.push_str(&format!(" OFFSET {}", ir.offset));
99        }
100
101        (sql, bindings)
102    }
103
104    fn quote_identifier(&self, name: &str) -> String {
105        format!("`{name}`")
106    }
107
108    fn name(&self) -> &str {
109        "ClickHouse"
110    }
111}
112
113fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
114    match node {
115        FilterNode::Empty => String::new(),
116        FilterNode::Condition { column, op, value } => {
117            compile_condition(column, op, value, bindings)
118        }
119        FilterNode::And(children) => {
120            let parts: Vec<String> = children.iter()
121                .map(|c| compile_filter(c, bindings))
122                .filter(|s| !s.is_empty())
123                .collect();
124            match parts.len() {
125                0 => String::new(),
126                1 => parts.into_iter().next().unwrap(),
127                _ => format!("({})", parts.join(" AND ")),
128            }
129        }
130        FilterNode::Or(children) => {
131            let parts: Vec<String> = children.iter()
132                .map(|c| compile_filter(c, bindings))
133                .filter(|s| !s.is_empty())
134                .collect();
135            match parts.len() {
136                0 => String::new(),
137                1 => parts.into_iter().next().unwrap(),
138                _ => format!("({})", parts.join(" OR ")),
139            }
140        }
141    }
142}
143
144fn quote_col(column: &str) -> String {
145    if column.contains('(') {
146        column.to_string()
147    } else {
148        format!("`{column}`")
149    }
150}
151
152fn compile_condition(
153    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
154) -> String {
155    let col = quote_col(column);
156    match op {
157        CompareOp::In | CompareOp::NotIn => {
158            if let SqlValue::String(csv) = value {
159                let items: Vec<&str> = csv.split(',').collect();
160                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
161                for item in &items {
162                    bindings.push(SqlValue::String(item.trim().to_string()));
163                }
164                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
165            } else {
166                bindings.push(value.clone());
167                format!("{col} {} (?)", op.sql_op())
168            }
169        }
170        CompareOp::Includes => {
171            if let SqlValue::String(s) = value {
172                bindings.push(SqlValue::String(format!("%{s}%")));
173            } else {
174                bindings.push(value.clone());
175            }
176            format!("{col} LIKE ?")
177        }
178        CompareOp::IsNull | CompareOp::IsNotNull => {
179            format!("{col} {}", op.sql_op())
180        }
181        _ => {
182            bindings.push(value.clone());
183            format!("{col} {} ?", op.sql_op())
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    fn ch() -> ClickHouseDialect { ClickHouseDialect::new() }
193
194    #[test]
195    fn test_simple_select() {
196        let ir = QueryIR {
197            cube: "DEXTrades".into(), schema: "default".into(),
198            table: "dwd_dex_trades".into(),
199            selects: vec![
200                SelectExpr::Column { column: "tx_hash".into(), alias: None },
201                SelectExpr::Column { column: "token_a_amount".into(), alias: None },
202            ],
203            filters: FilterNode::Empty, having: FilterNode::Empty,
204            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
205        };
206        let (sql, bindings) = ch().compile(&ir);
207        assert_eq!(sql, "SELECT `tx_hash`, `token_a_amount` FROM `default`.`dwd_dex_trades` LIMIT 10");
208        assert!(bindings.is_empty());
209    }
210
211    #[test]
212    fn test_uniq_uses_native_function() {
213        let ir = QueryIR {
214            cube: "T".into(), schema: "db".into(), table: "t".into(),
215            selects: vec![
216                SelectExpr::Aggregate { function: "UNIQ".into(), column: "wallet".into(), alias: "__uniq".into() },
217            ],
218            filters: FilterNode::Empty, having: FilterNode::Empty,
219            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
220        };
221        let (sql, _) = ch().compile(&ir);
222        assert!(sql.contains("uniq(`wallet`) AS `__uniq`"), "ClickHouse should use native uniq(), got: {sql}");
223    }
224
225    #[test]
226    fn test_count_star() {
227        let ir = QueryIR {
228            cube: "T".into(), schema: "db".into(), table: "t".into(),
229            selects: vec![
230                SelectExpr::Aggregate { function: "COUNT".into(), column: "*".into(), alias: "__count".into() },
231            ],
232            filters: FilterNode::Empty, having: FilterNode::Empty,
233            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
234        };
235        let (sql, _) = ch().compile(&ir);
236        assert!(sql.contains("count() AS `__count`"), "ClickHouse should use count() not COUNT(*), got: {sql}");
237    }
238
239    #[test]
240    fn test_aggregate_lowercase() {
241        let ir = QueryIR {
242            cube: "T".into(), schema: "db".into(), table: "t".into(),
243            selects: vec![
244                SelectExpr::Aggregate { function: "SUM".into(), column: "amount".into(), alias: "__sum".into() },
245                SelectExpr::Aggregate { function: "AVG".into(), column: "price".into(), alias: "__avg".into() },
246            ],
247            filters: FilterNode::Empty, having: FilterNode::Empty,
248            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
249        };
250        let (sql, _) = ch().compile(&ir);
251        assert!(sql.contains("sum(`amount`) AS `__sum`"), "ClickHouse functions should be lowercase, got: {sql}");
252        assert!(sql.contains("avg(`price`) AS `__avg`"), "got: {sql}");
253    }
254
255    #[test]
256    fn test_where_and_order() {
257        let ir = QueryIR {
258            cube: "T".into(), schema: "db".into(), table: "t".into(),
259            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
260            filters: FilterNode::And(vec![
261                FilterNode::Condition { column: "chain_id".into(), op: CompareOp::Eq, value: SqlValue::Int(1) },
262                FilterNode::Condition { column: "amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
263            ]),
264            having: FilterNode::Empty, group_by: vec![],
265            order_by: vec![OrderExpr { column: "block_timestamp".into(), descending: true }],
266            limit: 25, offset: 0,
267        };
268        let (sql, bindings) = ch().compile(&ir);
269        assert!(sql.contains("WHERE (`chain_id` = ? AND `amount_usd` > ?)"));
270        assert!(sql.contains("ORDER BY `block_timestamp` DESC"));
271        assert_eq!(bindings.len(), 2);
272    }
273
274    #[test]
275    fn test_having_with_aggregate_expr() {
276        let ir = QueryIR {
277            cube: "T".into(), schema: "db".into(), table: "t".into(),
278            selects: vec![
279                SelectExpr::Column { column: "token_address".into(), alias: None },
280                SelectExpr::Aggregate { function: "SUM".into(), column: "amount_usd".into(), alias: "__sum".into() },
281            ],
282            filters: FilterNode::Empty,
283            having: FilterNode::Condition {
284                column: "sum(`amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
285            },
286            group_by: vec!["token_address".into()], order_by: vec![], limit: 25, offset: 0,
287        };
288        let (sql, bindings) = ch().compile(&ir);
289        assert!(sql.contains("GROUP BY `token_address`"));
290        assert!(sql.contains("HAVING sum(`amount_usd`) > ?"), "got: {sql}");
291        assert_eq!(bindings.len(), 1);
292    }
293}