1use super::{DialectImpl, DialectType};
12use crate::error::Result;
13use crate::expressions::{
14 AggFunc, Case, Cast, DataType, Expression, Function, LikeOp, UnaryFunc, VarArgFunc,
15};
16use crate::generator::{GeneratorConfig, IdentifierQuoteStyle};
17use crate::tokens::TokenizerConfig;
18
19pub struct AthenaDialect;
21
22impl DialectImpl for AthenaDialect {
23 fn dialect_type(&self) -> DialectType {
24 DialectType::Athena
25 }
26
27 fn tokenizer_config(&self) -> TokenizerConfig {
28 let mut config = TokenizerConfig::default();
29 config.identifiers.insert('"', '"');
31 config.identifiers.insert('`', '`');
33 config.nested_comments = false;
34 config.string_escapes.push('\\');
36 config
37 }
38
39 fn generator_config(&self) -> GeneratorConfig {
40 GeneratorConfig {
42 identifier_quote: '"',
43 identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
44 dialect: Some(DialectType::Athena),
45 schema_comment_with_eq: false,
46 ..Default::default()
47 }
48 }
49
50 fn generator_config_for_expr(&self, expr: &Expression) -> GeneratorConfig {
51 if should_use_hive_engine(expr) {
52 GeneratorConfig {
54 identifier_quote: '`',
55 identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
56 dialect: Some(DialectType::Athena),
57 schema_comment_with_eq: false,
58 ..Default::default()
59 }
60 } else {
61 GeneratorConfig {
63 identifier_quote: '"',
64 identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
65 dialect: Some(DialectType::Athena),
66 schema_comment_with_eq: false,
67 ..Default::default()
68 }
69 }
70 }
71
72 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
73 match expr {
74 Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
76 original_name: None,
77 expressions: vec![f.this, f.expression],
78 }))),
79
80 Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
82 original_name: None,
83 expressions: vec![f.this, f.expression],
84 }))),
85
86 Expression::Coalesce(mut f) => {
88 f.original_name = None;
89 Ok(Expression::Coalesce(f))
90 }
91
92 Expression::TryCast(c) => Ok(Expression::TryCast(c)),
94
95 Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
97
98 Expression::ILike(op) => {
100 let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
101 let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
102 Ok(Expression::Like(Box::new(LikeOp {
103 left: lower_left,
104 right: lower_right,
105 escape: op.escape,
106 quantifier: op.quantifier.clone(),
107 })))
108 }
109
110 Expression::CountIf(f) => {
112 let case_expr = Expression::Case(Box::new(Case {
113 operand: None,
114 whens: vec![(f.this.clone(), Expression::number(1))],
115 else_: Some(Expression::number(0)),
116 comments: Vec::new(),
117 }));
118 Ok(Expression::Sum(Box::new(AggFunc {
119 ignore_nulls: None,
120 having_max: None,
121 this: case_expr,
122 distinct: f.distinct,
123 filter: f.filter,
124 order_by: Vec::new(),
125 name: None,
126 limit: None,
127 })))
128 }
129
130 Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
132 crate::expressions::UnnestFunc {
133 this: f.this,
134 expressions: Vec::new(),
135 with_ordinality: false,
136 alias: None,
137 offset_alias: None,
138 },
139 ))),
140
141 Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
143 crate::expressions::UnnestFunc {
144 this: f.this,
145 expressions: Vec::new(),
146 with_ordinality: false,
147 alias: None,
148 offset_alias: None,
149 },
150 ))),
151
152 Expression::Function(f) => self.transform_function(*f),
154
155 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
157
158 Expression::Cast(c) => self.transform_cast(*c),
160
161 _ => Ok(expr),
163 }
164 }
165}
166
167impl AthenaDialect {
168 fn transform_function(&self, f: Function) -> Result<Expression> {
169 let name_upper = f.name.to_uppercase();
170 match name_upper.as_str() {
171 "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
173 original_name: None,
174 expressions: f.args,
175 }))),
176
177 "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
179 original_name: None,
180 expressions: f.args,
181 }))),
182
183 "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
185 original_name: None,
186 expressions: f.args,
187 }))),
188
189 "GETDATE" => Ok(Expression::CurrentTimestamp(
191 crate::expressions::CurrentTimestamp {
192 precision: None,
193 sysdate: false,
194 },
195 )),
196
197 "NOW" => Ok(Expression::CurrentTimestamp(
199 crate::expressions::CurrentTimestamp {
200 precision: None,
201 sysdate: false,
202 },
203 )),
204
205 "RAND" => Ok(Expression::Function(Box::new(Function::new(
207 "RANDOM".to_string(),
208 vec![],
209 )))),
210
211 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
213 Function::new("LISTAGG".to_string(), f.args),
214 ))),
215
216 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
218 Function::new("LISTAGG".to_string(), f.args),
219 ))),
220
221 "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
223 "SUBSTRING".to_string(),
224 f.args,
225 )))),
226
227 "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
229 f.args.into_iter().next().unwrap(),
230 )))),
231
232 "CHARINDEX" if f.args.len() >= 2 => {
234 let mut args = f.args;
235 let substring = args.remove(0);
236 let string = args.remove(0);
237 Ok(Expression::Function(Box::new(Function::new(
238 "STRPOS".to_string(),
239 vec![string, substring],
240 ))))
241 }
242
243 "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
245 "STRPOS".to_string(),
246 f.args,
247 )))),
248
249 "LOCATE" if f.args.len() >= 2 => {
251 let mut args = f.args;
252 let substring = args.remove(0);
253 let string = args.remove(0);
254 Ok(Expression::Function(Box::new(Function::new(
255 "STRPOS".to_string(),
256 vec![string, substring],
257 ))))
258 }
259
260 "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
262 Function::new("CARDINALITY".to_string(), f.args),
263 ))),
264
265 "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
267 "CARDINALITY".to_string(),
268 f.args,
269 )))),
270
271 "TO_DATE" if !f.args.is_empty() => {
273 if f.args.len() == 1 {
274 Ok(Expression::Cast(Box::new(Cast {
275 this: f.args.into_iter().next().unwrap(),
276 to: DataType::Date,
277 trailing_comments: Vec::new(),
278 double_colon_syntax: false,
279 format: None,
280 default: None,
281 })))
282 } else {
283 Ok(Expression::Function(Box::new(Function::new(
284 "DATE_PARSE".to_string(),
285 f.args,
286 ))))
287 }
288 }
289
290 "TO_TIMESTAMP" if !f.args.is_empty() => {
292 if f.args.len() == 1 {
293 Ok(Expression::Cast(Box::new(Cast {
294 this: f.args.into_iter().next().unwrap(),
295 to: DataType::Timestamp {
296 precision: None,
297 timezone: false,
298 },
299 trailing_comments: Vec::new(),
300 double_colon_syntax: false,
301 format: None,
302 default: None,
303 })))
304 } else {
305 Ok(Expression::Function(Box::new(Function::new(
306 "DATE_PARSE".to_string(),
307 f.args,
308 ))))
309 }
310 }
311
312 "STRFTIME" if f.args.len() >= 2 => {
314 let mut args = f.args;
315 let format = args.remove(0);
316 let date = args.remove(0);
317 Ok(Expression::Function(Box::new(Function::new(
318 "DATE_FORMAT".to_string(),
319 vec![date, format],
320 ))))
321 }
322
323 "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
325 "DATE_FORMAT".to_string(),
326 f.args,
327 )))),
328
329 "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
331 Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
332 ))),
333
334 "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
336 Function::new("ARRAY_AGG".to_string(), f.args),
337 ))),
338
339 _ => Ok(Expression::Function(Box::new(f))),
341 }
342 }
343
344 fn transform_aggregate_function(
345 &self,
346 f: Box<crate::expressions::AggregateFunction>,
347 ) -> Result<Expression> {
348 let name_upper = f.name.to_uppercase();
349 match name_upper.as_str() {
350 "COUNT_IF" if !f.args.is_empty() => {
352 let condition = f.args.into_iter().next().unwrap();
353 let case_expr = Expression::Case(Box::new(Case {
354 operand: None,
355 whens: vec![(condition, Expression::number(1))],
356 else_: Some(Expression::number(0)),
357 comments: Vec::new(),
358 }));
359 Ok(Expression::Sum(Box::new(AggFunc {
360 ignore_nulls: None,
361 having_max: None,
362 this: case_expr,
363 distinct: f.distinct,
364 filter: f.filter,
365 order_by: Vec::new(),
366 name: None,
367 limit: None,
368 })))
369 }
370
371 "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
373 "ARBITRARY".to_string(),
374 f.args,
375 )))),
376
377 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
379 Function::new("LISTAGG".to_string(), f.args),
380 ))),
381
382 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
384 Function::new("LISTAGG".to_string(), f.args),
385 ))),
386
387 _ => Ok(Expression::AggregateFunction(f)),
389 }
390 }
391
392 fn transform_cast(&self, c: Cast) -> Result<Expression> {
393 Ok(Expression::Cast(Box::new(c)))
395 }
396}
397
398fn should_use_hive_engine(expr: &Expression) -> bool {
411 match expr {
412 Expression::CreateTable(ct) => {
414 if let Some(ref modifier) = ct.table_modifier {
416 if modifier.to_uppercase() == "EXTERNAL" {
417 return true;
418 }
419 }
420 ct.as_select.is_none()
423 }
424
425 Expression::CreateView(_) => false,
427
428 Expression::CreateSchema(_) => true,
430 Expression::CreateDatabase(_) => true,
431
432 Expression::AlterTable(_) => true,
434 Expression::AlterView(_) => true,
435 Expression::AlterIndex(_) => true,
436 Expression::AlterSequence(_) => true,
437
438 Expression::DropView(_) => false,
440
441 Expression::DropTable(_) => true,
443 Expression::DropSchema(_) => true,
444 Expression::DropDatabase(_) => true,
445 Expression::DropIndex(_) => true,
446 Expression::DropFunction(_) => true,
447 Expression::DropProcedure(_) => true,
448 Expression::DropSequence(_) => true,
449
450 Expression::Describe(_) => true,
452 Expression::Show(_) => true,
453
454 _ => false,
456 }
457}