use once_cell::sync::Lazy;
use regex::Regex;
use sqlparser::ast::{Delete, FromTable, Query, SetExpr, Statement, TableWithJoins};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use thiserror::Error;
#[cfg(feature = "permission")]
pub use crate::permission::PermissionAction;
#[cfg(not(feature = "permission"))]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PermissionAction {
Select,
Insert,
Update,
Delete,
}
#[derive(Debug, Error)]
pub enum SqlParseError {
#[error("Failed to parse SQL: {0}")]
ParseError(String),
#[error("Unsupported SQL statement type: {0}")]
UnsupportedStatement(String),
#[error("Empty SQL statement")]
EmptyStatement,
#[error("Multiple statements not allowed")]
MultipleStatements,
#[error("SQL statement contains variables: {0}")]
ContainsVariables(String),
}
#[derive(Debug, Clone)]
pub struct ParsedSqlOperation {
pub operation_type: SqlOperationType,
pub table_name: Option<String>,
pub sql: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SqlOperationType {
Select,
Insert,
Update,
Delete,
Ddl,
Dcl,
Transaction,
Other,
}
pub struct SqlParser {
dialect: GenericDialect,
}
impl Default for SqlParser {
fn default() -> Self {
Self::new()
}
}
impl SqlParser {
#[inline]
pub fn new() -> Self {
Self {
dialect: GenericDialect {},
}
}
#[inline]
pub fn with_dialect(_db_type: &str) -> Self {
Self::new()
}
pub fn parse_single(&self, sql: &str) -> Result<ParsedSqlOperation, SqlParseError> {
let sql = sql.trim();
if sql.is_empty() {
return Err(SqlParseError::EmptyStatement);
}
if sql.contains(';') {
if sql.starts_with("SET ") {
let parts: Vec<&str> = sql.split(';').collect();
for part in parts.iter() {
let trimmed = part.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.starts_with("--") {
continue;
}
if !trimmed.starts_with("SET ") && !trimmed.starts_with("SET") {
return Err(SqlParseError::MultipleStatements);
}
}
} else {
let last_non_comment = self.get_last_non_comment(sql);
if last_non_comment.ends_with(';') {
let sql_without_semicolon = sql[..sql.len() - 1].trim();
if sql_without_semicolon.contains(';') {
return Err(SqlParseError::MultipleStatements);
}
} else {
return Err(SqlParseError::MultipleStatements);
}
}
}
let is_set_statement = sql.trim().starts_with("SET ") || sql.trim().starts_with("SET");
if contains_variables(sql, is_set_statement) {
return Err(SqlParseError::ContainsVariables(
"SQL contains potentially dangerous variables. Use parameterized queries instead.".to_string(),
));
}
let statements = Parser::parse_sql(&self.dialect, sql).map_err(|e| SqlParseError::ParseError(e.to_string()))?;
if !is_set_statement && statements.len() != 1 {
return Err(SqlParseError::MultipleStatements);
}
let statement = statements
.into_iter()
.next()
.ok_or_else(|| SqlParseError::ParseError("No statement found".to_string()))?;
self.classify_statement(statement, sql.to_string())
}
fn get_last_non_comment(&self, sql: &str) -> String {
let lines: Vec<&str> = sql.lines().collect();
let mut result = String::new();
for line in lines.iter().rev() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("--") {
continue;
}
result = trimmed.to_string();
break;
}
result
}
pub fn parse_operation(&self, sql: &str) -> Option<(String, PermissionAction)> {
self.parse_single(sql).ok().and_then(|parsed| {
let action = match parsed.operation_type {
SqlOperationType::Select => PermissionAction::Select,
SqlOperationType::Insert => PermissionAction::Insert,
SqlOperationType::Update => PermissionAction::Update,
SqlOperationType::Delete => PermissionAction::Delete,
SqlOperationType::Ddl => PermissionAction::Select, SqlOperationType::Dcl => PermissionAction::Select, SqlOperationType::Transaction => PermissionAction::Select, SqlOperationType::Other => PermissionAction::Select, };
parsed.table_name.map(|table_name| (table_name, action))
})
}
fn classify_statement(&self, statement: Statement, sql: String) -> Result<ParsedSqlOperation, SqlParseError> {
let (operation_type, table_name) = match statement {
Statement::Query(query) => (SqlOperationType::Select, extract_table_from_query(&query)),
Statement::Insert(insert) => (SqlOperationType::Insert, Some(insert.table_name.to_string())),
Statement::Update { table, .. } => (
SqlOperationType::Update,
extract_table_name_from_table_with_joins(&table),
),
Statement::Delete(delete) => {
let table_name = extract_table_from_delete(&delete);
(SqlOperationType::Delete, table_name)
}
Statement::CreateTable { name, .. } => (SqlOperationType::Ddl, Some(name.to_string())),
Statement::AlterTable { name, .. } => (SqlOperationType::Ddl, Some(name.to_string())),
Statement::Drop { names, object_type, .. } => {
let is_table = format!("{:?}", object_type).contains("Table");
let table_name = if is_table && !names.is_empty() {
Some(names[0].to_string())
} else {
None
};
(SqlOperationType::Ddl, table_name)
}
Statement::Truncate { table_name, .. } => (SqlOperationType::Ddl, Some(table_name.to_string())),
Statement::CreateIndex { table_name, .. } => (SqlOperationType::Ddl, Some(table_name.to_string())),
Statement::Grant { .. } => (SqlOperationType::Dcl, None),
Statement::Revoke { .. } => (SqlOperationType::Dcl, None),
Statement::StartTransaction { .. } | Statement::Commit { .. } | Statement::Rollback { .. } => {
(SqlOperationType::Transaction, None)
}
Statement::SetVariable {
local: _,
hivevar: _,
variables,
..
} => {
let var_name = variables.to_string().to_lowercase();
if is_ddl_related_variable(&var_name) {
(SqlOperationType::Ddl, None)
} else {
(SqlOperationType::Other, None)
}
}
_ => (SqlOperationType::Other, None),
};
Ok(ParsedSqlOperation {
operation_type,
table_name,
sql,
})
}
}
fn contains_variables(sql: &str, is_set_statement: bool) -> bool {
let sql_without_strings = remove_string_literals(sql);
static PATTERNS: Lazy<Vec<Regex>> = Lazy::new(|| {
vec![
Regex::new(r"@[\w]+").expect("Regex pattern should be valid"),
Regex::new(r":[a-zA-Z_][\w]*").expect("Regex pattern should be valid"),
Regex::new(r"\$\{?[\w]+\}?").expect("Regex pattern should be valid"),
Regex::new(r"%[\w]+%").expect("Regex pattern should be valid"),
Regex::new(r"0x[0-9A-Fa-f]+").expect("Regex pattern should be valid"),
]
});
if PATTERNS.is_empty() {
return false;
}
for pattern in PATTERNS.iter() {
if pattern.is_match(&sql_without_strings) {
if is_set_statement {
let pattern_str = pattern.as_str();
if pattern_str.starts_with("@") {
continue;
}
}
return true;
}
}
false
}
fn remove_string_literals(sql: &str) -> String {
let mut result = String::new();
let mut in_string = false;
let mut string_char = ' ';
let mut escape_next = false;
for ch in sql.chars() {
if escape_next {
escape_next = false;
if in_string {
result.push(' '); } else {
result.push(ch);
}
continue;
}
if ch == '\\' {
escape_next = true;
if in_string {
result.push(' '); } else {
result.push(ch);
}
continue;
}
if in_string {
if ch == string_char {
in_string = false;
}
result.push(' '); continue;
}
if ch == '\'' || ch == '"' || ch == '`' {
in_string = true;
string_char = ch;
result.push(' '); continue;
}
result.push(ch);
}
result
}
fn extract_table_name_from_table_with_joins(table_with_joins: &TableWithJoins) -> Option<String> {
if let sqlparser::ast::TableFactor::Table { name, .. } = &table_with_joins.relation {
return Some(name.to_string());
}
None
}
fn extract_table_from_delete(delete: &Delete) -> Option<String> {
if !delete.tables.is_empty() {
return Some(delete.tables[0].to_string());
}
match &delete.from {
FromTable::WithFromKeyword(tables) => {
if !tables.is_empty() {
extract_table_name_from_table_with_joins(&tables[0])
} else {
None
}
}
FromTable::WithoutKeyword(tables) => {
if !tables.is_empty() {
extract_table_name_from_table_with_joins(&tables[0])
} else {
None
}
}
}
}
fn extract_table_from_query(query: &Query) -> Option<String> {
let SetExpr::Select(select) = query.body.as_ref() else {
return None;
};
if select.from.is_empty() {
return None;
}
extract_table_name_from_table_with_joins(&select.from[0])
}
fn is_ddl_related_variable(var_name: &str) -> bool {
let ddl_vars = [
"foreign_keys",
"auto_increment_increment",
"sql_mode",
"character_set",
"collation",
];
ddl_vars.iter().any(|v| var_name.contains(v))
}
pub fn is_ddl_operation(sql: &str) -> bool {
let parser = SqlParser::new();
parser
.parse_single(sql)
.map(|parsed| matches!(parsed.operation_type, SqlOperationType::Ddl | SqlOperationType::Dcl))
.unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_select() {
let parser = SqlParser::new();
let result = parser.parse_single("SELECT * FROM users WHERE id = 1");
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.operation_type, SqlOperationType::Select);
assert_eq!(parsed.table_name, Some("users".to_string()));
}
#[test]
fn test_parse_insert() {
let parser = SqlParser::new();
let result = parser.parse_single("INSERT INTO users (name) VALUES ('test')");
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.operation_type, SqlOperationType::Insert);
assert_eq!(parsed.table_name, Some("users".to_string()));
}
#[test]
fn test_parse_update() {
let parser = SqlParser::new();
let result = parser.parse_single("UPDATE users SET name = 'test' WHERE id = 1");
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.operation_type, SqlOperationType::Update);
assert_eq!(parsed.table_name, Some("users".to_string()));
}
#[test]
fn test_parse_delete() {
let parser = SqlParser::new();
let result = parser.parse_single("DELETE FROM users WHERE id = 1");
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.operation_type, SqlOperationType::Delete);
assert_eq!(parsed.table_name, Some("users".to_string()));
}
#[test]
fn test_parse_create_table() {
let parser = SqlParser::new();
let result = parser.parse_single("CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(255))");
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.operation_type, SqlOperationType::Ddl);
assert_eq!(parsed.table_name, Some("users".to_string()));
}
#[test]
fn test_parse_drop_table() {
let parser = SqlParser::new();
let result = parser.parse_single("DROP TABLE users");
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.operation_type, SqlOperationType::Ddl);
assert_eq!(parsed.table_name, Some("users".to_string()));
}
#[test]
fn test_parse_grant() {
let parser = SqlParser::new();
let result = parser.parse_single("GRANT ALL PRIVILEGES ON users TO user1");
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.operation_type, SqlOperationType::Dcl);
}
#[test]
fn test_multiple_statements_rejected() {
let parser = SqlParser::new();
let result = parser.parse_single("SELECT * FROM users; SELECT * FROM posts");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), SqlParseError::MultipleStatements));
}
#[test]
fn test_empty_statement_rejected() {
let parser = SqlParser::new();
let result = parser.parse_single("");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), SqlParseError::EmptyStatement));
}
#[test]
fn test_variables_detected() {
let parser = SqlParser::new();
let result = parser.parse_single("SELECT * FROM users WHERE id = @userId");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), SqlParseError::ContainsVariables(..)));
}
#[test]
fn test_is_ddl_operation() {
assert!(is_ddl_operation("CREATE TABLE users (id INT)"));
assert!(is_ddl_operation("DROP TABLE users"));
assert!(is_ddl_operation("ALTER TABLE users ADD COLUMN name VARCHAR(255)"));
assert!(!is_ddl_operation("SELECT * FROM users"));
assert!(!is_ddl_operation("INSERT INTO users (name) VALUES ('test')"));
}
}