Skip to main content

polyglot_sql/dialects/
starrocks.rs

1//! StarRocks Dialect
2//!
3//! StarRocks-specific transformations based on sqlglot patterns.
4//! StarRocks is MySQL-compatible with OLAP extensions (similar to Doris).
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, Lateral, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12/// StarRocks dialect
13pub struct StarRocksDialect;
14
15impl DialectImpl for StarRocksDialect {
16    fn dialect_type(&self) -> DialectType {
17        DialectType::StarRocks
18    }
19
20    fn tokenizer_config(&self) -> TokenizerConfig {
21        use crate::tokens::TokenType;
22        let mut config = TokenizerConfig::default();
23        // StarRocks uses backticks for identifiers (MySQL-style)
24        config.identifiers.insert('`', '`');
25        // Remove double quotes from identifiers (MySQL-style)
26        config.identifiers.remove(&'"');
27        config.quotes.insert("\"".to_string(), "\"".to_string());
28        config.nested_comments = false;
29        // LARGEINT maps to INT128
30        config
31            .keywords
32            .insert("LARGEINT".to_string(), TokenType::Int128);
33        config
34    }
35
36    fn generator_config(&self) -> GeneratorConfig {
37        use crate::generator::IdentifierQuoteStyle;
38        GeneratorConfig {
39            identifier_quote: '`',
40            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
41            dialect: Some(DialectType::StarRocks),
42            // StarRocks: INSERT OVERWRITE (without TABLE keyword)
43            insert_overwrite: " OVERWRITE",
44            // StarRocks: PROPERTIES prefix for WITH properties
45            with_properties_prefix: "PROPERTIES",
46            // StarRocks uses MySQL-style settings
47            null_ordering_supported: false,
48            limit_only_literals: true,
49            semi_anti_join_with_side: false,
50            supports_table_alias_columns: false,
51            values_as_table: false,
52            tablesample_requires_parens: false,
53            tablesample_with_method: false,
54            aggregate_filter_supported: false,
55            try_supported: false,
56            supports_convert_timezone: false,
57            supports_uescape: false,
58            supports_between_flags: false,
59            query_hints: false,
60            parameter_token: "?",
61            supports_window_exclude: false,
62            supports_exploding_projections: false,
63            // StarRocks: COMMENT 'value' (naked property, no = sign)
64            schema_comment_with_eq: false,
65            ..Default::default()
66        }
67    }
68
69    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
70        match expr {
71            // IFNULL is native in StarRocks (MySQL-style)
72            Expression::IfNull(f) => Ok(Expression::IfNull(f)),
73
74            // NVL -> IFNULL in StarRocks
75            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
76
77            // TryCast -> not directly supported, use CAST
78            Expression::TryCast(c) => Ok(Expression::Cast(c)),
79
80            // SafeCast -> CAST in StarRocks
81            Expression::SafeCast(c) => Ok(Expression::Cast(c)),
82
83            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
84            Expression::CountIf(f) => {
85                let case_expr = Expression::Case(Box::new(Case {
86                    operand: None,
87                    whens: vec![(f.this.clone(), Expression::number(1))],
88                    else_: Some(Expression::number(0)),
89                    comments: Vec::new(),
90                }));
91                Ok(Expression::Sum(Box::new(AggFunc {
92                    ignore_nulls: None,
93                    having_max: None,
94                    this: case_expr,
95                    distinct: f.distinct,
96                    filter: f.filter,
97                    order_by: Vec::new(),
98                    name: None,
99                    limit: None,
100                })))
101            }
102
103            // RAND is native in StarRocks
104            Expression::Rand(r) => Ok(Expression::Rand(r)),
105
106            // JSON arrow syntax: preserve -> for StarRocks (arrow_json_extract_sql)
107            Expression::JsonExtract(mut f) => {
108                // Set arrow_syntax to true to preserve -> operator
109                f.arrow_syntax = true;
110                Ok(Expression::JsonExtract(f))
111            }
112
113            Expression::JsonExtractScalar(mut f) => {
114                // Set arrow_syntax to true to preserve ->> operator
115                f.arrow_syntax = true;
116                Ok(Expression::JsonExtractScalar(f))
117            }
118
119            // Generic function transformations
120            Expression::Function(f) => self.transform_function(*f),
121
122            // Generic aggregate function transformations
123            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
124
125            // Cast transformations
126            Expression::Cast(c) => self.transform_cast(*c),
127
128            // Handle LATERAL UNNEST - StarRocks requires column alias "unnest" by default
129            Expression::Lateral(mut l) => {
130                self.transform_lateral(&mut l)?;
131                Ok(Expression::Lateral(l))
132            }
133
134            // Pass through everything else
135            _ => Ok(expr),
136        }
137    }
138}
139
140impl StarRocksDialect {
141    fn transform_function(&self, f: Function) -> Result<Expression> {
142        let name_upper = f.name.to_uppercase();
143        match name_upper.as_str() {
144            // NVL -> IFNULL
145            "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
146                "IFNULL".to_string(),
147                f.args,
148            )))),
149
150            // ISNULL -> IFNULL
151            "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
152                "IFNULL".to_string(),
153                f.args,
154            )))),
155
156            // COALESCE is native in StarRocks
157            "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc {
158                original_name: None,
159                expressions: f.args,
160            }))),
161
162            // NOW is native in StarRocks
163            "NOW" => Ok(Expression::CurrentTimestamp(
164                crate::expressions::CurrentTimestamp {
165                    precision: None,
166                    sysdate: false,
167                },
168            )),
169
170            // GETDATE -> NOW in StarRocks
171            "GETDATE" => Ok(Expression::CurrentTimestamp(
172                crate::expressions::CurrentTimestamp {
173                    precision: None,
174                    sysdate: false,
175                },
176            )),
177
178            // GROUP_CONCAT is native in StarRocks
179            "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
180
181            // STRING_AGG -> GROUP_CONCAT
182            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
183                Function::new("GROUP_CONCAT".to_string(), f.args),
184            ))),
185
186            // LISTAGG -> GROUP_CONCAT
187            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
188                "GROUP_CONCAT".to_string(),
189                f.args,
190            )))),
191
192            // SUBSTR is native in StarRocks
193            "SUBSTR" => Ok(Expression::Function(Box::new(f))),
194
195            // SUBSTRING is native in StarRocks
196            "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
197
198            // LENGTH is native in StarRocks
199            "LENGTH" => Ok(Expression::Function(Box::new(f))),
200
201            // LEN -> LENGTH
202            "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
203                "LENGTH".to_string(),
204                f.args,
205            )))),
206
207            // CHARINDEX -> INSTR in StarRocks (with swapped args)
208            "CHARINDEX" if f.args.len() >= 2 => {
209                let mut args = f.args;
210                let substring = args.remove(0);
211                let string = args.remove(0);
212                Ok(Expression::Function(Box::new(Function::new(
213                    "INSTR".to_string(),
214                    vec![string, substring],
215                ))))
216            }
217
218            // STRPOS -> INSTR
219            "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
220                "INSTR".to_string(),
221                f.args,
222            )))),
223
224            // DATE_TRUNC is native in StarRocks
225            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
226
227            // ARRAY_AGG is native in StarRocks
228            "ARRAY_AGG" => Ok(Expression::Function(Box::new(f))),
229
230            // COLLECT_LIST -> ARRAY_AGG
231            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
232                Function::new("ARRAY_AGG".to_string(), f.args),
233            ))),
234
235            // ARRAY_JOIN is native in StarRocks
236            "ARRAY_JOIN" => Ok(Expression::Function(Box::new(f))),
237
238            // ARRAY_FLATTEN is native in StarRocks
239            "ARRAY_FLATTEN" => Ok(Expression::Function(Box::new(f))),
240
241            // FLATTEN -> ARRAY_FLATTEN
242            "FLATTEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
243                "ARRAY_FLATTEN".to_string(),
244                f.args,
245            )))),
246
247            // TO_DATE is native in StarRocks
248            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
249
250            // DATE_FORMAT is native in StarRocks
251            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
252
253            // strftime -> DATE_FORMAT
254            "STRFTIME" if f.args.len() >= 2 => {
255                let mut args = f.args;
256                let format = args.remove(0);
257                let date = args.remove(0);
258                Ok(Expression::Function(Box::new(Function::new(
259                    "DATE_FORMAT".to_string(),
260                    vec![date, format],
261                ))))
262            }
263
264            // TO_CHAR -> DATE_FORMAT
265            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
266                "DATE_FORMAT".to_string(),
267                f.args,
268            )))),
269
270            // JSON_EXTRACT -> arrow operator in StarRocks
271            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
272
273            // GET_JSON_OBJECT -> JSON_EXTRACT
274            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
275                Function::new("JSON_EXTRACT".to_string(), f.args),
276            ))),
277
278            // REGEXP is native in StarRocks
279            "REGEXP" => Ok(Expression::Function(Box::new(f))),
280
281            // RLIKE is native in StarRocks
282            "RLIKE" => Ok(Expression::Function(Box::new(f))),
283
284            // REGEXP_LIKE -> REGEXP
285            "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
286                Function::new("REGEXP".to_string(), f.args),
287            ))),
288
289            // ARRAY_INTERSECTION -> ARRAY_INTERSECT
290            "ARRAY_INTERSECTION" => Ok(Expression::Function(Box::new(Function::new(
291                "ARRAY_INTERSECT".to_string(),
292                f.args,
293            )))),
294
295            // ST_MAKEPOINT -> ST_POINT
296            "ST_MAKEPOINT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
297                Function::new("ST_POINT".to_string(), f.args),
298            ))),
299
300            // ST_DISTANCE(a, b) -> ST_DISTANCE_SPHERE(ST_X(a), ST_Y(a), ST_X(b), ST_Y(b))
301            "ST_DISTANCE" if f.args.len() == 2 => {
302                let a = f.args[0].clone();
303                let b = f.args[1].clone();
304                Ok(Expression::Function(Box::new(Function::new(
305                    "ST_DISTANCE_SPHERE".to_string(),
306                    vec![
307                        Expression::Function(Box::new(Function::new(
308                            "ST_X".to_string(),
309                            vec![a.clone()],
310                        ))),
311                        Expression::Function(Box::new(Function::new("ST_Y".to_string(), vec![a]))),
312                        Expression::Function(Box::new(Function::new(
313                            "ST_X".to_string(),
314                            vec![b.clone()],
315                        ))),
316                        Expression::Function(Box::new(Function::new("ST_Y".to_string(), vec![b]))),
317                    ],
318                ))))
319            }
320
321            // Pass through everything else
322            _ => Ok(Expression::Function(Box::new(f))),
323        }
324    }
325
326    fn transform_aggregate_function(
327        &self,
328        f: Box<crate::expressions::AggregateFunction>,
329    ) -> Result<Expression> {
330        let name_upper = f.name.to_uppercase();
331        match name_upper.as_str() {
332            // COUNT_IF -> SUM(CASE WHEN...)
333            "COUNT_IF" if !f.args.is_empty() => {
334                let condition = f.args.into_iter().next().unwrap();
335                let case_expr = Expression::Case(Box::new(Case {
336                    operand: None,
337                    whens: vec![(condition, Expression::number(1))],
338                    else_: Some(Expression::number(0)),
339                    comments: Vec::new(),
340                }));
341                Ok(Expression::Sum(Box::new(AggFunc {
342                    ignore_nulls: None,
343                    having_max: None,
344                    this: case_expr,
345                    distinct: f.distinct,
346                    filter: f.filter,
347                    order_by: Vec::new(),
348                    name: None,
349                    limit: None,
350                })))
351            }
352
353            // APPROX_COUNT_DISTINCT is native in StarRocks
354            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
355
356            // Pass through everything else
357            _ => Ok(Expression::AggregateFunction(f)),
358        }
359    }
360
361    fn transform_cast(&self, c: Cast) -> Result<Expression> {
362        // StarRocks: CAST(x AS TIMESTAMP/TIMESTAMPTZ) -> TIMESTAMP(x) function
363        // Similar to MySQL behavior
364        match &c.to {
365            crate::expressions::DataType::Timestamp { .. } => Ok(Expression::Function(Box::new(
366                Function::new("TIMESTAMP".to_string(), vec![c.this]),
367            ))),
368            crate::expressions::DataType::Custom { name }
369                if name.to_uppercase() == "TIMESTAMPTZ"
370                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
371            {
372                Ok(Expression::Function(Box::new(Function::new(
373                    "TIMESTAMP".to_string(),
374                    vec![c.this],
375                ))))
376            }
377            // StarRocks type mappings are handled in the generator
378            _ => Ok(Expression::Cast(Box::new(c))),
379        }
380    }
381
382    /// Transform LATERAL UNNEST for StarRocks
383    /// StarRocks requires UNNEST to have a default column alias of "unnest" if not specified.
384    /// Python reference: starrocks.py _parse_unnest
385    fn transform_lateral(&self, l: &mut Box<Lateral>) -> Result<()> {
386        // Check if the lateral expression contains UNNEST
387        if let Expression::Unnest(_) = &*l.this {
388            // If there's a table alias but no column aliases, add "unnest" as default column
389            if l.alias.is_some() && l.column_aliases.is_empty() {
390                l.column_aliases.push("unnest".to_string());
391            }
392            // If there's no alias at all, add both table alias "unnest" and column alias "unnest"
393            else if l.alias.is_none() {
394                l.alias = Some("unnest".to_string());
395                l.column_aliases.push("unnest".to_string());
396            }
397        }
398        Ok(())
399    }
400}