1use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{
9 AggFunc, AggregateFunction, Case, Cast, DataType, Expression, Function, IntervalUnit,
10 IntervalUnitSpec, LikeOp, Literal, UnaryFunc, VarArgFunc,
11};
12#[cfg(feature = "generate")]
13use crate::generator::GeneratorConfig;
14use crate::tokens::TokenizerConfig;
15
16pub struct TrinoDialect;
18
19impl DialectImpl for TrinoDialect {
20 fn dialect_type(&self) -> DialectType {
21 DialectType::Trino
22 }
23
24 fn tokenizer_config(&self) -> TokenizerConfig {
25 let mut config = TokenizerConfig::default();
26 config.identifiers.insert('"', '"');
28 config.nested_comments = false;
30 config.keywords.remove("QUALIFY");
33 config
34 }
35
36 #[cfg(feature = "generate")]
37
38 fn generator_config(&self) -> GeneratorConfig {
39 use crate::generator::IdentifierQuoteStyle;
40 GeneratorConfig {
41 identifier_quote: '"',
42 identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
43 dialect: Some(DialectType::Trino),
44 limit_only_literals: true,
45 tz_to_with_time_zone: true,
46 ..Default::default()
47 }
48 }
49
50 #[cfg(feature = "transpile")]
51
52 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
53 match expr {
54 Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
56 original_name: None,
57 expressions: vec![f.this, f.expression],
58 inferred_type: None,
59 }))),
60
61 Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
63 original_name: None,
64 expressions: vec![f.this, f.expression],
65 inferred_type: None,
66 }))),
67
68 Expression::Coalesce(mut f) => {
70 f.original_name = None;
71 Ok(Expression::Coalesce(f))
72 }
73
74 Expression::TryCast(c) => Ok(Expression::TryCast(c)),
76
77 Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
79
80 Expression::ILike(op) => {
82 let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
83 let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
84 Ok(Expression::Like(Box::new(LikeOp {
85 left: lower_left,
86 right: lower_right,
87 escape: op.escape,
88 quantifier: op.quantifier.clone(),
89 inferred_type: None,
90 })))
91 }
92
93 Expression::CountIf(f) => Ok(Expression::CountIf(f)),
95
96 Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
98 crate::expressions::UnnestFunc {
99 this: f.this,
100 expressions: Vec::new(),
101 with_ordinality: false,
102 alias: None,
103 offset_alias: None,
104 },
105 ))),
106
107 Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
109 crate::expressions::UnnestFunc {
110 this: f.this,
111 expressions: Vec::new(),
112 with_ordinality: false,
113 alias: None,
114 offset_alias: None,
115 },
116 ))),
117
118 Expression::Function(f) => self.transform_function(*f),
120
121 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
123
124 Expression::Cast(c) => self.transform_cast(*c),
126
127 Expression::Trim(mut f) => {
130 if !f.sql_standard_syntax && f.characters.is_some() {
131 f.sql_standard_syntax = true;
133 }
134 Ok(Expression::Trim(f))
135 }
136
137 Expression::ListAgg(mut f) => {
139 if f.separator.is_none() {
140 f.separator = Some(Expression::Literal(Box::new(Literal::String(
141 ",".to_string(),
142 ))));
143 }
144 Ok(Expression::ListAgg(f))
145 }
146
147 Expression::Interval(mut interval) => {
149 if interval.unit.is_none() {
150 if let Some(Expression::Literal(ref lit)) = interval.this {
151 if let Literal::String(ref s) = lit.as_ref() {
152 if let Some((value, unit)) = Self::parse_compound_interval(s) {
153 interval.this =
154 Some(Expression::Literal(Box::new(Literal::String(value))));
155 interval.unit = Some(unit);
156 }
157 }
158 }
159 }
160 Ok(Expression::Interval(interval))
161 }
162
163 _ => Ok(expr),
165 }
166 }
167}
168
169#[cfg(feature = "transpile")]
170impl TrinoDialect {
171 fn parse_compound_interval(s: &str) -> Option<(String, IntervalUnitSpec)> {
174 let s = s.trim();
175 let parts: Vec<&str> = s.split_whitespace().collect();
176 if parts.len() != 2 {
177 return None;
178 }
179 let value = parts[0].to_string();
180 let unit = match parts[1].to_uppercase().as_str() {
181 "YEAR" | "YEARS" => IntervalUnit::Year,
182 "MONTH" | "MONTHS" => IntervalUnit::Month,
183 "DAY" | "DAYS" => IntervalUnit::Day,
184 "HOUR" | "HOURS" => IntervalUnit::Hour,
185 "MINUTE" | "MINUTES" => IntervalUnit::Minute,
186 "SECOND" | "SECONDS" => IntervalUnit::Second,
187 "MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
188 "MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
189 _ => return None,
190 };
191 Some((
192 value,
193 IntervalUnitSpec::Simple {
194 unit,
195 use_plural: false,
196 },
197 ))
198 }
199
200 fn transform_function(&self, f: Function) -> Result<Expression> {
201 let name_upper = f.name.to_uppercase();
202 match name_upper.as_str() {
203 "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
205 original_name: None,
206 expressions: f.args,
207 inferred_type: None,
208 }))),
209
210 "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
212 original_name: None,
213 expressions: f.args,
214 inferred_type: None,
215 }))),
216
217 "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
219 original_name: None,
220 expressions: f.args,
221 inferred_type: None,
222 }))),
223
224 "GETDATE" => Ok(Expression::CurrentTimestamp(
226 crate::expressions::CurrentTimestamp {
227 precision: None,
228 sysdate: false,
229 },
230 )),
231
232 "NOW" => Ok(Expression::CurrentTimestamp(
234 crate::expressions::CurrentTimestamp {
235 precision: None,
236 sysdate: false,
237 },
238 )),
239
240 "RAND" => Ok(Expression::Function(Box::new(Function::new(
242 "RANDOM".to_string(),
243 vec![],
244 )))),
245
246 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
248 Function::new("LISTAGG".to_string(), f.args),
249 ))),
250
251 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
253 Function::new("LISTAGG".to_string(), f.args),
254 ))),
255
256 "LISTAGG" => Ok(Expression::Function(Box::new(f))),
258
259 "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
261 "SUBSTRING".to_string(),
262 f.args,
263 )))),
264
265 "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
267 f.args.into_iter().next().unwrap(),
268 )))),
269
270 "CHARINDEX" if f.args.len() >= 2 => {
272 let mut args = f.args;
273 let substring = args.remove(0);
274 let string = args.remove(0);
275 Ok(Expression::Function(Box::new(Function::new(
276 "STRPOS".to_string(),
277 vec![string, substring],
278 ))))
279 }
280
281 "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
283 "STRPOS".to_string(),
284 f.args,
285 )))),
286
287 "LOCATE" if f.args.len() >= 2 => {
289 let mut args = f.args;
290 let substring = args.remove(0);
291 let string = args.remove(0);
292 Ok(Expression::Function(Box::new(Function::new(
293 "STRPOS".to_string(),
294 vec![string, substring],
295 ))))
296 }
297
298 "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
300 Function::new("CARDINALITY".to_string(), f.args),
301 ))),
302
303 "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
305 "CARDINALITY".to_string(),
306 f.args,
307 )))),
308
309 "ARRAY_CONTAINS" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
311 Function::new("CONTAINS".to_string(), f.args),
312 ))),
313
314 "TO_DATE" if !f.args.is_empty() => {
316 if f.args.len() == 1 {
317 Ok(Expression::Cast(Box::new(Cast {
318 this: f.args.into_iter().next().unwrap(),
319 to: DataType::Date,
320 trailing_comments: Vec::new(),
321 double_colon_syntax: false,
322 format: None,
323 default: None,
324 inferred_type: None,
325 })))
326 } else {
327 Ok(Expression::Function(Box::new(Function::new(
328 "DATE_PARSE".to_string(),
329 f.args,
330 ))))
331 }
332 }
333
334 "TO_TIMESTAMP" if !f.args.is_empty() => {
336 if f.args.len() == 1 {
337 Ok(Expression::Cast(Box::new(Cast {
338 this: f.args.into_iter().next().unwrap(),
339 to: DataType::Timestamp {
340 precision: None,
341 timezone: false,
342 },
343 trailing_comments: Vec::new(),
344 double_colon_syntax: false,
345 format: None,
346 default: None,
347 inferred_type: None,
348 })))
349 } else {
350 Ok(Expression::Function(Box::new(Function::new(
351 "DATE_PARSE".to_string(),
352 f.args,
353 ))))
354 }
355 }
356
357 "STRFTIME" if f.args.len() >= 2 => {
359 let mut args = f.args;
360 let format = args.remove(0);
361 let date = args.remove(0);
362 Ok(Expression::Function(Box::new(Function::new(
363 "DATE_FORMAT".to_string(),
364 vec![date, format],
365 ))))
366 }
367
368 "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
370 "DATE_FORMAT".to_string(),
371 f.args,
372 )))),
373
374 "LEVENSHTEIN" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
376 Function::new("LEVENSHTEIN_DISTANCE".to_string(), f.args),
377 ))),
378
379 "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
381 Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
382 ))),
383
384 "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
386 Function::new("ARRAY_AGG".to_string(), f.args),
387 ))),
388
389 "COLLECT_SET" if !f.args.is_empty() => {
391 let array_agg =
392 Expression::Function(Box::new(Function::new("ARRAY_AGG".to_string(), f.args)));
393 Ok(Expression::Function(Box::new(Function::new(
394 "ARRAY_DISTINCT".to_string(),
395 vec![array_agg],
396 ))))
397 }
398
399 "RLIKE" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
401 "REGEXP_LIKE".to_string(),
402 f.args,
403 )))),
404
405 "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
407 "REGEXP_LIKE".to_string(),
408 f.args,
409 )))),
410
411 "ARRAY_SUM" if f.args.len() == 1 => {
414 Ok(Expression::Function(Box::new(f)))
418 }
419
420 _ => Ok(Expression::Function(Box::new(f))),
422 }
423 }
424
425 fn transform_aggregate_function(
426 &self,
427 f: Box<crate::expressions::AggregateFunction>,
428 ) -> Result<Expression> {
429 let name_upper = f.name.to_uppercase();
430 match name_upper.as_str() {
431 "COUNT_IF" if !f.args.is_empty() => {
433 let condition = f.args.into_iter().next().unwrap();
434 let case_expr = Expression::Case(Box::new(Case {
435 operand: None,
436 whens: vec![(condition, Expression::number(1))],
437 else_: Some(Expression::number(0)),
438 comments: Vec::new(),
439 inferred_type: None,
440 }));
441 Ok(Expression::Sum(Box::new(AggFunc {
442 ignore_nulls: None,
443 having_max: None,
444 this: case_expr,
445 distinct: f.distinct,
446 filter: f.filter,
447 order_by: Vec::new(),
448 name: None,
449 limit: None,
450 inferred_type: None,
451 })))
452 }
453
454 "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
456 "ARBITRARY".to_string(),
457 f.args,
458 )))),
459
460 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
462 Function::new("LISTAGG".to_string(), f.args),
463 ))),
464
465 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
467 Function::new("LISTAGG".to_string(), f.args),
468 ))),
469
470 "VAR" if !f.args.is_empty() => {
472 Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
473 name: "VAR_POP".to_string(),
474 args: f.args,
475 distinct: f.distinct,
476 filter: f.filter,
477 order_by: Vec::new(),
478 limit: None,
479 ignore_nulls: None,
480 inferred_type: None,
481 })))
482 }
483
484 "VARIANCE" if !f.args.is_empty() => {
486 Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
487 name: "VAR_SAMP".to_string(),
488 args: f.args,
489 distinct: f.distinct,
490 filter: f.filter,
491 order_by: Vec::new(),
492 limit: None,
493 ignore_nulls: None,
494 inferred_type: None,
495 })))
496 }
497
498 _ => Ok(Expression::AggregateFunction(f)),
500 }
501 }
502
503 fn transform_cast(&self, c: Cast) -> Result<Expression> {
504 Ok(Expression::Cast(Box::new(c)))
506 }
507}