1use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12pub struct TiDBDialect;
14
15impl DialectImpl for TiDBDialect {
16 fn dialect_type(&self) -> DialectType {
17 DialectType::TiDB
18 }
19
20 fn tokenizer_config(&self) -> TokenizerConfig {
21 let mut config = TokenizerConfig::default();
22 config.identifiers.insert('`', '`');
24 config.nested_comments = false;
25 config
26 }
27
28 fn generator_config(&self) -> GeneratorConfig {
29 use crate::generator::IdentifierQuoteStyle;
30 GeneratorConfig {
31 identifier_quote: '`',
32 identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
33 dialect: Some(DialectType::TiDB),
34 ..Default::default()
35 }
36 }
37
38 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
39 match expr {
40 Expression::IfNull(f) => Ok(Expression::IfNull(f)),
42
43 Expression::Nvl(f) => Ok(Expression::IfNull(f)),
45
46 Expression::TryCast(c) => Ok(Expression::Cast(c)),
48
49 Expression::SafeCast(c) => Ok(Expression::Cast(c)),
51
52 Expression::CountIf(f) => {
54 let case_expr = Expression::Case(Box::new(Case {
55 operand: None,
56 whens: vec![(f.this.clone(), Expression::number(1))],
57 else_: Some(Expression::number(0)),
58 comments: Vec::new(),
59 inferred_type: None,
60 }));
61 Ok(Expression::Sum(Box::new(AggFunc {
62 ignore_nulls: None,
63 having_max: None,
64 this: case_expr,
65 distinct: f.distinct,
66 filter: f.filter,
67 order_by: Vec::new(),
68 name: None,
69 limit: None,
70 inferred_type: None,
71 })))
72 }
73
74 Expression::Rand(r) => Ok(Expression::Rand(r)),
76
77 Expression::Function(f) => self.transform_function(*f),
79
80 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
82
83 Expression::Cast(c) => self.transform_cast(*c),
85
86 _ => Ok(expr),
88 }
89 }
90}
91
92impl TiDBDialect {
93 fn transform_function(&self, f: Function) -> Result<Expression> {
94 let name_upper = f.name.to_uppercase();
95 match name_upper.as_str() {
96 "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
98 "IFNULL".to_string(),
99 f.args,
100 )))),
101
102 "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
104 "IFNULL".to_string(),
105 f.args,
106 )))),
107
108 "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc {
110 original_name: None,
111 expressions: f.args,
112 inferred_type: None,
113 }))),
114
115 "NOW" => Ok(Expression::CurrentTimestamp(
117 crate::expressions::CurrentTimestamp {
118 precision: None,
119 sysdate: false,
120 },
121 )),
122
123 "GETDATE" => Ok(Expression::CurrentTimestamp(
125 crate::expressions::CurrentTimestamp {
126 precision: None,
127 sysdate: false,
128 },
129 )),
130
131 "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
133
134 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
136 Function::new("GROUP_CONCAT".to_string(), f.args),
137 ))),
138
139 "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
141 "GROUP_CONCAT".to_string(),
142 f.args,
143 )))),
144
145 "SUBSTR" => Ok(Expression::Function(Box::new(f))),
147
148 "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
150
151 "LENGTH" => Ok(Expression::Function(Box::new(f))),
153
154 "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
156 "LENGTH".to_string(),
157 f.args,
158 )))),
159
160 "CHARINDEX" if f.args.len() >= 2 => {
162 let mut args = f.args;
163 let substring = args.remove(0);
164 let string = args.remove(0);
165 Ok(Expression::Function(Box::new(Function::new(
166 "INSTR".to_string(),
167 vec![string, substring],
168 ))))
169 }
170
171 "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
173 "INSTR".to_string(),
174 f.args,
175 )))),
176
177 "LOCATE" => Ok(Expression::Function(Box::new(f))),
179
180 "INSTR" => Ok(Expression::Function(Box::new(f))),
182
183 "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
185
186 "STRFTIME" if f.args.len() >= 2 => {
188 let mut args = f.args;
189 let format = args.remove(0);
190 let date = args.remove(0);
191 Ok(Expression::Function(Box::new(Function::new(
192 "DATE_FORMAT".to_string(),
193 vec![date, format],
194 ))))
195 }
196
197 "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
199 "DATE_FORMAT".to_string(),
200 f.args,
201 )))),
202
203 "STR_TO_DATE" => Ok(Expression::Function(Box::new(f))),
205
206 "TO_DATE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
208 "STR_TO_DATE".to_string(),
209 f.args,
210 )))),
211
212 "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
214
215 "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
217 Function::new("JSON_EXTRACT".to_string(), f.args),
218 ))),
219
220 "REGEXP" => Ok(Expression::Function(Box::new(f))),
222
223 "RLIKE" => Ok(Expression::Function(Box::new(f))),
225
226 "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
228 Function::new("REGEXP".to_string(), f.args),
229 ))),
230
231 _ => Ok(Expression::Function(Box::new(f))),
233 }
234 }
235
236 fn transform_aggregate_function(
237 &self,
238 f: Box<crate::expressions::AggregateFunction>,
239 ) -> Result<Expression> {
240 let name_upper = f.name.to_uppercase();
241 match name_upper.as_str() {
242 "COUNT_IF" if !f.args.is_empty() => {
244 let condition = f.args.into_iter().next().unwrap();
245 let case_expr = Expression::Case(Box::new(Case {
246 operand: None,
247 whens: vec![(condition, Expression::number(1))],
248 else_: Some(Expression::number(0)),
249 comments: Vec::new(),
250 inferred_type: None,
251 }));
252 Ok(Expression::Sum(Box::new(AggFunc {
253 ignore_nulls: None,
254 having_max: None,
255 this: case_expr,
256 distinct: f.distinct,
257 filter: f.filter,
258 order_by: Vec::new(),
259 name: None,
260 limit: None,
261 inferred_type: None,
262 })))
263 }
264
265 "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
267
268 _ => Ok(Expression::AggregateFunction(f)),
270 }
271 }
272
273 fn transform_cast(&self, c: Cast) -> Result<Expression> {
274 Ok(Expression::Cast(Box::new(c)))
276 }
277}