Skip to main content

activecube_rs/sql/
clickhouse.rs

1use crate::compiler::ir::*;
2use crate::sql::dialect::SqlDialect;
3
4pub struct ClickHouseDialect;
5
6impl ClickHouseDialect {
7    pub fn new() -> Self {
8        Self
9    }
10}
11
12impl Default for ClickHouseDialect {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl SqlDialect for ClickHouseDialect {
19    fn compile(&self, ir: &QueryIR) -> (String, Vec<SqlValue>) {
20        let mut bindings = Vec::new();
21        let mut sql = String::new();
22
23        sql.push_str("SELECT ");
24        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
25            SelectExpr::Column { column, alias } => match alias {
26                Some(a) => format!("`{column}` AS `{a}`"),
27                None => format!("`{column}`"),
28            },
29            SelectExpr::Aggregate { function, column, alias } => {
30                let func = function.to_uppercase();
31                if func == "COUNT" && column == "*" {
32                    format!("count() AS `{alias}`")
33                } else if func == "UNIQ" {
34                    format!("uniq(`{column}`) AS `{alias}`")
35                } else {
36                    format!("{func}(`{column}`) AS `{alias}`", func = func.to_lowercase())
37                }
38            }
39        }).collect();
40        sql.push_str(&select_parts.join(", "));
41
42        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
43
44        let where_clause = compile_filter(&ir.filters, &mut bindings);
45        if !where_clause.is_empty() {
46            sql.push_str(" WHERE ");
47            sql.push_str(&where_clause);
48        }
49
50        if !ir.group_by.is_empty() {
51            sql.push_str(" GROUP BY ");
52            let cols: Vec<String> = ir.group_by.iter().map(|c| format!("`{c}`")).collect();
53            sql.push_str(&cols.join(", "));
54        }
55
56        let having_clause = compile_filter(&ir.having, &mut bindings);
57        if !having_clause.is_empty() {
58            sql.push_str(" HAVING ");
59            sql.push_str(&having_clause);
60        }
61
62        if !ir.order_by.is_empty() {
63            sql.push_str(" ORDER BY ");
64            let parts: Vec<String> = ir.order_by.iter().map(|o| {
65                let dir = if o.descending { "DESC" } else { "ASC" };
66                format!("`{}` {dir}", o.column)
67            }).collect();
68            sql.push_str(&parts.join(", "));
69        }
70
71        sql.push_str(&format!(" LIMIT {}", ir.limit));
72        if ir.offset > 0 {
73            sql.push_str(&format!(" OFFSET {}", ir.offset));
74        }
75
76        (sql, bindings)
77    }
78
79    fn quote_identifier(&self, name: &str) -> String {
80        format!("`{name}`")
81    }
82
83    fn name(&self) -> &str {
84        "ClickHouse"
85    }
86}
87
88fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
89    match node {
90        FilterNode::Empty => String::new(),
91        FilterNode::Condition { column, op, value } => {
92            compile_condition(column, op, value, bindings)
93        }
94        FilterNode::And(children) => {
95            let parts: Vec<String> = children.iter()
96                .map(|c| compile_filter(c, bindings))
97                .filter(|s| !s.is_empty())
98                .collect();
99            match parts.len() {
100                0 => String::new(),
101                1 => parts.into_iter().next().unwrap(),
102                _ => format!("({})", parts.join(" AND ")),
103            }
104        }
105        FilterNode::Or(children) => {
106            let parts: Vec<String> = children.iter()
107                .map(|c| compile_filter(c, bindings))
108                .filter(|s| !s.is_empty())
109                .collect();
110            match parts.len() {
111                0 => String::new(),
112                1 => parts.into_iter().next().unwrap(),
113                _ => format!("({})", parts.join(" OR ")),
114            }
115        }
116    }
117}
118
119fn quote_col(column: &str) -> String {
120    if column.contains('(') {
121        column.to_string()
122    } else {
123        format!("`{column}`")
124    }
125}
126
127fn compile_condition(
128    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
129) -> String {
130    let col = quote_col(column);
131    match op {
132        CompareOp::In | CompareOp::NotIn => {
133            if let SqlValue::String(csv) = value {
134                let items: Vec<&str> = csv.split(',').collect();
135                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
136                for item in &items {
137                    bindings.push(SqlValue::String(item.trim().to_string()));
138                }
139                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
140            } else {
141                bindings.push(value.clone());
142                format!("{col} {} (?)", op.sql_op())
143            }
144        }
145        CompareOp::Includes => {
146            if let SqlValue::String(s) = value {
147                bindings.push(SqlValue::String(format!("%{s}%")));
148            } else {
149                bindings.push(value.clone());
150            }
151            format!("{col} LIKE ?")
152        }
153        CompareOp::IsNull | CompareOp::IsNotNull => {
154            format!("{col} {}", op.sql_op())
155        }
156        _ => {
157            bindings.push(value.clone());
158            format!("{col} {} ?", op.sql_op())
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn ch() -> ClickHouseDialect { ClickHouseDialect::new() }
168
169    #[test]
170    fn test_simple_select() {
171        let ir = QueryIR {
172            cube: "DEXTrades".into(), schema: "default".into(),
173            table: "dwd_dex_trades".into(),
174            selects: vec![
175                SelectExpr::Column { column: "tx_hash".into(), alias: None },
176                SelectExpr::Column { column: "token_a_amount".into(), alias: None },
177            ],
178            filters: FilterNode::Empty, having: FilterNode::Empty,
179            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
180        };
181        let (sql, bindings) = ch().compile(&ir);
182        assert_eq!(sql, "SELECT `tx_hash`, `token_a_amount` FROM `default`.`dwd_dex_trades` LIMIT 10");
183        assert!(bindings.is_empty());
184    }
185
186    #[test]
187    fn test_uniq_uses_native_function() {
188        let ir = QueryIR {
189            cube: "T".into(), schema: "db".into(), table: "t".into(),
190            selects: vec![
191                SelectExpr::Aggregate { function: "UNIQ".into(), column: "wallet".into(), alias: "__uniq".into() },
192            ],
193            filters: FilterNode::Empty, having: FilterNode::Empty,
194            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
195        };
196        let (sql, _) = ch().compile(&ir);
197        assert!(sql.contains("uniq(`wallet`) AS `__uniq`"), "ClickHouse should use native uniq(), got: {sql}");
198    }
199
200    #[test]
201    fn test_count_star() {
202        let ir = QueryIR {
203            cube: "T".into(), schema: "db".into(), table: "t".into(),
204            selects: vec![
205                SelectExpr::Aggregate { function: "COUNT".into(), column: "*".into(), alias: "__count".into() },
206            ],
207            filters: FilterNode::Empty, having: FilterNode::Empty,
208            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
209        };
210        let (sql, _) = ch().compile(&ir);
211        assert!(sql.contains("count() AS `__count`"), "ClickHouse should use count() not COUNT(*), got: {sql}");
212    }
213
214    #[test]
215    fn test_aggregate_lowercase() {
216        let ir = QueryIR {
217            cube: "T".into(), schema: "db".into(), table: "t".into(),
218            selects: vec![
219                SelectExpr::Aggregate { function: "SUM".into(), column: "amount".into(), alias: "__sum".into() },
220                SelectExpr::Aggregate { function: "AVG".into(), column: "price".into(), alias: "__avg".into() },
221            ],
222            filters: FilterNode::Empty, having: FilterNode::Empty,
223            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
224        };
225        let (sql, _) = ch().compile(&ir);
226        assert!(sql.contains("sum(`amount`) AS `__sum`"), "ClickHouse functions should be lowercase, got: {sql}");
227        assert!(sql.contains("avg(`price`) AS `__avg`"), "got: {sql}");
228    }
229
230    #[test]
231    fn test_where_and_order() {
232        let ir = QueryIR {
233            cube: "T".into(), schema: "db".into(), table: "t".into(),
234            selects: vec![SelectExpr::Column { column: "id".into(), alias: None }],
235            filters: FilterNode::And(vec![
236                FilterNode::Condition { column: "chain_id".into(), op: CompareOp::Eq, value: SqlValue::Int(1) },
237                FilterNode::Condition { column: "amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
238            ]),
239            having: FilterNode::Empty, group_by: vec![],
240            order_by: vec![OrderExpr { column: "block_timestamp".into(), descending: true }],
241            limit: 25, offset: 0,
242        };
243        let (sql, bindings) = ch().compile(&ir);
244        assert!(sql.contains("WHERE (`chain_id` = ? AND `amount_usd` > ?)"));
245        assert!(sql.contains("ORDER BY `block_timestamp` DESC"));
246        assert_eq!(bindings.len(), 2);
247    }
248
249    #[test]
250    fn test_having_with_aggregate_expr() {
251        let ir = QueryIR {
252            cube: "T".into(), schema: "db".into(), table: "t".into(),
253            selects: vec![
254                SelectExpr::Column { column: "token_address".into(), alias: None },
255                SelectExpr::Aggregate { function: "SUM".into(), column: "amount_usd".into(), alias: "__sum".into() },
256            ],
257            filters: FilterNode::Empty,
258            having: FilterNode::Condition {
259                column: "sum(`amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
260            },
261            group_by: vec!["token_address".into()], order_by: vec![], limit: 25, offset: 0,
262        };
263        let (sql, bindings) = ch().compile(&ir);
264        assert!(sql.contains("GROUP BY `token_address`"));
265        assert!(sql.contains("HAVING sum(`amount_usd`) > ?"), "got: {sql}");
266        assert_eq!(bindings.len(), 1);
267    }
268}