Skip to main content

activecube_rs/sql/
starrocks.rs

1use crate::compiler::ir::*;
2use crate::sql::dialect::SqlDialect;
3
4pub struct StarRocksDialect;
5
6impl StarRocksDialect {
7    pub fn new() -> Self {
8        Self
9    }
10}
11
12impl Default for StarRocksDialect {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl SqlDialect for StarRocksDialect {
19    fn compile(&self, ir: &QueryIR) -> (String, Vec<SqlValue>) {
20        let mut bindings = Vec::new();
21        let mut sql = String::new();
22
23        sql.push_str("SELECT ");
24        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
25            SelectExpr::Column { column, alias } => match alias {
26                Some(a) => format!("`{column}` AS `{a}`"),
27                None => format!("`{column}`"),
28            },
29            SelectExpr::Aggregate { function, column, alias, condition } => {
30                let func = function.to_uppercase();
31                match (func.as_str(), column.as_str(), condition) {
32                    ("COUNT", "*", None) => format!("COUNT(*) AS `{alias}`"),
33                    ("COUNT", "*", Some(cond)) => format!("COUNT(IF({cond}, 1, NULL)) AS `{alias}`"),
34                    ("UNIQ", col, None) => format!("COUNT(DISTINCT `{col}`) AS `{alias}`"),
35                    ("UNIQ", col, Some(cond)) => format!("COUNT(DISTINCT IF({cond}, `{col}`, NULL)) AS `{alias}`"),
36                    (f, col, None) => format!("{f}(`{col}`) AS `{alias}`"),
37                    (f, col, Some(cond)) => format!("{f}(IF({cond}, `{col}`, NULL)) AS `{alias}`"),
38                }
39            }
40        }).collect();
41        sql.push_str(&select_parts.join(", "));
42
43        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
44
45        let where_clause = compile_filter(&ir.filters, &mut bindings);
46        if !where_clause.is_empty() {
47            sql.push_str(" WHERE ");
48            sql.push_str(&where_clause);
49        }
50
51        if !ir.group_by.is_empty() {
52            sql.push_str(" GROUP BY ");
53            let cols: Vec<String> = ir.group_by.iter().map(|c| format!("`{c}`")).collect();
54            sql.push_str(&cols.join(", "));
55        }
56
57        let having_clause = compile_filter(&ir.having, &mut bindings);
58        if !having_clause.is_empty() {
59            sql.push_str(" HAVING ");
60            sql.push_str(&having_clause);
61        }
62
63        if !ir.order_by.is_empty() {
64            sql.push_str(" ORDER BY ");
65            let parts: Vec<String> = ir.order_by.iter().map(|o| {
66                let dir = if o.descending { "DESC" } else { "ASC" };
67                format!("`{}` {dir}", o.column)
68            }).collect();
69            sql.push_str(&parts.join(", "));
70        }
71
72        sql.push_str(&format!(" LIMIT {}", ir.limit));
73        if ir.offset > 0 {
74            sql.push_str(&format!(" OFFSET {}", ir.offset));
75        }
76
77        (sql, bindings)
78    }
79
80    fn quote_identifier(&self, name: &str) -> String {
81        format!("`{name}`")
82    }
83
84    fn name(&self) -> &str {
85        "StarRocks"
86    }
87}
88
89fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
90    match node {
91        FilterNode::Empty => String::new(),
92        FilterNode::Condition { column, op, value } => {
93            compile_condition(column, op, value, bindings)
94        }
95        FilterNode::And(children) => {
96            let parts: Vec<String> = children.iter()
97                .map(|c| compile_filter(c, bindings))
98                .filter(|s| !s.is_empty())
99                .collect();
100            match parts.len() {
101                0 => String::new(),
102                1 => parts.into_iter().next().unwrap(),
103                _ => format!("({})", parts.join(" AND ")),
104            }
105        }
106        FilterNode::Or(children) => {
107            let parts: Vec<String> = children.iter()
108                .map(|c| compile_filter(c, bindings))
109                .filter(|s| !s.is_empty())
110                .collect();
111            match parts.len() {
112                0 => String::new(),
113                1 => parts.into_iter().next().unwrap(),
114                _ => format!("({})", parts.join(" OR ")),
115            }
116        }
117    }
118}
119
120/// Quote a column identifier, but leave aggregate expressions (containing `(`)
121/// unquoted so that `SUM(\`col\`)` doesn't become `` `SUM(\`col\`)` ``.
122fn quote_col(column: &str) -> String {
123    if column.contains('(') {
124        column.to_string()
125    } else {
126        format!("`{column}`")
127    }
128}
129
130fn compile_condition(
131    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
132) -> String {
133    let col = quote_col(column);
134    match op {
135        CompareOp::In | CompareOp::NotIn => {
136            if let SqlValue::String(csv) = value {
137                let items: Vec<&str> = csv.split(',').collect();
138                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
139                for item in &items {
140                    bindings.push(SqlValue::String(item.trim().to_string()));
141                }
142                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
143            } else {
144                bindings.push(value.clone());
145                format!("{col} {} (?)", op.sql_op())
146            }
147        }
148        CompareOp::Includes => {
149            if let SqlValue::String(s) = value {
150                bindings.push(SqlValue::String(format!("%{s}%")));
151            } else {
152                bindings.push(value.clone());
153            }
154            format!("{col} LIKE ?")
155        }
156        CompareOp::IsNull | CompareOp::IsNotNull => {
157            format!("{col} {}", op.sql_op())
158        }
159        _ => {
160            bindings.push(value.clone());
161            format!("{col} {} ?", op.sql_op())
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    fn make_dialect() -> StarRocksDialect { StarRocksDialect::new() }
171
172    #[test]
173    fn test_simple_select() {
174        let ir = QueryIR {
175            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
176            table: "sol_activities".into(),
177            selects: vec![
178                SelectExpr::Column { column: "tx_hash".into(), alias: None },
179                SelectExpr::Column { column: "buy_amount".into(), alias: None },
180            ],
181            filters: FilterNode::Empty, having: FilterNode::Empty,
182            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
183            limit_by: None,
184            use_final: false,
185        };
186        let (sql, bindings) = make_dialect().compile(&ir);
187        assert_eq!(sql, "SELECT `tx_hash`, `buy_amount` FROM `dexes_dwd`.`sol_activities` LIMIT 10");
188        assert!(bindings.is_empty());
189    }
190
191    #[test]
192    fn test_where_and_order() {
193        let ir = QueryIR {
194            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
195            table: "sol_activities".into(),
196            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
197            filters: FilterNode::And(vec![
198                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
199                FilterNode::Condition { column: "success".into(), op: CompareOp::Eq, value: SqlValue::Bool(true) },
200            ]),
201            having: FilterNode::Empty, group_by: vec![],
202            order_by: vec![OrderExpr { column: "buy_amount_usd".into(), descending: true }],
203            limit: 25, offset: 0,
204            limit_by: None,
205            use_final: false,
206        };
207        let (sql, bindings) = make_dialect().compile(&ir);
208        assert!(sql.contains("WHERE (`buy_amount_usd` > ? AND `success` = ?)"));
209        assert!(sql.contains("ORDER BY `buy_amount_usd` DESC"));
210        assert_eq!(bindings.len(), 2);
211    }
212
213    #[test]
214    fn test_or_condition() {
215        let ir = QueryIR {
216            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
217            table: "sol_activities".into(),
218            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
219            filters: FilterNode::And(vec![
220                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
221                FilterNode::Or(vec![
222                    FilterNode::Condition { column: "buy_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
223                    FilterNode::Condition { column: "sell_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
224                ]),
225            ]),
226            having: FilterNode::Empty, group_by: vec![], order_by: vec![], limit: 25, offset: 0,
227            limit_by: None,
228            use_final: false,
229        };
230        let (sql, bindings) = make_dialect().compile(&ir);
231        assert!(sql.contains("(`buy_token_symbol` = ? OR `sell_token_symbol` = ?)"));
232        assert_eq!(bindings.len(), 3);
233    }
234
235    #[test]
236    fn test_aggregate_with_having() {
237        let ir = QueryIR {
238            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
239            table: "sol_activities".into(),
240            selects: vec![
241                SelectExpr::Column { column: "buy_token_symbol".into(), alias: None },
242                SelectExpr::Aggregate { function: "SUM".into(), column: "buy_amount_usd".into(), alias: "__sum".into(), condition: None },
243            ],
244            filters: FilterNode::Empty,
245            having: FilterNode::Condition {
246                column: "SUM(`buy_amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
247            },
248            group_by: vec!["buy_token_symbol".into()], order_by: vec![], limit: 25, offset: 0,
249            limit_by: None,
250            use_final: false,
251        };
252        let (sql, bindings) = make_dialect().compile(&ir);
253        assert!(sql.contains("GROUP BY `buy_token_symbol`"));
254        assert!(sql.contains("HAVING SUM(`buy_amount_usd`) > ?"), "HAVING clause should not backtick-wrap aggregate expressions, got: {sql}");
255        assert_eq!(bindings.len(), 1);
256    }
257
258    #[test]
259    fn test_offset() {
260        let ir = QueryIR {
261            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
262            table: "sol_activities".into(),
263            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
264            filters: FilterNode::Empty, having: FilterNode::Empty,
265            group_by: vec![], order_by: vec![], limit: 10, offset: 20,
266            limit_by: None,
267            use_final: false,
268        };
269        let (sql, _) = make_dialect().compile(&ir);
270        assert!(sql.ends_with("LIMIT 10 OFFSET 20"));
271    }
272}