Skip to main content

polyglot_sql/dialects/
starrocks.rs

1//! StarRocks Dialect
2//!
3//! StarRocks-specific transformations based on sqlglot patterns.
4//! StarRocks is MySQL-compatible with OLAP extensions (similar to Doris).
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, Lateral, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12/// StarRocks dialect
13pub struct StarRocksDialect;
14
15impl DialectImpl for StarRocksDialect {
16    fn dialect_type(&self) -> DialectType {
17        DialectType::StarRocks
18    }
19
20    fn tokenizer_config(&self) -> TokenizerConfig {
21        use crate::tokens::TokenType;
22        let mut config = TokenizerConfig::default();
23        // StarRocks uses backticks for identifiers (MySQL-style)
24        config.identifiers.insert('`', '`');
25        // Remove double quotes from identifiers (MySQL-style)
26        config.identifiers.remove(&'"');
27        config.quotes.insert("\"".to_string(), "\"".to_string());
28        config.nested_comments = false;
29        // LARGEINT maps to INT128
30        config
31            .keywords
32            .insert("LARGEINT".to_string(), TokenType::Int128);
33        config
34    }
35
36    fn generator_config(&self) -> GeneratorConfig {
37        use crate::generator::IdentifierQuoteStyle;
38        GeneratorConfig {
39            identifier_quote: '`',
40            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
41            dialect: Some(DialectType::StarRocks),
42            // StarRocks: INSERT OVERWRITE (without TABLE keyword)
43            insert_overwrite: " OVERWRITE",
44            // StarRocks: PROPERTIES prefix for WITH properties
45            with_properties_prefix: "PROPERTIES",
46            // StarRocks uses MySQL-style settings
47            null_ordering_supported: false,
48            limit_only_literals: true,
49            semi_anti_join_with_side: false,
50            supports_table_alias_columns: false,
51            values_as_table: false,
52            tablesample_requires_parens: false,
53            tablesample_with_method: false,
54            aggregate_filter_supported: false,
55            try_supported: false,
56            supports_convert_timezone: false,
57            supports_uescape: false,
58            supports_between_flags: false,
59            query_hints: false,
60            parameter_token: "?",
61            supports_window_exclude: false,
62            supports_exploding_projections: false,
63            // StarRocks: COMMENT 'value' (naked property, no = sign)
64            schema_comment_with_eq: false,
65            ..Default::default()
66        }
67    }
68
69    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
70        match expr {
71            // IFNULL is native in StarRocks (MySQL-style)
72            Expression::IfNull(f) => Ok(Expression::IfNull(f)),
73
74            // NVL -> IFNULL in StarRocks
75            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
76
77            // TryCast -> not directly supported, use CAST
78            Expression::TryCast(c) => Ok(Expression::Cast(c)),
79
80            // SafeCast -> CAST in StarRocks
81            Expression::SafeCast(c) => Ok(Expression::Cast(c)),
82
83            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
84            Expression::CountIf(f) => {
85                let case_expr = Expression::Case(Box::new(Case {
86                    operand: None,
87                    whens: vec![(f.this.clone(), Expression::number(1))],
88                    else_: Some(Expression::number(0)),
89                    comments: Vec::new(),
90                    inferred_type: None,
91                }));
92                Ok(Expression::Sum(Box::new(AggFunc {
93                    ignore_nulls: None,
94                    having_max: None,
95                    this: case_expr,
96                    distinct: f.distinct,
97                    filter: f.filter,
98                    order_by: Vec::new(),
99                    name: None,
100                    limit: None,
101                    inferred_type: None,
102                })))
103            }
104
105            // RAND is native in StarRocks
106            Expression::Rand(r) => Ok(Expression::Rand(r)),
107
108            // JSON arrow syntax: preserve -> for StarRocks (arrow_json_extract_sql)
109            Expression::JsonExtract(mut f) => {
110                // Set arrow_syntax to true to preserve -> operator
111                f.arrow_syntax = true;
112                Ok(Expression::JsonExtract(f))
113            }
114
115            Expression::JsonExtractScalar(mut f) => {
116                // Set arrow_syntax to true to preserve ->> operator
117                f.arrow_syntax = true;
118                Ok(Expression::JsonExtractScalar(f))
119            }
120
121            // Generic function transformations
122            Expression::Function(f) => self.transform_function(*f),
123
124            // Generic aggregate function transformations
125            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
126
127            // Cast transformations
128            Expression::Cast(c) => self.transform_cast(*c),
129
130            // Handle LATERAL UNNEST - StarRocks requires column alias "unnest" by default
131            Expression::Lateral(mut l) => {
132                self.transform_lateral(&mut l)?;
133                Ok(Expression::Lateral(l))
134            }
135
136            // Pass through everything else
137            _ => Ok(expr),
138        }
139    }
140}
141
142impl StarRocksDialect {
143    fn transform_function(&self, f: Function) -> Result<Expression> {
144        let name_upper = f.name.to_uppercase();
145        match name_upper.as_str() {
146            // NVL -> IFNULL
147            "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
148                "IFNULL".to_string(),
149                f.args,
150            )))),
151
152            // ISNULL -> IFNULL
153            "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
154                "IFNULL".to_string(),
155                f.args,
156            )))),
157
158            // COALESCE is native in StarRocks
159            "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc {
160                original_name: None,
161                expressions: f.args,
162                inferred_type: None,
163            }))),
164
165            // NOW is native in StarRocks
166            "NOW" => Ok(Expression::CurrentTimestamp(
167                crate::expressions::CurrentTimestamp {
168                    precision: None,
169                    sysdate: false,
170                },
171            )),
172
173            // GETDATE -> NOW in StarRocks
174            "GETDATE" => Ok(Expression::CurrentTimestamp(
175                crate::expressions::CurrentTimestamp {
176                    precision: None,
177                    sysdate: false,
178                },
179            )),
180
181            // GROUP_CONCAT is native in StarRocks
182            "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
183
184            // STRING_AGG -> GROUP_CONCAT
185            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
186                Function::new("GROUP_CONCAT".to_string(), f.args),
187            ))),
188
189            // LISTAGG -> GROUP_CONCAT
190            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
191                "GROUP_CONCAT".to_string(),
192                f.args,
193            )))),
194
195            // SUBSTR is native in StarRocks
196            "SUBSTR" => Ok(Expression::Function(Box::new(f))),
197
198            // SUBSTRING is native in StarRocks
199            "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
200
201            // LENGTH is native in StarRocks
202            "LENGTH" => Ok(Expression::Function(Box::new(f))),
203
204            // LEN -> LENGTH
205            "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
206                "LENGTH".to_string(),
207                f.args,
208            )))),
209
210            // CHARINDEX -> INSTR in StarRocks (with swapped args)
211            "CHARINDEX" if f.args.len() >= 2 => {
212                let mut args = f.args;
213                let substring = args.remove(0);
214                let string = args.remove(0);
215                Ok(Expression::Function(Box::new(Function::new(
216                    "INSTR".to_string(),
217                    vec![string, substring],
218                ))))
219            }
220
221            // STRPOS -> INSTR
222            "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
223                "INSTR".to_string(),
224                f.args,
225            )))),
226
227            // DATE_TRUNC is native in StarRocks
228            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
229
230            // ARRAY_AGG is native in StarRocks
231            "ARRAY_AGG" => Ok(Expression::Function(Box::new(f))),
232
233            // COLLECT_LIST -> ARRAY_AGG
234            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
235                Function::new("ARRAY_AGG".to_string(), f.args),
236            ))),
237
238            // ARRAY_JOIN is native in StarRocks
239            "ARRAY_JOIN" => Ok(Expression::Function(Box::new(f))),
240
241            // ARRAY_FLATTEN is native in StarRocks
242            "ARRAY_FLATTEN" => Ok(Expression::Function(Box::new(f))),
243
244            // FLATTEN -> ARRAY_FLATTEN
245            "FLATTEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
246                "ARRAY_FLATTEN".to_string(),
247                f.args,
248            )))),
249
250            // TO_DATE is native in StarRocks
251            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
252
253            // DATE_FORMAT is native in StarRocks
254            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
255
256            // strftime -> DATE_FORMAT
257            "STRFTIME" if f.args.len() >= 2 => {
258                let mut args = f.args;
259                let format = args.remove(0);
260                let date = args.remove(0);
261                Ok(Expression::Function(Box::new(Function::new(
262                    "DATE_FORMAT".to_string(),
263                    vec![date, format],
264                ))))
265            }
266
267            // TO_CHAR -> DATE_FORMAT
268            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
269                "DATE_FORMAT".to_string(),
270                f.args,
271            )))),
272
273            // JSON_EXTRACT -> arrow operator in StarRocks
274            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
275
276            // GET_JSON_OBJECT -> JSON_EXTRACT
277            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
278                Function::new("JSON_EXTRACT".to_string(), f.args),
279            ))),
280
281            // REGEXP is native in StarRocks
282            "REGEXP" => Ok(Expression::Function(Box::new(f))),
283
284            // RLIKE is native in StarRocks
285            "RLIKE" => Ok(Expression::Function(Box::new(f))),
286
287            // REGEXP_LIKE -> REGEXP
288            "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
289                Function::new("REGEXP".to_string(), f.args),
290            ))),
291
292            // ARRAY_INTERSECTION -> ARRAY_INTERSECT
293            "ARRAY_INTERSECTION" => Ok(Expression::Function(Box::new(Function::new(
294                "ARRAY_INTERSECT".to_string(),
295                f.args,
296            )))),
297
298            // ST_MAKEPOINT -> ST_POINT
299            "ST_MAKEPOINT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
300                Function::new("ST_POINT".to_string(), f.args),
301            ))),
302
303            // ST_DISTANCE(a, b) -> ST_DISTANCE_SPHERE(ST_X(a), ST_Y(a), ST_X(b), ST_Y(b))
304            "ST_DISTANCE" if f.args.len() == 2 => {
305                let a = f.args[0].clone();
306                let b = f.args[1].clone();
307                Ok(Expression::Function(Box::new(Function::new(
308                    "ST_DISTANCE_SPHERE".to_string(),
309                    vec![
310                        Expression::Function(Box::new(Function::new(
311                            "ST_X".to_string(),
312                            vec![a.clone()],
313                        ))),
314                        Expression::Function(Box::new(Function::new("ST_Y".to_string(), vec![a]))),
315                        Expression::Function(Box::new(Function::new(
316                            "ST_X".to_string(),
317                            vec![b.clone()],
318                        ))),
319                        Expression::Function(Box::new(Function::new("ST_Y".to_string(), vec![b]))),
320                    ],
321                ))))
322            }
323
324            // Pass through everything else
325            _ => Ok(Expression::Function(Box::new(f))),
326        }
327    }
328
329    fn transform_aggregate_function(
330        &self,
331        f: Box<crate::expressions::AggregateFunction>,
332    ) -> Result<Expression> {
333        let name_upper = f.name.to_uppercase();
334        match name_upper.as_str() {
335            // COUNT_IF -> SUM(CASE WHEN...)
336            "COUNT_IF" if !f.args.is_empty() => {
337                let condition = f.args.into_iter().next().unwrap();
338                let case_expr = Expression::Case(Box::new(Case {
339                    operand: None,
340                    whens: vec![(condition, Expression::number(1))],
341                    else_: Some(Expression::number(0)),
342                    comments: Vec::new(),
343                    inferred_type: None,
344                }));
345                Ok(Expression::Sum(Box::new(AggFunc {
346                    ignore_nulls: None,
347                    having_max: None,
348                    this: case_expr,
349                    distinct: f.distinct,
350                    filter: f.filter,
351                    order_by: Vec::new(),
352                    name: None,
353                    limit: None,
354                    inferred_type: None,
355                })))
356            }
357
358            // APPROX_COUNT_DISTINCT is native in StarRocks
359            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
360
361            // Pass through everything else
362            _ => Ok(Expression::AggregateFunction(f)),
363        }
364    }
365
366    fn transform_cast(&self, c: Cast) -> Result<Expression> {
367        // StarRocks: CAST(x AS TIMESTAMP/TIMESTAMPTZ) -> TIMESTAMP(x) function
368        // Similar to MySQL behavior
369        match &c.to {
370            crate::expressions::DataType::Timestamp { .. } => Ok(Expression::Function(Box::new(
371                Function::new("TIMESTAMP".to_string(), vec![c.this]),
372            ))),
373            crate::expressions::DataType::Custom { name }
374                if name.to_uppercase() == "TIMESTAMPTZ"
375                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
376            {
377                Ok(Expression::Function(Box::new(Function::new(
378                    "TIMESTAMP".to_string(),
379                    vec![c.this],
380                ))))
381            }
382            // StarRocks type mappings are handled in the generator
383            _ => Ok(Expression::Cast(Box::new(c))),
384        }
385    }
386
387    /// Transform LATERAL UNNEST for StarRocks
388    /// StarRocks requires UNNEST to have a default column alias of "unnest" if not specified.
389    /// Python reference: starrocks.py _parse_unnest
390    fn transform_lateral(&self, l: &mut Box<Lateral>) -> Result<()> {
391        // Check if the lateral expression contains UNNEST
392        if let Expression::Unnest(_) = &*l.this {
393            // If there's a table alias but no column aliases, add "unnest" as default column
394            if l.alias.is_some() && l.column_aliases.is_empty() {
395                l.column_aliases.push("unnest".to_string());
396            }
397            // If there's no alias at all, add both table alias "unnest" and column alias "unnest"
398            else if l.alias.is_none() {
399                l.alias = Some("unnest".to_string());
400                l.column_aliases.push("unnest".to_string());
401            }
402        }
403        Ok(())
404    }
405}