Skip to main content

polyglot_sql/dialects/
tidb.rs

1//! TiDB Dialect
2//!
3//! TiDB-specific transformations based on sqlglot patterns.
4//! TiDB is MySQL-compatible with distributed database extensions.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12/// TiDB dialect (MySQL-compatible distributed database)
13pub struct TiDBDialect;
14
15impl DialectImpl for TiDBDialect {
16    fn dialect_type(&self) -> DialectType {
17        DialectType::TiDB
18    }
19
20    fn tokenizer_config(&self) -> TokenizerConfig {
21        let mut config = TokenizerConfig::default();
22        // TiDB uses backticks for identifiers (MySQL-style)
23        config.identifiers.insert('`', '`');
24        config.nested_comments = false;
25        config
26    }
27
28    fn generator_config(&self) -> GeneratorConfig {
29        use crate::generator::IdentifierQuoteStyle;
30        GeneratorConfig {
31            identifier_quote: '`',
32            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
33            dialect: Some(DialectType::TiDB),
34            ..Default::default()
35        }
36    }
37
38    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
39        match expr {
40            // IFNULL is native in TiDB (MySQL-style)
41            Expression::IfNull(f) => Ok(Expression::IfNull(f)),
42
43            // NVL -> IFNULL in TiDB
44            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
45
46            // TryCast -> not directly supported, use CAST
47            Expression::TryCast(c) => Ok(Expression::Cast(c)),
48
49            // SafeCast -> CAST in TiDB
50            Expression::SafeCast(c) => Ok(Expression::Cast(c)),
51
52            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
53            Expression::CountIf(f) => {
54                let case_expr = Expression::Case(Box::new(Case {
55                    operand: None,
56                    whens: vec![(f.this.clone(), Expression::number(1))],
57                    else_: Some(Expression::number(0)),
58                }));
59                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
60                    this: case_expr,
61                    distinct: f.distinct,
62                    filter: f.filter,
63                    order_by: Vec::new(),
64                    name: None,
65                    limit: None,
66                })))
67            }
68
69            // RAND is native in TiDB
70            Expression::Rand(r) => Ok(Expression::Rand(r)),
71
72            // Generic function transformations
73            Expression::Function(f) => self.transform_function(*f),
74
75            // Generic aggregate function transformations
76            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
77
78            // Cast transformations
79            Expression::Cast(c) => self.transform_cast(*c),
80
81            // Pass through everything else
82            _ => Ok(expr),
83        }
84    }
85}
86
87impl TiDBDialect {
88    fn transform_function(&self, f: Function) -> Result<Expression> {
89        let name_upper = f.name.to_uppercase();
90        match name_upper.as_str() {
91            // NVL -> IFNULL
92            "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
93                "IFNULL".to_string(),
94                f.args,
95            )))),
96
97            // ISNULL -> IFNULL
98            "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
99                "IFNULL".to_string(),
100                f.args,
101            )))),
102
103            // COALESCE is native in TiDB
104            "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
105                expressions: f.args,
106            }))),
107
108            // NOW is native in TiDB
109            "NOW" => Ok(Expression::CurrentTimestamp(
110                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
111            )),
112
113            // GETDATE -> NOW
114            "GETDATE" => Ok(Expression::CurrentTimestamp(
115                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
116            )),
117
118            // GROUP_CONCAT is native in TiDB
119            "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
120
121            // STRING_AGG -> GROUP_CONCAT
122            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
123                "GROUP_CONCAT".to_string(),
124                f.args,
125            )))),
126
127            // LISTAGG -> GROUP_CONCAT
128            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
129                "GROUP_CONCAT".to_string(),
130                f.args,
131            )))),
132
133            // SUBSTR is native in TiDB
134            "SUBSTR" => Ok(Expression::Function(Box::new(f))),
135
136            // SUBSTRING is native in TiDB
137            "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
138
139            // LENGTH is native in TiDB
140            "LENGTH" => Ok(Expression::Function(Box::new(f))),
141
142            // LEN -> LENGTH
143            "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
144                "LENGTH".to_string(),
145                f.args,
146            )))),
147
148            // CHARINDEX -> INSTR (with swapped args)
149            "CHARINDEX" if f.args.len() >= 2 => {
150                let mut args = f.args;
151                let substring = args.remove(0);
152                let string = args.remove(0);
153                Ok(Expression::Function(Box::new(Function::new(
154                    "INSTR".to_string(),
155                    vec![string, substring],
156                ))))
157            }
158
159            // STRPOS -> INSTR
160            "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
161                "INSTR".to_string(),
162                f.args,
163            )))),
164
165            // LOCATE is native in TiDB
166            "LOCATE" => Ok(Expression::Function(Box::new(f))),
167
168            // INSTR is native in TiDB
169            "INSTR" => Ok(Expression::Function(Box::new(f))),
170
171            // DATE_FORMAT is native in TiDB
172            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
173
174            // strftime -> DATE_FORMAT
175            "STRFTIME" if f.args.len() >= 2 => {
176                let mut args = f.args;
177                let format = args.remove(0);
178                let date = args.remove(0);
179                Ok(Expression::Function(Box::new(Function::new(
180                    "DATE_FORMAT".to_string(),
181                    vec![date, format],
182                ))))
183            }
184
185            // TO_CHAR -> DATE_FORMAT
186            "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
187                "DATE_FORMAT".to_string(),
188                f.args,
189            )))),
190
191            // STR_TO_DATE is native in TiDB
192            "STR_TO_DATE" => Ok(Expression::Function(Box::new(f))),
193
194            // TO_DATE -> STR_TO_DATE
195            "TO_DATE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
196                "STR_TO_DATE".to_string(),
197                f.args,
198            )))),
199
200            // JSON_EXTRACT is native in TiDB
201            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
202
203            // GET_JSON_OBJECT -> JSON_EXTRACT
204            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
205                Function::new("JSON_EXTRACT".to_string(), f.args),
206            ))),
207
208            // REGEXP is native in TiDB
209            "REGEXP" => Ok(Expression::Function(Box::new(f))),
210
211            // RLIKE is native in TiDB
212            "RLIKE" => Ok(Expression::Function(Box::new(f))),
213
214            // REGEXP_LIKE -> REGEXP
215            "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
216                "REGEXP".to_string(),
217                f.args,
218            )))),
219
220            // Pass through everything else
221            _ => Ok(Expression::Function(Box::new(f))),
222        }
223    }
224
225    fn transform_aggregate_function(
226        &self,
227        f: Box<crate::expressions::AggregateFunction>,
228    ) -> Result<Expression> {
229        let name_upper = f.name.to_uppercase();
230        match name_upper.as_str() {
231            // COUNT_IF -> SUM(CASE WHEN...)
232            "COUNT_IF" if !f.args.is_empty() => {
233                let condition = f.args.into_iter().next().unwrap();
234                let case_expr = Expression::Case(Box::new(Case {
235                    operand: None,
236                    whens: vec![(condition, Expression::number(1))],
237                    else_: Some(Expression::number(0)),
238                }));
239                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
240                    this: case_expr,
241                    distinct: f.distinct,
242                    filter: f.filter,
243                    order_by: Vec::new(),
244                    name: None,
245                    limit: None,
246                })))
247            }
248
249            // APPROX_COUNT_DISTINCT is native in TiDB
250            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
251
252            // Pass through everything else
253            _ => Ok(Expression::AggregateFunction(f)),
254        }
255    }
256
257    fn transform_cast(&self, c: Cast) -> Result<Expression> {
258        // TiDB type mappings are handled in the generator
259        Ok(Expression::Cast(Box::new(c)))
260    }
261}