Skip to main content

polyglot_sql/dialects/
trino.rs

1//! Trino Dialect
2//!
3//! Trino-specific transformations based on sqlglot patterns.
4//! Trino is largely compatible with Presto but has some differences.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{
9    AggFunc, AggregateFunction, Case, Cast, DataType, Expression, Function, IntervalUnit,
10    IntervalUnitSpec, LikeOp, Literal, UnaryFunc, VarArgFunc,
11};
12use crate::generator::GeneratorConfig;
13use crate::tokens::TokenizerConfig;
14
15/// Trino dialect
16pub struct TrinoDialect;
17
18impl DialectImpl for TrinoDialect {
19    fn dialect_type(&self) -> DialectType {
20        DialectType::Trino
21    }
22
23    fn tokenizer_config(&self) -> TokenizerConfig {
24        let mut config = TokenizerConfig::default();
25        // Trino uses double quotes for identifiers
26        config.identifiers.insert('"', '"');
27        // Trino does NOT support nested comments
28        config.nested_comments = false;
29        // Trino does NOT support QUALIFY - it's a valid identifier
30        // (unlike Snowflake, BigQuery, DuckDB which have QUALIFY clause)
31        config.keywords.remove("QUALIFY");
32        config
33    }
34
35    fn generator_config(&self) -> GeneratorConfig {
36        use crate::generator::IdentifierQuoteStyle;
37        GeneratorConfig {
38            identifier_quote: '"',
39            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
40            dialect: Some(DialectType::Trino),
41            limit_only_literals: true,
42            tz_to_with_time_zone: true,
43            ..Default::default()
44        }
45    }
46
47    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
48        match expr {
49            // IFNULL -> COALESCE in Trino
50            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
51                original_name: None,
52                expressions: vec![f.this, f.expression],
53                inferred_type: None,
54            }))),
55
56            // NVL -> COALESCE in Trino
57            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
58                original_name: None,
59                expressions: vec![f.this, f.expression],
60                inferred_type: None,
61            }))),
62
63            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
64            Expression::Coalesce(mut f) => {
65                f.original_name = None;
66                Ok(Expression::Coalesce(f))
67            }
68
69            // TryCast stays as TryCast (Trino supports TRY_CAST)
70            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
71
72            // SafeCast -> TRY_CAST in Trino
73            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
74
75            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
76            Expression::ILike(op) => {
77                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
78                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
79                Ok(Expression::Like(Box::new(LikeOp {
80                    left: lower_left,
81                    right: lower_right,
82                    escape: op.escape,
83                    quantifier: op.quantifier.clone(),
84                    inferred_type: None,
85                })))
86            }
87
88            // CountIf is native in Trino (keep as-is)
89            Expression::CountIf(f) => Ok(Expression::CountIf(f)),
90
91            // EXPLODE -> UNNEST in Trino
92            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
93                crate::expressions::UnnestFunc {
94                    this: f.this,
95                    expressions: Vec::new(),
96                    with_ordinality: false,
97                    alias: None,
98                    offset_alias: None,
99                },
100            ))),
101
102            // ExplodeOuter -> UNNEST in Trino
103            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
104                crate::expressions::UnnestFunc {
105                    this: f.this,
106                    expressions: Vec::new(),
107                    with_ordinality: false,
108                    alias: None,
109                    offset_alias: None,
110                },
111            ))),
112
113            // Generic function transformations
114            Expression::Function(f) => self.transform_function(*f),
115
116            // Generic aggregate function transformations
117            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
118
119            // Cast transformations
120            Expression::Cast(c) => self.transform_cast(*c),
121
122            // TRIM: Convert comma syntax TRIM(str, chars) to SQL standard TRIM(chars FROM str)
123            // Trino requires SQL standard syntax for TRIM with characters
124            Expression::Trim(mut f) => {
125                if !f.sql_standard_syntax && f.characters.is_some() {
126                    // Convert from TRIM(str, chars) to TRIM(chars FROM str)
127                    f.sql_standard_syntax = true;
128                }
129                Ok(Expression::Trim(f))
130            }
131
132            // LISTAGG: Add default separator ',' if none is specified (Trino style)
133            Expression::ListAgg(mut f) => {
134                if f.separator.is_none() {
135                    f.separator = Some(Expression::Literal(Literal::String(",".to_string())));
136                }
137                Ok(Expression::ListAgg(f))
138            }
139
140            // Interval: Split compound string intervals like INTERVAL '1 day' into INTERVAL '1' DAY
141            Expression::Interval(mut interval) => {
142                if interval.unit.is_none() {
143                    if let Some(Expression::Literal(Literal::String(ref s))) = interval.this {
144                        if let Some((value, unit)) = Self::parse_compound_interval(s) {
145                            interval.this = Some(Expression::Literal(Literal::String(value)));
146                            interval.unit = Some(unit);
147                        }
148                    }
149                }
150                Ok(Expression::Interval(interval))
151            }
152
153            // Pass through everything else
154            _ => Ok(expr),
155        }
156    }
157}
158
159impl TrinoDialect {
160    /// Parse a compound interval string like "1 day" into (value, unit_spec).
161    /// Returns None if the string doesn't match a known pattern.
162    fn parse_compound_interval(s: &str) -> Option<(String, IntervalUnitSpec)> {
163        let s = s.trim();
164        let parts: Vec<&str> = s.split_whitespace().collect();
165        if parts.len() != 2 {
166            return None;
167        }
168        let value = parts[0].to_string();
169        let unit = match parts[1].to_uppercase().as_str() {
170            "YEAR" | "YEARS" => IntervalUnit::Year,
171            "MONTH" | "MONTHS" => IntervalUnit::Month,
172            "DAY" | "DAYS" => IntervalUnit::Day,
173            "HOUR" | "HOURS" => IntervalUnit::Hour,
174            "MINUTE" | "MINUTES" => IntervalUnit::Minute,
175            "SECOND" | "SECONDS" => IntervalUnit::Second,
176            "MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
177            "MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
178            _ => return None,
179        };
180        Some((
181            value,
182            IntervalUnitSpec::Simple {
183                unit,
184                use_plural: false,
185            },
186        ))
187    }
188
189    fn transform_function(&self, f: Function) -> Result<Expression> {
190        let name_upper = f.name.to_uppercase();
191        match name_upper.as_str() {
192            // IFNULL -> COALESCE
193            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
194                original_name: None,
195                expressions: f.args,
196                inferred_type: None,
197            }))),
198
199            // NVL -> COALESCE
200            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
201                original_name: None,
202                expressions: f.args,
203                inferred_type: None,
204            }))),
205
206            // ISNULL -> COALESCE
207            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
208                original_name: None,
209                expressions: f.args,
210                inferred_type: None,
211            }))),
212
213            // GETDATE -> CURRENT_TIMESTAMP
214            "GETDATE" => Ok(Expression::CurrentTimestamp(
215                crate::expressions::CurrentTimestamp {
216                    precision: None,
217                    sysdate: false,
218                },
219            )),
220
221            // NOW -> CURRENT_TIMESTAMP
222            "NOW" => Ok(Expression::CurrentTimestamp(
223                crate::expressions::CurrentTimestamp {
224                    precision: None,
225                    sysdate: false,
226                },
227            )),
228
229            // RAND -> RANDOM in Trino
230            "RAND" => Ok(Expression::Function(Box::new(Function::new(
231                "RANDOM".to_string(),
232                vec![],
233            )))),
234
235            // GROUP_CONCAT -> LISTAGG in Trino (Trino supports LISTAGG)
236            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
237                Function::new("LISTAGG".to_string(), f.args),
238            ))),
239
240            // STRING_AGG -> LISTAGG in Trino
241            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
242                Function::new("LISTAGG".to_string(), f.args),
243            ))),
244
245            // LISTAGG is native in Trino
246            "LISTAGG" => Ok(Expression::Function(Box::new(f))),
247
248            // SUBSTR -> SUBSTRING
249            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
250                "SUBSTRING".to_string(),
251                f.args,
252            )))),
253
254            // LEN -> LENGTH
255            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
256                f.args.into_iter().next().unwrap(),
257            )))),
258
259            // CHARINDEX -> STRPOS in Trino (with swapped args)
260            "CHARINDEX" if f.args.len() >= 2 => {
261                let mut args = f.args;
262                let substring = args.remove(0);
263                let string = args.remove(0);
264                Ok(Expression::Function(Box::new(Function::new(
265                    "STRPOS".to_string(),
266                    vec![string, substring],
267                ))))
268            }
269
270            // INSTR -> STRPOS
271            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
272                "STRPOS".to_string(),
273                f.args,
274            )))),
275
276            // LOCATE -> STRPOS in Trino (with swapped args)
277            "LOCATE" if f.args.len() >= 2 => {
278                let mut args = f.args;
279                let substring = args.remove(0);
280                let string = args.remove(0);
281                Ok(Expression::Function(Box::new(Function::new(
282                    "STRPOS".to_string(),
283                    vec![string, substring],
284                ))))
285            }
286
287            // ARRAY_LENGTH -> CARDINALITY in Trino
288            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
289                Function::new("CARDINALITY".to_string(), f.args),
290            ))),
291
292            // SIZE -> CARDINALITY in Trino
293            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
294                "CARDINALITY".to_string(),
295                f.args,
296            )))),
297
298            // ARRAY_CONTAINS -> CONTAINS in Trino
299            "ARRAY_CONTAINS" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
300                Function::new("CONTAINS".to_string(), f.args),
301            ))),
302
303            // TO_DATE -> CAST to DATE or DATE_PARSE
304            "TO_DATE" if !f.args.is_empty() => {
305                if f.args.len() == 1 {
306                    Ok(Expression::Cast(Box::new(Cast {
307                        this: f.args.into_iter().next().unwrap(),
308                        to: DataType::Date,
309                        trailing_comments: Vec::new(),
310                        double_colon_syntax: false,
311                        format: None,
312                        default: None,
313                        inferred_type: None,
314                    })))
315                } else {
316                    Ok(Expression::Function(Box::new(Function::new(
317                        "DATE_PARSE".to_string(),
318                        f.args,
319                    ))))
320                }
321            }
322
323            // TO_TIMESTAMP -> CAST or DATE_PARSE
324            "TO_TIMESTAMP" if !f.args.is_empty() => {
325                if f.args.len() == 1 {
326                    Ok(Expression::Cast(Box::new(Cast {
327                        this: f.args.into_iter().next().unwrap(),
328                        to: DataType::Timestamp {
329                            precision: None,
330                            timezone: false,
331                        },
332                        trailing_comments: Vec::new(),
333                        double_colon_syntax: false,
334                        format: None,
335                        default: None,
336                        inferred_type: None,
337                    })))
338                } else {
339                    Ok(Expression::Function(Box::new(Function::new(
340                        "DATE_PARSE".to_string(),
341                        f.args,
342                    ))))
343                }
344            }
345
346            // strftime -> DATE_FORMAT in Trino
347            "STRFTIME" if f.args.len() >= 2 => {
348                let mut args = f.args;
349                let format = args.remove(0);
350                let date = args.remove(0);
351                Ok(Expression::Function(Box::new(Function::new(
352                    "DATE_FORMAT".to_string(),
353                    vec![date, format],
354                ))))
355            }
356
357            // TO_CHAR -> DATE_FORMAT in Trino
358            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
359                "DATE_FORMAT".to_string(),
360                f.args,
361            )))),
362
363            // LEVENSHTEIN -> LEVENSHTEIN_DISTANCE in Trino
364            "LEVENSHTEIN" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
365                Function::new("LEVENSHTEIN_DISTANCE".to_string(), f.args),
366            ))),
367
368            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Trino
369            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
370                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
371            ))),
372
373            // COLLECT_LIST -> ARRAY_AGG
374            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
375                Function::new("ARRAY_AGG".to_string(), f.args),
376            ))),
377
378            // COLLECT_SET -> ARRAY_DISTINCT(ARRAY_AGG())
379            "COLLECT_SET" if !f.args.is_empty() => {
380                let array_agg =
381                    Expression::Function(Box::new(Function::new("ARRAY_AGG".to_string(), f.args)));
382                Ok(Expression::Function(Box::new(Function::new(
383                    "ARRAY_DISTINCT".to_string(),
384                    vec![array_agg],
385                ))))
386            }
387
388            // RLIKE -> REGEXP_LIKE in Trino
389            "RLIKE" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
390                "REGEXP_LIKE".to_string(),
391                f.args,
392            )))),
393
394            // REGEXP -> REGEXP_LIKE in Trino
395            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
396                "REGEXP_LIKE".to_string(),
397                f.args,
398            )))),
399
400            // ARRAY_SUM -> REDUCE in Trino (complex transformation)
401            // For simplicity, we'll use a different approach
402            "ARRAY_SUM" if f.args.len() == 1 => {
403                // This is a complex transformation in Presto/Trino
404                // ARRAY_SUM(arr) -> REDUCE(arr, 0, (s, x) -> s + x, s -> s)
405                // For now, pass through and let user handle it
406                Ok(Expression::Function(Box::new(f)))
407            }
408
409            // Pass through everything else
410            _ => Ok(Expression::Function(Box::new(f))),
411        }
412    }
413
414    fn transform_aggregate_function(
415        &self,
416        f: Box<crate::expressions::AggregateFunction>,
417    ) -> Result<Expression> {
418        let name_upper = f.name.to_uppercase();
419        match name_upper.as_str() {
420            // COUNT_IF -> SUM(CASE WHEN...)
421            "COUNT_IF" if !f.args.is_empty() => {
422                let condition = f.args.into_iter().next().unwrap();
423                let case_expr = Expression::Case(Box::new(Case {
424                    operand: None,
425                    whens: vec![(condition, Expression::number(1))],
426                    else_: Some(Expression::number(0)),
427                    comments: Vec::new(),
428                    inferred_type: None,
429                }));
430                Ok(Expression::Sum(Box::new(AggFunc {
431                    ignore_nulls: None,
432                    having_max: None,
433                    this: case_expr,
434                    distinct: f.distinct,
435                    filter: f.filter,
436                    order_by: Vec::new(),
437                    name: None,
438                    limit: None,
439                    inferred_type: None,
440                })))
441            }
442
443            // ANY_VALUE -> ARBITRARY in Trino
444            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
445                "ARBITRARY".to_string(),
446                f.args,
447            )))),
448
449            // GROUP_CONCAT -> LISTAGG in Trino
450            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
451                Function::new("LISTAGG".to_string(), f.args),
452            ))),
453
454            // STRING_AGG -> LISTAGG in Trino
455            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
456                Function::new("LISTAGG".to_string(), f.args),
457            ))),
458
459            // VAR -> VAR_POP in Trino
460            "VAR" if !f.args.is_empty() => {
461                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
462                    name: "VAR_POP".to_string(),
463                    args: f.args,
464                    distinct: f.distinct,
465                    filter: f.filter,
466                    order_by: Vec::new(),
467                    limit: None,
468                    ignore_nulls: None,
469                    inferred_type: None,
470                })))
471            }
472
473            // VARIANCE -> VAR_SAMP in Trino
474            "VARIANCE" if !f.args.is_empty() => {
475                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
476                    name: "VAR_SAMP".to_string(),
477                    args: f.args,
478                    distinct: f.distinct,
479                    filter: f.filter,
480                    order_by: Vec::new(),
481                    limit: None,
482                    ignore_nulls: None,
483                    inferred_type: None,
484                })))
485            }
486
487            // Pass through everything else
488            _ => Ok(Expression::AggregateFunction(f)),
489        }
490    }
491
492    fn transform_cast(&self, c: Cast) -> Result<Expression> {
493        // Trino type mappings are handled in the generator
494        Ok(Expression::Cast(Box::new(c)))
495    }
496}