Skip to main content

polyglot_sql/dialects/
trino.rs

1//! Trino Dialect
2//!
3//! Trino-specific transformations based on sqlglot patterns.
4//! Trino is largely compatible with Presto but has some differences.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{
9    AggFunc, AggregateFunction, Case, Cast, DataType, Expression, Function, IntervalUnit,
10    IntervalUnitSpec, LikeOp, Literal, UnaryFunc, VarArgFunc,
11};
12#[cfg(feature = "generate")]
13use crate::generator::GeneratorConfig;
14use crate::tokens::TokenizerConfig;
15
16/// Trino dialect
17pub struct TrinoDialect;
18
19impl DialectImpl for TrinoDialect {
20    fn dialect_type(&self) -> DialectType {
21        DialectType::Trino
22    }
23
24    fn tokenizer_config(&self) -> TokenizerConfig {
25        let mut config = TokenizerConfig::default();
26        // Trino uses double quotes for identifiers
27        config.identifiers.insert('"', '"');
28        // Trino does NOT support nested comments
29        config.nested_comments = false;
30        // Trino does NOT support QUALIFY - it's a valid identifier
31        // (unlike Snowflake, BigQuery, DuckDB which have QUALIFY clause)
32        config.keywords.remove("QUALIFY");
33        config
34    }
35
36    #[cfg(feature = "generate")]
37
38    fn generator_config(&self) -> GeneratorConfig {
39        use crate::generator::IdentifierQuoteStyle;
40        GeneratorConfig {
41            identifier_quote: '"',
42            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
43            dialect: Some(DialectType::Trino),
44            limit_only_literals: true,
45            tz_to_with_time_zone: true,
46            ..Default::default()
47        }
48    }
49
50    #[cfg(feature = "transpile")]
51
52    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
53        match expr {
54            // IFNULL -> COALESCE in Trino
55            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
56                original_name: None,
57                expressions: vec![f.this, f.expression],
58                inferred_type: None,
59            }))),
60
61            // NVL -> COALESCE in Trino
62            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
63                original_name: None,
64                expressions: vec![f.this, f.expression],
65                inferred_type: None,
66            }))),
67
68            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
69            Expression::Coalesce(mut f) => {
70                f.original_name = None;
71                Ok(Expression::Coalesce(f))
72            }
73
74            // TryCast stays as TryCast (Trino supports TRY_CAST)
75            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
76
77            // SafeCast -> TRY_CAST in Trino
78            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
79
80            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
81            Expression::ILike(op) => {
82                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
83                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
84                Ok(Expression::Like(Box::new(LikeOp {
85                    left: lower_left,
86                    right: lower_right,
87                    escape: op.escape,
88                    quantifier: op.quantifier.clone(),
89                    inferred_type: None,
90                })))
91            }
92
93            // CountIf is native in Trino (keep as-is)
94            Expression::CountIf(f) => Ok(Expression::CountIf(f)),
95
96            // EXPLODE -> UNNEST in Trino
97            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
98                crate::expressions::UnnestFunc {
99                    this: f.this,
100                    expressions: Vec::new(),
101                    with_ordinality: false,
102                    alias: None,
103                    offset_alias: None,
104                },
105            ))),
106
107            // ExplodeOuter -> UNNEST in Trino
108            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
109                crate::expressions::UnnestFunc {
110                    this: f.this,
111                    expressions: Vec::new(),
112                    with_ordinality: false,
113                    alias: None,
114                    offset_alias: None,
115                },
116            ))),
117
118            // Generic function transformations
119            Expression::Function(f) => self.transform_function(*f),
120
121            // Generic aggregate function transformations
122            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
123
124            // Cast transformations
125            Expression::Cast(c) => self.transform_cast(*c),
126
127            // TRIM: Convert comma syntax TRIM(str, chars) to SQL standard TRIM(chars FROM str)
128            // Trino requires SQL standard syntax for TRIM with characters
129            Expression::Trim(mut f) => {
130                if !f.sql_standard_syntax && f.characters.is_some() {
131                    // Convert from TRIM(str, chars) to TRIM(chars FROM str)
132                    f.sql_standard_syntax = true;
133                }
134                Ok(Expression::Trim(f))
135            }
136
137            // LISTAGG: Add default separator ',' if none is specified (Trino style)
138            Expression::ListAgg(mut f) => {
139                if f.separator.is_none() {
140                    f.separator = Some(Expression::Literal(Box::new(Literal::String(
141                        ",".to_string(),
142                    ))));
143                }
144                Ok(Expression::ListAgg(f))
145            }
146
147            // Interval: Split compound string intervals like INTERVAL '1 day' into INTERVAL '1' DAY
148            Expression::Interval(mut interval) => {
149                if interval.unit.is_none() {
150                    if let Some(Expression::Literal(ref lit)) = interval.this {
151                        if let Literal::String(ref s) = lit.as_ref() {
152                            if let Some((value, unit)) = Self::parse_compound_interval(s) {
153                                interval.this =
154                                    Some(Expression::Literal(Box::new(Literal::String(value))));
155                                interval.unit = Some(unit);
156                            }
157                        }
158                    }
159                }
160                Ok(Expression::Interval(interval))
161            }
162
163            // Pass through everything else
164            _ => Ok(expr),
165        }
166    }
167}
168
169#[cfg(feature = "transpile")]
170impl TrinoDialect {
171    /// Parse a compound interval string like "1 day" into (value, unit_spec).
172    /// Returns None if the string doesn't match a known pattern.
173    fn parse_compound_interval(s: &str) -> Option<(String, IntervalUnitSpec)> {
174        let s = s.trim();
175        let parts: Vec<&str> = s.split_whitespace().collect();
176        if parts.len() != 2 {
177            return None;
178        }
179        let value = parts[0].to_string();
180        let unit = match parts[1].to_uppercase().as_str() {
181            "YEAR" | "YEARS" => IntervalUnit::Year,
182            "MONTH" | "MONTHS" => IntervalUnit::Month,
183            "DAY" | "DAYS" => IntervalUnit::Day,
184            "HOUR" | "HOURS" => IntervalUnit::Hour,
185            "MINUTE" | "MINUTES" => IntervalUnit::Minute,
186            "SECOND" | "SECONDS" => IntervalUnit::Second,
187            "MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
188            "MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
189            _ => return None,
190        };
191        Some((
192            value,
193            IntervalUnitSpec::Simple {
194                unit,
195                use_plural: false,
196            },
197        ))
198    }
199
200    fn transform_function(&self, f: Function) -> Result<Expression> {
201        let name_upper = f.name.to_uppercase();
202        match name_upper.as_str() {
203            // IFNULL -> COALESCE
204            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
205                original_name: None,
206                expressions: f.args,
207                inferred_type: None,
208            }))),
209
210            // NVL -> COALESCE
211            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
212                original_name: None,
213                expressions: f.args,
214                inferred_type: None,
215            }))),
216
217            // ISNULL -> COALESCE
218            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
219                original_name: None,
220                expressions: f.args,
221                inferred_type: None,
222            }))),
223
224            // GETDATE -> CURRENT_TIMESTAMP
225            "GETDATE" => Ok(Expression::CurrentTimestamp(
226                crate::expressions::CurrentTimestamp {
227                    precision: None,
228                    sysdate: false,
229                },
230            )),
231
232            // NOW -> CURRENT_TIMESTAMP
233            "NOW" => Ok(Expression::CurrentTimestamp(
234                crate::expressions::CurrentTimestamp {
235                    precision: None,
236                    sysdate: false,
237                },
238            )),
239
240            // RAND -> RANDOM in Trino
241            "RAND" => Ok(Expression::Function(Box::new(Function::new(
242                "RANDOM".to_string(),
243                vec![],
244            )))),
245
246            // GROUP_CONCAT -> LISTAGG in Trino (Trino supports LISTAGG)
247            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
248                Function::new("LISTAGG".to_string(), f.args),
249            ))),
250
251            // STRING_AGG -> LISTAGG in Trino
252            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
253                Function::new("LISTAGG".to_string(), f.args),
254            ))),
255
256            // LISTAGG is native in Trino
257            "LISTAGG" => Ok(Expression::Function(Box::new(f))),
258
259            // SUBSTR -> SUBSTRING
260            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
261                "SUBSTRING".to_string(),
262                f.args,
263            )))),
264
265            // LEN -> LENGTH
266            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
267                f.args.into_iter().next().unwrap(),
268            )))),
269
270            // CHARINDEX -> STRPOS in Trino (with swapped args)
271            "CHARINDEX" if f.args.len() >= 2 => {
272                let mut args = f.args;
273                let substring = args.remove(0);
274                let string = args.remove(0);
275                Ok(Expression::Function(Box::new(Function::new(
276                    "STRPOS".to_string(),
277                    vec![string, substring],
278                ))))
279            }
280
281            // INSTR -> STRPOS
282            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
283                "STRPOS".to_string(),
284                f.args,
285            )))),
286
287            // LOCATE -> STRPOS in Trino (with swapped args)
288            "LOCATE" if f.args.len() >= 2 => {
289                let mut args = f.args;
290                let substring = args.remove(0);
291                let string = args.remove(0);
292                Ok(Expression::Function(Box::new(Function::new(
293                    "STRPOS".to_string(),
294                    vec![string, substring],
295                ))))
296            }
297
298            // ARRAY_LENGTH -> CARDINALITY in Trino
299            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
300                Function::new("CARDINALITY".to_string(), f.args),
301            ))),
302
303            // SIZE -> CARDINALITY in Trino
304            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
305                "CARDINALITY".to_string(),
306                f.args,
307            )))),
308
309            // ARRAY_CONTAINS -> CONTAINS in Trino
310            "ARRAY_CONTAINS" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
311                Function::new("CONTAINS".to_string(), f.args),
312            ))),
313
314            // TO_DATE -> CAST to DATE or DATE_PARSE
315            "TO_DATE" if !f.args.is_empty() => {
316                if f.args.len() == 1 {
317                    Ok(Expression::Cast(Box::new(Cast {
318                        this: f.args.into_iter().next().unwrap(),
319                        to: DataType::Date,
320                        trailing_comments: Vec::new(),
321                        double_colon_syntax: false,
322                        format: None,
323                        default: None,
324                        inferred_type: None,
325                    })))
326                } else {
327                    Ok(Expression::Function(Box::new(Function::new(
328                        "DATE_PARSE".to_string(),
329                        f.args,
330                    ))))
331                }
332            }
333
334            // TO_TIMESTAMP -> CAST or DATE_PARSE
335            "TO_TIMESTAMP" if !f.args.is_empty() => {
336                if f.args.len() == 1 {
337                    Ok(Expression::Cast(Box::new(Cast {
338                        this: f.args.into_iter().next().unwrap(),
339                        to: DataType::Timestamp {
340                            precision: None,
341                            timezone: false,
342                        },
343                        trailing_comments: Vec::new(),
344                        double_colon_syntax: false,
345                        format: None,
346                        default: None,
347                        inferred_type: None,
348                    })))
349                } else {
350                    Ok(Expression::Function(Box::new(Function::new(
351                        "DATE_PARSE".to_string(),
352                        f.args,
353                    ))))
354                }
355            }
356
357            // strftime -> DATE_FORMAT in Trino
358            "STRFTIME" if f.args.len() >= 2 => {
359                let mut args = f.args;
360                let format = args.remove(0);
361                let date = args.remove(0);
362                Ok(Expression::Function(Box::new(Function::new(
363                    "DATE_FORMAT".to_string(),
364                    vec![date, format],
365                ))))
366            }
367
368            // TO_CHAR -> DATE_FORMAT in Trino
369            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
370                "DATE_FORMAT".to_string(),
371                f.args,
372            )))),
373
374            // LEVENSHTEIN -> LEVENSHTEIN_DISTANCE in Trino
375            "LEVENSHTEIN" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
376                Function::new("LEVENSHTEIN_DISTANCE".to_string(), f.args),
377            ))),
378
379            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Trino
380            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
381                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
382            ))),
383
384            // COLLECT_LIST -> ARRAY_AGG
385            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
386                Function::new("ARRAY_AGG".to_string(), f.args),
387            ))),
388
389            // COLLECT_SET -> ARRAY_DISTINCT(ARRAY_AGG())
390            "COLLECT_SET" if !f.args.is_empty() => {
391                let array_agg =
392                    Expression::Function(Box::new(Function::new("ARRAY_AGG".to_string(), f.args)));
393                Ok(Expression::Function(Box::new(Function::new(
394                    "ARRAY_DISTINCT".to_string(),
395                    vec![array_agg],
396                ))))
397            }
398
399            // RLIKE -> REGEXP_LIKE in Trino
400            "RLIKE" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
401                "REGEXP_LIKE".to_string(),
402                f.args,
403            )))),
404
405            // REGEXP -> REGEXP_LIKE in Trino
406            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
407                "REGEXP_LIKE".to_string(),
408                f.args,
409            )))),
410
411            // ARRAY_SUM -> REDUCE in Trino (complex transformation)
412            // For simplicity, we'll use a different approach
413            "ARRAY_SUM" if f.args.len() == 1 => {
414                // This is a complex transformation in Presto/Trino
415                // ARRAY_SUM(arr) -> REDUCE(arr, 0, (s, x) -> s + x, s -> s)
416                // For now, pass through and let user handle it
417                Ok(Expression::Function(Box::new(f)))
418            }
419
420            // Pass through everything else
421            _ => Ok(Expression::Function(Box::new(f))),
422        }
423    }
424
425    fn transform_aggregate_function(
426        &self,
427        f: Box<crate::expressions::AggregateFunction>,
428    ) -> Result<Expression> {
429        let name_upper = f.name.to_uppercase();
430        match name_upper.as_str() {
431            // COUNT_IF -> SUM(CASE WHEN...)
432            "COUNT_IF" if !f.args.is_empty() => {
433                let condition = f.args.into_iter().next().unwrap();
434                let case_expr = Expression::Case(Box::new(Case {
435                    operand: None,
436                    whens: vec![(condition, Expression::number(1))],
437                    else_: Some(Expression::number(0)),
438                    comments: Vec::new(),
439                    inferred_type: None,
440                }));
441                Ok(Expression::Sum(Box::new(AggFunc {
442                    ignore_nulls: None,
443                    having_max: None,
444                    this: case_expr,
445                    distinct: f.distinct,
446                    filter: f.filter,
447                    order_by: Vec::new(),
448                    name: None,
449                    limit: None,
450                    inferred_type: None,
451                })))
452            }
453
454            // ANY_VALUE -> ARBITRARY in Trino
455            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
456                "ARBITRARY".to_string(),
457                f.args,
458            )))),
459
460            // GROUP_CONCAT -> LISTAGG in Trino
461            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
462                Function::new("LISTAGG".to_string(), f.args),
463            ))),
464
465            // STRING_AGG -> LISTAGG in Trino
466            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
467                Function::new("LISTAGG".to_string(), f.args),
468            ))),
469
470            // VAR -> VAR_POP in Trino
471            "VAR" if !f.args.is_empty() => {
472                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
473                    name: "VAR_POP".to_string(),
474                    args: f.args,
475                    distinct: f.distinct,
476                    filter: f.filter,
477                    order_by: Vec::new(),
478                    limit: None,
479                    ignore_nulls: None,
480                    inferred_type: None,
481                })))
482            }
483
484            // VARIANCE -> VAR_SAMP in Trino
485            "VARIANCE" if !f.args.is_empty() => {
486                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
487                    name: "VAR_SAMP".to_string(),
488                    args: f.args,
489                    distinct: f.distinct,
490                    filter: f.filter,
491                    order_by: Vec::new(),
492                    limit: None,
493                    ignore_nulls: None,
494                    inferred_type: None,
495                })))
496            }
497
498            // Pass through everything else
499            _ => Ok(Expression::AggregateFunction(f)),
500        }
501    }
502
503    fn transform_cast(&self, c: Cast) -> Result<Expression> {
504        // Trino type mappings are handled in the generator
505        Ok(Expression::Cast(Box::new(c)))
506    }
507}