Skip to main content

polyglot_sql/dialects/
trino.rs

1//! Trino Dialect
2//!
3//! Trino-specific transformations based on sqlglot patterns.
4//! Trino is largely compatible with Presto but has some differences.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{
9    AggFunc, AggregateFunction, Case, Cast, DataType, Expression, Function, IntervalUnit,
10    IntervalUnitSpec, LikeOp, Literal, UnaryFunc, VarArgFunc,
11};
12use crate::generator::GeneratorConfig;
13use crate::tokens::TokenizerConfig;
14
15/// Trino dialect
16pub struct TrinoDialect;
17
18impl DialectImpl for TrinoDialect {
19    fn dialect_type(&self) -> DialectType {
20        DialectType::Trino
21    }
22
23    fn tokenizer_config(&self) -> TokenizerConfig {
24        let mut config = TokenizerConfig::default();
25        // Trino uses double quotes for identifiers
26        config.identifiers.insert('"', '"');
27        // Trino does NOT support nested comments
28        config.nested_comments = false;
29        // Trino does NOT support QUALIFY - it's a valid identifier
30        // (unlike Snowflake, BigQuery, DuckDB which have QUALIFY clause)
31        config.keywords.remove("QUALIFY");
32        config
33    }
34
35    fn generator_config(&self) -> GeneratorConfig {
36        use crate::generator::IdentifierQuoteStyle;
37        GeneratorConfig {
38            identifier_quote: '"',
39            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
40            dialect: Some(DialectType::Trino),
41            limit_only_literals: true,
42            tz_to_with_time_zone: true,
43            ..Default::default()
44        }
45    }
46
47    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
48        match expr {
49            // IFNULL -> COALESCE in Trino
50            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
51                original_name: None,
52                expressions: vec![f.this, f.expression],
53            }))),
54
55            // NVL -> COALESCE in Trino
56            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
57                original_name: None,
58                expressions: vec![f.this, f.expression],
59            }))),
60
61            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
62            Expression::Coalesce(mut f) => {
63                f.original_name = None;
64                Ok(Expression::Coalesce(f))
65            }
66
67            // TryCast stays as TryCast (Trino supports TRY_CAST)
68            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
69
70            // SafeCast -> TRY_CAST in Trino
71            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
72
73            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
74            Expression::ILike(op) => {
75                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
76                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
77                Ok(Expression::Like(Box::new(LikeOp {
78                    left: lower_left,
79                    right: lower_right,
80                    escape: op.escape,
81                    quantifier: op.quantifier.clone(),
82                })))
83            }
84
85            // CountIf is native in Trino (keep as-is)
86            Expression::CountIf(f) => Ok(Expression::CountIf(f)),
87
88            // EXPLODE -> UNNEST in Trino
89            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
90                crate::expressions::UnnestFunc {
91                    this: f.this,
92                    expressions: Vec::new(),
93                    with_ordinality: false,
94                    alias: None,
95                    offset_alias: None,
96                },
97            ))),
98
99            // ExplodeOuter -> UNNEST in Trino
100            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
101                crate::expressions::UnnestFunc {
102                    this: f.this,
103                    expressions: Vec::new(),
104                    with_ordinality: false,
105                    alias: None,
106                    offset_alias: None,
107                },
108            ))),
109
110            // Generic function transformations
111            Expression::Function(f) => self.transform_function(*f),
112
113            // Generic aggregate function transformations
114            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
115
116            // Cast transformations
117            Expression::Cast(c) => self.transform_cast(*c),
118
119            // TRIM: Convert comma syntax TRIM(str, chars) to SQL standard TRIM(chars FROM str)
120            // Trino requires SQL standard syntax for TRIM with characters
121            Expression::Trim(mut f) => {
122                if !f.sql_standard_syntax && f.characters.is_some() {
123                    // Convert from TRIM(str, chars) to TRIM(chars FROM str)
124                    f.sql_standard_syntax = true;
125                }
126                Ok(Expression::Trim(f))
127            }
128
129            // LISTAGG: Add default separator ',' if none is specified (Trino style)
130            Expression::ListAgg(mut f) => {
131                if f.separator.is_none() {
132                    f.separator = Some(Expression::Literal(Literal::String(",".to_string())));
133                }
134                Ok(Expression::ListAgg(f))
135            }
136
137            // Interval: Split compound string intervals like INTERVAL '1 day' into INTERVAL '1' DAY
138            Expression::Interval(mut interval) => {
139                if interval.unit.is_none() {
140                    if let Some(Expression::Literal(Literal::String(ref s))) = interval.this {
141                        if let Some((value, unit)) = Self::parse_compound_interval(s) {
142                            interval.this = Some(Expression::Literal(Literal::String(value)));
143                            interval.unit = Some(unit);
144                        }
145                    }
146                }
147                Ok(Expression::Interval(interval))
148            }
149
150            // Pass through everything else
151            _ => Ok(expr),
152        }
153    }
154}
155
156impl TrinoDialect {
157    /// Parse a compound interval string like "1 day" into (value, unit_spec).
158    /// Returns None if the string doesn't match a known pattern.
159    fn parse_compound_interval(s: &str) -> Option<(String, IntervalUnitSpec)> {
160        let s = s.trim();
161        let parts: Vec<&str> = s.split_whitespace().collect();
162        if parts.len() != 2 {
163            return None;
164        }
165        let value = parts[0].to_string();
166        let unit = match parts[1].to_uppercase().as_str() {
167            "YEAR" | "YEARS" => IntervalUnit::Year,
168            "MONTH" | "MONTHS" => IntervalUnit::Month,
169            "DAY" | "DAYS" => IntervalUnit::Day,
170            "HOUR" | "HOURS" => IntervalUnit::Hour,
171            "MINUTE" | "MINUTES" => IntervalUnit::Minute,
172            "SECOND" | "SECONDS" => IntervalUnit::Second,
173            "MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
174            "MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
175            _ => return None,
176        };
177        Some((
178            value,
179            IntervalUnitSpec::Simple {
180                unit,
181                use_plural: false,
182            },
183        ))
184    }
185
186    fn transform_function(&self, f: Function) -> Result<Expression> {
187        let name_upper = f.name.to_uppercase();
188        match name_upper.as_str() {
189            // IFNULL -> COALESCE
190            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
191                original_name: None,
192                expressions: f.args,
193            }))),
194
195            // NVL -> COALESCE
196            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
197                original_name: None,
198                expressions: f.args,
199            }))),
200
201            // ISNULL -> COALESCE
202            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
203                original_name: None,
204                expressions: f.args,
205            }))),
206
207            // GETDATE -> CURRENT_TIMESTAMP
208            "GETDATE" => Ok(Expression::CurrentTimestamp(
209                crate::expressions::CurrentTimestamp {
210                    precision: None,
211                    sysdate: false,
212                },
213            )),
214
215            // NOW -> CURRENT_TIMESTAMP
216            "NOW" => Ok(Expression::CurrentTimestamp(
217                crate::expressions::CurrentTimestamp {
218                    precision: None,
219                    sysdate: false,
220                },
221            )),
222
223            // RAND -> RANDOM in Trino
224            "RAND" => Ok(Expression::Function(Box::new(Function::new(
225                "RANDOM".to_string(),
226                vec![],
227            )))),
228
229            // GROUP_CONCAT -> LISTAGG in Trino (Trino supports LISTAGG)
230            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
231                Function::new("LISTAGG".to_string(), f.args),
232            ))),
233
234            // STRING_AGG -> LISTAGG in Trino
235            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
236                Function::new("LISTAGG".to_string(), f.args),
237            ))),
238
239            // LISTAGG is native in Trino
240            "LISTAGG" => Ok(Expression::Function(Box::new(f))),
241
242            // SUBSTR -> SUBSTRING
243            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
244                "SUBSTRING".to_string(),
245                f.args,
246            )))),
247
248            // LEN -> LENGTH
249            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
250                f.args.into_iter().next().unwrap(),
251            )))),
252
253            // CHARINDEX -> STRPOS in Trino (with swapped args)
254            "CHARINDEX" if f.args.len() >= 2 => {
255                let mut args = f.args;
256                let substring = args.remove(0);
257                let string = args.remove(0);
258                Ok(Expression::Function(Box::new(Function::new(
259                    "STRPOS".to_string(),
260                    vec![string, substring],
261                ))))
262            }
263
264            // INSTR -> STRPOS
265            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
266                "STRPOS".to_string(),
267                f.args,
268            )))),
269
270            // LOCATE -> STRPOS in Trino (with swapped args)
271            "LOCATE" if f.args.len() >= 2 => {
272                let mut args = f.args;
273                let substring = args.remove(0);
274                let string = args.remove(0);
275                Ok(Expression::Function(Box::new(Function::new(
276                    "STRPOS".to_string(),
277                    vec![string, substring],
278                ))))
279            }
280
281            // ARRAY_LENGTH -> CARDINALITY in Trino
282            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
283                Function::new("CARDINALITY".to_string(), f.args),
284            ))),
285
286            // SIZE -> CARDINALITY in Trino
287            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
288                "CARDINALITY".to_string(),
289                f.args,
290            )))),
291
292            // ARRAY_CONTAINS -> CONTAINS in Trino
293            "ARRAY_CONTAINS" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
294                Function::new("CONTAINS".to_string(), f.args),
295            ))),
296
297            // TO_DATE -> CAST to DATE or DATE_PARSE
298            "TO_DATE" if !f.args.is_empty() => {
299                if f.args.len() == 1 {
300                    Ok(Expression::Cast(Box::new(Cast {
301                        this: f.args.into_iter().next().unwrap(),
302                        to: DataType::Date,
303                        trailing_comments: Vec::new(),
304                        double_colon_syntax: false,
305                        format: None,
306                        default: None,
307                    })))
308                } else {
309                    Ok(Expression::Function(Box::new(Function::new(
310                        "DATE_PARSE".to_string(),
311                        f.args,
312                    ))))
313                }
314            }
315
316            // TO_TIMESTAMP -> CAST or DATE_PARSE
317            "TO_TIMESTAMP" if !f.args.is_empty() => {
318                if f.args.len() == 1 {
319                    Ok(Expression::Cast(Box::new(Cast {
320                        this: f.args.into_iter().next().unwrap(),
321                        to: DataType::Timestamp {
322                            precision: None,
323                            timezone: false,
324                        },
325                        trailing_comments: Vec::new(),
326                        double_colon_syntax: false,
327                        format: None,
328                        default: None,
329                    })))
330                } else {
331                    Ok(Expression::Function(Box::new(Function::new(
332                        "DATE_PARSE".to_string(),
333                        f.args,
334                    ))))
335                }
336            }
337
338            // strftime -> DATE_FORMAT in Trino
339            "STRFTIME" if f.args.len() >= 2 => {
340                let mut args = f.args;
341                let format = args.remove(0);
342                let date = args.remove(0);
343                Ok(Expression::Function(Box::new(Function::new(
344                    "DATE_FORMAT".to_string(),
345                    vec![date, format],
346                ))))
347            }
348
349            // TO_CHAR -> DATE_FORMAT in Trino
350            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
351                "DATE_FORMAT".to_string(),
352                f.args,
353            )))),
354
355            // LEVENSHTEIN -> LEVENSHTEIN_DISTANCE in Trino
356            "LEVENSHTEIN" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
357                Function::new("LEVENSHTEIN_DISTANCE".to_string(), f.args),
358            ))),
359
360            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Trino
361            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
362                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
363            ))),
364
365            // COLLECT_LIST -> ARRAY_AGG
366            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
367                Function::new("ARRAY_AGG".to_string(), f.args),
368            ))),
369
370            // COLLECT_SET -> ARRAY_DISTINCT(ARRAY_AGG())
371            "COLLECT_SET" if !f.args.is_empty() => {
372                let array_agg =
373                    Expression::Function(Box::new(Function::new("ARRAY_AGG".to_string(), f.args)));
374                Ok(Expression::Function(Box::new(Function::new(
375                    "ARRAY_DISTINCT".to_string(),
376                    vec![array_agg],
377                ))))
378            }
379
380            // RLIKE -> REGEXP_LIKE in Trino
381            "RLIKE" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
382                "REGEXP_LIKE".to_string(),
383                f.args,
384            )))),
385
386            // REGEXP -> REGEXP_LIKE in Trino
387            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
388                "REGEXP_LIKE".to_string(),
389                f.args,
390            )))),
391
392            // ARRAY_SUM -> REDUCE in Trino (complex transformation)
393            // For simplicity, we'll use a different approach
394            "ARRAY_SUM" if f.args.len() == 1 => {
395                // This is a complex transformation in Presto/Trino
396                // ARRAY_SUM(arr) -> REDUCE(arr, 0, (s, x) -> s + x, s -> s)
397                // For now, pass through and let user handle it
398                Ok(Expression::Function(Box::new(f)))
399            }
400
401            // Pass through everything else
402            _ => Ok(Expression::Function(Box::new(f))),
403        }
404    }
405
406    fn transform_aggregate_function(
407        &self,
408        f: Box<crate::expressions::AggregateFunction>,
409    ) -> Result<Expression> {
410        let name_upper = f.name.to_uppercase();
411        match name_upper.as_str() {
412            // COUNT_IF -> SUM(CASE WHEN...)
413            "COUNT_IF" if !f.args.is_empty() => {
414                let condition = f.args.into_iter().next().unwrap();
415                let case_expr = Expression::Case(Box::new(Case {
416                    operand: None,
417                    whens: vec![(condition, Expression::number(1))],
418                    else_: Some(Expression::number(0)),
419                    comments: Vec::new(),
420                }));
421                Ok(Expression::Sum(Box::new(AggFunc {
422                    ignore_nulls: None,
423                    having_max: None,
424                    this: case_expr,
425                    distinct: f.distinct,
426                    filter: f.filter,
427                    order_by: Vec::new(),
428                    name: None,
429                    limit: None,
430                })))
431            }
432
433            // ANY_VALUE -> ARBITRARY in Trino
434            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
435                "ARBITRARY".to_string(),
436                f.args,
437            )))),
438
439            // GROUP_CONCAT -> LISTAGG in Trino
440            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
441                Function::new("LISTAGG".to_string(), f.args),
442            ))),
443
444            // STRING_AGG -> LISTAGG in Trino
445            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
446                Function::new("LISTAGG".to_string(), f.args),
447            ))),
448
449            // VAR -> VAR_POP in Trino
450            "VAR" if !f.args.is_empty() => {
451                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
452                    name: "VAR_POP".to_string(),
453                    args: f.args,
454                    distinct: f.distinct,
455                    filter: f.filter,
456                    order_by: Vec::new(),
457                    limit: None,
458                    ignore_nulls: None,
459                })))
460            }
461
462            // VARIANCE -> VAR_SAMP in Trino
463            "VARIANCE" if !f.args.is_empty() => {
464                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
465                    name: "VAR_SAMP".to_string(),
466                    args: f.args,
467                    distinct: f.distinct,
468                    filter: f.filter,
469                    order_by: Vec::new(),
470                    limit: None,
471                    ignore_nulls: None,
472                })))
473            }
474
475            // Pass through everything else
476            _ => Ok(Expression::AggregateFunction(f)),
477        }
478    }
479
480    fn transform_cast(&self, c: Cast) -> Result<Expression> {
481        // Trino type mappings are handled in the generator
482        Ok(Expression::Cast(Box::new(c)))
483    }
484}