Skip to main content

polyglot_sql/dialects/
trino.rs

1//! Trino Dialect
2//!
3//! Trino-specific transformations based on sqlglot patterns.
4//! Trino is largely compatible with Presto but has some differences.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{
9    AggFunc, AggregateFunction, Case, Cast, DataType, Expression, Function, IntervalUnit, IntervalUnitSpec, LikeOp, Literal, UnaryFunc, VarArgFunc,
10};
11use crate::generator::GeneratorConfig;
12use crate::tokens::TokenizerConfig;
13
14/// Trino dialect
15pub struct TrinoDialect;
16
17impl DialectImpl for TrinoDialect {
18    fn dialect_type(&self) -> DialectType {
19        DialectType::Trino
20    }
21
22    fn tokenizer_config(&self) -> TokenizerConfig {
23        let mut config = TokenizerConfig::default();
24        // Trino uses double quotes for identifiers
25        config.identifiers.insert('"', '"');
26        // Trino does NOT support nested comments
27        config.nested_comments = false;
28        // Trino does NOT support QUALIFY - it's a valid identifier
29        // (unlike Snowflake, BigQuery, DuckDB which have QUALIFY clause)
30        config.keywords.remove("QUALIFY");
31        config
32    }
33
34    fn generator_config(&self) -> GeneratorConfig {
35        use crate::generator::IdentifierQuoteStyle;
36        GeneratorConfig {
37            identifier_quote: '"',
38            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
39            dialect: Some(DialectType::Trino),
40            limit_only_literals: true,
41            tz_to_with_time_zone: true,
42            ..Default::default()
43        }
44    }
45
46    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
47        match expr {
48            // IFNULL -> COALESCE in Trino
49            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
50                expressions: vec![f.this, f.expression],
51            }))),
52
53            // NVL -> COALESCE in Trino
54            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
55                expressions: vec![f.this, f.expression],
56            }))),
57
58            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
59            Expression::Coalesce(mut f) => {
60                f.original_name = None;
61                Ok(Expression::Coalesce(f))
62            }
63
64            // TryCast stays as TryCast (Trino supports TRY_CAST)
65            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
66
67            // SafeCast -> TRY_CAST in Trino
68            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
69
70            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
71            Expression::ILike(op) => {
72                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
73                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
74                Ok(Expression::Like(Box::new(LikeOp {
75                    left: lower_left,
76                    right: lower_right,
77                    escape: op.escape,
78                    quantifier: op.quantifier.clone(),
79                })))
80            }
81
82            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
83            Expression::CountIf(f) => {
84                let case_expr = Expression::Case(Box::new(Case {
85                    operand: None,
86                    whens: vec![(f.this.clone(), Expression::number(1))],
87                    else_: Some(Expression::number(0)),
88                }));
89                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
90                    this: case_expr,
91                    distinct: f.distinct,
92                    filter: f.filter,
93                    order_by: Vec::new(),
94                name: None,
95                limit: None,
96                })))
97            }
98
99            // EXPLODE -> UNNEST in Trino
100            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
101                crate::expressions::UnnestFunc {
102                    this: f.this,
103                    expressions: Vec::new(),
104                    with_ordinality: false,
105                    alias: None,
106                    offset_alias: None,
107                },
108            ))),
109
110            // ExplodeOuter -> UNNEST in Trino
111            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
112                crate::expressions::UnnestFunc {
113                    this: f.this,
114                    expressions: Vec::new(),
115                    with_ordinality: false,
116                    alias: None,
117                    offset_alias: None,
118                },
119            ))),
120
121            // Generic function transformations
122            Expression::Function(f) => self.transform_function(*f),
123
124            // Generic aggregate function transformations
125            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
126
127            // Cast transformations
128            Expression::Cast(c) => self.transform_cast(*c),
129
130            // TRIM: Convert comma syntax TRIM(str, chars) to SQL standard TRIM(chars FROM str)
131            // Trino requires SQL standard syntax for TRIM with characters
132            Expression::Trim(mut f) => {
133                if !f.sql_standard_syntax && f.characters.is_some() {
134                    // Convert from TRIM(str, chars) to TRIM(chars FROM str)
135                    f.sql_standard_syntax = true;
136                }
137                Ok(Expression::Trim(f))
138            }
139
140            // LISTAGG: Add default separator ',' if none is specified (Trino style)
141            Expression::ListAgg(mut f) => {
142                if f.separator.is_none() {
143                    f.separator = Some(Expression::Literal(Literal::String(",".to_string())));
144                }
145                Ok(Expression::ListAgg(f))
146            }
147
148            // Interval: Split compound string intervals like INTERVAL '1 day' into INTERVAL '1' DAY
149            Expression::Interval(mut interval) => {
150                if interval.unit.is_none() {
151                    if let Some(Expression::Literal(Literal::String(ref s))) = interval.this {
152                        if let Some((value, unit)) = Self::parse_compound_interval(s) {
153                            interval.this = Some(Expression::Literal(Literal::String(value)));
154                            interval.unit = Some(unit);
155                        }
156                    }
157                }
158                Ok(Expression::Interval(interval))
159            }
160
161            // Pass through everything else
162            _ => Ok(expr),
163        }
164    }
165}
166
167impl TrinoDialect {
168    /// Parse a compound interval string like "1 day" into (value, unit_spec).
169    /// Returns None if the string doesn't match a known pattern.
170    fn parse_compound_interval(s: &str) -> Option<(String, IntervalUnitSpec)> {
171        let s = s.trim();
172        let parts: Vec<&str> = s.split_whitespace().collect();
173        if parts.len() != 2 {
174            return None;
175        }
176        let value = parts[0].to_string();
177        let unit = match parts[1].to_uppercase().as_str() {
178            "YEAR" | "YEARS" => IntervalUnit::Year,
179            "MONTH" | "MONTHS" => IntervalUnit::Month,
180            "DAY" | "DAYS" => IntervalUnit::Day,
181            "HOUR" | "HOURS" => IntervalUnit::Hour,
182            "MINUTE" | "MINUTES" => IntervalUnit::Minute,
183            "SECOND" | "SECONDS" => IntervalUnit::Second,
184            "MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
185            "MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
186            _ => return None,
187        };
188        Some((value, IntervalUnitSpec::Simple { unit, use_plural: false }))
189    }
190
191    fn transform_function(&self, f: Function) -> Result<Expression> {
192        let name_upper = f.name.to_uppercase();
193        match name_upper.as_str() {
194            // IFNULL -> COALESCE
195            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
196                expressions: f.args,
197            }))),
198
199            // NVL -> COALESCE
200            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
201                expressions: f.args,
202            }))),
203
204            // ISNULL -> COALESCE
205            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
206                expressions: f.args,
207            }))),
208
209            // GETDATE -> CURRENT_TIMESTAMP
210            "GETDATE" => Ok(Expression::CurrentTimestamp(
211                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
212            )),
213
214            // NOW -> CURRENT_TIMESTAMP
215            "NOW" => Ok(Expression::CurrentTimestamp(
216                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
217            )),
218
219            // RAND -> RANDOM in Trino
220            "RAND" => Ok(Expression::Function(Box::new(Function::new(
221                "RANDOM".to_string(),
222                vec![],
223            )))),
224
225            // GROUP_CONCAT -> LISTAGG in Trino (Trino supports LISTAGG)
226            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
227                Function::new("LISTAGG".to_string(), f.args),
228            ))),
229
230            // STRING_AGG -> LISTAGG in Trino
231            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
232                Function::new("LISTAGG".to_string(), f.args),
233            ))),
234
235            // LISTAGG is native in Trino
236            "LISTAGG" => Ok(Expression::Function(Box::new(f))),
237
238            // SUBSTR -> SUBSTRING
239            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
240                "SUBSTRING".to_string(),
241                f.args,
242            )))),
243
244            // LEN -> LENGTH
245            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
246                f.args.into_iter().next().unwrap(),
247            )))),
248
249            // CHARINDEX -> STRPOS in Trino (with swapped args)
250            "CHARINDEX" if f.args.len() >= 2 => {
251                let mut args = f.args;
252                let substring = args.remove(0);
253                let string = args.remove(0);
254                Ok(Expression::Function(Box::new(Function::new(
255                    "STRPOS".to_string(),
256                    vec![string, substring],
257                ))))
258            }
259
260            // INSTR -> STRPOS
261            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
262                "STRPOS".to_string(),
263                f.args,
264            )))),
265
266            // LOCATE -> STRPOS in Trino (with swapped args)
267            "LOCATE" if f.args.len() >= 2 => {
268                let mut args = f.args;
269                let substring = args.remove(0);
270                let string = args.remove(0);
271                Ok(Expression::Function(Box::new(Function::new(
272                    "STRPOS".to_string(),
273                    vec![string, substring],
274                ))))
275            }
276
277            // ARRAY_LENGTH -> CARDINALITY in Trino
278            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
279                Function::new("CARDINALITY".to_string(), f.args),
280            ))),
281
282            // SIZE -> CARDINALITY in Trino
283            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
284                "CARDINALITY".to_string(),
285                f.args,
286            )))),
287
288            // ARRAY_CONTAINS -> CONTAINS in Trino
289            "ARRAY_CONTAINS" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
290                Function::new("CONTAINS".to_string(), f.args),
291            ))),
292
293            // TO_DATE -> CAST to DATE or DATE_PARSE
294            "TO_DATE" if !f.args.is_empty() => {
295                if f.args.len() == 1 {
296                    Ok(Expression::Cast(Box::new(Cast {
297                        this: f.args.into_iter().next().unwrap(),
298                        to: DataType::Date,
299                        trailing_comments: Vec::new(),
300                        double_colon_syntax: false,
301                        format: None,
302                        default: None,
303                    })))
304                } else {
305                    Ok(Expression::Function(Box::new(Function::new(
306                        "DATE_PARSE".to_string(),
307                        f.args,
308                    ))))
309                }
310            }
311
312            // TO_TIMESTAMP -> CAST or DATE_PARSE
313            "TO_TIMESTAMP" if !f.args.is_empty() => {
314                if f.args.len() == 1 {
315                    Ok(Expression::Cast(Box::new(Cast {
316                        this: f.args.into_iter().next().unwrap(),
317                        to: DataType::Timestamp { precision: None, timezone: false },
318                        trailing_comments: Vec::new(),
319                        double_colon_syntax: false,
320                        format: None,
321                        default: None,
322                    })))
323                } else {
324                    Ok(Expression::Function(Box::new(Function::new(
325                        "DATE_PARSE".to_string(),
326                        f.args,
327                    ))))
328                }
329            }
330
331            // strftime -> DATE_FORMAT in Trino
332            "STRFTIME" if f.args.len() >= 2 => {
333                let mut args = f.args;
334                let format = args.remove(0);
335                let date = args.remove(0);
336                Ok(Expression::Function(Box::new(Function::new(
337                    "DATE_FORMAT".to_string(),
338                    vec![date, format],
339                ))))
340            }
341
342            // TO_CHAR -> DATE_FORMAT in Trino
343            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
344                "DATE_FORMAT".to_string(),
345                f.args,
346            )))),
347
348            // LEVENSHTEIN -> LEVENSHTEIN_DISTANCE in Trino
349            "LEVENSHTEIN" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
350                Function::new("LEVENSHTEIN_DISTANCE".to_string(), f.args),
351            ))),
352
353            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Trino
354            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
355                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
356            ))),
357
358            // COLLECT_LIST -> ARRAY_AGG
359            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
360                Function::new("ARRAY_AGG".to_string(), f.args),
361            ))),
362
363            // COLLECT_SET -> ARRAY_DISTINCT(ARRAY_AGG())
364            "COLLECT_SET" if !f.args.is_empty() => {
365                let array_agg = Expression::Function(Box::new(Function::new(
366                    "ARRAY_AGG".to_string(),
367                    f.args,
368                )));
369                Ok(Expression::Function(Box::new(Function::new(
370                    "ARRAY_DISTINCT".to_string(),
371                    vec![array_agg],
372                ))))
373            }
374
375            // RLIKE -> REGEXP_LIKE in Trino
376            "RLIKE" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
377                "REGEXP_LIKE".to_string(),
378                f.args,
379            )))),
380
381            // REGEXP -> REGEXP_LIKE in Trino
382            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
383                "REGEXP_LIKE".to_string(),
384                f.args,
385            )))),
386
387            // ARRAY_SUM -> REDUCE in Trino (complex transformation)
388            // For simplicity, we'll use a different approach
389            "ARRAY_SUM" if f.args.len() == 1 => {
390                // This is a complex transformation in Presto/Trino
391                // ARRAY_SUM(arr) -> REDUCE(arr, 0, (s, x) -> s + x, s -> s)
392                // For now, pass through and let user handle it
393                Ok(Expression::Function(Box::new(f)))
394            }
395
396            // Pass through everything else
397            _ => Ok(Expression::Function(Box::new(f))),
398        }
399    }
400
401    fn transform_aggregate_function(
402        &self,
403        f: Box<crate::expressions::AggregateFunction>,
404    ) -> Result<Expression> {
405        let name_upper = f.name.to_uppercase();
406        match name_upper.as_str() {
407            // COUNT_IF -> SUM(CASE WHEN...)
408            "COUNT_IF" if !f.args.is_empty() => {
409                let condition = f.args.into_iter().next().unwrap();
410                let case_expr = Expression::Case(Box::new(Case {
411                    operand: None,
412                    whens: vec![(condition, Expression::number(1))],
413                    else_: Some(Expression::number(0)),
414                }));
415                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
416                    this: case_expr,
417                    distinct: f.distinct,
418                    filter: f.filter,
419                    order_by: Vec::new(),
420                name: None,
421                limit: None,
422                })))
423            }
424
425            // ANY_VALUE -> ARBITRARY in Trino
426            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
427                "ARBITRARY".to_string(),
428                f.args,
429            )))),
430
431            // GROUP_CONCAT -> LISTAGG in Trino
432            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
433                Function::new("LISTAGG".to_string(), f.args),
434            ))),
435
436            // STRING_AGG -> LISTAGG in Trino
437            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
438                Function::new("LISTAGG".to_string(), f.args),
439            ))),
440
441            // VAR -> VAR_POP in Trino
442            "VAR" if !f.args.is_empty() => Ok(Expression::AggregateFunction(Box::new(
443                AggregateFunction {
444                    name: "VAR_POP".to_string(),
445                    args: f.args,
446                    distinct: f.distinct,
447                    filter: f.filter,
448                    order_by: Vec::new(),
449                    limit: None,
450                    ignore_nulls: None,
451                },
452            ))),
453
454            // VARIANCE -> VAR_SAMP in Trino
455            "VARIANCE" if !f.args.is_empty() => Ok(Expression::AggregateFunction(Box::new(
456                AggregateFunction {
457                    name: "VAR_SAMP".to_string(),
458                    args: f.args,
459                    distinct: f.distinct,
460                    filter: f.filter,
461                    order_by: Vec::new(),
462                    limit: None,
463                    ignore_nulls: None,
464                },
465            ))),
466
467            // Pass through everything else
468            _ => Ok(Expression::AggregateFunction(f)),
469        }
470    }
471
472    fn transform_cast(&self, c: Cast) -> Result<Expression> {
473        // Trino type mappings are handled in the generator
474        Ok(Expression::Cast(Box::new(c)))
475    }
476}