use super::{DialectImpl, DialectType};
use crate::error::Result;
use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
use crate::generator::GeneratorConfig;
use crate::tokens::TokenizerConfig;
pub struct TiDBDialect;
impl DialectImpl for TiDBDialect {
fn dialect_type(&self) -> DialectType {
DialectType::TiDB
}
fn tokenizer_config(&self) -> TokenizerConfig {
let mut config = TokenizerConfig::default();
config.identifiers.insert('`', '`');
config.nested_comments = false;
config
}
fn generator_config(&self) -> GeneratorConfig {
use crate::generator::IdentifierQuoteStyle;
GeneratorConfig {
identifier_quote: '`',
identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
dialect: Some(DialectType::TiDB),
..Default::default()
}
}
fn transform_expr(&self, expr: Expression) -> Result<Expression> {
match expr {
Expression::IfNull(f) => Ok(Expression::IfNull(f)),
Expression::Nvl(f) => Ok(Expression::IfNull(f)),
Expression::TryCast(c) => Ok(Expression::Cast(c)),
Expression::SafeCast(c) => Ok(Expression::Cast(c)),
Expression::CountIf(f) => {
let case_expr = Expression::Case(Box::new(Case {
operand: None,
whens: vec![(f.this.clone(), Expression::number(1))],
else_: Some(Expression::number(0)),
comments: Vec::new(),
inferred_type: None,
}));
Ok(Expression::Sum(Box::new(AggFunc {
ignore_nulls: None,
having_max: None,
this: case_expr,
distinct: f.distinct,
filter: f.filter,
order_by: Vec::new(),
name: None,
limit: None,
inferred_type: None,
})))
}
Expression::Rand(r) => Ok(Expression::Rand(r)),
Expression::Function(f) => self.transform_function(*f),
Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
Expression::Cast(c) => self.transform_cast(*c),
_ => Ok(expr),
}
}
}
impl TiDBDialect {
fn transform_function(&self, f: Function) -> Result<Expression> {
let name_upper = f.name.to_uppercase();
match name_upper.as_str() {
"NVL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
"IFNULL".to_string(),
f.args,
)))),
"ISNULL" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
"IFNULL".to_string(),
f.args,
)))),
"COALESCE" => Ok(Expression::Coalesce(Box::new(VarArgFunc {
original_name: None,
expressions: f.args,
inferred_type: None,
}))),
"NOW" => Ok(Expression::CurrentTimestamp(
crate::expressions::CurrentTimestamp {
precision: None,
sysdate: false,
},
)),
"GETDATE" => Ok(Expression::CurrentTimestamp(
crate::expressions::CurrentTimestamp {
precision: None,
sysdate: false,
},
)),
"GROUP_CONCAT" => Ok(Expression::Function(Box::new(f))),
"STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
Function::new("GROUP_CONCAT".to_string(), f.args),
))),
"LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
"GROUP_CONCAT".to_string(),
f.args,
)))),
"SUBSTR" => Ok(Expression::Function(Box::new(f))),
"SUBSTRING" => Ok(Expression::Function(Box::new(f))),
"LENGTH" => Ok(Expression::Function(Box::new(f))),
"LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
"LENGTH".to_string(),
f.args,
)))),
"CHARINDEX" if f.args.len() >= 2 => {
let mut args = f.args;
let substring = args.remove(0);
let string = args.remove(0);
Ok(Expression::Function(Box::new(Function::new(
"INSTR".to_string(),
vec![string, substring],
))))
}
"STRPOS" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
"INSTR".to_string(),
f.args,
)))),
"LOCATE" => Ok(Expression::Function(Box::new(f))),
"INSTR" => Ok(Expression::Function(Box::new(f))),
"DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
"STRFTIME" if f.args.len() >= 2 => {
let mut args = f.args;
let format = args.remove(0);
let date = args.remove(0);
Ok(Expression::Function(Box::new(Function::new(
"DATE_FORMAT".to_string(),
vec![date, format],
))))
}
"TO_CHAR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
"DATE_FORMAT".to_string(),
f.args,
)))),
"STR_TO_DATE" => Ok(Expression::Function(Box::new(f))),
"TO_DATE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
"STR_TO_DATE".to_string(),
f.args,
)))),
"JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
"GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
Function::new("JSON_EXTRACT".to_string(), f.args),
))),
"REGEXP" => Ok(Expression::Function(Box::new(f))),
"RLIKE" => Ok(Expression::Function(Box::new(f))),
"REGEXP_LIKE" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
Function::new("REGEXP".to_string(), f.args),
))),
_ => Ok(Expression::Function(Box::new(f))),
}
}
fn transform_aggregate_function(
&self,
f: Box<crate::expressions::AggregateFunction>,
) -> Result<Expression> {
let name_upper = f.name.to_uppercase();
match name_upper.as_str() {
"COUNT_IF" if !f.args.is_empty() => {
let condition = f.args.into_iter().next().unwrap();
let case_expr = Expression::Case(Box::new(Case {
operand: None,
whens: vec![(condition, Expression::number(1))],
else_: Some(Expression::number(0)),
comments: Vec::new(),
inferred_type: None,
}));
Ok(Expression::Sum(Box::new(AggFunc {
ignore_nulls: None,
having_max: None,
this: case_expr,
distinct: f.distinct,
filter: f.filter,
order_by: Vec::new(),
name: None,
limit: None,
inferred_type: None,
})))
}
"APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
_ => Ok(Expression::AggregateFunction(f)),
}
}
fn transform_cast(&self, c: Cast) -> Result<Expression> {
Ok(Expression::Cast(Box::new(c)))
}
}