1use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12pub struct TiDBDialect;
14
15impl DialectImpl for TiDBDialect {
16 fn dialect_type(&self) -> DialectType {
17 DialectType::TiDB
18 }
19
20 fn tokenizer_config(&self) -> TokenizerConfig {
21 let mut config = TokenizerConfig::default();
22 config.identifiers.insert('`', '`');
24 config.nested_comments = false;
25 config
26 }
27
28 fn generator_config(&self) -> GeneratorConfig {
29 use crate::generator::IdentifierQuoteStyle;
30 GeneratorConfig {
31 identifier_quote: '`',
32 identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
33 dialect: Some(DialectType::TiDB),
34 ..Default::default()
35 }
36 }
37
38 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
39 match expr {
40 Expression::IfNull(f) => Ok(Expression::IfNull(f)),
42
43 Expression::Nvl(f) => Ok(Expression::IfNull(f)),
45
46 Expression::TryCast(c) => Ok(Expression::Cast(c)),
48
49 Expression::SafeCast(c) => Ok(Expression::Cast(c)),
51
52 Expression::CountIf(f) => {
54 let case_expr = Expression::Case(Box::new(Case {
55 operand: None,
56 whens: vec![(f.this.clone(), Expression::number(1))],
57 else_: Some(Expression::number(0)),
58 }));
59 Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
60 this: case_expr,
61 distinct: f.distinct,
62 filter: f.filter,
63 order_by: Vec::new(),
64 name: None,
65 limit: None,
66 })))
67 }
68
69 Expression::Rand(r) => Ok(Expression::Rand(r)),
71
72 Expression::Function(f) => self.transform_function(*f),
74
75 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
77
78 Expression::Cast(c) => self.transform_cast(*c),
80
81 _ => Ok(expr),
83 }
84 }
85}
86
87impl TiDBDialect {
88 fn transform_function(&self, f: Function) -> Result<Expression> {
89 let name_upper = f.name.to_uppercase();
90 match name_upper.as_str() {
91 "NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
93 "IFNULL".to_string(),
94 f.args,
95 )))),
96
97 "ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
99 "IFNULL".to_string(),
100 f.args,
101 )))),
102
103 "COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
105 expressions: f.args,
106 }))),
107
108 "NOW" => Ok(Expression::CurrentTimestamp(
110 crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
111 )),
112
113 "GETDATE" => Ok(Expression::CurrentTimestamp(
115 crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
116 )),
117
118 "GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
120
121 "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
123 "GROUP_CONCAT".to_string(),
124 f.args,
125 )))),
126
127 "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
129 "GROUP_CONCAT".to_string(),
130 f.args,
131 )))),
132
133 "SUBSTR" => Ok(Expression::Function(Box::new(f))),
135
136 "SUBSTRING" => Ok(Expression::Function(Box::new(f))),
138
139 "LENGTH" => Ok(Expression::Function(Box::new(f))),
141
142 "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
144 "LENGTH".to_string(),
145 f.args,
146 )))),
147
148 "CHARINDEX" if f.args.len() >= 2 => {
150 let mut args = f.args;
151 let substring = args.remove(0);
152 let string = args.remove(0);
153 Ok(Expression::Function(Box::new(Function::new(
154 "INSTR".to_string(),
155 vec![string, substring],
156 ))))
157 }
158
159 "STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
161 "INSTR".to_string(),
162 f.args,
163 )))),
164
165 "LOCATE" => Ok(Expression::Function(Box::new(f))),
167
168 "INSTR" => Ok(Expression::Function(Box::new(f))),
170
171 "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
173
174 "STRFTIME" if f.args.len() >= 2 => {
176 let mut args = f.args;
177 let format = args.remove(0);
178 let date = args.remove(0);
179 Ok(Expression::Function(Box::new(Function::new(
180 "DATE_FORMAT".to_string(),
181 vec![date, format],
182 ))))
183 }
184
185 "TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
187 "DATE_FORMAT".to_string(),
188 f.args,
189 )))),
190
191 "STR_TO_DATE" => Ok(Expression::Function(Box::new(f))),
193
194 "TO_DATE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
196 "STR_TO_DATE".to_string(),
197 f.args,
198 )))),
199
200 "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
202
203 "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
205 Function::new("JSON_EXTRACT".to_string(), f.args),
206 ))),
207
208 "REGEXP" => Ok(Expression::Function(Box::new(f))),
210
211 "RLIKE" => Ok(Expression::Function(Box::new(f))),
213
214 "REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
216 "REGEXP".to_string(),
217 f.args,
218 )))),
219
220 _ => Ok(Expression::Function(Box::new(f))),
222 }
223 }
224
225 fn transform_aggregate_function(
226 &self,
227 f: Box<crate::expressions::AggregateFunction>,
228 ) -> Result<Expression> {
229 let name_upper = f.name.to_uppercase();
230 match name_upper.as_str() {
231 "COUNT_IF" if !f.args.is_empty() => {
233 let condition = f.args.into_iter().next().unwrap();
234 let case_expr = Expression::Case(Box::new(Case {
235 operand: None,
236 whens: vec![(condition, Expression::number(1))],
237 else_: Some(Expression::number(0)),
238 }));
239 Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
240 this: case_expr,
241 distinct: f.distinct,
242 filter: f.filter,
243 order_by: Vec::new(),
244 name: None,
245 limit: None,
246 })))
247 }
248
249 "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
251
252 _ => Ok(Expression::AggregateFunction(f)),
254 }
255 }
256
257 fn transform_cast(&self, c: Cast) -> Result<Expression> {
258 Ok(Expression::Cast(Box::new(c)))
260 }
261}