1use super::{DialectImpl, DialectType};
12use crate::error::Result;
13use crate::expressions::{
14 AggFunc, Case, Cast, DataType, Expression, Function, LikeOp, UnaryFunc, VarArgFunc,
15};
16use crate::generator::{GeneratorConfig, IdentifierQuoteStyle};
17use crate::tokens::TokenizerConfig;
18
19pub struct AthenaDialect;
21
22impl DialectImpl for AthenaDialect {
23 fn dialect_type(&self) -> DialectType {
24 DialectType::Athena
25 }
26
27 fn tokenizer_config(&self) -> TokenizerConfig {
28 let mut config = TokenizerConfig::default();
29 config.identifiers.insert('"', '"');
31 config.identifiers.insert('`', '`');
33 config.nested_comments = false;
34 config.string_escapes.push('\\');
36 config
37 }
38
39 fn generator_config(&self) -> GeneratorConfig {
40 GeneratorConfig {
42 identifier_quote: '"',
43 identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
44 dialect: Some(DialectType::Athena),
45 schema_comment_with_eq: false,
46 ..Default::default()
47 }
48 }
49
50 fn generator_config_for_expr(&self, expr: &Expression) -> GeneratorConfig {
51 if should_use_hive_engine(expr) {
52 GeneratorConfig {
54 identifier_quote: '`',
55 identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
56 dialect: Some(DialectType::Athena),
57 schema_comment_with_eq: false,
58 ..Default::default()
59 }
60 } else {
61 GeneratorConfig {
63 identifier_quote: '"',
64 identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
65 dialect: Some(DialectType::Athena),
66 schema_comment_with_eq: false,
67 ..Default::default()
68 }
69 }
70 }
71
72 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
73 match expr {
74 Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
76 expressions: vec![f.this, f.expression],
77 }))),
78
79 Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
81 expressions: vec![f.this, f.expression],
82 }))),
83
84 Expression::Coalesce(mut f) => {
86 f.original_name = None;
87 Ok(Expression::Coalesce(f))
88 }
89
90 Expression::TryCast(c) => Ok(Expression::TryCast(c)),
92
93 Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
95
96 Expression::ILike(op) => {
98 let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
99 let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
100 Ok(Expression::Like(Box::new(LikeOp {
101 left: lower_left,
102 right: lower_right,
103 escape: op.escape,
104 quantifier: op.quantifier.clone(),
105 })))
106 }
107
108 Expression::CountIf(f) => {
110 let case_expr = Expression::Case(Box::new(Case {
111 operand: None,
112 whens: vec![(f.this.clone(), Expression::number(1))],
113 else_: Some(Expression::number(0)),
114 }));
115 Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
116 this: case_expr,
117 distinct: f.distinct,
118 filter: f.filter,
119 order_by: Vec::new(),
120 name: None,
121 limit: None,
122 })))
123 }
124
125 Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
127 crate::expressions::UnnestFunc {
128 this: f.this,
129 expressions: Vec::new(),
130 with_ordinality: false,
131 alias: None,
132 offset_alias: None,
133 },
134 ))),
135
136 Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
138 crate::expressions::UnnestFunc {
139 this: f.this,
140 expressions: Vec::new(),
141 with_ordinality: false,
142 alias: None,
143 offset_alias: None,
144 },
145 ))),
146
147 Expression::Function(f) => self.transform_function(*f),
149
150 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
152
153 Expression::Cast(c) => self.transform_cast(*c),
155
156 _ => Ok(expr),
158 }
159 }
160}
161
162impl AthenaDialect {
163 fn transform_function(&self, f: Function) -> Result<Expression> {
164 let name_upper = f.name.to_uppercase();
165 match name_upper.as_str() {
166 "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
168 expressions: f.args,
169 }))),
170
171 "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
173 expressions: f.args,
174 }))),
175
176 "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
178 expressions: f.args,
179 }))),
180
181 "GETDATE" => Ok(Expression::CurrentTimestamp(
183 crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
184 )),
185
186 "NOW" => Ok(Expression::CurrentTimestamp(
188 crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
189 )),
190
191 "RAND" => Ok(Expression::Function(Box::new(Function::new(
193 "RANDOM".to_string(),
194 vec![],
195 )))),
196
197 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
199 Function::new("LISTAGG".to_string(), f.args),
200 ))),
201
202 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
204 Function::new("LISTAGG".to_string(), f.args),
205 ))),
206
207 "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
209 "SUBSTRING".to_string(),
210 f.args,
211 )))),
212
213 "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
215 f.args.into_iter().next().unwrap(),
216 )))),
217
218 "CHARINDEX" if f.args.len() >= 2 => {
220 let mut args = f.args;
221 let substring = args.remove(0);
222 let string = args.remove(0);
223 Ok(Expression::Function(Box::new(Function::new(
224 "STRPOS".to_string(),
225 vec![string, substring],
226 ))))
227 }
228
229 "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
231 "STRPOS".to_string(),
232 f.args,
233 )))),
234
235 "LOCATE" if f.args.len() >= 2 => {
237 let mut args = f.args;
238 let substring = args.remove(0);
239 let string = args.remove(0);
240 Ok(Expression::Function(Box::new(Function::new(
241 "STRPOS".to_string(),
242 vec![string, substring],
243 ))))
244 }
245
246 "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
248 Function::new("CARDINALITY".to_string(), f.args),
249 ))),
250
251 "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
253 "CARDINALITY".to_string(),
254 f.args,
255 )))),
256
257 "TO_DATE" if !f.args.is_empty() => {
259 if f.args.len() == 1 {
260 Ok(Expression::Cast(Box::new(Cast {
261 this: f.args.into_iter().next().unwrap(),
262 to: DataType::Date,
263 trailing_comments: Vec::new(),
264 double_colon_syntax: false,
265 format: None,
266 default: None,
267 })))
268 } else {
269 Ok(Expression::Function(Box::new(Function::new(
270 "DATE_PARSE".to_string(),
271 f.args,
272 ))))
273 }
274 }
275
276 "TO_TIMESTAMP" if !f.args.is_empty() => {
278 if f.args.len() == 1 {
279 Ok(Expression::Cast(Box::new(Cast {
280 this: f.args.into_iter().next().unwrap(),
281 to: DataType::Timestamp {
282 precision: None,
283 timezone: false,
284 },
285 trailing_comments: Vec::new(),
286 double_colon_syntax: false,
287 format: None,
288 default: None,
289 })))
290 } else {
291 Ok(Expression::Function(Box::new(Function::new(
292 "DATE_PARSE".to_string(),
293 f.args,
294 ))))
295 }
296 }
297
298 "STRFTIME" if f.args.len() >= 2 => {
300 let mut args = f.args;
301 let format = args.remove(0);
302 let date = args.remove(0);
303 Ok(Expression::Function(Box::new(Function::new(
304 "DATE_FORMAT".to_string(),
305 vec![date, format],
306 ))))
307 }
308
309 "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
311 "DATE_FORMAT".to_string(),
312 f.args,
313 )))),
314
315 "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
317 Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
318 ))),
319
320 "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
322 Function::new("ARRAY_AGG".to_string(), f.args),
323 ))),
324
325 _ => Ok(Expression::Function(Box::new(f))),
327 }
328 }
329
330 fn transform_aggregate_function(
331 &self,
332 f: Box<crate::expressions::AggregateFunction>,
333 ) -> Result<Expression> {
334 let name_upper = f.name.to_uppercase();
335 match name_upper.as_str() {
336 "COUNT_IF" if !f.args.is_empty() => {
338 let condition = f.args.into_iter().next().unwrap();
339 let case_expr = Expression::Case(Box::new(Case {
340 operand: None,
341 whens: vec![(condition, Expression::number(1))],
342 else_: Some(Expression::number(0)),
343 }));
344 Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
345 this: case_expr,
346 distinct: f.distinct,
347 filter: f.filter,
348 order_by: Vec::new(),
349 name: None,
350 limit: None,
351 })))
352 }
353
354 "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
356 "ARBITRARY".to_string(),
357 f.args,
358 )))),
359
360 "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
362 Function::new("LISTAGG".to_string(), f.args),
363 ))),
364
365 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
367 Function::new("LISTAGG".to_string(), f.args),
368 ))),
369
370 _ => Ok(Expression::AggregateFunction(f)),
372 }
373 }
374
375 fn transform_cast(&self, c: Cast) -> Result<Expression> {
376 Ok(Expression::Cast(Box::new(c)))
378 }
379}
380
381fn should_use_hive_engine(expr: &Expression) -> bool {
394 match expr {
395 Expression::CreateTable(ct) => {
397 if let Some(ref modifier) = ct.table_modifier {
399 if modifier.to_uppercase() == "EXTERNAL" {
400 return true;
401 }
402 }
403 ct.as_select.is_none()
406 }
407
408 Expression::CreateView(_) => false,
410
411 Expression::CreateSchema(_) => true,
413 Expression::CreateDatabase(_) => true,
414
415 Expression::AlterTable(_) => true,
417 Expression::AlterView(_) => true,
418 Expression::AlterIndex(_) => true,
419 Expression::AlterSequence(_) => true,
420
421 Expression::DropView(_) => false,
423
424 Expression::DropTable(_) => true,
426 Expression::DropSchema(_) => true,
427 Expression::DropDatabase(_) => true,
428 Expression::DropIndex(_) => true,
429 Expression::DropFunction(_) => true,
430 Expression::DropProcedure(_) => true,
431 Expression::DropSequence(_) => true,
432
433 Expression::Describe(_) => true,
435 Expression::Show(_) => true,
436
437 _ => false,
439 }
440}