use sqlparser::{
ast::{
Expr, LimitClause, ObjectName, ObjectType,
Statement, ValueWithSpan,OffsetRows, Value as AstValue,Offset, Query, SetExpr
},
dialect::{MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect,GenericDialect},
tokenizer::{Location as TokenizeLocation, Span},
parser::{Parser, ParserError}
};
use crate::common::DatabaseKind;
#[derive(Debug, Clone, Default)]
pub struct ParseSQL {
pub returns_result: bool, pub forbidden_in_txn: bool, pub table_name: Option<String>, pub kind: String, pub offset: Option<u64>,
pub limit: Option<u64>,
pub update_sql: Option<String>, }
impl ParseSQL {
pub fn other_new(
returns_result: bool,
forbidden_in_txn: bool,
table_name: Option<&str>,
kind: &str,
) -> Self {
Self {
returns_result,
forbidden_in_txn,
table_name: table_name.filter(|s| !s.is_empty()).map(|s| s.to_string()),
kind: kind.to_string(),
offset: None,
limit: None,
update_sql: None,
}
}
pub fn select_new(
update_sql: Option<String>,
table_name: Option<String>,
limit: Option<u64>,
offset: Option<u64>,
) -> Self {
Self {
returns_result: true,
forbidden_in_txn: false,
table_name,
kind: "query".to_string(),
offset,
limit,
update_sql,
}
}
pub fn to_count_sql(sql: &str, kind: DatabaseKind) -> anyhow::Result<Option<String>> {
let mut stmt = Self::to_statement(sql, &kind)?;
if let Statement::Query(q) = &mut stmt {
if q.limit_clause.is_some() {
q.limit_clause = None;
let count_sql = format!("SELECT COUNT(*) FROM ({}) t", stmt);
return Ok(Some(count_sql));
}
}
Ok(None) }
pub fn to_statement(sql: &str, kind: &DatabaseKind) -> Result<Statement, ParserError> {
let mut parser = match kind {
DatabaseKind::Postgres => Parser::new(&PostgreSqlDialect {}),
DatabaseKind::Sqlite => Parser::new(&SQLiteDialect {}),
DatabaseKind::MySql => Parser::new(&MySqlDialect {}),
DatabaseKind::Access => Parser::new(&MsSqlDialect {}),
DatabaseKind::Api => Parser::new(&GenericDialect {}),
_ => Parser::new(&GenericDialect {}),
};
parser = parser.try_with_sql(sql)?;
parser.parse_statement()
}
pub fn full_table_name(object_name: &ObjectName) -> String {
object_name
.0
.iter()
.filter_map(|ident| ident.as_ident())
.map(|ident| match ident.quote_style {
Some(q) => format!("{q}{}{q}", ident.value),
None => ident.value.clone(),
})
.collect::<Vec<_>>()
.join(".")
}
pub fn modify_limit(query: &mut Box<Query>) -> (u64, u64) {
let default_limit = 1000;
let default_offset = 0;
match &mut query.limit_clause {
Some(LimitClause::LimitOffset { limit, offset, .. }) => {
if limit.is_none() {
*limit = Some(Expr::Value(ValueWithSpan {
value: AstValue::Number(default_limit.to_string(), false),
span: Span {
start: TokenizeLocation { line: 0, column: 0 },
end: TokenizeLocation { line: 0, column: 0 },
},
}));
}
let offset_val = offset
.as_ref()
.and_then(|o| {
if let Expr::Value(ValueWithSpan {
value: AstValue::Number(n, _),
..
}) = &o.value
{
n.parse::<u64>().ok()
} else {
None
}
})
.unwrap_or(default_offset);
let limit_val = limit
.as_ref()
.and_then(|l| {
if let Expr::Value(ValueWithSpan {
value: AstValue::Number(n, _),
..
}) = l
{
n.parse::<u64>().ok()
} else {
None
}
})
.unwrap_or(default_limit);
(limit_val, offset_val)
}
Some(LimitClause::OffsetCommaLimit { offset, limit }) => {
let offset_val = if let Expr::Value(ValueWithSpan {
value: AstValue::Number(n, _),
..
}) = offset
{
n.parse::<u64>().unwrap_or(default_offset)
} else {
default_offset
};
let limit_val = if let Expr::Value(ValueWithSpan {
value: AstValue::Number(n, _),
..
}) = limit
{
n.parse::<u64>().unwrap_or(default_limit)
} else {
default_limit
};
(limit_val, offset_val)
}
None => {
query.limit_clause = Some(LimitClause::LimitOffset {
limit: Some(Expr::Value(ValueWithSpan {
value: AstValue::Number(default_limit.to_string(), false),
span: Span {
start: TokenizeLocation { line: 0, column: 0 },
end: TokenizeLocation { line: 0, column: 0 },
},
})),
offset: Some(Offset {
value: Expr::Value(ValueWithSpan {
value: AstValue::Number(default_offset.to_string(), false),
span: Span {
start: TokenizeLocation { line: 0, column: 0 },
end: TokenizeLocation { line: 0, column: 0 },
},
}),
rows: OffsetRows::None,
}),
limit_by: vec![],
});
(default_limit, default_offset)
}
}
}
pub fn parse(sql: &str, kind: DatabaseKind) -> anyhow::Result<Self> {
let mut stmt = Self::to_statement(sql, &kind)?;
match &mut stmt {
Statement::Query(query) => {
let mut table_name = None;
match &mut *query.body {
SetExpr::Select(select) => {
for table_with_joins in &select.from {
if let sqlparser::ast::TableFactor::Table { name, .. } =
&table_with_joins.relation
{
table_name = Some(Self::full_table_name(name));
break;
}
}
let (limit, offset) = Self::modify_limit(query); return Ok(Self::select_new(
Some(stmt.to_string()),
table_name,
Some(limit),
Some(offset),
));
}
SetExpr::SetOperation { .. } => {
let (limit, offset) = Self::modify_limit(query); return Ok(Self::select_new(
Some(stmt.to_string()),
table_name,
Some(limit),
Some(offset),
));
}
_ => {}
}
Ok(Self::select_new(None, table_name, None, None))
}
Statement::Delete(d) if d.returning.is_some() => {
Ok(Self::other_new(true, false, None, "delete"))
}
Statement::Insert(i) if i.returning.is_some() => {
Ok(Self::other_new(true, false, None, "insert"))
}
Statement::Update {
returning: Some(_), ..
} => Ok(Self::other_new(true, false, None, "update")),
Statement::ShowDatabases { .. }
| Statement::ShowSchemas { .. }
| Statement::ShowObjects { .. }
| Statement::ShowVariable { .. }
| Statement::ShowVariables { .. }
| Statement::ShowStatus { .. }
| Statement::ShowTables { .. }
| Statement::ShowColumns { .. }
| Statement::ShowViews { .. }
| Statement::ShowCollation { .. }
| Statement::ShowCreate { .. }
| Statement::Explain { .. }
| Statement::LISTEN { .. }
| Statement::List { .. }
| Statement::Fetch { .. }
| Statement::Execute { .. }
| Statement::Call { .. }
| Statement::Merge {
output: Some(_), ..
} => Ok(Self::other_new(true, false, None, "other")),
Statement::StartTransaction { .. }
| Statement::Commit { .. }
| Statement::Rollback { .. }
| Statement::CreateDatabase { .. }
| Statement::CreateSchema { .. }
| Statement::AlterType { .. }
| Statement::CreateIndex { .. }
| Statement::Flush { .. } => Ok(Self::other_new(false, true, None, "ddl")),
Statement::Drop { object_type, .. } => {
let forbidden_in_txn =
matches!(object_type, ObjectType::Database | ObjectType::Table);
Ok(Self::other_new(false, forbidden_in_txn, None, "other"))
}
Statement::Pragma {
value: _,
name,
is_eq,
} => {
if *is_eq {
Ok(Self::other_new(false, false, None, "other"))
} else {
let is_res = matches!(
name.to_string().to_lowercase().as_str(),
"table_info"
| "index_list"
| "index_info"
| "page_size"
| "encoding"
| "journal_mode"
| "foreign_key_list"
| "database_list"
| "collation_list"
| "pragma_list"
| "table_xinfo"
);
Ok(Self::other_new(is_res, false, None, "other"))
}
}
_ => Ok(Self::other_new(false, false, None, "")),
}
}
pub fn split_sql(sql: &str, kind: &DatabaseKind) -> anyhow::Result<Vec<String>> {
use sqlparser::tokenizer::{Token, Tokenizer};
let mut tokenizer = match kind {
DatabaseKind::Postgres => Tokenizer::new(&PostgreSqlDialect {}, sql),
DatabaseKind::MySql => Tokenizer::new(&MySqlDialect {}, sql),
DatabaseKind::SqlServer => Tokenizer::new(&MsSqlDialect {}, sql),
DatabaseKind::Sqlite => Tokenizer::new(&SQLiteDialect {}, sql),
_ => Tokenizer::new(&GenericDialect {}, sql),
};
let tokens = tokenizer.tokenize()?;
let mut stmts = Vec::new();
let mut current = String::new();
for token in tokens {
match token {
Token::SemiColon => {
if !current.trim().is_empty() {
stmts.push(current.trim().to_string());
current.clear();
}
}
_ => {
current.push_str(&token.to_string());
}
}
}
if !current.trim().is_empty() {
stmts.push(current.trim().to_string());
}
Ok(stmts)
}
}
#[test]
fn test_parse_sql() {
let _ = "SELECT * FROM (SELECT * FROM table2 WHERE id = 1 AND name = 'John') WHERE id = 1 AND name = 'John'";
pub fn test_sql_statements() -> Vec<&'static str> {
vec![
"SELECT id, name FROM users",
"SELECT id, name FROM users WHERE age > 18 ORDER BY name ASC",
"SELECT * FROM products LIMIT 50 OFFSET 10",
"SELECT u.id, u.name FROM users u WHERE u.id IN (SELECT user_id FROM orders WHERE total > 100)",
"SELECT id FROM users WHERE active = true UNION SELECT id FROM admins",
"INSERT INTO users (name, age) VALUES ('Alice', 30) RETURNING id",
"UPDATE products SET price = price * 1.1 WHERE category = 'books' RETURNING id, price",
"DELETE FROM sessions WHERE expires < NOW() RETURNING id",
"CREATE TABLE logs (id SERIAL PRIMARY KEY, message TEXT, created_at TIMESTAMP)",
"SHOW TABLES",
]
}
let test_sqls = test_sql_statements();
for sql in test_sqls {
let res = ParseSQL::parse(sql, DatabaseKind::Postgres).unwrap();
println!("{}:{:#?}", sql, res);
}
}