Skip to main content

polyglot_sql/dialects/
trino.rs

1//! Trino Dialect
2//!
3//! Trino-specific transformations based on sqlglot patterns.
4//! Trino is largely compatible with Presto but has some differences.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{
9    AggFunc, AggregateFunction, Case, Cast, DataType, Expression, Function, IntervalUnit,
10    IntervalUnitSpec, LikeOp, Literal, UnaryFunc, VarArgFunc,
11};
12use crate::generator::GeneratorConfig;
13use crate::tokens::TokenizerConfig;
14
15/// Trino dialect
16pub struct TrinoDialect;
17
18impl DialectImpl for TrinoDialect {
19    fn dialect_type(&self) -> DialectType {
20        DialectType::Trino
21    }
22
23    fn tokenizer_config(&self) -> TokenizerConfig {
24        let mut config = TokenizerConfig::default();
25        // Trino uses double quotes for identifiers
26        config.identifiers.insert('"', '"');
27        // Trino does NOT support nested comments
28        config.nested_comments = false;
29        // Trino does NOT support QUALIFY - it's a valid identifier
30        // (unlike Snowflake, BigQuery, DuckDB which have QUALIFY clause)
31        config.keywords.remove("QUALIFY");
32        config
33    }
34
35    fn generator_config(&self) -> GeneratorConfig {
36        use crate::generator::IdentifierQuoteStyle;
37        GeneratorConfig {
38            identifier_quote: '"',
39            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
40            dialect: Some(DialectType::Trino),
41            limit_only_literals: true,
42            tz_to_with_time_zone: true,
43            ..Default::default()
44        }
45    }
46
47    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
48        match expr {
49            // IFNULL -> COALESCE in Trino
50            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
51                original_name: None,
52                expressions: vec![f.this, f.expression],
53                inferred_type: None,
54            }))),
55
56            // NVL -> COALESCE in Trino
57            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
58                original_name: None,
59                expressions: vec![f.this, f.expression],
60                inferred_type: None,
61            }))),
62
63            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
64            Expression::Coalesce(mut f) => {
65                f.original_name = None;
66                Ok(Expression::Coalesce(f))
67            }
68
69            // TryCast stays as TryCast (Trino supports TRY_CAST)
70            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
71
72            // SafeCast -> TRY_CAST in Trino
73            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
74
75            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
76            Expression::ILike(op) => {
77                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
78                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
79                Ok(Expression::Like(Box::new(LikeOp {
80                    left: lower_left,
81                    right: lower_right,
82                    escape: op.escape,
83                    quantifier: op.quantifier.clone(),
84                    inferred_type: None,
85                })))
86            }
87
88            // CountIf is native in Trino (keep as-is)
89            Expression::CountIf(f) => Ok(Expression::CountIf(f)),
90
91            // EXPLODE -> UNNEST in Trino
92            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
93                crate::expressions::UnnestFunc {
94                    this: f.this,
95                    expressions: Vec::new(),
96                    with_ordinality: false,
97                    alias: None,
98                    offset_alias: None,
99                },
100            ))),
101
102            // ExplodeOuter -> UNNEST in Trino
103            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
104                crate::expressions::UnnestFunc {
105                    this: f.this,
106                    expressions: Vec::new(),
107                    with_ordinality: false,
108                    alias: None,
109                    offset_alias: None,
110                },
111            ))),
112
113            // Generic function transformations
114            Expression::Function(f) => self.transform_function(*f),
115
116            // Generic aggregate function transformations
117            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
118
119            // Cast transformations
120            Expression::Cast(c) => self.transform_cast(*c),
121
122            // TRIM: Convert comma syntax TRIM(str, chars) to SQL standard TRIM(chars FROM str)
123            // Trino requires SQL standard syntax for TRIM with characters
124            Expression::Trim(mut f) => {
125                if !f.sql_standard_syntax && f.characters.is_some() {
126                    // Convert from TRIM(str, chars) to TRIM(chars FROM str)
127                    f.sql_standard_syntax = true;
128                }
129                Ok(Expression::Trim(f))
130            }
131
132            // LISTAGG: Add default separator ',' if none is specified (Trino style)
133            Expression::ListAgg(mut f) => {
134                if f.separator.is_none() {
135                    f.separator = Some(Expression::Literal(Box::new(Literal::String(",".to_string()))));
136                }
137                Ok(Expression::ListAgg(f))
138            }
139
140            // Interval: Split compound string intervals like INTERVAL '1 day' into INTERVAL '1' DAY
141            Expression::Interval(mut interval) => {
142                if interval.unit.is_none() {
143                    if let Some(Expression::Literal(ref lit)) = interval.this {
144                        if let Literal::String(ref s) = lit.as_ref() {
145                        if let Some((value, unit)) = Self::parse_compound_interval(s) {
146                            interval.this = Some(Expression::Literal(Box::new(Literal::String(value))));
147                            interval.unit = Some(unit);
148                        }
149                    }
150                    }
151                }
152                Ok(Expression::Interval(interval))
153            }
154
155            // Pass through everything else
156            _ => Ok(expr),
157        }
158    }
159}
160
161impl TrinoDialect {
162    /// Parse a compound interval string like "1 day" into (value, unit_spec).
163    /// Returns None if the string doesn't match a known pattern.
164    fn parse_compound_interval(s: &str) -> Option<(String, IntervalUnitSpec)> {
165        let s = s.trim();
166        let parts: Vec<&str> = s.split_whitespace().collect();
167        if parts.len() != 2 {
168            return None;
169        }
170        let value = parts[0].to_string();
171        let unit = match parts[1].to_uppercase().as_str() {
172            "YEAR" | "YEARS" => IntervalUnit::Year,
173            "MONTH" | "MONTHS" => IntervalUnit::Month,
174            "DAY" | "DAYS" => IntervalUnit::Day,
175            "HOUR" | "HOURS" => IntervalUnit::Hour,
176            "MINUTE" | "MINUTES" => IntervalUnit::Minute,
177            "SECOND" | "SECONDS" => IntervalUnit::Second,
178            "MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
179            "MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
180            _ => return None,
181        };
182        Some((
183            value,
184            IntervalUnitSpec::Simple {
185                unit,
186                use_plural: false,
187            },
188        ))
189    }
190
191    fn transform_function(&self, f: Function) -> Result<Expression> {
192        let name_upper = f.name.to_uppercase();
193        match name_upper.as_str() {
194            // IFNULL -> COALESCE
195            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
196                original_name: None,
197                expressions: f.args,
198                inferred_type: None,
199            }))),
200
201            // NVL -> COALESCE
202            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
203                original_name: None,
204                expressions: f.args,
205                inferred_type: None,
206            }))),
207
208            // ISNULL -> COALESCE
209            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
210                original_name: None,
211                expressions: f.args,
212                inferred_type: None,
213            }))),
214
215            // GETDATE -> CURRENT_TIMESTAMP
216            "GETDATE" => Ok(Expression::CurrentTimestamp(
217                crate::expressions::CurrentTimestamp {
218                    precision: None,
219                    sysdate: false,
220                },
221            )),
222
223            // NOW -> CURRENT_TIMESTAMP
224            "NOW" => Ok(Expression::CurrentTimestamp(
225                crate::expressions::CurrentTimestamp {
226                    precision: None,
227                    sysdate: false,
228                },
229            )),
230
231            // RAND -> RANDOM in Trino
232            "RAND" => Ok(Expression::Function(Box::new(Function::new(
233                "RANDOM".to_string(),
234                vec![],
235            )))),
236
237            // GROUP_CONCAT -> LISTAGG in Trino (Trino supports LISTAGG)
238            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
239                Function::new("LISTAGG".to_string(), f.args),
240            ))),
241
242            // STRING_AGG -> LISTAGG in Trino
243            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
244                Function::new("LISTAGG".to_string(), f.args),
245            ))),
246
247            // LISTAGG is native in Trino
248            "LISTAGG" => Ok(Expression::Function(Box::new(f))),
249
250            // SUBSTR -> SUBSTRING
251            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
252                "SUBSTRING".to_string(),
253                f.args,
254            )))),
255
256            // LEN -> LENGTH
257            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
258                f.args.into_iter().next().unwrap(),
259            )))),
260
261            // CHARINDEX -> STRPOS in Trino (with swapped args)
262            "CHARINDEX" if f.args.len() >= 2 => {
263                let mut args = f.args;
264                let substring = args.remove(0);
265                let string = args.remove(0);
266                Ok(Expression::Function(Box::new(Function::new(
267                    "STRPOS".to_string(),
268                    vec![string, substring],
269                ))))
270            }
271
272            // INSTR -> STRPOS
273            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
274                "STRPOS".to_string(),
275                f.args,
276            )))),
277
278            // LOCATE -> STRPOS in Trino (with swapped args)
279            "LOCATE" if f.args.len() >= 2 => {
280                let mut args = f.args;
281                let substring = args.remove(0);
282                let string = args.remove(0);
283                Ok(Expression::Function(Box::new(Function::new(
284                    "STRPOS".to_string(),
285                    vec![string, substring],
286                ))))
287            }
288
289            // ARRAY_LENGTH -> CARDINALITY in Trino
290            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
291                Function::new("CARDINALITY".to_string(), f.args),
292            ))),
293
294            // SIZE -> CARDINALITY in Trino
295            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
296                "CARDINALITY".to_string(),
297                f.args,
298            )))),
299
300            // ARRAY_CONTAINS -> CONTAINS in Trino
301            "ARRAY_CONTAINS" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
302                Function::new("CONTAINS".to_string(), f.args),
303            ))),
304
305            // TO_DATE -> CAST to DATE or DATE_PARSE
306            "TO_DATE" if !f.args.is_empty() => {
307                if f.args.len() == 1 {
308                    Ok(Expression::Cast(Box::new(Cast {
309                        this: f.args.into_iter().next().unwrap(),
310                        to: DataType::Date,
311                        trailing_comments: Vec::new(),
312                        double_colon_syntax: false,
313                        format: None,
314                        default: None,
315                        inferred_type: None,
316                    })))
317                } else {
318                    Ok(Expression::Function(Box::new(Function::new(
319                        "DATE_PARSE".to_string(),
320                        f.args,
321                    ))))
322                }
323            }
324
325            // TO_TIMESTAMP -> CAST or DATE_PARSE
326            "TO_TIMESTAMP" if !f.args.is_empty() => {
327                if f.args.len() == 1 {
328                    Ok(Expression::Cast(Box::new(Cast {
329                        this: f.args.into_iter().next().unwrap(),
330                        to: DataType::Timestamp {
331                            precision: None,
332                            timezone: false,
333                        },
334                        trailing_comments: Vec::new(),
335                        double_colon_syntax: false,
336                        format: None,
337                        default: None,
338                        inferred_type: None,
339                    })))
340                } else {
341                    Ok(Expression::Function(Box::new(Function::new(
342                        "DATE_PARSE".to_string(),
343                        f.args,
344                    ))))
345                }
346            }
347
348            // strftime -> DATE_FORMAT in Trino
349            "STRFTIME" if f.args.len() >= 2 => {
350                let mut args = f.args;
351                let format = args.remove(0);
352                let date = args.remove(0);
353                Ok(Expression::Function(Box::new(Function::new(
354                    "DATE_FORMAT".to_string(),
355                    vec![date, format],
356                ))))
357            }
358
359            // TO_CHAR -> DATE_FORMAT in Trino
360            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
361                "DATE_FORMAT".to_string(),
362                f.args,
363            )))),
364
365            // LEVENSHTEIN -> LEVENSHTEIN_DISTANCE in Trino
366            "LEVENSHTEIN" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
367                Function::new("LEVENSHTEIN_DISTANCE".to_string(), f.args),
368            ))),
369
370            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Trino
371            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
372                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
373            ))),
374
375            // COLLECT_LIST -> ARRAY_AGG
376            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
377                Function::new("ARRAY_AGG".to_string(), f.args),
378            ))),
379
380            // COLLECT_SET -> ARRAY_DISTINCT(ARRAY_AGG())
381            "COLLECT_SET" if !f.args.is_empty() => {
382                let array_agg =
383                    Expression::Function(Box::new(Function::new("ARRAY_AGG".to_string(), f.args)));
384                Ok(Expression::Function(Box::new(Function::new(
385                    "ARRAY_DISTINCT".to_string(),
386                    vec![array_agg],
387                ))))
388            }
389
390            // RLIKE -> REGEXP_LIKE in Trino
391            "RLIKE" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
392                "REGEXP_LIKE".to_string(),
393                f.args,
394            )))),
395
396            // REGEXP -> REGEXP_LIKE in Trino
397            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
398                "REGEXP_LIKE".to_string(),
399                f.args,
400            )))),
401
402            // ARRAY_SUM -> REDUCE in Trino (complex transformation)
403            // For simplicity, we'll use a different approach
404            "ARRAY_SUM" if f.args.len() == 1 => {
405                // This is a complex transformation in Presto/Trino
406                // ARRAY_SUM(arr) -> REDUCE(arr, 0, (s, x) -> s + x, s -> s)
407                // For now, pass through and let user handle it
408                Ok(Expression::Function(Box::new(f)))
409            }
410
411            // Pass through everything else
412            _ => Ok(Expression::Function(Box::new(f))),
413        }
414    }
415
416    fn transform_aggregate_function(
417        &self,
418        f: Box<crate::expressions::AggregateFunction>,
419    ) -> Result<Expression> {
420        let name_upper = f.name.to_uppercase();
421        match name_upper.as_str() {
422            // COUNT_IF -> SUM(CASE WHEN...)
423            "COUNT_IF" if !f.args.is_empty() => {
424                let condition = f.args.into_iter().next().unwrap();
425                let case_expr = Expression::Case(Box::new(Case {
426                    operand: None,
427                    whens: vec![(condition, Expression::number(1))],
428                    else_: Some(Expression::number(0)),
429                    comments: Vec::new(),
430                    inferred_type: None,
431                }));
432                Ok(Expression::Sum(Box::new(AggFunc {
433                    ignore_nulls: None,
434                    having_max: None,
435                    this: case_expr,
436                    distinct: f.distinct,
437                    filter: f.filter,
438                    order_by: Vec::new(),
439                    name: None,
440                    limit: None,
441                    inferred_type: None,
442                })))
443            }
444
445            // ANY_VALUE -> ARBITRARY in Trino
446            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
447                "ARBITRARY".to_string(),
448                f.args,
449            )))),
450
451            // GROUP_CONCAT -> LISTAGG in Trino
452            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
453                Function::new("LISTAGG".to_string(), f.args),
454            ))),
455
456            // STRING_AGG -> LISTAGG in Trino
457            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
458                Function::new("LISTAGG".to_string(), f.args),
459            ))),
460
461            // VAR -> VAR_POP in Trino
462            "VAR" if !f.args.is_empty() => {
463                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
464                    name: "VAR_POP".to_string(),
465                    args: f.args,
466                    distinct: f.distinct,
467                    filter: f.filter,
468                    order_by: Vec::new(),
469                    limit: None,
470                    ignore_nulls: None,
471                    inferred_type: None,
472                })))
473            }
474
475            // VARIANCE -> VAR_SAMP in Trino
476            "VARIANCE" if !f.args.is_empty() => {
477                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
478                    name: "VAR_SAMP".to_string(),
479                    args: f.args,
480                    distinct: f.distinct,
481                    filter: f.filter,
482                    order_by: Vec::new(),
483                    limit: None,
484                    ignore_nulls: None,
485                    inferred_type: None,
486                })))
487            }
488
489            // Pass through everything else
490            _ => Ok(Expression::AggregateFunction(f)),
491        }
492    }
493
494    fn transform_cast(&self, c: Cast) -> Result<Expression> {
495        // Trino type mappings are handled in the generator
496        Ok(Expression::Cast(Box::new(c)))
497    }
498}