Skip to main content

polyglot_sql/dialects/
athena.rs

1//! Athena Dialect
2//!
3//! AWS Athena-specific transformations based on sqlglot patterns.
4//! Athena routes between Hive (DDL) and Trino (DML) engines:
5//!
6//! - **Hive** (backticks): CREATE EXTERNAL TABLE, CREATE TABLE (no AS SELECT),
7//!   ALTER, DROP (except VIEW), DESCRIBE, SHOW
8//! - **Trino** (double quotes): CREATE VIEW, CREATE TABLE AS SELECT, DROP VIEW,
9//!   SELECT, INSERT, UPDATE, DELETE, MERGE
10
11use super::{DialectImpl, DialectType};
12use crate::error::Result;
13use crate::expressions::{
14    AggFunc, Case, Cast, DataType, Expression, Function, LikeOp, UnaryFunc, VarArgFunc,
15};
16use crate::generator::{GeneratorConfig, IdentifierQuoteStyle};
17use crate::tokens::TokenizerConfig;
18
19/// Athena dialect (based on Trino for DML operations)
20pub struct AthenaDialect;
21
22impl DialectImpl for AthenaDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Athena
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Athena uses double quotes for identifiers (Trino-style for DML)
30        config.identifiers.insert('"', '"');
31        // Also supports backticks (Hive-style for DDL)
32        config.identifiers.insert('`', '`');
33        config.nested_comments = false;
34        // Athena/Hive supports backslash escapes in string literals (e.g., \' for escaped quote)
35        config.string_escapes.push('\\');
36        config
37    }
38
39    fn generator_config(&self) -> GeneratorConfig {
40        // Default config uses Trino style (double quotes)
41        GeneratorConfig {
42            identifier_quote: '"',
43            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
44            dialect: Some(DialectType::Athena),
45            schema_comment_with_eq: false,
46            ..Default::default()
47        }
48    }
49
50    fn generator_config_for_expr(&self, expr: &Expression) -> GeneratorConfig {
51        if should_use_hive_engine(expr) {
52            // Hive mode: backticks for identifiers
53            GeneratorConfig {
54                identifier_quote: '`',
55                identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
56                dialect: Some(DialectType::Athena),
57                schema_comment_with_eq: false,
58                ..Default::default()
59            }
60        } else {
61            // Trino mode: double quotes for identifiers
62            GeneratorConfig {
63                identifier_quote: '"',
64                identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
65                dialect: Some(DialectType::Athena),
66                schema_comment_with_eq: false,
67                ..Default::default()
68            }
69        }
70    }
71
72    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
73        match expr {
74            // IFNULL -> COALESCE in Athena
75            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
76                original_name: None,
77                expressions: vec![f.this, f.expression],
78            }))),
79
80            // NVL -> COALESCE in Athena
81            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
82                original_name: None,
83                expressions: vec![f.this, f.expression],
84            }))),
85
86            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
87            Expression::Coalesce(mut f) => {
88                f.original_name = None;
89                Ok(Expression::Coalesce(f))
90            }
91
92            // TryCast stays as TryCast (Athena/Trino supports TRY_CAST)
93            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
94
95            // SafeCast -> TRY_CAST in Athena
96            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
97
98            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
99            Expression::ILike(op) => {
100                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
101                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
102                Ok(Expression::Like(Box::new(LikeOp {
103                    left: lower_left,
104                    right: lower_right,
105                    escape: op.escape,
106                    quantifier: op.quantifier.clone(),
107                })))
108            }
109
110            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
111            Expression::CountIf(f) => {
112                let case_expr = Expression::Case(Box::new(Case {
113                    operand: None,
114                    whens: vec![(f.this.clone(), Expression::number(1))],
115                    else_: Some(Expression::number(0)),
116                    comments: Vec::new(),
117                }));
118                Ok(Expression::Sum(Box::new(AggFunc {
119                    ignore_nulls: None,
120                    having_max: None,
121                    this: case_expr,
122                    distinct: f.distinct,
123                    filter: f.filter,
124                    order_by: Vec::new(),
125                    name: None,
126                    limit: None,
127                })))
128            }
129
130            // EXPLODE -> UNNEST in Athena
131            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
132                crate::expressions::UnnestFunc {
133                    this: f.this,
134                    expressions: Vec::new(),
135                    with_ordinality: false,
136                    alias: None,
137                    offset_alias: None,
138                },
139            ))),
140
141            // ExplodeOuter -> UNNEST in Athena
142            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
143                crate::expressions::UnnestFunc {
144                    this: f.this,
145                    expressions: Vec::new(),
146                    with_ordinality: false,
147                    alias: None,
148                    offset_alias: None,
149                },
150            ))),
151
152            // Generic function transformations
153            Expression::Function(f) => self.transform_function(*f),
154
155            // Generic aggregate function transformations
156            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
157
158            // Cast transformations
159            Expression::Cast(c) => self.transform_cast(*c),
160
161            // Pass through everything else
162            _ => Ok(expr),
163        }
164    }
165}
166
167impl AthenaDialect {
168    fn transform_function(&self, f: Function) -> Result<Expression> {
169        let name_upper = f.name.to_uppercase();
170        match name_upper.as_str() {
171            // IFNULL -> COALESCE
172            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
173                original_name: None,
174                expressions: f.args,
175            }))),
176
177            // NVL -> COALESCE
178            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
179                original_name: None,
180                expressions: f.args,
181            }))),
182
183            // ISNULL -> COALESCE
184            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
185                original_name: None,
186                expressions: f.args,
187            }))),
188
189            // GETDATE -> CURRENT_TIMESTAMP
190            "GETDATE" => Ok(Expression::CurrentTimestamp(
191                crate::expressions::CurrentTimestamp {
192                    precision: None,
193                    sysdate: false,
194                },
195            )),
196
197            // NOW -> CURRENT_TIMESTAMP
198            "NOW" => Ok(Expression::CurrentTimestamp(
199                crate::expressions::CurrentTimestamp {
200                    precision: None,
201                    sysdate: false,
202                },
203            )),
204
205            // RAND -> RANDOM in Athena
206            "RAND" => Ok(Expression::Function(Box::new(Function::new(
207                "RANDOM".to_string(),
208                vec![],
209            )))),
210
211            // GROUP_CONCAT -> LISTAGG in Athena (Trino-style)
212            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
213                Function::new("LISTAGG".to_string(), f.args),
214            ))),
215
216            // STRING_AGG -> LISTAGG in Athena
217            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
218                Function::new("LISTAGG".to_string(), f.args),
219            ))),
220
221            // SUBSTR -> SUBSTRING
222            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
223                "SUBSTRING".to_string(),
224                f.args,
225            )))),
226
227            // LEN -> LENGTH
228            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
229                f.args.into_iter().next().unwrap(),
230            )))),
231
232            // CHARINDEX -> STRPOS in Athena (with swapped args)
233            "CHARINDEX" if f.args.len() >= 2 => {
234                let mut args = f.args;
235                let substring = args.remove(0);
236                let string = args.remove(0);
237                Ok(Expression::Function(Box::new(Function::new(
238                    "STRPOS".to_string(),
239                    vec![string, substring],
240                ))))
241            }
242
243            // INSTR -> STRPOS
244            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
245                "STRPOS".to_string(),
246                f.args,
247            )))),
248
249            // LOCATE -> STRPOS in Athena (with swapped args)
250            "LOCATE" if f.args.len() >= 2 => {
251                let mut args = f.args;
252                let substring = args.remove(0);
253                let string = args.remove(0);
254                Ok(Expression::Function(Box::new(Function::new(
255                    "STRPOS".to_string(),
256                    vec![string, substring],
257                ))))
258            }
259
260            // ARRAY_LENGTH -> CARDINALITY in Athena
261            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
262                Function::new("CARDINALITY".to_string(), f.args),
263            ))),
264
265            // SIZE -> CARDINALITY in Athena
266            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
267                "CARDINALITY".to_string(),
268                f.args,
269            )))),
270
271            // TO_DATE -> CAST to DATE or DATE_PARSE
272            "TO_DATE" if !f.args.is_empty() => {
273                if f.args.len() == 1 {
274                    Ok(Expression::Cast(Box::new(Cast {
275                        this: f.args.into_iter().next().unwrap(),
276                        to: DataType::Date,
277                        trailing_comments: Vec::new(),
278                        double_colon_syntax: false,
279                        format: None,
280                        default: None,
281                    })))
282                } else {
283                    Ok(Expression::Function(Box::new(Function::new(
284                        "DATE_PARSE".to_string(),
285                        f.args,
286                    ))))
287                }
288            }
289
290            // TO_TIMESTAMP -> CAST or DATE_PARSE
291            "TO_TIMESTAMP" if !f.args.is_empty() => {
292                if f.args.len() == 1 {
293                    Ok(Expression::Cast(Box::new(Cast {
294                        this: f.args.into_iter().next().unwrap(),
295                        to: DataType::Timestamp {
296                            precision: None,
297                            timezone: false,
298                        },
299                        trailing_comments: Vec::new(),
300                        double_colon_syntax: false,
301                        format: None,
302                        default: None,
303                    })))
304                } else {
305                    Ok(Expression::Function(Box::new(Function::new(
306                        "DATE_PARSE".to_string(),
307                        f.args,
308                    ))))
309                }
310            }
311
312            // strftime -> DATE_FORMAT in Athena
313            "STRFTIME" if f.args.len() >= 2 => {
314                let mut args = f.args;
315                let format = args.remove(0);
316                let date = args.remove(0);
317                Ok(Expression::Function(Box::new(Function::new(
318                    "DATE_FORMAT".to_string(),
319                    vec![date, format],
320                ))))
321            }
322
323            // TO_CHAR -> DATE_FORMAT in Athena
324            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
325                "DATE_FORMAT".to_string(),
326                f.args,
327            )))),
328
329            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Athena
330            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
331                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
332            ))),
333
334            // COLLECT_LIST -> ARRAY_AGG
335            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
336                Function::new("ARRAY_AGG".to_string(), f.args),
337            ))),
338
339            // Pass through everything else
340            _ => Ok(Expression::Function(Box::new(f))),
341        }
342    }
343
344    fn transform_aggregate_function(
345        &self,
346        f: Box<crate::expressions::AggregateFunction>,
347    ) -> Result<Expression> {
348        let name_upper = f.name.to_uppercase();
349        match name_upper.as_str() {
350            // COUNT_IF -> SUM(CASE WHEN...)
351            "COUNT_IF" if !f.args.is_empty() => {
352                let condition = f.args.into_iter().next().unwrap();
353                let case_expr = Expression::Case(Box::new(Case {
354                    operand: None,
355                    whens: vec![(condition, Expression::number(1))],
356                    else_: Some(Expression::number(0)),
357                    comments: Vec::new(),
358                }));
359                Ok(Expression::Sum(Box::new(AggFunc {
360                    ignore_nulls: None,
361                    having_max: None,
362                    this: case_expr,
363                    distinct: f.distinct,
364                    filter: f.filter,
365                    order_by: Vec::new(),
366                    name: None,
367                    limit: None,
368                })))
369            }
370
371            // ANY_VALUE -> ARBITRARY in Athena (Trino)
372            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
373                "ARBITRARY".to_string(),
374                f.args,
375            )))),
376
377            // GROUP_CONCAT -> LISTAGG in Athena
378            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
379                Function::new("LISTAGG".to_string(), f.args),
380            ))),
381
382            // STRING_AGG -> LISTAGG in Athena
383            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
384                Function::new("LISTAGG".to_string(), f.args),
385            ))),
386
387            // Pass through everything else
388            _ => Ok(Expression::AggregateFunction(f)),
389        }
390    }
391
392    fn transform_cast(&self, c: Cast) -> Result<Expression> {
393        // Athena type mappings are handled in the generator
394        Ok(Expression::Cast(Box::new(c)))
395    }
396}
397
398/// Determine if an expression should be generated using Hive engine (backticks)
399/// or Trino engine (double quotes).
400///
401/// Hive is used for:
402/// - CREATE EXTERNAL TABLE
403/// - CREATE TABLE (without AS SELECT)
404/// - CREATE SCHEMA / CREATE DATABASE
405/// - ALTER statements
406/// - DROP statements (except DROP VIEW)
407/// - DESCRIBE / SHOW statements
408///
409/// Trino is used for everything else (DML, CREATE VIEW, etc.)
410fn should_use_hive_engine(expr: &Expression) -> bool {
411    match expr {
412        // CREATE TABLE: Hive if EXTERNAL or no AS SELECT
413        Expression::CreateTable(ct) => {
414            // CREATE EXTERNAL TABLE → Hive
415            if let Some(ref modifier) = ct.table_modifier {
416                if modifier.to_uppercase() == "EXTERNAL" {
417                    return true;
418                }
419            }
420            // CREATE TABLE ... AS SELECT → Trino
421            // CREATE TABLE (without query) → Hive
422            ct.as_select.is_none()
423        }
424
425        // CREATE VIEW → Trino
426        Expression::CreateView(_) => false,
427
428        // CREATE SCHEMA / DATABASE → Hive
429        Expression::CreateSchema(_) => true,
430        Expression::CreateDatabase(_) => true,
431
432        // ALTER statements → Hive
433        Expression::AlterTable(_) => true,
434        Expression::AlterView(_) => true,
435        Expression::AlterIndex(_) => true,
436        Expression::AlterSequence(_) => true,
437
438        // DROP VIEW → Trino (because CREATE VIEW is Trino)
439        Expression::DropView(_) => false,
440
441        // Other DROP statements → Hive
442        Expression::DropTable(_) => true,
443        Expression::DropSchema(_) => true,
444        Expression::DropDatabase(_) => true,
445        Expression::DropIndex(_) => true,
446        Expression::DropFunction(_) => true,
447        Expression::DropProcedure(_) => true,
448        Expression::DropSequence(_) => true,
449
450        // DESCRIBE / SHOW → Hive
451        Expression::Describe(_) => true,
452        Expression::Show(_) => true,
453
454        // Everything else (SELECT, INSERT, UPDATE, DELETE, MERGE, etc.) → Trino
455        _ => false,
456    }
457}