use sqlparser::dialect::{GenericDialect, MySqlDialect, PostgreSqlDialect};
use sqlparser::parser::Parser;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DatabaseDialect {
Generic,
MySql,
PostgreSql,
}
#[must_use]
pub fn is_valid_expression_injection(fragment: &str, dialect: DatabaseDialect) -> bool {
let query = format!("SELECT * FROM mock_table WHERE id = {fragment}");
let result = match dialect {
DatabaseDialect::Generic => Parser::parse_sql(&GenericDialect {}, &query),
DatabaseDialect::MySql => Parser::parse_sql(&MySqlDialect {}, &query),
DatabaseDialect::PostgreSql => Parser::parse_sql(&PostgreSqlDialect {}, &query),
};
result.is_ok()
}
#[must_use]
pub fn is_valid_query(query: &str, dialect: DatabaseDialect) -> bool {
let result = match dialect {
DatabaseDialect::Generic => Parser::parse_sql(&GenericDialect {}, query),
DatabaseDialect::MySql => Parser::parse_sql(&MySqlDialect {}, query),
DatabaseDialect::PostgreSql => Parser::parse_sql(&PostgreSqlDialect {}, query),
};
result.is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_sql_fragment_parses() {
assert!(is_valid_expression_injection(
"1 OR 1=1 --",
DatabaseDialect::Generic
));
assert!(is_valid_expression_injection(
"1/**/OR/**/1=1",
DatabaseDialect::MySql
));
}
#[test]
fn invalid_sql_fragment_fails() {
assert!(!is_valid_expression_injection(
"1 OR 1=/**/",
DatabaseDialect::Generic
));
assert!(!is_valid_expression_injection(
"1 O R 1=1",
DatabaseDialect::Generic
));
}
}