1use crate::compiler::ir::*;
2use crate::sql::dialect::SqlDialect;
3
4pub struct StarRocksDialect;
5
6impl StarRocksDialect {
7 pub fn new() -> Self {
8 Self
9 }
10}
11
12impl Default for StarRocksDialect {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl SqlDialect for StarRocksDialect {
19 fn compile(&self, ir: &QueryIR) -> (String, Vec<SqlValue>) {
20 let mut bindings = Vec::new();
21 let mut sql = String::new();
22
23 sql.push_str("SELECT ");
24 let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
25 SelectExpr::Column { column, alias } => match alias {
26 Some(a) => format!("`{column}` AS `{a}`"),
27 None => format!("`{column}`"),
28 },
29 SelectExpr::Aggregate { function, column, alias, condition } => {
30 let func = function.to_uppercase();
31 match (func.as_str(), column.as_str(), condition) {
32 ("COUNT", "*", None) => format!("COUNT(*) AS `{alias}`"),
33 ("COUNT", "*", Some(cond)) => format!("COUNT(IF({cond}, 1, NULL)) AS `{alias}`"),
34 ("UNIQ", col, None) => format!("COUNT(DISTINCT `{col}`) AS `{alias}`"),
35 ("UNIQ", col, Some(cond)) => format!("COUNT(DISTINCT IF({cond}, `{col}`, NULL)) AS `{alias}`"),
36 (f, col, None) => format!("{f}(`{col}`) AS `{alias}`"),
37 (f, col, Some(cond)) => format!("{f}(IF({cond}, `{col}`, NULL)) AS `{alias}`"),
38 }
39 }
40 }).collect();
41 sql.push_str(&select_parts.join(", "));
42
43 sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
44
45 let where_clause = compile_filter(&ir.filters, &mut bindings);
46 if !where_clause.is_empty() {
47 sql.push_str(" WHERE ");
48 sql.push_str(&where_clause);
49 }
50
51 if !ir.group_by.is_empty() {
52 sql.push_str(" GROUP BY ");
53 let cols: Vec<String> = ir.group_by.iter().map(|c| format!("`{c}`")).collect();
54 sql.push_str(&cols.join(", "));
55 }
56
57 let having_clause = compile_filter(&ir.having, &mut bindings);
58 if !having_clause.is_empty() {
59 sql.push_str(" HAVING ");
60 sql.push_str(&having_clause);
61 }
62
63 if !ir.order_by.is_empty() {
64 sql.push_str(" ORDER BY ");
65 let parts: Vec<String> = ir.order_by.iter().map(|o| {
66 let dir = if o.descending { "DESC" } else { "ASC" };
67 format!("`{}` {dir}", o.column)
68 }).collect();
69 sql.push_str(&parts.join(", "));
70 }
71
72 sql.push_str(&format!(" LIMIT {}", ir.limit));
73 if ir.offset > 0 {
74 sql.push_str(&format!(" OFFSET {}", ir.offset));
75 }
76
77 (sql, bindings)
78 }
79
80 fn quote_identifier(&self, name: &str) -> String {
81 format!("`{name}`")
82 }
83
84 fn name(&self) -> &str {
85 "StarRocks"
86 }
87}
88
89fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
90 match node {
91 FilterNode::Empty => String::new(),
92 FilterNode::Condition { column, op, value } => {
93 compile_condition(column, op, value, bindings)
94 }
95 FilterNode::And(children) => {
96 let parts: Vec<String> = children.iter()
97 .map(|c| compile_filter(c, bindings))
98 .filter(|s| !s.is_empty())
99 .collect();
100 match parts.len() {
101 0 => String::new(),
102 1 => parts.into_iter().next().unwrap(),
103 _ => format!("({})", parts.join(" AND ")),
104 }
105 }
106 FilterNode::Or(children) => {
107 let parts: Vec<String> = children.iter()
108 .map(|c| compile_filter(c, bindings))
109 .filter(|s| !s.is_empty())
110 .collect();
111 match parts.len() {
112 0 => String::new(),
113 1 => parts.into_iter().next().unwrap(),
114 _ => format!("({})", parts.join(" OR ")),
115 }
116 }
117 }
118}
119
120fn quote_col(column: &str) -> String {
123 if column.contains('(') {
124 column.to_string()
125 } else {
126 format!("`{column}`")
127 }
128}
129
130fn compile_condition(
131 column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
132) -> String {
133 let col = quote_col(column);
134 match op {
135 CompareOp::In | CompareOp::NotIn => {
136 if let SqlValue::String(csv) = value {
137 let items: Vec<&str> = csv.split(',').collect();
138 let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
139 for item in &items {
140 bindings.push(SqlValue::String(item.trim().to_string()));
141 }
142 format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
143 } else {
144 bindings.push(value.clone());
145 format!("{col} {} (?)", op.sql_op())
146 }
147 }
148 CompareOp::Includes => {
149 if let SqlValue::String(s) = value {
150 bindings.push(SqlValue::String(format!("%{s}%")));
151 } else {
152 bindings.push(value.clone());
153 }
154 format!("{col} LIKE ?")
155 }
156 CompareOp::IsNull | CompareOp::IsNotNull => {
157 format!("{col} {}", op.sql_op())
158 }
159 _ => {
160 bindings.push(value.clone());
161 format!("{col} {} ?", op.sql_op())
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 fn make_dialect() -> StarRocksDialect { StarRocksDialect::new() }
171
172 #[test]
173 fn test_simple_select() {
174 let ir = QueryIR {
175 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
176 table: "sol_activities".into(),
177 selects: vec![
178 SelectExpr::Column { column: "tx_hash".into(), alias: None },
179 SelectExpr::Column { column: "buy_amount".into(), alias: None },
180 ],
181 filters: FilterNode::Empty, having: FilterNode::Empty,
182 group_by: vec![], order_by: vec![], limit: 10, offset: 0,
183 limit_by: None,
184 use_final: false,
185 };
186 let (sql, bindings) = make_dialect().compile(&ir);
187 assert_eq!(sql, "SELECT `tx_hash`, `buy_amount` FROM `dexes_dwd`.`sol_activities` LIMIT 10");
188 assert!(bindings.is_empty());
189 }
190
191 #[test]
192 fn test_where_and_order() {
193 let ir = QueryIR {
194 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
195 table: "sol_activities".into(),
196 selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
197 filters: FilterNode::And(vec![
198 FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
199 FilterNode::Condition { column: "success".into(), op: CompareOp::Eq, value: SqlValue::Bool(true) },
200 ]),
201 having: FilterNode::Empty, group_by: vec![],
202 order_by: vec![OrderExpr { column: "buy_amount_usd".into(), descending: true }],
203 limit: 25, offset: 0,
204 limit_by: None,
205 use_final: false,
206 };
207 let (sql, bindings) = make_dialect().compile(&ir);
208 assert!(sql.contains("WHERE (`buy_amount_usd` > ? AND `success` = ?)"));
209 assert!(sql.contains("ORDER BY `buy_amount_usd` DESC"));
210 assert_eq!(bindings.len(), 2);
211 }
212
213 #[test]
214 fn test_or_condition() {
215 let ir = QueryIR {
216 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
217 table: "sol_activities".into(),
218 selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
219 filters: FilterNode::And(vec![
220 FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
221 FilterNode::Or(vec![
222 FilterNode::Condition { column: "buy_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
223 FilterNode::Condition { column: "sell_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
224 ]),
225 ]),
226 having: FilterNode::Empty, group_by: vec![], order_by: vec![], limit: 25, offset: 0,
227 limit_by: None,
228 use_final: false,
229 };
230 let (sql, bindings) = make_dialect().compile(&ir);
231 assert!(sql.contains("(`buy_token_symbol` = ? OR `sell_token_symbol` = ?)"));
232 assert_eq!(bindings.len(), 3);
233 }
234
235 #[test]
236 fn test_aggregate_with_having() {
237 let ir = QueryIR {
238 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
239 table: "sol_activities".into(),
240 selects: vec![
241 SelectExpr::Column { column: "buy_token_symbol".into(), alias: None },
242 SelectExpr::Aggregate { function: "SUM".into(), column: "buy_amount_usd".into(), alias: "__sum".into(), condition: None },
243 ],
244 filters: FilterNode::Empty,
245 having: FilterNode::Condition {
246 column: "SUM(`buy_amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
247 },
248 group_by: vec!["buy_token_symbol".into()], order_by: vec![], limit: 25, offset: 0,
249 limit_by: None,
250 use_final: false,
251 };
252 let (sql, bindings) = make_dialect().compile(&ir);
253 assert!(sql.contains("GROUP BY `buy_token_symbol`"));
254 assert!(sql.contains("HAVING SUM(`buy_amount_usd`) > ?"), "HAVING clause should not backtick-wrap aggregate expressions, got: {sql}");
255 assert_eq!(bindings.len(), 1);
256 }
257
258 #[test]
259 fn test_offset() {
260 let ir = QueryIR {
261 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
262 table: "sol_activities".into(),
263 selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
264 filters: FilterNode::Empty, having: FilterNode::Empty,
265 group_by: vec![], order_by: vec![], limit: 10, offset: 20,
266 limit_by: None,
267 use_final: false,
268 };
269 let (sql, _) = make_dialect().compile(&ir);
270 assert!(sql.ends_with("LIMIT 10 OFFSET 20"));
271 }
272}