Skip to main content

polyglot_sql/dialects/
athena.rs

1//! Athena Dialect
2//!
3//! AWS Athena-specific transformations based on sqlglot patterns.
4//! Athena routes between Hive (DDL) and Trino (DML) engines:
5//!
6//! - **Hive** (backticks): CREATE EXTERNAL TABLE, CREATE TABLE (no AS SELECT),
7//!   ALTER, DROP (except VIEW), DESCRIBE, SHOW
8//! - **Trino** (double quotes): CREATE VIEW, CREATE TABLE AS SELECT, DROP VIEW,
9//!   SELECT, INSERT, UPDATE, DELETE, MERGE
10
11use super::{DialectImpl, DialectType};
12use crate::error::Result;
13use crate::expressions::{
14    AggFunc, Case, Cast, DataType, Expression, Function, LikeOp, UnaryFunc, VarArgFunc,
15};
16use crate::generator::{GeneratorConfig, IdentifierQuoteStyle};
17use crate::tokens::TokenizerConfig;
18
19/// Athena dialect (based on Trino for DML operations)
20pub struct AthenaDialect;
21
22impl DialectImpl for AthenaDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Athena
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Athena uses double quotes for identifiers (Trino-style for DML)
30        config.identifiers.insert('"', '"');
31        // Also supports backticks (Hive-style for DDL)
32        config.identifiers.insert('`', '`');
33        config.nested_comments = false;
34        // Athena/Hive supports backslash escapes in string literals (e.g., \' for escaped quote)
35        config.string_escapes.push('\\');
36        config
37    }
38
39    fn generator_config(&self) -> GeneratorConfig {
40        // Default config uses Trino style (double quotes)
41        GeneratorConfig {
42            identifier_quote: '"',
43            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
44            dialect: Some(DialectType::Athena),
45            schema_comment_with_eq: false,
46            ..Default::default()
47        }
48    }
49
50    fn generator_config_for_expr(&self, expr: &Expression) -> GeneratorConfig {
51        if should_use_hive_engine(expr) {
52            // Hive mode: backticks for identifiers
53            GeneratorConfig {
54                identifier_quote: '`',
55                identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
56                dialect: Some(DialectType::Athena),
57                schema_comment_with_eq: false,
58                ..Default::default()
59            }
60        } else {
61            // Trino mode: double quotes for identifiers
62            GeneratorConfig {
63                identifier_quote: '"',
64                identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
65                dialect: Some(DialectType::Athena),
66                schema_comment_with_eq: false,
67                ..Default::default()
68            }
69        }
70    }
71
72    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
73        match expr {
74            // IFNULL -> COALESCE in Athena
75            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
76                original_name: None,
77                expressions: vec![f.this, f.expression],
78                inferred_type: None,
79            }))),
80
81            // NVL -> COALESCE in Athena
82            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
83                original_name: None,
84                expressions: vec![f.this, f.expression],
85                inferred_type: None,
86            }))),
87
88            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
89            Expression::Coalesce(mut f) => {
90                f.original_name = None;
91                Ok(Expression::Coalesce(f))
92            }
93
94            // TryCast stays as TryCast (Athena/Trino supports TRY_CAST)
95            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
96
97            // SafeCast -> TRY_CAST in Athena
98            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
99
100            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
101            Expression::ILike(op) => {
102                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
103                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
104                Ok(Expression::Like(Box::new(LikeOp {
105                    left: lower_left,
106                    right: lower_right,
107                    escape: op.escape,
108                    quantifier: op.quantifier.clone(),
109                    inferred_type: None,
110                })))
111            }
112
113            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
114            Expression::CountIf(f) => {
115                let case_expr = Expression::Case(Box::new(Case {
116                    operand: None,
117                    whens: vec![(f.this.clone(), Expression::number(1))],
118                    else_: Some(Expression::number(0)),
119                    comments: Vec::new(),
120                    inferred_type: None,
121                }));
122                Ok(Expression::Sum(Box::new(AggFunc {
123                    ignore_nulls: None,
124                    having_max: None,
125                    this: case_expr,
126                    distinct: f.distinct,
127                    filter: f.filter,
128                    order_by: Vec::new(),
129                    name: None,
130                    limit: None,
131                    inferred_type: None,
132                })))
133            }
134
135            // EXPLODE -> UNNEST in Athena
136            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
137                crate::expressions::UnnestFunc {
138                    this: f.this,
139                    expressions: Vec::new(),
140                    with_ordinality: false,
141                    alias: None,
142                    offset_alias: None,
143                },
144            ))),
145
146            // ExplodeOuter -> UNNEST in Athena
147            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
148                crate::expressions::UnnestFunc {
149                    this: f.this,
150                    expressions: Vec::new(),
151                    with_ordinality: false,
152                    alias: None,
153                    offset_alias: None,
154                },
155            ))),
156
157            // Generic function transformations
158            Expression::Function(f) => self.transform_function(*f),
159
160            // Generic aggregate function transformations
161            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
162
163            // Cast transformations
164            Expression::Cast(c) => self.transform_cast(*c),
165
166            // Pass through everything else
167            _ => Ok(expr),
168        }
169    }
170}
171
172impl AthenaDialect {
173    fn transform_function(&self, f: Function) -> Result<Expression> {
174        let name_upper = f.name.to_uppercase();
175        match name_upper.as_str() {
176            // IFNULL -> COALESCE
177            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
178                original_name: None,
179                expressions: f.args,
180                inferred_type: None,
181            }))),
182
183            // NVL -> COALESCE
184            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
185                original_name: None,
186                expressions: f.args,
187                inferred_type: None,
188            }))),
189
190            // ISNULL -> COALESCE
191            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
192                original_name: None,
193                expressions: f.args,
194                inferred_type: None,
195            }))),
196
197            // GETDATE -> CURRENT_TIMESTAMP
198            "GETDATE" => Ok(Expression::CurrentTimestamp(
199                crate::expressions::CurrentTimestamp {
200                    precision: None,
201                    sysdate: false,
202                },
203            )),
204
205            // NOW -> CURRENT_TIMESTAMP
206            "NOW" => Ok(Expression::CurrentTimestamp(
207                crate::expressions::CurrentTimestamp {
208                    precision: None,
209                    sysdate: false,
210                },
211            )),
212
213            // RAND -> RANDOM in Athena
214            "RAND" => Ok(Expression::Function(Box::new(Function::new(
215                "RANDOM".to_string(),
216                vec![],
217            )))),
218
219            // GROUP_CONCAT -> LISTAGG in Athena (Trino-style)
220            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
221                Function::new("LISTAGG".to_string(), f.args),
222            ))),
223
224            // STRING_AGG -> LISTAGG in Athena
225            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
226                Function::new("LISTAGG".to_string(), f.args),
227            ))),
228
229            // SUBSTR -> SUBSTRING
230            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
231                "SUBSTRING".to_string(),
232                f.args,
233            )))),
234
235            // LEN -> LENGTH
236            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
237                f.args.into_iter().next().unwrap(),
238            )))),
239
240            // CHARINDEX -> STRPOS in Athena (with swapped args)
241            "CHARINDEX" if f.args.len() >= 2 => {
242                let mut args = f.args;
243                let substring = args.remove(0);
244                let string = args.remove(0);
245                Ok(Expression::Function(Box::new(Function::new(
246                    "STRPOS".to_string(),
247                    vec![string, substring],
248                ))))
249            }
250
251            // INSTR -> STRPOS
252            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
253                "STRPOS".to_string(),
254                f.args,
255            )))),
256
257            // LOCATE -> STRPOS in Athena (with swapped args)
258            "LOCATE" if f.args.len() >= 2 => {
259                let mut args = f.args;
260                let substring = args.remove(0);
261                let string = args.remove(0);
262                Ok(Expression::Function(Box::new(Function::new(
263                    "STRPOS".to_string(),
264                    vec![string, substring],
265                ))))
266            }
267
268            // ARRAY_LENGTH -> CARDINALITY in Athena
269            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
270                Function::new("CARDINALITY".to_string(), f.args),
271            ))),
272
273            // SIZE -> CARDINALITY in Athena
274            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
275                "CARDINALITY".to_string(),
276                f.args,
277            )))),
278
279            // TO_DATE -> CAST to DATE or DATE_PARSE
280            "TO_DATE" if !f.args.is_empty() => {
281                if f.args.len() == 1 {
282                    Ok(Expression::Cast(Box::new(Cast {
283                        this: f.args.into_iter().next().unwrap(),
284                        to: DataType::Date,
285                        trailing_comments: Vec::new(),
286                        double_colon_syntax: false,
287                        format: None,
288                        default: None,
289                        inferred_type: None,
290                    })))
291                } else {
292                    Ok(Expression::Function(Box::new(Function::new(
293                        "DATE_PARSE".to_string(),
294                        f.args,
295                    ))))
296                }
297            }
298
299            // TO_TIMESTAMP -> CAST or DATE_PARSE
300            "TO_TIMESTAMP" if !f.args.is_empty() => {
301                if f.args.len() == 1 {
302                    Ok(Expression::Cast(Box::new(Cast {
303                        this: f.args.into_iter().next().unwrap(),
304                        to: DataType::Timestamp {
305                            precision: None,
306                            timezone: false,
307                        },
308                        trailing_comments: Vec::new(),
309                        double_colon_syntax: false,
310                        format: None,
311                        default: None,
312                        inferred_type: None,
313                    })))
314                } else {
315                    Ok(Expression::Function(Box::new(Function::new(
316                        "DATE_PARSE".to_string(),
317                        f.args,
318                    ))))
319                }
320            }
321
322            // strftime -> DATE_FORMAT in Athena
323            "STRFTIME" if f.args.len() >= 2 => {
324                let mut args = f.args;
325                let format = args.remove(0);
326                let date = args.remove(0);
327                Ok(Expression::Function(Box::new(Function::new(
328                    "DATE_FORMAT".to_string(),
329                    vec![date, format],
330                ))))
331            }
332
333            // TO_CHAR -> DATE_FORMAT in Athena
334            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
335                "DATE_FORMAT".to_string(),
336                f.args,
337            )))),
338
339            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Athena
340            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
341                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
342            ))),
343
344            // COLLECT_LIST -> ARRAY_AGG
345            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
346                Function::new("ARRAY_AGG".to_string(), f.args),
347            ))),
348
349            // Pass through everything else
350            _ => Ok(Expression::Function(Box::new(f))),
351        }
352    }
353
354    fn transform_aggregate_function(
355        &self,
356        f: Box<crate::expressions::AggregateFunction>,
357    ) -> Result<Expression> {
358        let name_upper = f.name.to_uppercase();
359        match name_upper.as_str() {
360            // COUNT_IF -> SUM(CASE WHEN...)
361            "COUNT_IF" if !f.args.is_empty() => {
362                let condition = f.args.into_iter().next().unwrap();
363                let case_expr = Expression::Case(Box::new(Case {
364                    operand: None,
365                    whens: vec![(condition, Expression::number(1))],
366                    else_: Some(Expression::number(0)),
367                    comments: Vec::new(),
368                    inferred_type: None,
369                }));
370                Ok(Expression::Sum(Box::new(AggFunc {
371                    ignore_nulls: None,
372                    having_max: None,
373                    this: case_expr,
374                    distinct: f.distinct,
375                    filter: f.filter,
376                    order_by: Vec::new(),
377                    name: None,
378                    limit: None,
379                    inferred_type: None,
380                })))
381            }
382
383            // ANY_VALUE -> ARBITRARY in Athena (Trino)
384            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
385                "ARBITRARY".to_string(),
386                f.args,
387            )))),
388
389            // GROUP_CONCAT -> LISTAGG in Athena
390            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
391                Function::new("LISTAGG".to_string(), f.args),
392            ))),
393
394            // STRING_AGG -> LISTAGG in Athena
395            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
396                Function::new("LISTAGG".to_string(), f.args),
397            ))),
398
399            // Pass through everything else
400            _ => Ok(Expression::AggregateFunction(f)),
401        }
402    }
403
404    fn transform_cast(&self, c: Cast) -> Result<Expression> {
405        // Athena type mappings are handled in the generator
406        Ok(Expression::Cast(Box::new(c)))
407    }
408}
409
410/// Determine if an expression should be generated using Hive engine (backticks)
411/// or Trino engine (double quotes).
412///
413/// Hive is used for:
414/// - CREATE EXTERNAL TABLE
415/// - CREATE TABLE (without AS SELECT)
416/// - CREATE SCHEMA / CREATE DATABASE
417/// - ALTER statements
418/// - DROP statements (except DROP VIEW)
419/// - DESCRIBE / SHOW statements
420///
421/// Trino is used for everything else (DML, CREATE VIEW, etc.)
422fn should_use_hive_engine(expr: &Expression) -> bool {
423    match expr {
424        // CREATE TABLE: Hive if EXTERNAL or no AS SELECT
425        Expression::CreateTable(ct) => {
426            // CREATE EXTERNAL TABLE → Hive
427            if let Some(ref modifier) = ct.table_modifier {
428                if modifier.to_uppercase() == "EXTERNAL" {
429                    return true;
430                }
431            }
432            // CREATE TABLE ... AS SELECT → Trino
433            // CREATE TABLE (without query) → Hive
434            ct.as_select.is_none()
435        }
436
437        // CREATE VIEW → Trino
438        Expression::CreateView(_) => false,
439
440        // CREATE SCHEMA / DATABASE → Hive
441        Expression::CreateSchema(_) => true,
442        Expression::CreateDatabase(_) => true,
443
444        // ALTER statements → Hive
445        Expression::AlterTable(_) => true,
446        Expression::AlterView(_) => true,
447        Expression::AlterIndex(_) => true,
448        Expression::AlterSequence(_) => true,
449
450        // DROP VIEW → Trino (because CREATE VIEW is Trino)
451        Expression::DropView(_) => false,
452
453        // Other DROP statements → Hive
454        Expression::DropTable(_) => true,
455        Expression::DropSchema(_) => true,
456        Expression::DropDatabase(_) => true,
457        Expression::DropIndex(_) => true,
458        Expression::DropFunction(_) => true,
459        Expression::DropProcedure(_) => true,
460        Expression::DropSequence(_) => true,
461
462        // DESCRIBE / SHOW → Hive
463        Expression::Describe(_) => true,
464        Expression::Show(_) => true,
465
466        // Everything else (SELECT, INSERT, UPDATE, DELETE, MERGE, etc.) → Trino
467        _ => false,
468    }
469}