Skip to main content

polyglot_sql/dialects/
trino.rs

1//! Trino Dialect
2//!
3//! Trino-specific transformations based on sqlglot patterns.
4//! Trino is largely compatible with Presto but has some differences.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{
9    AggFunc, AggregateFunction, Case, Cast, DataType, Expression, Function, IntervalUnit,
10    IntervalUnitSpec, LikeOp, Literal, UnaryFunc, VarArgFunc,
11};
12use crate::generator::GeneratorConfig;
13use crate::tokens::TokenizerConfig;
14
15/// Trino dialect
16pub struct TrinoDialect;
17
18impl DialectImpl for TrinoDialect {
19    fn dialect_type(&self) -> DialectType {
20        DialectType::Trino
21    }
22
23    fn tokenizer_config(&self) -> TokenizerConfig {
24        let mut config = TokenizerConfig::default();
25        // Trino uses double quotes for identifiers
26        config.identifiers.insert('"', '"');
27        // Trino does NOT support nested comments
28        config.nested_comments = false;
29        // Trino does NOT support QUALIFY - it's a valid identifier
30        // (unlike Snowflake, BigQuery, DuckDB which have QUALIFY clause)
31        config.keywords.remove("QUALIFY");
32        config
33    }
34
35    fn generator_config(&self) -> GeneratorConfig {
36        use crate::generator::IdentifierQuoteStyle;
37        GeneratorConfig {
38            identifier_quote: '"',
39            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
40            dialect: Some(DialectType::Trino),
41            limit_only_literals: true,
42            tz_to_with_time_zone: true,
43            ..Default::default()
44        }
45    }
46
47    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
48        match expr {
49            // IFNULL -> COALESCE in Trino
50            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
51                original_name: None,
52                expressions: vec![f.this, f.expression],
53                inferred_type: None,
54            }))),
55
56            // NVL -> COALESCE in Trino
57            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
58                original_name: None,
59                expressions: vec![f.this, f.expression],
60                inferred_type: None,
61            }))),
62
63            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
64            Expression::Coalesce(mut f) => {
65                f.original_name = None;
66                Ok(Expression::Coalesce(f))
67            }
68
69            // TryCast stays as TryCast (Trino supports TRY_CAST)
70            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
71
72            // SafeCast -> TRY_CAST in Trino
73            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
74
75            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
76            Expression::ILike(op) => {
77                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
78                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
79                Ok(Expression::Like(Box::new(LikeOp {
80                    left: lower_left,
81                    right: lower_right,
82                    escape: op.escape,
83                    quantifier: op.quantifier.clone(),
84                    inferred_type: None,
85                })))
86            }
87
88            // CountIf is native in Trino (keep as-is)
89            Expression::CountIf(f) => Ok(Expression::CountIf(f)),
90
91            // EXPLODE -> UNNEST in Trino
92            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
93                crate::expressions::UnnestFunc {
94                    this: f.this,
95                    expressions: Vec::new(),
96                    with_ordinality: false,
97                    alias: None,
98                    offset_alias: None,
99                },
100            ))),
101
102            // ExplodeOuter -> UNNEST in Trino
103            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
104                crate::expressions::UnnestFunc {
105                    this: f.this,
106                    expressions: Vec::new(),
107                    with_ordinality: false,
108                    alias: None,
109                    offset_alias: None,
110                },
111            ))),
112
113            // Generic function transformations
114            Expression::Function(f) => self.transform_function(*f),
115
116            // Generic aggregate function transformations
117            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
118
119            // Cast transformations
120            Expression::Cast(c) => self.transform_cast(*c),
121
122            // TRIM: Convert comma syntax TRIM(str, chars) to SQL standard TRIM(chars FROM str)
123            // Trino requires SQL standard syntax for TRIM with characters
124            Expression::Trim(mut f) => {
125                if !f.sql_standard_syntax && f.characters.is_some() {
126                    // Convert from TRIM(str, chars) to TRIM(chars FROM str)
127                    f.sql_standard_syntax = true;
128                }
129                Ok(Expression::Trim(f))
130            }
131
132            // LISTAGG: Add default separator ',' if none is specified (Trino style)
133            Expression::ListAgg(mut f) => {
134                if f.separator.is_none() {
135                    f.separator = Some(Expression::Literal(Box::new(Literal::String(
136                        ",".to_string(),
137                    ))));
138                }
139                Ok(Expression::ListAgg(f))
140            }
141
142            // Interval: Split compound string intervals like INTERVAL '1 day' into INTERVAL '1' DAY
143            Expression::Interval(mut interval) => {
144                if interval.unit.is_none() {
145                    if let Some(Expression::Literal(ref lit)) = interval.this {
146                        if let Literal::String(ref s) = lit.as_ref() {
147                            if let Some((value, unit)) = Self::parse_compound_interval(s) {
148                                interval.this =
149                                    Some(Expression::Literal(Box::new(Literal::String(value))));
150                                interval.unit = Some(unit);
151                            }
152                        }
153                    }
154                }
155                Ok(Expression::Interval(interval))
156            }
157
158            // Pass through everything else
159            _ => Ok(expr),
160        }
161    }
162}
163
164impl TrinoDialect {
165    /// Parse a compound interval string like "1 day" into (value, unit_spec).
166    /// Returns None if the string doesn't match a known pattern.
167    fn parse_compound_interval(s: &str) -> Option<(String, IntervalUnitSpec)> {
168        let s = s.trim();
169        let parts: Vec<&str> = s.split_whitespace().collect();
170        if parts.len() != 2 {
171            return None;
172        }
173        let value = parts[0].to_string();
174        let unit = match parts[1].to_uppercase().as_str() {
175            "YEAR" | "YEARS" => IntervalUnit::Year,
176            "MONTH" | "MONTHS" => IntervalUnit::Month,
177            "DAY" | "DAYS" => IntervalUnit::Day,
178            "HOUR" | "HOURS" => IntervalUnit::Hour,
179            "MINUTE" | "MINUTES" => IntervalUnit::Minute,
180            "SECOND" | "SECONDS" => IntervalUnit::Second,
181            "MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
182            "MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
183            _ => return None,
184        };
185        Some((
186            value,
187            IntervalUnitSpec::Simple {
188                unit,
189                use_plural: false,
190            },
191        ))
192    }
193
194    fn transform_function(&self, f: Function) -> Result<Expression> {
195        let name_upper = f.name.to_uppercase();
196        match name_upper.as_str() {
197            // IFNULL -> COALESCE
198            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
199                original_name: None,
200                expressions: f.args,
201                inferred_type: None,
202            }))),
203
204            // NVL -> COALESCE
205            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
206                original_name: None,
207                expressions: f.args,
208                inferred_type: None,
209            }))),
210
211            // ISNULL -> COALESCE
212            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
213                original_name: None,
214                expressions: f.args,
215                inferred_type: None,
216            }))),
217
218            // GETDATE -> CURRENT_TIMESTAMP
219            "GETDATE" => Ok(Expression::CurrentTimestamp(
220                crate::expressions::CurrentTimestamp {
221                    precision: None,
222                    sysdate: false,
223                },
224            )),
225
226            // NOW -> CURRENT_TIMESTAMP
227            "NOW" => Ok(Expression::CurrentTimestamp(
228                crate::expressions::CurrentTimestamp {
229                    precision: None,
230                    sysdate: false,
231                },
232            )),
233
234            // RAND -> RANDOM in Trino
235            "RAND" => Ok(Expression::Function(Box::new(Function::new(
236                "RANDOM".to_string(),
237                vec![],
238            )))),
239
240            // GROUP_CONCAT -> LISTAGG in Trino (Trino supports LISTAGG)
241            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
242                Function::new("LISTAGG".to_string(), f.args),
243            ))),
244
245            // STRING_AGG -> LISTAGG in Trino
246            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
247                Function::new("LISTAGG".to_string(), f.args),
248            ))),
249
250            // LISTAGG is native in Trino
251            "LISTAGG" => Ok(Expression::Function(Box::new(f))),
252
253            // SUBSTR -> SUBSTRING
254            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
255                "SUBSTRING".to_string(),
256                f.args,
257            )))),
258
259            // LEN -> LENGTH
260            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
261                f.args.into_iter().next().unwrap(),
262            )))),
263
264            // CHARINDEX -> STRPOS in Trino (with swapped args)
265            "CHARINDEX" if f.args.len() >= 2 => {
266                let mut args = f.args;
267                let substring = args.remove(0);
268                let string = args.remove(0);
269                Ok(Expression::Function(Box::new(Function::new(
270                    "STRPOS".to_string(),
271                    vec![string, substring],
272                ))))
273            }
274
275            // INSTR -> STRPOS
276            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
277                "STRPOS".to_string(),
278                f.args,
279            )))),
280
281            // LOCATE -> STRPOS in Trino (with swapped args)
282            "LOCATE" if f.args.len() >= 2 => {
283                let mut args = f.args;
284                let substring = args.remove(0);
285                let string = args.remove(0);
286                Ok(Expression::Function(Box::new(Function::new(
287                    "STRPOS".to_string(),
288                    vec![string, substring],
289                ))))
290            }
291
292            // ARRAY_LENGTH -> CARDINALITY in Trino
293            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
294                Function::new("CARDINALITY".to_string(), f.args),
295            ))),
296
297            // SIZE -> CARDINALITY in Trino
298            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
299                "CARDINALITY".to_string(),
300                f.args,
301            )))),
302
303            // ARRAY_CONTAINS -> CONTAINS in Trino
304            "ARRAY_CONTAINS" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
305                Function::new("CONTAINS".to_string(), f.args),
306            ))),
307
308            // TO_DATE -> CAST to DATE or DATE_PARSE
309            "TO_DATE" if !f.args.is_empty() => {
310                if f.args.len() == 1 {
311                    Ok(Expression::Cast(Box::new(Cast {
312                        this: f.args.into_iter().next().unwrap(),
313                        to: DataType::Date,
314                        trailing_comments: Vec::new(),
315                        double_colon_syntax: false,
316                        format: None,
317                        default: None,
318                        inferred_type: None,
319                    })))
320                } else {
321                    Ok(Expression::Function(Box::new(Function::new(
322                        "DATE_PARSE".to_string(),
323                        f.args,
324                    ))))
325                }
326            }
327
328            // TO_TIMESTAMP -> CAST or DATE_PARSE
329            "TO_TIMESTAMP" if !f.args.is_empty() => {
330                if f.args.len() == 1 {
331                    Ok(Expression::Cast(Box::new(Cast {
332                        this: f.args.into_iter().next().unwrap(),
333                        to: DataType::Timestamp {
334                            precision: None,
335                            timezone: false,
336                        },
337                        trailing_comments: Vec::new(),
338                        double_colon_syntax: false,
339                        format: None,
340                        default: None,
341                        inferred_type: None,
342                    })))
343                } else {
344                    Ok(Expression::Function(Box::new(Function::new(
345                        "DATE_PARSE".to_string(),
346                        f.args,
347                    ))))
348                }
349            }
350
351            // strftime -> DATE_FORMAT in Trino
352            "STRFTIME" if f.args.len() >= 2 => {
353                let mut args = f.args;
354                let format = args.remove(0);
355                let date = args.remove(0);
356                Ok(Expression::Function(Box::new(Function::new(
357                    "DATE_FORMAT".to_string(),
358                    vec![date, format],
359                ))))
360            }
361
362            // TO_CHAR -> DATE_FORMAT in Trino
363            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
364                "DATE_FORMAT".to_string(),
365                f.args,
366            )))),
367
368            // LEVENSHTEIN -> LEVENSHTEIN_DISTANCE in Trino
369            "LEVENSHTEIN" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
370                Function::new("LEVENSHTEIN_DISTANCE".to_string(), f.args),
371            ))),
372
373            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Trino
374            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
375                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
376            ))),
377
378            // COLLECT_LIST -> ARRAY_AGG
379            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
380                Function::new("ARRAY_AGG".to_string(), f.args),
381            ))),
382
383            // COLLECT_SET -> ARRAY_DISTINCT(ARRAY_AGG())
384            "COLLECT_SET" if !f.args.is_empty() => {
385                let array_agg =
386                    Expression::Function(Box::new(Function::new("ARRAY_AGG".to_string(), f.args)));
387                Ok(Expression::Function(Box::new(Function::new(
388                    "ARRAY_DISTINCT".to_string(),
389                    vec![array_agg],
390                ))))
391            }
392
393            // RLIKE -> REGEXP_LIKE in Trino
394            "RLIKE" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
395                "REGEXP_LIKE".to_string(),
396                f.args,
397            )))),
398
399            // REGEXP -> REGEXP_LIKE in Trino
400            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
401                "REGEXP_LIKE".to_string(),
402                f.args,
403            )))),
404
405            // ARRAY_SUM -> REDUCE in Trino (complex transformation)
406            // For simplicity, we'll use a different approach
407            "ARRAY_SUM" if f.args.len() == 1 => {
408                // This is a complex transformation in Presto/Trino
409                // ARRAY_SUM(arr) -> REDUCE(arr, 0, (s, x) -> s + x, s -> s)
410                // For now, pass through and let user handle it
411                Ok(Expression::Function(Box::new(f)))
412            }
413
414            // Pass through everything else
415            _ => Ok(Expression::Function(Box::new(f))),
416        }
417    }
418
419    fn transform_aggregate_function(
420        &self,
421        f: Box<crate::expressions::AggregateFunction>,
422    ) -> Result<Expression> {
423        let name_upper = f.name.to_uppercase();
424        match name_upper.as_str() {
425            // COUNT_IF -> SUM(CASE WHEN...)
426            "COUNT_IF" if !f.args.is_empty() => {
427                let condition = f.args.into_iter().next().unwrap();
428                let case_expr = Expression::Case(Box::new(Case {
429                    operand: None,
430                    whens: vec![(condition, Expression::number(1))],
431                    else_: Some(Expression::number(0)),
432                    comments: Vec::new(),
433                    inferred_type: None,
434                }));
435                Ok(Expression::Sum(Box::new(AggFunc {
436                    ignore_nulls: None,
437                    having_max: None,
438                    this: case_expr,
439                    distinct: f.distinct,
440                    filter: f.filter,
441                    order_by: Vec::new(),
442                    name: None,
443                    limit: None,
444                    inferred_type: None,
445                })))
446            }
447
448            // ANY_VALUE -> ARBITRARY in Trino
449            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
450                "ARBITRARY".to_string(),
451                f.args,
452            )))),
453
454            // GROUP_CONCAT -> LISTAGG in Trino
455            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
456                Function::new("LISTAGG".to_string(), f.args),
457            ))),
458
459            // STRING_AGG -> LISTAGG in Trino
460            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
461                Function::new("LISTAGG".to_string(), f.args),
462            ))),
463
464            // VAR -> VAR_POP in Trino
465            "VAR" if !f.args.is_empty() => {
466                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
467                    name: "VAR_POP".to_string(),
468                    args: f.args,
469                    distinct: f.distinct,
470                    filter: f.filter,
471                    order_by: Vec::new(),
472                    limit: None,
473                    ignore_nulls: None,
474                    inferred_type: None,
475                })))
476            }
477
478            // VARIANCE -> VAR_SAMP in Trino
479            "VARIANCE" if !f.args.is_empty() => {
480                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
481                    name: "VAR_SAMP".to_string(),
482                    args: f.args,
483                    distinct: f.distinct,
484                    filter: f.filter,
485                    order_by: Vec::new(),
486                    limit: None,
487                    ignore_nulls: None,
488                    inferred_type: None,
489                })))
490            }
491
492            // Pass through everything else
493            _ => Ok(Expression::AggregateFunction(f)),
494        }
495    }
496
497    fn transform_cast(&self, c: Cast) -> Result<Expression> {
498        // Trino type mappings are handled in the generator
499        Ok(Expression::Cast(Box::new(c)))
500    }
501}