1use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12pub struct TiDBDialect;
14
15impl DialectImpl for TiDBDialect {
16 fn dialect_type(&self) -> DialectType {
17 DialectType::TiDB
18 }
19
20 fn tokenizer_config(&self) -> TokenizerConfig {
21 let mut config = TokenizerConfig::default();
22 config.identifiers.insert('`', '`');
24 config.nested_comments = false;
25 config
26 }
27
28 fn generator_config(&self) -> GeneratorConfig {
29 use crate::generator::IdentifierQuoteStyle;
30 GeneratorConfig {
31 identifier_quote: '`',
32 identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
33 dialect: Some(DialectType::TiDB),
34 ..Default::default()
35 }
36 }
37
38 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
39 match expr {
40 Expression::IfNull(f) => Ok(Expression::IfNull(f)),
42
43 Expression::Nvl(f) => Ok(Expression::IfNull(f)),
45
46 Expression::TryCast(c) => Ok(Expression::Cast(c)),
48
49 Expression::SafeCast(c) => Ok(Expression::Cast(c)),
51
52 Expression::CountIf(f) => {
54 let case_expr = Expression::Case(Box::new(Case {
55 operand: None,
56 whens: vec![(f.this.clone(), Expression::number(1))],
57 else_: Some(Expression::number(0)),
58 comments: Vec::new(),
59 }));
60 Ok(Expression::Sum(Box::new(AggFunc {
61 ignore_nulls: None,
62 having_max: None,
63 this: case_expr,
64 distinct: f.distinct,
65 filter: f.filter,
66 order_by: Vec::new(),
67 name: None,
68 limit: None,
69 })))
70 }
71
72 Expression::Rand(r) => Ok(Expression::Rand(r)),
74
75 Expression::Function(f) => self.transform_function(*f),
77
78 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
80
81 Expression::Cast(c) => self.transform_cast(*c),
83
84 _ => Ok(expr),
86 }
87 }
88}
89
90impl TiDBDialect {
91 fn transform_function(&self, f: Function) -> Result<Expression> {
92 let name_upper = f.name.to_uppercase();
93 match name_upper.as_str() {
94 "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
96 "IFNULL".to_string(),
97 f.args,
98 )))),
99
100 "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
102 "IFNULL".to_string(),
103 f.args,
104 )))),
105
106 "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc {
108 original_name: None,
109 expressions: f.args,
110 }))),
111
112 "NOW" => Ok(Expression::CurrentTimestamp(
114 crate::expressions::CurrentTimestamp {
115 precision: None,
116 sysdate: false,
117 },
118 )),
119
120 "GETDATE" => Ok(Expression::CurrentTimestamp(
122 crate::expressions::CurrentTimestamp {
123 precision: None,
124 sysdate: false,
125 },
126 )),
127
128 "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
130
131 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
133 Function::new("GROUP_CONCAT".to_string(), f.args),
134 ))),
135
136 "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
138 "GROUP_CONCAT".to_string(),
139 f.args,
140 )))),
141
142 "SUBSTR" => Ok(Expression::Function(Box::new(f))),
144
145 "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
147
148 "LENGTH" => Ok(Expression::Function(Box::new(f))),
150
151 "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
153 "LENGTH".to_string(),
154 f.args,
155 )))),
156
157 "CHARINDEX" if f.args.len() >= 2 => {
159 let mut args = f.args;
160 let substring = args.remove(0);
161 let string = args.remove(0);
162 Ok(Expression::Function(Box::new(Function::new(
163 "INSTR".to_string(),
164 vec![string, substring],
165 ))))
166 }
167
168 "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
170 "INSTR".to_string(),
171 f.args,
172 )))),
173
174 "LOCATE" => Ok(Expression::Function(Box::new(f))),
176
177 "INSTR" => Ok(Expression::Function(Box::new(f))),
179
180 "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
182
183 "STRFTIME" if f.args.len() >= 2 => {
185 let mut args = f.args;
186 let format = args.remove(0);
187 let date = args.remove(0);
188 Ok(Expression::Function(Box::new(Function::new(
189 "DATE_FORMAT".to_string(),
190 vec![date, format],
191 ))))
192 }
193
194 "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
196 "DATE_FORMAT".to_string(),
197 f.args,
198 )))),
199
200 "STR_TO_DATE" => Ok(Expression::Function(Box::new(f))),
202
203 "TO_DATE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
205 "STR_TO_DATE".to_string(),
206 f.args,
207 )))),
208
209 "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
211
212 "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
214 Function::new("JSON_EXTRACT".to_string(), f.args),
215 ))),
216
217 "REGEXP" => Ok(Expression::Function(Box::new(f))),
219
220 "RLIKE" => Ok(Expression::Function(Box::new(f))),
222
223 "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
225 Function::new("REGEXP".to_string(), f.args),
226 ))),
227
228 _ => Ok(Expression::Function(Box::new(f))),
230 }
231 }
232
233 fn transform_aggregate_function(
234 &self,
235 f: Box<crate::expressions::AggregateFunction>,
236 ) -> Result<Expression> {
237 let name_upper = f.name.to_uppercase();
238 match name_upper.as_str() {
239 "COUNT_IF" if !f.args.is_empty() => {
241 let condition = f.args.into_iter().next().unwrap();
242 let case_expr = Expression::Case(Box::new(Case {
243 operand: None,
244 whens: vec![(condition, Expression::number(1))],
245 else_: Some(Expression::number(0)),
246 comments: Vec::new(),
247 }));
248 Ok(Expression::Sum(Box::new(AggFunc {
249 ignore_nulls: None,
250 having_max: None,
251 this: case_expr,
252 distinct: f.distinct,
253 filter: f.filter,
254 order_by: Vec::new(),
255 name: None,
256 limit: None,
257 })))
258 }
259
260 "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
262
263 _ => Ok(Expression::AggregateFunction(f)),
265 }
266 }
267
268 fn transform_cast(&self, c: Cast) -> Result<Expression> {
269 Ok(Expression::Cast(Box::new(c)))
271 }
272}