Skip to main content

polyglot_sql/dialects/
tidb.rs

1//! TiDB Dialect
2//!
3//! TiDB-specific transformations based on sqlglot patterns.
4//! TiDB is MySQL-compatible with distributed database extensions.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12/// TiDB dialect (MySQL-compatible distributed database)
13pub struct TiDBDialect;
14
15impl DialectImpl for TiDBDialect {
16    fn dialect_type(&self) -> DialectType {
17        DialectType::TiDB
18    }
19
20    fn tokenizer_config(&self) -> TokenizerConfig {
21        let mut config = TokenizerConfig::default();
22        // TiDB uses backticks for identifiers (MySQL-style)
23        config.identifiers.insert('`', '`');
24        config.nested_comments = false;
25        config
26    }
27
28    fn generator_config(&self) -> GeneratorConfig {
29        use crate::generator::IdentifierQuoteStyle;
30        GeneratorConfig {
31            identifier_quote: '`',
32            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
33            dialect: Some(DialectType::TiDB),
34            ..Default::default()
35        }
36    }
37
38    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
39        match expr {
40            // IFNULL is native in TiDB (MySQL-style)
41            Expression::IfNull(f) => Ok(Expression::IfNull(f)),
42
43            // NVL -> IFNULL in TiDB
44            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
45
46            // TryCast -> not directly supported, use CAST
47            Expression::TryCast(c) => Ok(Expression::Cast(c)),
48
49            // SafeCast -> CAST in TiDB
50            Expression::SafeCast(c) => Ok(Expression::Cast(c)),
51
52            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
53            Expression::CountIf(f) => {
54                let case_expr = Expression::Case(Box::new(Case {
55                    operand: None,
56                    whens: vec![(f.this.clone(), Expression::number(1))],
57                    else_: Some(Expression::number(0)),
58                    comments: Vec::new(),
59                    inferred_type: None,
60                }));
61                Ok(Expression::Sum(Box::new(AggFunc {
62                    ignore_nulls: None,
63                    having_max: None,
64                    this: case_expr,
65                    distinct: f.distinct,
66                    filter: f.filter,
67                    order_by: Vec::new(),
68                    name: None,
69                    limit: None,
70                    inferred_type: None,
71                })))
72            }
73
74            // RAND is native in TiDB
75            Expression::Rand(r) => Ok(Expression::Rand(r)),
76
77            // Generic function transformations
78            Expression::Function(f) => self.transform_function(*f),
79
80            // Generic aggregate function transformations
81            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
82
83            // Cast transformations
84            Expression::Cast(c) => self.transform_cast(*c),
85
86            // Pass through everything else
87            _ => Ok(expr),
88        }
89    }
90}
91
92impl TiDBDialect {
93    fn transform_function(&self, f: Function) -> Result<Expression> {
94        let name_upper = f.name.to_uppercase();
95        match name_upper.as_str() {
96            // NVL -> IFNULL
97            "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
98                "IFNULL".to_string(),
99                f.args,
100            )))),
101
102            // ISNULL -> IFNULL
103            "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
104                "IFNULL".to_string(),
105                f.args,
106            )))),
107
108            // COALESCE is native in TiDB
109            "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc {
110                original_name: None,
111                expressions: f.args,
112                inferred_type: None,
113            }))),
114
115            // NOW is native in TiDB
116            "NOW" => Ok(Expression::CurrentTimestamp(
117                crate::expressions::CurrentTimestamp {
118                    precision: None,
119                    sysdate: false,
120                },
121            )),
122
123            // GETDATE -> NOW
124            "GETDATE" => Ok(Expression::CurrentTimestamp(
125                crate::expressions::CurrentTimestamp {
126                    precision: None,
127                    sysdate: false,
128                },
129            )),
130
131            // GROUP_CONCAT is native in TiDB
132            "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
133
134            // STRING_AGG -> GROUP_CONCAT
135            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
136                Function::new("GROUP_CONCAT".to_string(), f.args),
137            ))),
138
139            // LISTAGG -> GROUP_CONCAT
140            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
141                "GROUP_CONCAT".to_string(),
142                f.args,
143            )))),
144
145            // SUBSTR is native in TiDB
146            "SUBSTR" => Ok(Expression::Function(Box::new(f))),
147
148            // SUBSTRING is native in TiDB
149            "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
150
151            // LENGTH is native in TiDB
152            "LENGTH" => Ok(Expression::Function(Box::new(f))),
153
154            // LEN -> LENGTH
155            "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
156                "LENGTH".to_string(),
157                f.args,
158            )))),
159
160            // CHARINDEX -> INSTR (with swapped args)
161            "CHARINDEX" if f.args.len() >= 2 => {
162                let mut args = f.args;
163                let substring = args.remove(0);
164                let string = args.remove(0);
165                Ok(Expression::Function(Box::new(Function::new(
166                    "INSTR".to_string(),
167                    vec![string, substring],
168                ))))
169            }
170
171            // STRPOS -> INSTR
172            "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
173                "INSTR".to_string(),
174                f.args,
175            )))),
176
177            // LOCATE is native in TiDB
178            "LOCATE" => Ok(Expression::Function(Box::new(f))),
179
180            // INSTR is native in TiDB
181            "INSTR" => Ok(Expression::Function(Box::new(f))),
182
183            // DATE_FORMAT is native in TiDB
184            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
185
186            // strftime -> DATE_FORMAT
187            "STRFTIME" if f.args.len() >= 2 => {
188                let mut args = f.args;
189                let format = args.remove(0);
190                let date = args.remove(0);
191                Ok(Expression::Function(Box::new(Function::new(
192                    "DATE_FORMAT".to_string(),
193                    vec![date, format],
194                ))))
195            }
196
197            // TO_CHAR -> DATE_FORMAT
198            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
199                "DATE_FORMAT".to_string(),
200                f.args,
201            )))),
202
203            // STR_TO_DATE is native in TiDB
204            "STR_TO_DATE" => Ok(Expression::Function(Box::new(f))),
205
206            // TO_DATE -> STR_TO_DATE
207            "TO_DATE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
208                "STR_TO_DATE".to_string(),
209                f.args,
210            )))),
211
212            // JSON_EXTRACT is native in TiDB
213            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
214
215            // GET_JSON_OBJECT -> JSON_EXTRACT
216            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
217                Function::new("JSON_EXTRACT".to_string(), f.args),
218            ))),
219
220            // REGEXP is native in TiDB
221            "REGEXP" => Ok(Expression::Function(Box::new(f))),
222
223            // RLIKE is native in TiDB
224            "RLIKE" => Ok(Expression::Function(Box::new(f))),
225
226            // REGEXP_LIKE -> REGEXP
227            "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
228                Function::new("REGEXP".to_string(), f.args),
229            ))),
230
231            // Pass through everything else
232            _ => Ok(Expression::Function(Box::new(f))),
233        }
234    }
235
236    fn transform_aggregate_function(
237        &self,
238        f: Box<crate::expressions::AggregateFunction>,
239    ) -> Result<Expression> {
240        let name_upper = f.name.to_uppercase();
241        match name_upper.as_str() {
242            // COUNT_IF -> SUM(CASE WHEN...)
243            "COUNT_IF" if !f.args.is_empty() => {
244                let condition = f.args.into_iter().next().unwrap();
245                let case_expr = Expression::Case(Box::new(Case {
246                    operand: None,
247                    whens: vec![(condition, Expression::number(1))],
248                    else_: Some(Expression::number(0)),
249                    comments: Vec::new(),
250                    inferred_type: None,
251                }));
252                Ok(Expression::Sum(Box::new(AggFunc {
253                    ignore_nulls: None,
254                    having_max: None,
255                    this: case_expr,
256                    distinct: f.distinct,
257                    filter: f.filter,
258                    order_by: Vec::new(),
259                    name: None,
260                    limit: None,
261                    inferred_type: None,
262                })))
263            }
264
265            // APPROX_COUNT_DISTINCT is native in TiDB
266            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
267
268            // Pass through everything else
269            _ => Ok(Expression::AggregateFunction(f)),
270        }
271    }
272
273    fn transform_cast(&self, c: Cast) -> Result<Expression> {
274        // TiDB type mappings are handled in the generator
275        Ok(Expression::Cast(Box::new(c)))
276    }
277}