Skip to main content

polyglot_sql/dialects/
tidb.rs

1//! TiDB Dialect
2//!
3//! TiDB-specific transformations based on sqlglot patterns.
4//! TiDB is MySQL-compatible with distributed database extensions.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9#[cfg(feature = "generate")]
10use crate::generator::GeneratorConfig;
11use crate::tokens::TokenizerConfig;
12
13/// TiDB dialect (MySQL-compatible distributed database)
14pub struct TiDBDialect;
15
16impl DialectImpl for TiDBDialect {
17    fn dialect_type(&self) -> DialectType {
18        DialectType::TiDB
19    }
20
21    fn tokenizer_config(&self) -> TokenizerConfig {
22        let mut config = TokenizerConfig::default();
23        // TiDB uses backticks for identifiers (MySQL-style)
24        config.identifiers.insert('`', '`');
25        config.nested_comments = false;
26        config
27    }
28
29    #[cfg(feature = "generate")]
30
31    fn generator_config(&self) -> GeneratorConfig {
32        use crate::generator::IdentifierQuoteStyle;
33        GeneratorConfig {
34            identifier_quote: '`',
35            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
36            dialect: Some(DialectType::TiDB),
37            ..Default::default()
38        }
39    }
40
41    #[cfg(feature = "transpile")]
42
43    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
44        match expr {
45            // IFNULL is native in TiDB (MySQL-style)
46            Expression::IfNull(f) => Ok(Expression::IfNull(f)),
47
48            // NVL -> IFNULL in TiDB
49            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
50
51            // TryCast -> not directly supported, use CAST
52            Expression::TryCast(c) => Ok(Expression::Cast(c)),
53
54            // SafeCast -> CAST in TiDB
55            Expression::SafeCast(c) => Ok(Expression::Cast(c)),
56
57            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
58            Expression::CountIf(f) => {
59                let case_expr = Expression::Case(Box::new(Case {
60                    operand: None,
61                    whens: vec![(f.this.clone(), Expression::number(1))],
62                    else_: Some(Expression::number(0)),
63                    comments: Vec::new(),
64                    inferred_type: None,
65                }));
66                Ok(Expression::Sum(Box::new(AggFunc {
67                    ignore_nulls: None,
68                    having_max: None,
69                    this: case_expr,
70                    distinct: f.distinct,
71                    filter: f.filter,
72                    order_by: Vec::new(),
73                    name: None,
74                    limit: None,
75                    inferred_type: None,
76                })))
77            }
78
79            // RAND is native in TiDB
80            Expression::Rand(r) => Ok(Expression::Rand(r)),
81
82            // Generic function transformations
83            Expression::Function(f) => self.transform_function(*f),
84
85            // Generic aggregate function transformations
86            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
87
88            // Cast transformations
89            Expression::Cast(c) => self.transform_cast(*c),
90
91            // Pass through everything else
92            _ => Ok(expr),
93        }
94    }
95}
96
97#[cfg(feature = "transpile")]
98impl TiDBDialect {
99    fn transform_function(&self, f: Function) -> Result<Expression> {
100        let name_upper = f.name.to_uppercase();
101        match name_upper.as_str() {
102            // NVL -> IFNULL
103            "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
104                "IFNULL".to_string(),
105                f.args,
106            )))),
107
108            // ISNULL -> IFNULL
109            "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
110                "IFNULL".to_string(),
111                f.args,
112            )))),
113
114            // COALESCE is native in TiDB
115            "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc {
116                original_name: None,
117                expressions: f.args,
118                inferred_type: None,
119            }))),
120
121            // NOW is native in TiDB
122            "NOW" => Ok(Expression::CurrentTimestamp(
123                crate::expressions::CurrentTimestamp {
124                    precision: None,
125                    sysdate: false,
126                },
127            )),
128
129            // GETDATE -> NOW
130            "GETDATE" => Ok(Expression::CurrentTimestamp(
131                crate::expressions::CurrentTimestamp {
132                    precision: None,
133                    sysdate: false,
134                },
135            )),
136
137            // GROUP_CONCAT is native in TiDB
138            "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
139
140            // STRING_AGG -> GROUP_CONCAT
141            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
142                Function::new("GROUP_CONCAT".to_string(), f.args),
143            ))),
144
145            // LISTAGG -> GROUP_CONCAT
146            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
147                "GROUP_CONCAT".to_string(),
148                f.args,
149            )))),
150
151            // SUBSTR is native in TiDB
152            "SUBSTR" => Ok(Expression::Function(Box::new(f))),
153
154            // SUBSTRING is native in TiDB
155            "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
156
157            // LENGTH is native in TiDB
158            "LENGTH" => Ok(Expression::Function(Box::new(f))),
159
160            // LEN -> LENGTH
161            "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
162                "LENGTH".to_string(),
163                f.args,
164            )))),
165
166            // CHARINDEX -> INSTR (with swapped args)
167            "CHARINDEX" if f.args.len() >= 2 => {
168                let mut args = f.args;
169                let substring = args.remove(0);
170                let string = args.remove(0);
171                Ok(Expression::Function(Box::new(Function::new(
172                    "INSTR".to_string(),
173                    vec![string, substring],
174                ))))
175            }
176
177            // STRPOS -> INSTR
178            "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
179                "INSTR".to_string(),
180                f.args,
181            )))),
182
183            // LOCATE is native in TiDB
184            "LOCATE" => Ok(Expression::Function(Box::new(f))),
185
186            // INSTR is native in TiDB
187            "INSTR" => Ok(Expression::Function(Box::new(f))),
188
189            // DATE_FORMAT is native in TiDB
190            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
191
192            // strftime -> DATE_FORMAT
193            "STRFTIME" if f.args.len() >= 2 => {
194                let mut args = f.args;
195                let format = args.remove(0);
196                let date = args.remove(0);
197                Ok(Expression::Function(Box::new(Function::new(
198                    "DATE_FORMAT".to_string(),
199                    vec![date, format],
200                ))))
201            }
202
203            // TO_CHAR -> DATE_FORMAT
204            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
205                "DATE_FORMAT".to_string(),
206                f.args,
207            )))),
208
209            // STR_TO_DATE is native in TiDB
210            "STR_TO_DATE" => Ok(Expression::Function(Box::new(f))),
211
212            // TO_DATE -> STR_TO_DATE
213            "TO_DATE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
214                "STR_TO_DATE".to_string(),
215                f.args,
216            )))),
217
218            // JSON_EXTRACT is native in TiDB
219            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
220
221            // GET_JSON_OBJECT -> JSON_EXTRACT
222            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
223                Function::new("JSON_EXTRACT".to_string(), f.args),
224            ))),
225
226            // REGEXP is native in TiDB
227            "REGEXP" => Ok(Expression::Function(Box::new(f))),
228
229            // RLIKE is native in TiDB
230            "RLIKE" => Ok(Expression::Function(Box::new(f))),
231
232            // REGEXP_LIKE -> REGEXP
233            "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
234                Function::new("REGEXP".to_string(), f.args),
235            ))),
236
237            // Pass through everything else
238            _ => Ok(Expression::Function(Box::new(f))),
239        }
240    }
241
242    fn transform_aggregate_function(
243        &self,
244        f: Box<crate::expressions::AggregateFunction>,
245    ) -> Result<Expression> {
246        let name_upper = f.name.to_uppercase();
247        match name_upper.as_str() {
248            // COUNT_IF -> SUM(CASE WHEN...)
249            "COUNT_IF" if !f.args.is_empty() => {
250                let condition = f.args.into_iter().next().unwrap();
251                let case_expr = Expression::Case(Box::new(Case {
252                    operand: None,
253                    whens: vec![(condition, Expression::number(1))],
254                    else_: Some(Expression::number(0)),
255                    comments: Vec::new(),
256                    inferred_type: None,
257                }));
258                Ok(Expression::Sum(Box::new(AggFunc {
259                    ignore_nulls: None,
260                    having_max: None,
261                    this: case_expr,
262                    distinct: f.distinct,
263                    filter: f.filter,
264                    order_by: Vec::new(),
265                    name: None,
266                    limit: None,
267                    inferred_type: None,
268                })))
269            }
270
271            // APPROX_COUNT_DISTINCT is native in TiDB
272            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
273
274            // Pass through everything else
275            _ => Ok(Expression::AggregateFunction(f)),
276        }
277    }
278
279    fn transform_cast(&self, c: Cast) -> Result<Expression> {
280        // TiDB type mappings are handled in the generator
281        Ok(Expression::Cast(Box::new(c)))
282    }
283}