1use super::{DialectImpl, DialectType};
12use crate::error::Result;
13use crate::expressions::{
14 AggFunc, Case, Cast, DataType, Expression, Function, LikeOp, UnaryFunc, VarArgFunc,
15};
16use crate::generator::{GeneratorConfig, IdentifierQuoteStyle};
17use crate::tokens::TokenizerConfig;
18
19pub struct AthenaDialect;
21
22impl DialectImpl for AthenaDialect {
23 fn dialect_type(&self) -> DialectType {
24 DialectType::Athena
25 }
26
27 fn tokenizer_config(&self) -> TokenizerConfig {
28 let mut config = TokenizerConfig::default();
29 config.identifiers.insert('"', '"');
31 config.identifiers.insert('`', '`');
33 config.nested_comments = false;
34 config.string_escapes.push('\\');
36 config
37 }
38
39 fn generator_config(&self) -> GeneratorConfig {
40 GeneratorConfig {
42 identifier_quote: '"',
43 identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
44 dialect: Some(DialectType::Athena),
45 schema_comment_with_eq: false,
46 ..Default::default()
47 }
48 }
49
50 fn generator_config_for_expr(&self, expr: &Expression) -> GeneratorConfig {
51 if should_use_hive_engine(expr) {
52 GeneratorConfig {
54 identifier_quote: '`',
55 identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
56 dialect: Some(DialectType::Athena),
57 schema_comment_with_eq: false,
58 ..Default::default()
59 }
60 } else {
61 GeneratorConfig {
63 identifier_quote: '"',
64 identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
65 dialect: Some(DialectType::Athena),
66 schema_comment_with_eq: false,
67 ..Default::default()
68 }
69 }
70 }
71
72 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
73 match expr {
74 Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
76 original_name: None,
77 expressions: vec![f.this, f.expression],
78 inferred_type: None,
79 }))),
80
81 Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
83 original_name: None,
84 expressions: vec![f.this, f.expression],
85 inferred_type: None,
86 }))),
87
88 Expression::Coalesce(mut f) => {
90 f.original_name = None;
91 Ok(Expression::Coalesce(f))
92 }
93
94 Expression::TryCast(c) => Ok(Expression::TryCast(c)),
96
97 Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
99
100 Expression::ILike(op) => {
102 let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
103 let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
104 Ok(Expression::Like(Box::new(LikeOp {
105 left: lower_left,
106 right: lower_right,
107 escape: op.escape,
108 quantifier: op.quantifier.clone(),
109 inferred_type: None,
110 })))
111 }
112
113 Expression::CountIf(f) => {
115 let case_expr = Expression::Case(Box::new(Case {
116 operand: None,
117 whens: vec![(f.this.clone(), Expression::number(1))],
118 else_: Some(Expression::number(0)),
119 comments: Vec::new(),
120 inferred_type: None,
121 }));
122 Ok(Expression::Sum(Box::new(AggFunc {
123 ignore_nulls: None,
124 having_max: None,
125 this: case_expr,
126 distinct: f.distinct,
127 filter: f.filter,
128 order_by: Vec::new(),
129 name: None,
130 limit: None,
131 inferred_type: None,
132 })))
133 }
134
135 Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
137 crate::expressions::UnnestFunc {
138 this: f.this,
139 expressions: Vec::new(),
140 with_ordinality: false,
141 alias: None,
142 offset_alias: None,
143 },
144 ))),
145
146 Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
148 crate::expressions::UnnestFunc {
149 this: f.this,
150 expressions: Vec::new(),
151 with_ordinality: false,
152 alias: None,
153 offset_alias: None,
154 },
155 ))),
156
157 Expression::Function(f) => self.transform_function(*f),
159
160 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
162
163 Expression::Cast(c) => self.transform_cast(*c),
165
166 _ => Ok(expr),
168 }
169 }
170}
171
172impl AthenaDialect {
173 fn transform_function(&self, f: Function) -> Result<Expression> {
174 let name_upper = f.name.to_uppercase();
175 match name_upper.as_str() {
176 "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
178 original_name: None,
179 expressions: f.args,
180 inferred_type: None,
181 }))),
182
183 "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
185 original_name: None,
186 expressions: f.args,
187 inferred_type: None,
188 }))),
189
190 "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
192 original_name: None,
193 expressions: f.args,
194 inferred_type: None,
195 }))),
196
197 "GETDATE" => Ok(Expression::CurrentTimestamp(
199 crate::expressions::CurrentTimestamp {
200 precision: None,
201 sysdate: false,
202 },
203 )),
204
205 "NOW" => Ok(Expression::CurrentTimestamp(
207 crate::expressions::CurrentTimestamp {
208 precision: None,
209 sysdate: false,
210 },
211 )),
212
213 "RAND" => Ok(Expression::Function(Box::new(Function::new(
215 "RANDOM".to_string(),
216 vec![],
217 )))),
218
219 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
221 Function::new("LISTAGG".to_string(), f.args),
222 ))),
223
224 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
226 Function::new("LISTAGG".to_string(), f.args),
227 ))),
228
229 "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
231 "SUBSTRING".to_string(),
232 f.args,
233 )))),
234
235 "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
237 f.args.into_iter().next().unwrap(),
238 )))),
239
240 "CHARINDEX" if f.args.len() >= 2 => {
242 let mut args = f.args;
243 let substring = args.remove(0);
244 let string = args.remove(0);
245 Ok(Expression::Function(Box::new(Function::new(
246 "STRPOS".to_string(),
247 vec![string, substring],
248 ))))
249 }
250
251 "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
253 "STRPOS".to_string(),
254 f.args,
255 )))),
256
257 "LOCATE" if f.args.len() >= 2 => {
259 let mut args = f.args;
260 let substring = args.remove(0);
261 let string = args.remove(0);
262 Ok(Expression::Function(Box::new(Function::new(
263 "STRPOS".to_string(),
264 vec![string, substring],
265 ))))
266 }
267
268 "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
270 Function::new("CARDINALITY".to_string(), f.args),
271 ))),
272
273 "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
275 "CARDINALITY".to_string(),
276 f.args,
277 )))),
278
279 "TO_DATE" if !f.args.is_empty() => {
281 if f.args.len() == 1 {
282 Ok(Expression::Cast(Box::new(Cast {
283 this: f.args.into_iter().next().unwrap(),
284 to: DataType::Date,
285 trailing_comments: Vec::new(),
286 double_colon_syntax: false,
287 format: None,
288 default: None,
289 inferred_type: None,
290 })))
291 } else {
292 Ok(Expression::Function(Box::new(Function::new(
293 "DATE_PARSE".to_string(),
294 f.args,
295 ))))
296 }
297 }
298
299 "TO_TIMESTAMP" if !f.args.is_empty() => {
301 if f.args.len() == 1 {
302 Ok(Expression::Cast(Box::new(Cast {
303 this: f.args.into_iter().next().unwrap(),
304 to: DataType::Timestamp {
305 precision: None,
306 timezone: false,
307 },
308 trailing_comments: Vec::new(),
309 double_colon_syntax: false,
310 format: None,
311 default: None,
312 inferred_type: None,
313 })))
314 } else {
315 Ok(Expression::Function(Box::new(Function::new(
316 "DATE_PARSE".to_string(),
317 f.args,
318 ))))
319 }
320 }
321
322 "STRFTIME" if f.args.len() >= 2 => {
324 let mut args = f.args;
325 let format = args.remove(0);
326 let date = args.remove(0);
327 Ok(Expression::Function(Box::new(Function::new(
328 "DATE_FORMAT".to_string(),
329 vec![date, format],
330 ))))
331 }
332
333 "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
335 "DATE_FORMAT".to_string(),
336 f.args,
337 )))),
338
339 "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
341 Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
342 ))),
343
344 "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
346 Function::new("ARRAY_AGG".to_string(), f.args),
347 ))),
348
349 _ => Ok(Expression::Function(Box::new(f))),
351 }
352 }
353
354 fn transform_aggregate_function(
355 &self,
356 f: Box<crate::expressions::AggregateFunction>,
357 ) -> Result<Expression> {
358 let name_upper = f.name.to_uppercase();
359 match name_upper.as_str() {
360 "COUNT_IF" if !f.args.is_empty() => {
362 let condition = f.args.into_iter().next().unwrap();
363 let case_expr = Expression::Case(Box::new(Case {
364 operand: None,
365 whens: vec![(condition, Expression::number(1))],
366 else_: Some(Expression::number(0)),
367 comments: Vec::new(),
368 inferred_type: None,
369 }));
370 Ok(Expression::Sum(Box::new(AggFunc {
371 ignore_nulls: None,
372 having_max: None,
373 this: case_expr,
374 distinct: f.distinct,
375 filter: f.filter,
376 order_by: Vec::new(),
377 name: None,
378 limit: None,
379 inferred_type: None,
380 })))
381 }
382
383 "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
385 "ARBITRARY".to_string(),
386 f.args,
387 )))),
388
389 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
391 Function::new("LISTAGG".to_string(), f.args),
392 ))),
393
394 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
396 Function::new("LISTAGG".to_string(), f.args),
397 ))),
398
399 _ => Ok(Expression::AggregateFunction(f)),
401 }
402 }
403
404 fn transform_cast(&self, c: Cast) -> Result<Expression> {
405 Ok(Expression::Cast(Box::new(c)))
407 }
408}
409
410fn should_use_hive_engine(expr: &Expression) -> bool {
423 match expr {
424 Expression::CreateTable(ct) => {
426 if let Some(ref modifier) = ct.table_modifier {
428 if modifier.to_uppercase() == "EXTERNAL" {
429 return true;
430 }
431 }
432 ct.as_select.is_none()
435 }
436
437 Expression::CreateView(_) => false,
439
440 Expression::CreateSchema(_) => true,
442 Expression::CreateDatabase(_) => true,
443
444 Expression::AlterTable(_) => true,
446 Expression::AlterView(_) => true,
447 Expression::AlterIndex(_) => true,
448 Expression::AlterSequence(_) => true,
449
450 Expression::DropView(_) => false,
452
453 Expression::DropTable(_) => true,
455 Expression::DropSchema(_) => true,
456 Expression::DropDatabase(_) => true,
457 Expression::DropIndex(_) => true,
458 Expression::DropFunction(_) => true,
459 Expression::DropProcedure(_) => true,
460 Expression::DropSequence(_) => true,
461
462 Expression::Describe(_) => true,
464 Expression::Show(_) => true,
465
466 _ => false,
468 }
469}