Skip to main content

activecube_rs/sql/
starrocks.rs

1use crate::compiler::ir::*;
2use crate::sql::dialect::SqlDialect;
3
4pub struct StarRocksDialect;
5
6impl StarRocksDialect {
7    pub fn new() -> Self {
8        Self
9    }
10}
11
12impl Default for StarRocksDialect {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl SqlDialect for StarRocksDialect {
19    fn compile(&self, ir: &QueryIR) -> (String, Vec<SqlValue>) {
20        let mut bindings = Vec::new();
21        let mut sql = String::new();
22
23        sql.push_str("SELECT ");
24        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
25            SelectExpr::Column { column, alias } => match alias {
26                Some(a) => format!("`{column}` AS `{a}`"),
27                None => format!("`{column}`"),
28            },
29            SelectExpr::Aggregate { function, column, alias } => {
30                let func = function.to_uppercase();
31                if func == "COUNT" && column == "*" {
32                    format!("COUNT(*) AS `{alias}`")
33                } else if func == "UNIQ" {
34                    format!("COUNT(DISTINCT `{column}`) AS `{alias}`")
35                } else {
36                    format!("{func}(`{column}`) AS `{alias}`")
37                }
38            }
39        }).collect();
40        sql.push_str(&select_parts.join(", "));
41
42        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
43
44        let where_clause = compile_filter(&ir.filters, &mut bindings);
45        if !where_clause.is_empty() {
46            sql.push_str(" WHERE ");
47            sql.push_str(&where_clause);
48        }
49
50        if !ir.group_by.is_empty() {
51            sql.push_str(" GROUP BY ");
52            let cols: Vec<String> = ir.group_by.iter().map(|c| format!("`{c}`")).collect();
53            sql.push_str(&cols.join(", "));
54        }
55
56        let having_clause = compile_filter(&ir.having, &mut bindings);
57        if !having_clause.is_empty() {
58            sql.push_str(" HAVING ");
59            sql.push_str(&having_clause);
60        }
61
62        if !ir.order_by.is_empty() {
63            sql.push_str(" ORDER BY ");
64            let parts: Vec<String> = ir.order_by.iter().map(|o| {
65                let dir = if o.descending { "DESC" } else { "ASC" };
66                format!("`{}` {dir}", o.column)
67            }).collect();
68            sql.push_str(&parts.join(", "));
69        }
70
71        sql.push_str(&format!(" LIMIT {}", ir.limit));
72        if ir.offset > 0 {
73            sql.push_str(&format!(" OFFSET {}", ir.offset));
74        }
75
76        (sql, bindings)
77    }
78
79    fn quote_identifier(&self, name: &str) -> String {
80        format!("`{name}`")
81    }
82
83    fn name(&self) -> &str {
84        "StarRocks"
85    }
86}
87
88fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
89    match node {
90        FilterNode::Empty => String::new(),
91        FilterNode::Condition { column, op, value } => {
92            compile_condition(column, op, value, bindings)
93        }
94        FilterNode::And(children) => {
95            let parts: Vec<String> = children.iter()
96                .map(|c| compile_filter(c, bindings))
97                .filter(|s| !s.is_empty())
98                .collect();
99            match parts.len() {
100                0 => String::new(),
101                1 => parts.into_iter().next().unwrap(),
102                _ => format!("({})", parts.join(" AND ")),
103            }
104        }
105        FilterNode::Or(children) => {
106            let parts: Vec<String> = children.iter()
107                .map(|c| compile_filter(c, bindings))
108                .filter(|s| !s.is_empty())
109                .collect();
110            match parts.len() {
111                0 => String::new(),
112                1 => parts.into_iter().next().unwrap(),
113                _ => format!("({})", parts.join(" OR ")),
114            }
115        }
116    }
117}
118
119/// Quote a column identifier, but leave aggregate expressions (containing `(`)
120/// unquoted so that `SUM(\`col\`)` doesn't become `` `SUM(\`col\`)` ``.
121fn quote_col(column: &str) -> String {
122    if column.contains('(') {
123        column.to_string()
124    } else {
125        format!("`{column}`")
126    }
127}
128
129fn compile_condition(
130    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
131) -> String {
132    let col = quote_col(column);
133    match op {
134        CompareOp::In | CompareOp::NotIn => {
135            if let SqlValue::String(csv) = value {
136                let items: Vec<&str> = csv.split(',').collect();
137                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
138                for item in &items {
139                    bindings.push(SqlValue::String(item.trim().to_string()));
140                }
141                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
142            } else {
143                bindings.push(value.clone());
144                format!("{col} {} (?)", op.sql_op())
145            }
146        }
147        CompareOp::Includes => {
148            if let SqlValue::String(s) = value {
149                bindings.push(SqlValue::String(format!("%{s}%")));
150            } else {
151                bindings.push(value.clone());
152            }
153            format!("{col} LIKE ?")
154        }
155        CompareOp::IsNull | CompareOp::IsNotNull => {
156            format!("{col} {}", op.sql_op())
157        }
158        _ => {
159            bindings.push(value.clone());
160            format!("{col} {} ?", op.sql_op())
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    fn make_dialect() -> StarRocksDialect { StarRocksDialect::new() }
170
171    #[test]
172    fn test_simple_select() {
173        let ir = QueryIR {
174            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
175            table: "sol_activities".into(),
176            selects: vec![
177                SelectExpr::Column { column: "tx_hash".into(), alias: None },
178                SelectExpr::Column { column: "buy_amount".into(), alias: None },
179            ],
180            filters: FilterNode::Empty, having: FilterNode::Empty,
181            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
182        };
183        let (sql, bindings) = make_dialect().compile(&ir);
184        assert_eq!(sql, "SELECT `tx_hash`, `buy_amount` FROM `dexes_dwd`.`sol_activities` LIMIT 10");
185        assert!(bindings.is_empty());
186    }
187
188    #[test]
189    fn test_where_and_order() {
190        let ir = QueryIR {
191            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
192            table: "sol_activities".into(),
193            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
194            filters: FilterNode::And(vec![
195                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
196                FilterNode::Condition { column: "success".into(), op: CompareOp::Eq, value: SqlValue::Bool(true) },
197            ]),
198            having: FilterNode::Empty, group_by: vec![],
199            order_by: vec![OrderExpr { column: "buy_amount_usd".into(), descending: true }],
200            limit: 25, offset: 0,
201        };
202        let (sql, bindings) = make_dialect().compile(&ir);
203        assert!(sql.contains("WHERE (`buy_amount_usd` > ? AND `success` = ?)"));
204        assert!(sql.contains("ORDER BY `buy_amount_usd` DESC"));
205        assert_eq!(bindings.len(), 2);
206    }
207
208    #[test]
209    fn test_or_condition() {
210        let ir = QueryIR {
211            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
212            table: "sol_activities".into(),
213            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
214            filters: FilterNode::And(vec![
215                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
216                FilterNode::Or(vec![
217                    FilterNode::Condition { column: "buy_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
218                    FilterNode::Condition { column: "sell_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
219                ]),
220            ]),
221            having: FilterNode::Empty, group_by: vec![], order_by: vec![], limit: 25, offset: 0,
222        };
223        let (sql, bindings) = make_dialect().compile(&ir);
224        assert!(sql.contains("(`buy_token_symbol` = ? OR `sell_token_symbol` = ?)"));
225        assert_eq!(bindings.len(), 3);
226    }
227
228    #[test]
229    fn test_aggregate_with_having() {
230        let ir = QueryIR {
231            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
232            table: "sol_activities".into(),
233            selects: vec![
234                SelectExpr::Column { column: "buy_token_symbol".into(), alias: None },
235                SelectExpr::Aggregate { function: "SUM".into(), column: "buy_amount_usd".into(), alias: "__sum".into() },
236            ],
237            filters: FilterNode::Empty,
238            having: FilterNode::Condition {
239                column: "SUM(`buy_amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
240            },
241            group_by: vec!["buy_token_symbol".into()], order_by: vec![], limit: 25, offset: 0,
242        };
243        let (sql, bindings) = make_dialect().compile(&ir);
244        assert!(sql.contains("GROUP BY `buy_token_symbol`"));
245        assert!(sql.contains("HAVING SUM(`buy_amount_usd`) > ?"), "HAVING clause should not backtick-wrap aggregate expressions, got: {sql}");
246        assert_eq!(bindings.len(), 1);
247    }
248
249    #[test]
250    fn test_offset() {
251        let ir = QueryIR {
252            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
253            table: "sol_activities".into(),
254            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
255            filters: FilterNode::Empty, having: FilterNode::Empty,
256            group_by: vec![], order_by: vec![], limit: 10, offset: 20,
257        };
258        let (sql, _) = make_dialect().compile(&ir);
259        assert!(sql.ends_with("LIMIT 10 OFFSET 20"));
260    }
261}