Skip to main content

activecube_rs/sql/
starrocks.rs

1use crate::compiler::ir::*;
2use crate::compiler::ir::CompileResult;
3use crate::sql::dialect::SqlDialect;
4
5pub struct StarRocksDialect;
6
7impl StarRocksDialect {
8    pub fn new() -> Self {
9        Self
10    }
11}
12
13impl Default for StarRocksDialect {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl SqlDialect for StarRocksDialect {
20    fn compile(&self, ir: &QueryIR) -> CompileResult {
21        let mut bindings = Vec::new();
22        let mut sql = String::new();
23
24        sql.push_str("SELECT ");
25        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
26            SelectExpr::Column { column, alias } => match alias {
27                Some(a) => format!("`{column}` AS `{a}`"),
28                None => format!("`{column}`"),
29            },
30            SelectExpr::Aggregate { function, column, alias, condition } => {
31                let func = function.to_uppercase();
32                match (func.as_str(), column.as_str(), condition) {
33                    ("COUNT", "*", None) => format!("COUNT(*) AS `{alias}`"),
34                    ("COUNT", "*", Some(cond)) => format!("COUNT(IF({cond}, 1, NULL)) AS `{alias}`"),
35                    ("UNIQ", col, None) => format!("COUNT(DISTINCT `{col}`) AS `{alias}`"),
36                    ("UNIQ", col, Some(cond)) => format!("COUNT(DISTINCT IF({cond}, `{col}`, NULL)) AS `{alias}`"),
37                    (f, col, None) => format!("{f}(`{col}`) AS `{alias}`"),
38                    (f, col, Some(cond)) => format!("{f}(IF({cond}, `{col}`, NULL)) AS `{alias}`"),
39                }
40            }
41        }).collect();
42        sql.push_str(&select_parts.join(", "));
43
44        if let Some(ref subquery) = ir.from_subquery {
45            sql.push_str(&format!(" FROM ({}) AS _t", subquery));
46        } else {
47            sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
48        }
49
50        let where_clause = compile_filter(&ir.filters, &mut bindings);
51        if !where_clause.is_empty() {
52            sql.push_str(" WHERE ");
53            sql.push_str(&where_clause);
54        }
55
56        if !ir.group_by.is_empty() {
57            sql.push_str(" GROUP BY ");
58            let cols: Vec<String> = ir.group_by.iter().map(|c| format!("`{c}`")).collect();
59            sql.push_str(&cols.join(", "));
60        }
61
62        let having_clause = compile_filter(&ir.having, &mut bindings);
63        if !having_clause.is_empty() {
64            sql.push_str(" HAVING ");
65            sql.push_str(&having_clause);
66        }
67
68        if !ir.order_by.is_empty() {
69            sql.push_str(" ORDER BY ");
70            let parts: Vec<String> = ir.order_by.iter().map(|o| {
71                let dir = if o.descending { "DESC" } else { "ASC" };
72                format!("`{}` {dir}", o.column)
73            }).collect();
74            sql.push_str(&parts.join(", "));
75        }
76
77        sql.push_str(&format!(" LIMIT {}", ir.limit));
78        if ir.offset > 0 {
79            sql.push_str(&format!(" OFFSET {}", ir.offset));
80        }
81
82        CompileResult { sql, bindings, alias_remap: vec![] }
83    }
84
85    fn quote_identifier(&self, name: &str) -> String {
86        format!("`{name}`")
87    }
88
89    fn name(&self) -> &str {
90        "StarRocks"
91    }
92}
93
94fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
95    match node {
96        FilterNode::Empty => String::new(),
97        FilterNode::Condition { column, op, value } => {
98            compile_condition(column, op, value, bindings)
99        }
100        FilterNode::And(children) => {
101            let parts: Vec<String> = children.iter()
102                .map(|c| compile_filter(c, bindings))
103                .filter(|s| !s.is_empty())
104                .collect();
105            match parts.len() {
106                0 => String::new(),
107                1 => parts.into_iter().next().unwrap(),
108                _ => format!("({})", parts.join(" AND ")),
109            }
110        }
111        FilterNode::Or(children) => {
112            let parts: Vec<String> = children.iter()
113                .map(|c| compile_filter(c, bindings))
114                .filter(|s| !s.is_empty())
115                .collect();
116            match parts.len() {
117                0 => String::new(),
118                1 => parts.into_iter().next().unwrap(),
119                _ => format!("({})", parts.join(" OR ")),
120            }
121        }
122    }
123}
124
125/// Quote a column identifier, but leave aggregate expressions (containing `(`)
126/// unquoted so that `SUM(\`col\`)` doesn't become `` `SUM(\`col\`)` ``.
127fn quote_col(column: &str) -> String {
128    if column.contains('(') {
129        column.to_string()
130    } else {
131        format!("`{column}`")
132    }
133}
134
135fn compile_condition(
136    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
137) -> String {
138    let col = quote_col(column);
139    match op {
140        CompareOp::In | CompareOp::NotIn => {
141            if let SqlValue::String(csv) = value {
142                let items: Vec<&str> = csv.split(',').collect();
143                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
144                for item in &items {
145                    bindings.push(SqlValue::String(item.trim().to_string()));
146                }
147                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
148            } else {
149                bindings.push(value.clone());
150                format!("{col} {} (?)", op.sql_op())
151            }
152        }
153        CompareOp::Includes => {
154            if let SqlValue::String(s) = value {
155                bindings.push(SqlValue::String(format!("%{s}%")));
156            } else {
157                bindings.push(value.clone());
158            }
159            format!("{col} LIKE ?")
160        }
161        CompareOp::IsNull | CompareOp::IsNotNull => {
162            format!("{col} {}", op.sql_op())
163        }
164        _ => {
165            bindings.push(value.clone());
166            format!("{col} {} ?", op.sql_op())
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    fn make_dialect() -> StarRocksDialect { StarRocksDialect::new() }
176
177    #[test]
178    fn test_simple_select() {
179        let ir = QueryIR {
180            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
181            table: "sol_activities".into(),
182            selects: vec![
183                SelectExpr::Column { column: "tx_hash".into(), alias: None },
184                SelectExpr::Column { column: "buy_amount".into(), alias: None },
185            ],
186            filters: FilterNode::Empty, having: FilterNode::Empty,
187            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
188            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None, from_subquery: None,
189        };
190        let r = make_dialect().compile(&ir);
191        assert_eq!(r.sql, "SELECT `tx_hash`, `buy_amount` FROM `dexes_dwd`.`sol_activities` LIMIT 10");
192        assert!(r.bindings.is_empty());
193    }
194
195    #[test]
196    fn test_where_and_order() {
197        let ir = QueryIR {
198            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
199            table: "sol_activities".into(),
200            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
201            filters: FilterNode::And(vec![
202                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
203                FilterNode::Condition { column: "success".into(), op: CompareOp::Eq, value: SqlValue::Bool(true) },
204            ]),
205            having: FilterNode::Empty, group_by: vec![],
206            order_by: vec![OrderExpr { column: "buy_amount_usd".into(), descending: true }],
207            limit: 25, offset: 0,
208            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None, from_subquery: None,
209        };
210        let r = make_dialect().compile(&ir);
211        assert!(r.sql.contains("WHERE (`buy_amount_usd` > ? AND `success` = ?)"));
212        assert!(r.sql.contains("ORDER BY `buy_amount_usd` DESC"));
213        assert_eq!(r.bindings.len(), 2);
214    }
215
216    #[test]
217    fn test_or_condition() {
218        let ir = QueryIR {
219            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
220            table: "sol_activities".into(),
221            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
222            filters: FilterNode::And(vec![
223                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
224                FilterNode::Or(vec![
225                    FilterNode::Condition { column: "buy_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
226                    FilterNode::Condition { column: "sell_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
227                ]),
228            ]),
229            having: FilterNode::Empty, group_by: vec![], order_by: vec![], limit: 25, offset: 0,
230            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None, from_subquery: None,
231        };
232        let r = make_dialect().compile(&ir);
233        assert!(r.sql.contains("(`buy_token_symbol` = ? OR `sell_token_symbol` = ?)"));
234        assert_eq!(r.bindings.len(), 3);
235    }
236
237    #[test]
238    fn test_aggregate_with_having() {
239        let ir = QueryIR {
240            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
241            table: "sol_activities".into(),
242            selects: vec![
243                SelectExpr::Column { column: "buy_token_symbol".into(), alias: None },
244                SelectExpr::Aggregate { function: "SUM".into(), column: "buy_amount_usd".into(), alias: "__sum".into(), condition: None },
245            ],
246            filters: FilterNode::Empty,
247            having: FilterNode::Condition {
248                column: "SUM(`buy_amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
249            },
250            group_by: vec!["buy_token_symbol".into()], order_by: vec![], limit: 25, offset: 0,
251            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None, from_subquery: None,
252        };
253        let r = make_dialect().compile(&ir);
254        assert!(r.sql.contains("GROUP BY `buy_token_symbol`"));
255        assert!(r.sql.contains("HAVING SUM(`buy_amount_usd`) > ?"), "HAVING clause should not backtick-wrap aggregate expressions, got: {r}", r = r.sql);
256        assert_eq!(r.bindings.len(), 1);
257    }
258
259    #[test]
260    fn test_offset() {
261        let ir = QueryIR {
262            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
263            table: "sol_activities".into(),
264            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
265            filters: FilterNode::Empty, having: FilterNode::Empty,
266            group_by: vec![], order_by: vec![], limit: 10, offset: 20,
267            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None, from_subquery: None,
268        };
269        let r = make_dialect().compile(&ir);
270        assert!(r.sql.ends_with("LIMIT 10 OFFSET 20"));
271    }
272
273    #[test]
274    fn test_from_subquery() {
275        let ir = QueryIR {
276            cube: "T".into(), schema: "dwd".into(), table: "t".into(),
277            selects: vec![SelectExpr::Column { column: "val".into(), alias: None }],
278            filters: FilterNode::Empty, having: FilterNode::Empty,
279            group_by: vec![], order_by: vec![], limit: 5, offset: 0,
280            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
281            from_subquery: Some("SELECT val FROM dwd.a UNION ALL SELECT val FROM dwd.b".into()),
282        };
283        let r = make_dialect().compile(&ir);
284        assert!(r.sql.contains("FROM (SELECT val FROM dwd.a UNION ALL SELECT val FROM dwd.b) AS _t"),
285            "StarRocks should use subquery in FROM, got: {}", r.sql);
286        assert!(!r.sql.contains("FROM `dwd`.`t`"),
287            "should NOT use schema.table when from_subquery is set, got: {}", r.sql);
288    }
289}