Skip to main content

activecube_rs/sql/
starrocks.rs

1use crate::compiler::ir::*;
2use crate::compiler::ir::CompileResult;
3use crate::sql::dialect::SqlDialect;
4
5pub struct StarRocksDialect;
6
7impl StarRocksDialect {
8    pub fn new() -> Self {
9        Self
10    }
11}
12
13impl Default for StarRocksDialect {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl SqlDialect for StarRocksDialect {
20    fn compile(&self, ir: &QueryIR) -> CompileResult {
21        let mut bindings = Vec::new();
22        let mut sql = String::new();
23
24        sql.push_str("SELECT ");
25        let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
26            SelectExpr::Column { column, alias } => match alias {
27                Some(a) => format!("`{column}` AS `{a}`"),
28                None => format!("`{column}`"),
29            },
30            SelectExpr::Aggregate { function, column, alias, condition } => {
31                let func = function.to_uppercase();
32                match (func.as_str(), column.as_str(), condition) {
33                    ("COUNT", "*", None) => format!("COUNT(*) AS `{alias}`"),
34                    ("COUNT", "*", Some(cond)) => format!("COUNT(IF({cond}, 1, NULL)) AS `{alias}`"),
35                    ("UNIQ", col, None) => format!("COUNT(DISTINCT `{col}`) AS `{alias}`"),
36                    ("UNIQ", col, Some(cond)) => format!("COUNT(DISTINCT IF({cond}, `{col}`, NULL)) AS `{alias}`"),
37                    (f, col, None) => format!("{f}(`{col}`) AS `{alias}`"),
38                    (f, col, Some(cond)) => format!("{f}(IF({cond}, `{col}`, NULL)) AS `{alias}`"),
39                }
40            }
41        }).collect();
42        sql.push_str(&select_parts.join(", "));
43
44        sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
45
46        let where_clause = compile_filter(&ir.filters, &mut bindings);
47        if !where_clause.is_empty() {
48            sql.push_str(" WHERE ");
49            sql.push_str(&where_clause);
50        }
51
52        if !ir.group_by.is_empty() {
53            sql.push_str(" GROUP BY ");
54            let cols: Vec<String> = ir.group_by.iter().map(|c| format!("`{c}`")).collect();
55            sql.push_str(&cols.join(", "));
56        }
57
58        let having_clause = compile_filter(&ir.having, &mut bindings);
59        if !having_clause.is_empty() {
60            sql.push_str(" HAVING ");
61            sql.push_str(&having_clause);
62        }
63
64        if !ir.order_by.is_empty() {
65            sql.push_str(" ORDER BY ");
66            let parts: Vec<String> = ir.order_by.iter().map(|o| {
67                let dir = if o.descending { "DESC" } else { "ASC" };
68                format!("`{}` {dir}", o.column)
69            }).collect();
70            sql.push_str(&parts.join(", "));
71        }
72
73        sql.push_str(&format!(" LIMIT {}", ir.limit));
74        if ir.offset > 0 {
75            sql.push_str(&format!(" OFFSET {}", ir.offset));
76        }
77
78        CompileResult { sql, bindings, alias_remap: vec![] }
79    }
80
81    fn quote_identifier(&self, name: &str) -> String {
82        format!("`{name}`")
83    }
84
85    fn name(&self) -> &str {
86        "StarRocks"
87    }
88}
89
90fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
91    match node {
92        FilterNode::Empty => String::new(),
93        FilterNode::Condition { column, op, value } => {
94            compile_condition(column, op, value, bindings)
95        }
96        FilterNode::And(children) => {
97            let parts: Vec<String> = children.iter()
98                .map(|c| compile_filter(c, bindings))
99                .filter(|s| !s.is_empty())
100                .collect();
101            match parts.len() {
102                0 => String::new(),
103                1 => parts.into_iter().next().unwrap(),
104                _ => format!("({})", parts.join(" AND ")),
105            }
106        }
107        FilterNode::Or(children) => {
108            let parts: Vec<String> = children.iter()
109                .map(|c| compile_filter(c, bindings))
110                .filter(|s| !s.is_empty())
111                .collect();
112            match parts.len() {
113                0 => String::new(),
114                1 => parts.into_iter().next().unwrap(),
115                _ => format!("({})", parts.join(" OR ")),
116            }
117        }
118    }
119}
120
121/// Quote a column identifier, but leave aggregate expressions (containing `(`)
122/// unquoted so that `SUM(\`col\`)` doesn't become `` `SUM(\`col\`)` ``.
123fn quote_col(column: &str) -> String {
124    if column.contains('(') {
125        column.to_string()
126    } else {
127        format!("`{column}`")
128    }
129}
130
131fn compile_condition(
132    column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
133) -> String {
134    let col = quote_col(column);
135    match op {
136        CompareOp::In | CompareOp::NotIn => {
137            if let SqlValue::String(csv) = value {
138                let items: Vec<&str> = csv.split(',').collect();
139                let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
140                for item in &items {
141                    bindings.push(SqlValue::String(item.trim().to_string()));
142                }
143                format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
144            } else {
145                bindings.push(value.clone());
146                format!("{col} {} (?)", op.sql_op())
147            }
148        }
149        CompareOp::Includes => {
150            if let SqlValue::String(s) = value {
151                bindings.push(SqlValue::String(format!("%{s}%")));
152            } else {
153                bindings.push(value.clone());
154            }
155            format!("{col} LIKE ?")
156        }
157        CompareOp::IsNull | CompareOp::IsNotNull => {
158            format!("{col} {}", op.sql_op())
159        }
160        _ => {
161            bindings.push(value.clone());
162            format!("{col} {} ?", op.sql_op())
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    fn make_dialect() -> StarRocksDialect { StarRocksDialect::new() }
172
173    #[test]
174    fn test_simple_select() {
175        let ir = QueryIR {
176            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
177            table: "sol_activities".into(),
178            selects: vec![
179                SelectExpr::Column { column: "tx_hash".into(), alias: None },
180                SelectExpr::Column { column: "buy_amount".into(), alias: None },
181            ],
182            filters: FilterNode::Empty, having: FilterNode::Empty,
183            group_by: vec![], order_by: vec![], limit: 10, offset: 0,
184            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
185        };
186        let r = make_dialect().compile(&ir);
187        assert_eq!(r.sql, "SELECT `tx_hash`, `buy_amount` FROM `dexes_dwd`.`sol_activities` LIMIT 10");
188        assert!(r.bindings.is_empty());
189    }
190
191    #[test]
192    fn test_where_and_order() {
193        let ir = QueryIR {
194            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
195            table: "sol_activities".into(),
196            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
197            filters: FilterNode::And(vec![
198                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
199                FilterNode::Condition { column: "success".into(), op: CompareOp::Eq, value: SqlValue::Bool(true) },
200            ]),
201            having: FilterNode::Empty, group_by: vec![],
202            order_by: vec![OrderExpr { column: "buy_amount_usd".into(), descending: true }],
203            limit: 25, offset: 0,
204            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
205        };
206        let r = make_dialect().compile(&ir);
207        assert!(r.sql.contains("WHERE (`buy_amount_usd` > ? AND `success` = ?)"));
208        assert!(r.sql.contains("ORDER BY `buy_amount_usd` DESC"));
209        assert_eq!(r.bindings.len(), 2);
210    }
211
212    #[test]
213    fn test_or_condition() {
214        let ir = QueryIR {
215            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
216            table: "sol_activities".into(),
217            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
218            filters: FilterNode::And(vec![
219                FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
220                FilterNode::Or(vec![
221                    FilterNode::Condition { column: "buy_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
222                    FilterNode::Condition { column: "sell_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
223                ]),
224            ]),
225            having: FilterNode::Empty, group_by: vec![], order_by: vec![], limit: 25, offset: 0,
226            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
227        };
228        let r = make_dialect().compile(&ir);
229        assert!(r.sql.contains("(`buy_token_symbol` = ? OR `sell_token_symbol` = ?)"));
230        assert_eq!(r.bindings.len(), 3);
231    }
232
233    #[test]
234    fn test_aggregate_with_having() {
235        let ir = QueryIR {
236            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
237            table: "sol_activities".into(),
238            selects: vec![
239                SelectExpr::Column { column: "buy_token_symbol".into(), alias: None },
240                SelectExpr::Aggregate { function: "SUM".into(), column: "buy_amount_usd".into(), alias: "__sum".into(), condition: None },
241            ],
242            filters: FilterNode::Empty,
243            having: FilterNode::Condition {
244                column: "SUM(`buy_amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
245            },
246            group_by: vec!["buy_token_symbol".into()], order_by: vec![], limit: 25, offset: 0,
247            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
248        };
249        let r = make_dialect().compile(&ir);
250        assert!(r.sql.contains("GROUP BY `buy_token_symbol`"));
251        assert!(r.sql.contains("HAVING SUM(`buy_amount_usd`) > ?"), "HAVING clause should not backtick-wrap aggregate expressions, got: {r}", r = r.sql);
252        assert_eq!(r.bindings.len(), 1);
253    }
254
255    #[test]
256    fn test_offset() {
257        let ir = QueryIR {
258            cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
259            table: "sol_activities".into(),
260            selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
261            filters: FilterNode::Empty, having: FilterNode::Empty,
262            group_by: vec![], order_by: vec![], limit: 10, offset: 20,
263            limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
264        };
265        let r = make_dialect().compile(&ir);
266        assert!(r.sql.ends_with("LIMIT 10 OFFSET 20"));
267    }
268}