1use crate::compiler::ir::*;
2use crate::compiler::ir::CompileResult;
3use crate::sql::dialect::SqlDialect;
4
5pub struct StarRocksDialect;
6
7impl StarRocksDialect {
8 pub fn new() -> Self {
9 Self
10 }
11}
12
13impl Default for StarRocksDialect {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl SqlDialect for StarRocksDialect {
20 fn compile(&self, ir: &QueryIR) -> CompileResult {
21 let mut bindings = Vec::new();
22 let mut sql = String::new();
23
24 sql.push_str("SELECT ");
25 let select_parts: Vec<String> = ir.selects.iter().map(|s| match s {
26 SelectExpr::Column { column, alias } => match alias {
27 Some(a) => format!("`{column}` AS `{a}`"),
28 None => format!("`{column}`"),
29 },
30 SelectExpr::Aggregate { function, column, alias, condition } => {
31 let func = function.to_uppercase();
32 match (func.as_str(), column.as_str(), condition) {
33 ("COUNT", "*", None) => format!("COUNT(*) AS `{alias}`"),
34 ("COUNT", "*", Some(cond)) => format!("COUNT(IF({cond}, 1, NULL)) AS `{alias}`"),
35 ("UNIQ", col, None) => format!("COUNT(DISTINCT `{col}`) AS `{alias}`"),
36 ("UNIQ", col, Some(cond)) => format!("COUNT(DISTINCT IF({cond}, `{col}`, NULL)) AS `{alias}`"),
37 (f, col, None) => format!("{f}(`{col}`) AS `{alias}`"),
38 (f, col, Some(cond)) => format!("{f}(IF({cond}, `{col}`, NULL)) AS `{alias}`"),
39 }
40 }
41 }).collect();
42 sql.push_str(&select_parts.join(", "));
43
44 sql.push_str(&format!(" FROM `{}`.`{}`", ir.schema, ir.table));
45
46 let where_clause = compile_filter(&ir.filters, &mut bindings);
47 if !where_clause.is_empty() {
48 sql.push_str(" WHERE ");
49 sql.push_str(&where_clause);
50 }
51
52 if !ir.group_by.is_empty() {
53 sql.push_str(" GROUP BY ");
54 let cols: Vec<String> = ir.group_by.iter().map(|c| format!("`{c}`")).collect();
55 sql.push_str(&cols.join(", "));
56 }
57
58 let having_clause = compile_filter(&ir.having, &mut bindings);
59 if !having_clause.is_empty() {
60 sql.push_str(" HAVING ");
61 sql.push_str(&having_clause);
62 }
63
64 if !ir.order_by.is_empty() {
65 sql.push_str(" ORDER BY ");
66 let parts: Vec<String> = ir.order_by.iter().map(|o| {
67 let dir = if o.descending { "DESC" } else { "ASC" };
68 format!("`{}` {dir}", o.column)
69 }).collect();
70 sql.push_str(&parts.join(", "));
71 }
72
73 sql.push_str(&format!(" LIMIT {}", ir.limit));
74 if ir.offset > 0 {
75 sql.push_str(&format!(" OFFSET {}", ir.offset));
76 }
77
78 CompileResult { sql, bindings, alias_remap: vec![] }
79 }
80
81 fn quote_identifier(&self, name: &str) -> String {
82 format!("`{name}`")
83 }
84
85 fn name(&self) -> &str {
86 "StarRocks"
87 }
88}
89
90fn compile_filter(node: &FilterNode, bindings: &mut Vec<SqlValue>) -> String {
91 match node {
92 FilterNode::Empty => String::new(),
93 FilterNode::Condition { column, op, value } => {
94 compile_condition(column, op, value, bindings)
95 }
96 FilterNode::And(children) => {
97 let parts: Vec<String> = children.iter()
98 .map(|c| compile_filter(c, bindings))
99 .filter(|s| !s.is_empty())
100 .collect();
101 match parts.len() {
102 0 => String::new(),
103 1 => parts.into_iter().next().unwrap(),
104 _ => format!("({})", parts.join(" AND ")),
105 }
106 }
107 FilterNode::Or(children) => {
108 let parts: Vec<String> = children.iter()
109 .map(|c| compile_filter(c, bindings))
110 .filter(|s| !s.is_empty())
111 .collect();
112 match parts.len() {
113 0 => String::new(),
114 1 => parts.into_iter().next().unwrap(),
115 _ => format!("({})", parts.join(" OR ")),
116 }
117 }
118 }
119}
120
121fn quote_col(column: &str) -> String {
124 if column.contains('(') {
125 column.to_string()
126 } else {
127 format!("`{column}`")
128 }
129}
130
131fn compile_condition(
132 column: &str, op: &CompareOp, value: &SqlValue, bindings: &mut Vec<SqlValue>,
133) -> String {
134 let col = quote_col(column);
135 match op {
136 CompareOp::In | CompareOp::NotIn => {
137 if let SqlValue::String(csv) = value {
138 let items: Vec<&str> = csv.split(',').collect();
139 let placeholders: Vec<&str> = items.iter().map(|_| "?").collect();
140 for item in &items {
141 bindings.push(SqlValue::String(item.trim().to_string()));
142 }
143 format!("{col} {} ({})", op.sql_op(), placeholders.join(", "))
144 } else {
145 bindings.push(value.clone());
146 format!("{col} {} (?)", op.sql_op())
147 }
148 }
149 CompareOp::Includes => {
150 if let SqlValue::String(s) = value {
151 bindings.push(SqlValue::String(format!("%{s}%")));
152 } else {
153 bindings.push(value.clone());
154 }
155 format!("{col} LIKE ?")
156 }
157 CompareOp::IsNull | CompareOp::IsNotNull => {
158 format!("{col} {}", op.sql_op())
159 }
160 _ => {
161 bindings.push(value.clone());
162 format!("{col} {} ?", op.sql_op())
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 fn make_dialect() -> StarRocksDialect { StarRocksDialect::new() }
172
173 #[test]
174 fn test_simple_select() {
175 let ir = QueryIR {
176 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
177 table: "sol_activities".into(),
178 selects: vec![
179 SelectExpr::Column { column: "tx_hash".into(), alias: None },
180 SelectExpr::Column { column: "buy_amount".into(), alias: None },
181 ],
182 filters: FilterNode::Empty, having: FilterNode::Empty,
183 group_by: vec![], order_by: vec![], limit: 10, offset: 0,
184 limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
185 };
186 let r = make_dialect().compile(&ir);
187 assert_eq!(r.sql, "SELECT `tx_hash`, `buy_amount` FROM `dexes_dwd`.`sol_activities` LIMIT 10");
188 assert!(r.bindings.is_empty());
189 }
190
191 #[test]
192 fn test_where_and_order() {
193 let ir = QueryIR {
194 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
195 table: "sol_activities".into(),
196 selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
197 filters: FilterNode::And(vec![
198 FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
199 FilterNode::Condition { column: "success".into(), op: CompareOp::Eq, value: SqlValue::Bool(true) },
200 ]),
201 having: FilterNode::Empty, group_by: vec![],
202 order_by: vec![OrderExpr { column: "buy_amount_usd".into(), descending: true }],
203 limit: 25, offset: 0,
204 limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
205 };
206 let r = make_dialect().compile(&ir);
207 assert!(r.sql.contains("WHERE (`buy_amount_usd` > ? AND `success` = ?)"));
208 assert!(r.sql.contains("ORDER BY `buy_amount_usd` DESC"));
209 assert_eq!(r.bindings.len(), 2);
210 }
211
212 #[test]
213 fn test_or_condition() {
214 let ir = QueryIR {
215 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
216 table: "sol_activities".into(),
217 selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
218 filters: FilterNode::And(vec![
219 FilterNode::Condition { column: "buy_amount_usd".into(), op: CompareOp::Gt, value: SqlValue::Float(1000.0) },
220 FilterNode::Or(vec![
221 FilterNode::Condition { column: "buy_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
222 FilterNode::Condition { column: "sell_token_symbol".into(), op: CompareOp::Eq, value: SqlValue::String("SOL".into()) },
223 ]),
224 ]),
225 having: FilterNode::Empty, group_by: vec![], order_by: vec![], limit: 25, offset: 0,
226 limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
227 };
228 let r = make_dialect().compile(&ir);
229 assert!(r.sql.contains("(`buy_token_symbol` = ? OR `sell_token_symbol` = ?)"));
230 assert_eq!(r.bindings.len(), 3);
231 }
232
233 #[test]
234 fn test_aggregate_with_having() {
235 let ir = QueryIR {
236 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
237 table: "sol_activities".into(),
238 selects: vec![
239 SelectExpr::Column { column: "buy_token_symbol".into(), alias: None },
240 SelectExpr::Aggregate { function: "SUM".into(), column: "buy_amount_usd".into(), alias: "__sum".into(), condition: None },
241 ],
242 filters: FilterNode::Empty,
243 having: FilterNode::Condition {
244 column: "SUM(`buy_amount_usd`)".into(), op: CompareOp::Gt, value: SqlValue::Float(1000000.0),
245 },
246 group_by: vec!["buy_token_symbol".into()], order_by: vec![], limit: 25, offset: 0,
247 limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
248 };
249 let r = make_dialect().compile(&ir);
250 assert!(r.sql.contains("GROUP BY `buy_token_symbol`"));
251 assert!(r.sql.contains("HAVING SUM(`buy_amount_usd`) > ?"), "HAVING clause should not backtick-wrap aggregate expressions, got: {r}", r = r.sql);
252 assert_eq!(r.bindings.len(), 1);
253 }
254
255 #[test]
256 fn test_offset() {
257 let ir = QueryIR {
258 cube: "DEXTrades".into(), schema: "dexes_dwd".into(),
259 table: "sol_activities".into(),
260 selects: vec![SelectExpr::Column { column: "tx_hash".into(), alias: None }],
261 filters: FilterNode::Empty, having: FilterNode::Empty,
262 group_by: vec![], order_by: vec![], limit: 10, offset: 20,
263 limit_by: None, use_final: false, joins: vec![], custom_query_builder: None,
264 };
265 let r = make_dialect().compile(&ir);
266 assert!(r.sql.ends_with("LIMIT 10 OFFSET 20"));
267 }
268}