Skip to main content

polyglot_sql/dialects/
athena.rs

1//! Athena Dialect
2//!
3//! AWS Athena-specific transformations based on sqlglot patterns.
4//! Athena routes between Hive (DDL) and Trino (DML) engines:
5//!
6//! - **Hive** (backticks): CREATE EXTERNAL TABLE, CREATE TABLE (no AS SELECT),
7//!   ALTER, DROP (except VIEW), DESCRIBE, SHOW
8//! - **Trino** (double quotes): CREATE VIEW, CREATE TABLE AS SELECT, DROP VIEW,
9//!   SELECT, INSERT, UPDATE, DELETE, MERGE
10
11use super::{DialectImpl, DialectType};
12use crate::error::Result;
13use crate::expressions::{
14    AggFunc, Case, Cast, DataType, Expression, Function, LikeOp, UnaryFunc, VarArgFunc,
15};
16#[cfg(feature = "generate")]
17use crate::generator::{GeneratorConfig, IdentifierQuoteStyle};
18use crate::tokens::TokenizerConfig;
19
20/// Athena dialect (based on Trino for DML operations)
21pub struct AthenaDialect;
22
23impl DialectImpl for AthenaDialect {
24    fn dialect_type(&self) -> DialectType {
25        DialectType::Athena
26    }
27
28    fn tokenizer_config(&self) -> TokenizerConfig {
29        let mut config = TokenizerConfig::default();
30        // Athena uses double quotes for identifiers (Trino-style for DML)
31        config.identifiers.insert('"', '"');
32        // Also supports backticks (Hive-style for DDL)
33        config.identifiers.insert('`', '`');
34        config.nested_comments = false;
35        // Athena/Hive supports backslash escapes in string literals (e.g., \' for escaped quote)
36        config.string_escapes.push('\\');
37        config
38    }
39
40    #[cfg(feature = "generate")]
41
42    fn generator_config(&self) -> GeneratorConfig {
43        // Default config uses Trino style (double quotes)
44        GeneratorConfig {
45            identifier_quote: '"',
46            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
47            dialect: Some(DialectType::Athena),
48            schema_comment_with_eq: false,
49            ..Default::default()
50        }
51    }
52
53    #[cfg(feature = "generate")]
54
55    fn generator_config_for_expr(&self, expr: &Expression) -> GeneratorConfig {
56        if should_use_hive_engine(expr) {
57            // Hive mode: backticks for identifiers
58            GeneratorConfig {
59                identifier_quote: '`',
60                identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
61                dialect: Some(DialectType::Athena),
62                schema_comment_with_eq: false,
63                ..Default::default()
64            }
65        } else {
66            // Trino mode: double quotes for identifiers
67            GeneratorConfig {
68                identifier_quote: '"',
69                identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
70                dialect: Some(DialectType::Athena),
71                schema_comment_with_eq: false,
72                ..Default::default()
73            }
74        }
75    }
76
77    #[cfg(feature = "transpile")]
78
79    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
80        match expr {
81            // IFNULL -> COALESCE in Athena
82            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
83                original_name: None,
84                expressions: vec![f.this, f.expression],
85                inferred_type: None,
86            }))),
87
88            // NVL -> COALESCE in Athena
89            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
90                original_name: None,
91                expressions: vec![f.this, f.expression],
92                inferred_type: None,
93            }))),
94
95            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
96            Expression::Coalesce(mut f) => {
97                f.original_name = None;
98                Ok(Expression::Coalesce(f))
99            }
100
101            // TryCast stays as TryCast (Athena/Trino supports TRY_CAST)
102            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
103
104            // SafeCast -> TRY_CAST in Athena
105            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
106
107            // ILike -> LOWER() LIKE LOWER() (Trino doesn't support ILIKE)
108            Expression::ILike(op) => {
109                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left.clone())));
110                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right.clone())));
111                Ok(Expression::Like(Box::new(LikeOp {
112                    left: lower_left,
113                    right: lower_right,
114                    escape: op.escape,
115                    quantifier: op.quantifier.clone(),
116                    inferred_type: None,
117                })))
118            }
119
120            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
121            Expression::CountIf(f) => {
122                let case_expr = Expression::Case(Box::new(Case {
123                    operand: None,
124                    whens: vec![(f.this.clone(), Expression::number(1))],
125                    else_: Some(Expression::number(0)),
126                    comments: Vec::new(),
127                    inferred_type: None,
128                }));
129                Ok(Expression::Sum(Box::new(AggFunc {
130                    ignore_nulls: None,
131                    having_max: None,
132                    this: case_expr,
133                    distinct: f.distinct,
134                    filter: f.filter,
135                    order_by: Vec::new(),
136                    name: None,
137                    limit: None,
138                    inferred_type: None,
139                })))
140            }
141
142            // EXPLODE -> UNNEST in Athena
143            Expression::Explode(f) => Ok(Expression::Unnest(Box::new(
144                crate::expressions::UnnestFunc {
145                    this: f.this,
146                    expressions: Vec::new(),
147                    with_ordinality: false,
148                    alias: None,
149                    offset_alias: None,
150                },
151            ))),
152
153            // ExplodeOuter -> UNNEST in Athena
154            Expression::ExplodeOuter(f) => Ok(Expression::Unnest(Box::new(
155                crate::expressions::UnnestFunc {
156                    this: f.this,
157                    expressions: Vec::new(),
158                    with_ordinality: false,
159                    alias: None,
160                    offset_alias: None,
161                },
162            ))),
163
164            // Generic function transformations
165            Expression::Function(f) => self.transform_function(*f),
166
167            // Generic aggregate function transformations
168            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
169
170            // Cast transformations
171            Expression::Cast(c) => self.transform_cast(*c),
172
173            // Pass through everything else
174            _ => Ok(expr),
175        }
176    }
177}
178
179#[cfg(feature = "transpile")]
180impl AthenaDialect {
181    fn transform_function(&self, f: Function) -> Result<Expression> {
182        let name_upper = f.name.to_uppercase();
183        match name_upper.as_str() {
184            // IFNULL -> COALESCE
185            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
186                original_name: None,
187                expressions: f.args,
188                inferred_type: None,
189            }))),
190
191            // NVL -> COALESCE
192            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
193                original_name: None,
194                expressions: f.args,
195                inferred_type: None,
196            }))),
197
198            // ISNULL -> COALESCE
199            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
200                original_name: None,
201                expressions: f.args,
202                inferred_type: None,
203            }))),
204
205            // GETDATE -> CURRENT_TIMESTAMP
206            "GETDATE" => Ok(Expression::CurrentTimestamp(
207                crate::expressions::CurrentTimestamp {
208                    precision: None,
209                    sysdate: false,
210                },
211            )),
212
213            // NOW -> CURRENT_TIMESTAMP
214            "NOW" => Ok(Expression::CurrentTimestamp(
215                crate::expressions::CurrentTimestamp {
216                    precision: None,
217                    sysdate: false,
218                },
219            )),
220
221            // RAND -> RANDOM in Athena
222            "RAND" => Ok(Expression::Function(Box::new(Function::new(
223                "RANDOM".to_string(),
224                vec![],
225            )))),
226
227            // GROUP_CONCAT -> LISTAGG in Athena (Trino-style)
228            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
229                Function::new("LISTAGG".to_string(), f.args),
230            ))),
231
232            // STRING_AGG -> LISTAGG in Athena
233            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
234                Function::new("LISTAGG".to_string(), f.args),
235            ))),
236
237            // SUBSTR -> SUBSTRING
238            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
239                "SUBSTRING".to_string(),
240                f.args,
241            )))),
242
243            // LEN -> LENGTH
244            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
245                f.args.into_iter().next().unwrap(),
246            )))),
247
248            // CHARINDEX -> STRPOS in Athena (with swapped args)
249            "CHARINDEX" if f.args.len() >= 2 => {
250                let mut args = f.args;
251                let substring = args.remove(0);
252                let string = args.remove(0);
253                Ok(Expression::Function(Box::new(Function::new(
254                    "STRPOS".to_string(),
255                    vec![string, substring],
256                ))))
257            }
258
259            // INSTR -> STRPOS
260            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
261                "STRPOS".to_string(),
262                f.args,
263            )))),
264
265            // LOCATE -> STRPOS in Athena (with swapped args)
266            "LOCATE" if f.args.len() >= 2 => {
267                let mut args = f.args;
268                let substring = args.remove(0);
269                let string = args.remove(0);
270                Ok(Expression::Function(Box::new(Function::new(
271                    "STRPOS".to_string(),
272                    vec![string, substring],
273                ))))
274            }
275
276            // ARRAY_LENGTH -> CARDINALITY in Athena
277            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
278                Function::new("CARDINALITY".to_string(), f.args),
279            ))),
280
281            // SIZE -> CARDINALITY in Athena
282            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
283                "CARDINALITY".to_string(),
284                f.args,
285            )))),
286
287            // TO_DATE -> CAST to DATE or DATE_PARSE
288            "TO_DATE" if !f.args.is_empty() => {
289                if f.args.len() == 1 {
290                    Ok(Expression::Cast(Box::new(Cast {
291                        this: f.args.into_iter().next().unwrap(),
292                        to: DataType::Date,
293                        trailing_comments: Vec::new(),
294                        double_colon_syntax: false,
295                        format: None,
296                        default: None,
297                        inferred_type: None,
298                    })))
299                } else {
300                    Ok(Expression::Function(Box::new(Function::new(
301                        "DATE_PARSE".to_string(),
302                        f.args,
303                    ))))
304                }
305            }
306
307            // TO_TIMESTAMP -> CAST or DATE_PARSE
308            "TO_TIMESTAMP" if !f.args.is_empty() => {
309                if f.args.len() == 1 {
310                    Ok(Expression::Cast(Box::new(Cast {
311                        this: f.args.into_iter().next().unwrap(),
312                        to: DataType::Timestamp {
313                            precision: None,
314                            timezone: false,
315                        },
316                        trailing_comments: Vec::new(),
317                        double_colon_syntax: false,
318                        format: None,
319                        default: None,
320                        inferred_type: None,
321                    })))
322                } else {
323                    Ok(Expression::Function(Box::new(Function::new(
324                        "DATE_PARSE".to_string(),
325                        f.args,
326                    ))))
327                }
328            }
329
330            // strftime -> DATE_FORMAT in Athena
331            "STRFTIME" if f.args.len() >= 2 => {
332                let mut args = f.args;
333                let format = args.remove(0);
334                let date = args.remove(0);
335                Ok(Expression::Function(Box::new(Function::new(
336                    "DATE_FORMAT".to_string(),
337                    vec![date, format],
338                ))))
339            }
340
341            // TO_CHAR -> DATE_FORMAT in Athena
342            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
343                "DATE_FORMAT".to_string(),
344                f.args,
345            )))),
346
347            // GET_JSON_OBJECT -> JSON_EXTRACT_SCALAR in Athena
348            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
349                Function::new("JSON_EXTRACT_SCALAR".to_string(), f.args),
350            ))),
351
352            // COLLECT_LIST -> ARRAY_AGG
353            "COLLECT_LIST" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
354                Function::new("ARRAY_AGG".to_string(), f.args),
355            ))),
356
357            // Pass through everything else
358            _ => Ok(Expression::Function(Box::new(f))),
359        }
360    }
361
362    fn transform_aggregate_function(
363        &self,
364        f: Box<crate::expressions::AggregateFunction>,
365    ) -> Result<Expression> {
366        let name_upper = f.name.to_uppercase();
367        match name_upper.as_str() {
368            // COUNT_IF -> SUM(CASE WHEN...)
369            "COUNT_IF" if !f.args.is_empty() => {
370                let condition = f.args.into_iter().next().unwrap();
371                let case_expr = Expression::Case(Box::new(Case {
372                    operand: None,
373                    whens: vec![(condition, Expression::number(1))],
374                    else_: Some(Expression::number(0)),
375                    comments: Vec::new(),
376                    inferred_type: None,
377                }));
378                Ok(Expression::Sum(Box::new(AggFunc {
379                    ignore_nulls: None,
380                    having_max: None,
381                    this: case_expr,
382                    distinct: f.distinct,
383                    filter: f.filter,
384                    order_by: Vec::new(),
385                    name: None,
386                    limit: None,
387                    inferred_type: None,
388                })))
389            }
390
391            // ANY_VALUE -> ARBITRARY in Athena (Trino)
392            "ANY_VALUE" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
393                "ARBITRARY".to_string(),
394                f.args,
395            )))),
396
397            // GROUP_CONCAT -> LISTAGG in Athena
398            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
399                Function::new("LISTAGG".to_string(), f.args),
400            ))),
401
402            // STRING_AGG -> LISTAGG in Athena
403            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
404                Function::new("LISTAGG".to_string(), f.args),
405            ))),
406
407            // Pass through everything else
408            _ => Ok(Expression::AggregateFunction(f)),
409        }
410    }
411
412    fn transform_cast(&self, c: Cast) -> Result<Expression> {
413        // Athena type mappings are handled in the generator
414        Ok(Expression::Cast(Box::new(c)))
415    }
416}
417
418/// Determine if an expression should be generated using Hive engine (backticks)
419/// or Trino engine (double quotes).
420///
421/// Hive is used for:
422/// - CREATE EXTERNAL TABLE
423/// - CREATE TABLE (without AS SELECT)
424/// - CREATE SCHEMA / CREATE DATABASE
425/// - ALTER statements
426/// - DROP statements (except DROP VIEW)
427/// - DESCRIBE / SHOW statements
428///
429/// Trino is used for everything else (DML, CREATE VIEW, etc.)
430fn should_use_hive_engine(expr: &Expression) -> bool {
431    match expr {
432        // CREATE TABLE: Hive if EXTERNAL or no AS SELECT
433        Expression::CreateTable(ct) => {
434            // CREATE EXTERNAL TABLE → Hive
435            if let Some(ref modifier) = ct.table_modifier {
436                if modifier.to_uppercase() == "EXTERNAL" {
437                    return true;
438                }
439            }
440            // CREATE TABLE ... AS SELECT → Trino
441            // CREATE TABLE (without query) → Hive
442            ct.as_select.is_none()
443        }
444
445        // CREATE VIEW → Trino
446        Expression::CreateView(_) => false,
447
448        // CREATE SCHEMA / DATABASE → Hive
449        Expression::CreateSchema(_) => true,
450        Expression::CreateDatabase(_) => true,
451
452        // ALTER statements → Hive
453        Expression::AlterTable(_) => true,
454        Expression::AlterView(_) => true,
455        Expression::AlterIndex(_) => true,
456        Expression::AlterSequence(_) => true,
457
458        // DROP VIEW → Trino (because CREATE VIEW is Trino)
459        Expression::DropView(_) => false,
460
461        // Other DROP statements → Hive
462        Expression::DropTable(_) => true,
463        Expression::DropSchema(_) => true,
464        Expression::DropDatabase(_) => true,
465        Expression::DropIndex(_) => true,
466        Expression::DropFunction(_) => true,
467        Expression::DropProcedure(_) => true,
468        Expression::DropSequence(_) => true,
469
470        // DESCRIBE / SHOW → Hive
471        Expression::Describe(_) => true,
472        Expression::Show(_) => true,
473
474        // Everything else (SELECT, INSERT, UPDATE, DELETE, MERGE, etc.) → Trino
475        _ => false,
476    }
477}