Skip to main content

polyglot_sql/dialects/
athena.rs

1//! Athena Dialect
2//!
3//! AWS Athena-specific transformations based on sqlglot patterns.
4//! Athena routes between Hive (DDL) and Trino (DML) engines:
5//!
6//! - **Hive** (backticks): CREATE EXTERNAL TABLE, CREATE TABLE (no AS SELECT),
7//!   ALTER, DROP (except VIEW), DESCRIBE, SHOW
8//! - **Trino** (double quotes): CREATE VIEW, CREATE TABLE AS SELECT, DROP VIEW,
9//!   SELECT, INSERT, UPDATE, DELETE, MERGE
10
11use super::{DialectImpl, DialectType};
12use crate::error::Result;
13use crate::expressions::{
14    AggFunc, Case, Cast, DataType, Expression, Function, LikeOp, UnaryFunc, VarArgFunc,
15};
16use crate::generator::{GeneratorConfig, IdentifierQuoteStyle};
17use crate::tokens::TokenizerConfig;
18
19/// Athena dialect (based on Trino for DML operations)
20pub struct AthenaDialect;
21
22impl DialectImpl for AthenaDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Athena
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Athena uses double quotes for identifiers (Trino-style for DML)
30        config.identifiers.insert('"', '"');
31        // Also supports backticks (Hive-style for DDL)
32        config.identifiers.insert('`', '`');
33        config.nested_comments = false;
34        // Athena/Hive supports backslash escapes in string literals (e.g., \' for escaped quote)
35        config.string_escapes.push('\\');
36        config
37    }
38
39    fn generator_config(&self) -> GeneratorConfig {
40        // Default config uses Trino style (double quotes)
41        GeneratorConfig {
42            identifier_quote: '"',
43            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
44            dialect: Some(DialectType::Athena),
45            schema_comment_with_eq: false,
46            ..Default::default()
47        }
48    }
49
50    fn generator_config_for_expr(&self, expr: &Expression) -> GeneratorConfig {
51        if should_use_hive_engine(expr) {
52            // Hive mode: backticks for identifiers
53            GeneratorConfig {
54                identifier_quote: '`',
55                identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
56                dialect: Some(DialectType::Athena),
57                schema_comment_with_eq: false,
58                ..Default::default()
59            }
60        } else {
61            // Trino mode: double quotes for identifiers
62            GeneratorConfig {
63                identifier_quote: '"',
64                identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
65                dialect: Some(DialectType::Athena),
66                schema_comment_with_eq: false,
67                ..Default::default()
68            }
69        }
70    }
71
72    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
73        match expr {
74            // IFNULL -> COALESCE in Athena
75            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
76                expressions: vec![f.this, f.expression],
77            }))),
78
79            // NVL -> COALESCE in Athena
80            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
81                expressions: vec![f.this, f.expression],
82            }))),
83
84            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
85            Expression::Coalesce(mut f) => {
86                f.original_name = None;
87                Ok(Expression::Coalesce(f))
88            }
89
90            // TryCast stays as TryCast (Athena/Trino supports TRY_CAST)
91            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
92
93            // SafeCast -> TRY_CAST in Athena
94            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
95
96            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
97            Expression::ILike(op) => {
98                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
99                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
100                Ok(Expression::Like(Box::new(LikeOp {
101                    left: lower_left,
102                    right: lower_right,
103                    escape: op.escape,
104                    quantifier: op.quantifier.clone(),
105                })))
106            }
107
108            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
109            Expression::CountIf(f) => {
110                let case_expr = Expression::Case(Box::new(Case {
111                    operand: None,
112                    whens: vec![(f.this.clone(), Expression::number(1))],
113                    else_: Some(Expression::number(0)),
114                }));
115                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
116                    this: case_expr,
117                    distinct: f.distinct,
118                    filter: f.filter,
119                    order_by: Vec::new(),
120                name: None,
121                limit: None,
122                })))
123            }
124
125            // EXPLODE -> UNNEST in Athena
126            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
127                crate::expressions::UnnestFunc {
128                    this: f.this,
129                    expressions: Vec::new(),
130                    with_ordinality: false,
131                    alias: None,
132                    offset_alias: None,
133                },
134            ))),
135
136            // ExplodeOuter -> UNNEST in Athena
137            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
138                crate::expressions::UnnestFunc {
139                    this: f.this,
140                    expressions: Vec::new(),
141                    with_ordinality: false,
142                    alias: None,
143                    offset_alias: None,
144                },
145            ))),
146
147            // Generic function transformations
148            Expression::Function(f) => self.transform_function(*f),
149
150            // Generic aggregate function transformations
151            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
152
153            // Cast transformations
154            Expression::Cast(c) => self.transform_cast(*c),
155
156            // Pass through everything else
157            _ => Ok(expr),
158        }
159    }
160}
161
162impl AthenaDialect {
163    fn transform_function(&self, f: Function) -> Result<Expression> {
164        let name_upper = f.name.to_uppercase();
165        match name_upper.as_str() {
166            // IFNULL -> COALESCE
167            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
168                expressions: f.args,
169            }))),
170
171            // NVL -> COALESCE
172            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
173                expressions: f.args,
174            }))),
175
176            // ISNULL -> COALESCE
177            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
178                expressions: f.args,
179            }))),
180
181            // GETDATE -> CURRENT_TIMESTAMP
182            "GETDATE" => Ok(Expression::CurrentTimestamp(
183                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
184            )),
185
186            // NOW -> CURRENT_TIMESTAMP
187            "NOW" => Ok(Expression::CurrentTimestamp(
188                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
189            )),
190
191            // RAND -> RANDOM in Athena
192            "RAND" => Ok(Expression::Function(Box::new(Function::new(
193                "RANDOM".to_string(),
194                vec![],
195            )))),
196
197            // GROUP_CONCAT -> LISTAGG in Athena (Trino-style)
198            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
199                Function::new("LISTAGG".to_string(), f.args),
200            ))),
201
202            // STRING_AGG -> LISTAGG in Athena
203            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
204                Function::new("LISTAGG".to_string(), f.args),
205            ))),
206
207            // SUBSTR -> SUBSTRING
208            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
209                "SUBSTRING".to_string(),
210                f.args,
211            )))),
212
213            // LEN -> LENGTH
214            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
215                f.args.into_iter().next().unwrap(),
216            )))),
217
218            // CHARINDEX -> STRPOS in Athena (with swapped args)
219            "CHARINDEX" if f.args.len() >= 2 => {
220                let mut args = f.args;
221                let substring = args.remove(0);
222                let string = args.remove(0);
223                Ok(Expression::Function(Box::new(Function::new(
224                    "STRPOS".to_string(),
225                    vec![string, substring],
226                ))))
227            }
228
229            // INSTR -> STRPOS
230            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
231                "STRPOS".to_string(),
232                f.args,
233            )))),
234
235            // LOCATE -> STRPOS in Athena (with swapped args)
236            "LOCATE" if f.args.len() >= 2 => {
237                let mut args = f.args;
238                let substring = args.remove(0);
239                let string = args.remove(0);
240                Ok(Expression::Function(Box::new(Function::new(
241                    "STRPOS".to_string(),
242                    vec![string, substring],
243                ))))
244            }
245
246            // ARRAY_LENGTH -> CARDINALITY in Athena
247            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
248                Function::new("CARDINALITY".to_string(), f.args),
249            ))),
250
251            // SIZE -> CARDINALITY in Athena
252            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
253                "CARDINALITY".to_string(),
254                f.args,
255            )))),
256
257            // TO_DATE -> CAST to DATE or DATE_PARSE
258            "TO_DATE" if !f.args.is_empty() => {
259                if f.args.len() == 1 {
260                    Ok(Expression::Cast(Box::new(Cast {
261                        this: f.args.into_iter().next().unwrap(),
262                        to: DataType::Date,
263                        trailing_comments: Vec::new(),
264                        double_colon_syntax: false,
265                        format: None,
266                        default: None,
267                    })))
268                } else {
269                    Ok(Expression::Function(Box::new(Function::new(
270                        "DATE_PARSE".to_string(),
271                        f.args,
272                    ))))
273                }
274            }
275
276            // TO_TIMESTAMP -> CAST or DATE_PARSE
277            "TO_TIMESTAMP" if !f.args.is_empty() => {
278                if f.args.len() == 1 {
279                    Ok(Expression::Cast(Box::new(Cast {
280                        this: f.args.into_iter().next().unwrap(),
281                        to: DataType::Timestamp {
282                            precision: None,
283                            timezone: false,
284                        },
285                        trailing_comments: Vec::new(),
286                        double_colon_syntax: false,
287                        format: None,
288                        default: None,
289                    })))
290                } else {
291                    Ok(Expression::Function(Box::new(Function::new(
292                        "DATE_PARSE".to_string(),
293                        f.args,
294                    ))))
295                }
296            }
297
298            // strftime -> DATE_FORMAT in Athena
299            "STRFTIME" if f.args.len() >= 2 => {
300                let mut args = f.args;
301                let format = args.remove(0);
302                let date = args.remove(0);
303                Ok(Expression::Function(Box::new(Function::new(
304                    "DATE_FORMAT".to_string(),
305                    vec![date, format],
306                ))))
307            }
308
309            // TO_CHAR -> DATE_FORMAT in Athena
310            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
311                "DATE_FORMAT".to_string(),
312                f.args,
313            )))),
314
315            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Athena
316            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
317                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
318            ))),
319
320            // COLLECT_LIST -> ARRAY_AGG
321            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
322                Function::new("ARRAY_AGG".to_string(), f.args),
323            ))),
324
325            // Pass through everything else
326            _ => Ok(Expression::Function(Box::new(f))),
327        }
328    }
329
330    fn transform_aggregate_function(
331        &self,
332        f: Box<crate::expressions::AggregateFunction>,
333    ) -> Result<Expression> {
334        let name_upper = f.name.to_uppercase();
335        match name_upper.as_str() {
336            // COUNT_IF -> SUM(CASE WHEN...)
337            "COUNT_IF" if !f.args.is_empty() => {
338                let condition = f.args.into_iter().next().unwrap();
339                let case_expr = Expression::Case(Box::new(Case {
340                    operand: None,
341                    whens: vec![(condition, Expression::number(1))],
342                    else_: Some(Expression::number(0)),
343                }));
344                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
345                    this: case_expr,
346                    distinct: f.distinct,
347                    filter: f.filter,
348                    order_by: Vec::new(),
349                name: None,
350                limit: None,
351                })))
352            }
353
354            // ANY_VALUE -> ARBITRARY in Athena (Trino)
355            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
356                "ARBITRARY".to_string(),
357                f.args,
358            )))),
359
360            // GROUP_CONCAT -> LISTAGG in Athena
361            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
362                Function::new("LISTAGG".to_string(), f.args),
363            ))),
364
365            // STRING_AGG -> LISTAGG in Athena
366            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
367                Function::new("LISTAGG".to_string(), f.args),
368            ))),
369
370            // Pass through everything else
371            _ => Ok(Expression::AggregateFunction(f)),
372        }
373    }
374
375    fn transform_cast(&self, c: Cast) -> Result<Expression> {
376        // Athena type mappings are handled in the generator
377        Ok(Expression::Cast(Box::new(c)))
378    }
379}
380
381/// Determine if an expression should be generated using Hive engine (backticks)
382/// or Trino engine (double quotes).
383///
384/// Hive is used for:
385/// - CREATE EXTERNAL TABLE
386/// - CREATE TABLE (without AS SELECT)
387/// - CREATE SCHEMA / CREATE DATABASE
388/// - ALTER statements
389/// - DROP statements (except DROP VIEW)
390/// - DESCRIBE / SHOW statements
391///
392/// Trino is used for everything else (DML, CREATE VIEW, etc.)
393fn should_use_hive_engine(expr: &Expression) -> bool {
394    match expr {
395        // CREATE TABLE: Hive if EXTERNAL or no AS SELECT
396        Expression::CreateTable(ct) => {
397            // CREATE EXTERNAL TABLE → Hive
398            if let Some(ref modifier) = ct.table_modifier {
399                if modifier.to_uppercase() == "EXTERNAL" {
400                    return true;
401                }
402            }
403            // CREATE TABLE ... AS SELECT → Trino
404            // CREATE TABLE (without query) → Hive
405            ct.as_select.is_none()
406        }
407
408        // CREATE VIEW → Trino
409        Expression::CreateView(_) => false,
410
411        // CREATE SCHEMA / DATABASE → Hive
412        Expression::CreateSchema(_) => true,
413        Expression::CreateDatabase(_) => true,
414
415        // ALTER statements → Hive
416        Expression::AlterTable(_) => true,
417        Expression::AlterView(_) => true,
418        Expression::AlterIndex(_) => true,
419        Expression::AlterSequence(_) => true,
420
421        // DROP VIEW → Trino (because CREATE VIEW is Trino)
422        Expression::DropView(_) => false,
423
424        // Other DROP statements → Hive
425        Expression::DropTable(_) => true,
426        Expression::DropSchema(_) => true,
427        Expression::DropDatabase(_) => true,
428        Expression::DropIndex(_) => true,
429        Expression::DropFunction(_) => true,
430        Expression::DropProcedure(_) => true,
431        Expression::DropSequence(_) => true,
432
433        // DESCRIBE / SHOW → Hive
434        Expression::Describe(_) => true,
435        Expression::Show(_) => true,
436
437        // Everything else (SELECT, INSERT, UPDATE, DELETE, MERGE, etc.) → Trino
438        _ => false,
439    }
440}