use crate::error::{QuickDbError, QuickDbResult};
use crate::types::DatabaseType;
pub struct DatabaseSecurityValidator {
db_type: DatabaseType,
}
impl DatabaseSecurityValidator {
pub fn new(db_type: DatabaseType) -> Self {
Self { db_type }
}
pub fn validate_field_name(&self, field_name: &str) -> QuickDbResult<()> {
if field_name.is_empty() {
return Err(QuickDbError::ValidationError {
field: "field_name".to_string(),
message: "字段名不能为空".to_string(),
});
}
if field_name.len() > 64 {
return Err(QuickDbError::ValidationError {
field: field_name.to_string(),
message: "字段名长度不能超过64个字符".to_string(),
});
}
match self.db_type {
DatabaseType::PostgreSQL | DatabaseType::MySQL | DatabaseType::SQLite => {
self.validate_sql_field_name(field_name)
}
DatabaseType::MongoDB => self.validate_nosql_field_name(field_name),
}
}
pub fn validate_table_name(&self, table_name: &str) -> QuickDbResult<()> {
if table_name.is_empty() {
return Err(QuickDbError::ValidationError {
field: "table_name".to_string(),
message: "表名不能为空".to_string(),
});
}
if table_name.len() > 64 {
return Err(QuickDbError::ValidationError {
field: table_name.to_string(),
message: "表名长度不能超过64个字符".to_string(),
});
}
match self.db_type {
DatabaseType::PostgreSQL | DatabaseType::MySQL | DatabaseType::SQLite => {
self.validate_sql_table_name(table_name)
}
DatabaseType::MongoDB => self.validate_nosql_collection_name(table_name),
}
}
pub fn get_safe_field_identifier(&self, field_name: &str) -> QuickDbResult<String> {
self.validate_field_name(field_name)?;
match self.db_type {
DatabaseType::PostgreSQL => Ok(format!("\"{}\"", field_name)),
DatabaseType::MySQL => Ok(format!("`{}`", field_name)),
DatabaseType::SQLite => Ok(format!("\"{}\"", field_name)),
DatabaseType::MongoDB => Ok(field_name.to_string()), }
}
pub fn get_safe_table_identifier(&self, table_name: &str) -> QuickDbResult<String> {
self.validate_table_name(table_name)?;
match self.db_type {
DatabaseType::PostgreSQL => Ok(format!("\"{}\"", table_name)),
DatabaseType::MySQL => Ok(format!("`{}`", table_name)),
DatabaseType::SQLite => Ok(format!("\"{}\"", table_name)),
DatabaseType::MongoDB => Ok(table_name.to_string()), }
}
fn validate_sql_field_name(&self, field_name: &str) -> QuickDbResult<()> {
if field_name.chars().next().unwrap().is_ascii_digit() {
return Err(QuickDbError::ValidationError {
field: field_name.to_string(),
message: "SQL字段名不能以数字开头".to_string(),
});
}
for (i, ch) in field_name.chars().enumerate() {
if !ch.is_ascii_alphanumeric() && ch != '_' {
return Err(QuickDbError::ValidationError {
field: field_name.to_string(),
message: format!("SQL字段名包含非法字符 '{}' 在位置 {}", ch, i),
});
}
}
let upper_name = field_name.to_uppercase();
let sql_keywords = [
"SELECT",
"FROM",
"WHERE",
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"DROP",
"ALTER",
"TABLE",
"INDEX",
"AND",
"OR",
"NOT",
"NULL",
"IS",
"IN",
"EXISTS",
"BETWEEN",
"LIKE",
"REGEXP",
"UNION",
"JOIN",
"INNER",
"LEFT",
"RIGHT",
"OUTER",
"GROUP",
"BY",
"HAVING",
"ORDER",
"LIMIT",
"OFFSET",
"DISTINCT",
"COUNT",
"SUM",
"AVG",
"MIN",
"MAX",
"AS",
"ON",
"PRIMARY",
"KEY",
"FOREIGN",
"REFERENCES",
"CASE",
"WHEN",
"THEN",
"ELSE",
"END",
"IF",
"COALESCE",
"CAST",
"CONVERT",
];
if sql_keywords.contains(&upper_name.as_str()) {
return Err(QuickDbError::ValidationError {
field: field_name.to_string(),
message: format!("字段名不能使用SQL关键字: {}", field_name),
});
}
Ok(())
}
fn validate_nosql_field_name(&self, field_name: &str) -> QuickDbResult<()> {
if field_name.starts_with('$') {
return Err(QuickDbError::ValidationError {
field: field_name.to_string(),
message: "NoSQL字段名不能以$开头".to_string(),
});
}
if field_name.contains('.') {
return Err(QuickDbError::ValidationError {
field: field_name.to_string(),
message: "NoSQL字段名不能包含点号".to_string(),
});
}
let mongo_reserved_names = [
"_id",
"id",
"ns",
"system",
"op",
"query",
"update",
"fields",
"new",
"upsert",
"multi",
"writeConcern",
"collation",
"arrayFilters",
"hint",
];
if mongo_reserved_names.contains(&field_name) {
return Err(QuickDbError::ValidationError {
field: field_name.to_string(),
message: format!("字段名不能使用MongoDB保留字: {}", field_name),
});
}
Ok(())
}
fn validate_sql_table_name(&self, table_name: &str) -> QuickDbResult<()> {
if table_name.chars().next().unwrap().is_ascii_digit() {
return Err(QuickDbError::ValidationError {
field: table_name.to_string(),
message: "SQL表名不能以数字开头".to_string(),
});
}
for (i, ch) in table_name.chars().enumerate() {
if !ch.is_ascii_alphanumeric() && ch != '_' {
return Err(QuickDbError::ValidationError {
field: table_name.to_string(),
message: format!("SQL表名包含非法字符 '{}' 在位置 {}", ch, i),
});
}
}
let upper_name = table_name.to_uppercase();
let sql_keywords = [
"SELECT",
"FROM",
"WHERE",
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"DROP",
"ALTER",
"TABLE",
"INDEX",
"DATABASE",
"SCHEMA",
"USER",
"ROLE",
"GRANT",
"REVOKE",
"COMMIT",
"ROLLBACK",
"TRANSACTION",
"VIEW",
"TRIGGER",
"PROCEDURE",
"FUNCTION",
"SEQUENCE",
"CONSTRAINT",
"PRIMARY",
"FOREIGN",
"REFERENCES",
];
if sql_keywords.contains(&upper_name.as_str()) {
return Err(QuickDbError::ValidationError {
field: table_name.to_string(),
message: format!("表名不能使用SQL关键字: {}", table_name),
});
}
Ok(())
}
fn validate_nosql_collection_name(&self, collection_name: &str) -> QuickDbResult<()> {
if collection_name.starts_with('$') {
return Err(QuickDbError::ValidationError {
field: collection_name.to_string(),
message: "集合名不能以$开头".to_string(),
});
}
if collection_name.contains('\0') {
return Err(QuickDbError::ValidationError {
field: collection_name.to_string(),
message: "集合名不能包含空字符".to_string(),
});
}
if collection_name.starts_with("system.") {
return Err(QuickDbError::ValidationError {
field: collection_name.to_string(),
message: "集合名不能以system.开头".to_string(),
});
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_field_validation() {
let validator = DatabaseSecurityValidator::new(DatabaseType::PostgreSQL);
assert!(validator.validate_field_name("name").is_ok());
assert!(validator.validate_field_name("user_name").is_ok());
assert!(validator.validate_field_name("createdAt").is_ok());
assert!(validator.validate_field_name("").is_err());
assert!(validator.validate_field_name("123name").is_err());
assert!(validator.validate_field_name("na-me").is_err());
assert!(validator.validate_field_name("na me").is_err());
assert!(validator.validate_field_name("select").is_err());
assert!(validator.validate_field_name("WHERE").is_err());
}
#[test]
fn test_nosql_field_validation() {
let validator = DatabaseSecurityValidator::new(DatabaseType::MongoDB);
assert!(validator.validate_field_name("name").is_ok());
assert!(validator.validate_field_name("user-name").is_ok()); assert!(validator.validate_field_name("123name").is_ok());
assert!(validator.validate_field_name("").is_err());
assert!(validator.validate_field_name("$name").is_err());
assert!(validator.validate_field_name("nested.field").is_err());
assert!(validator.validate_field_name("_id").is_err());
}
#[test]
fn test_safe_identifier_generation() {
let pg_validator = DatabaseSecurityValidator::new(DatabaseType::PostgreSQL);
let mysql_validator = DatabaseSecurityValidator::new(DatabaseType::MySQL);
assert_eq!(
pg_validator.get_safe_field_identifier("name").unwrap(),
"\"name\""
);
assert_eq!(
mysql_validator.get_safe_field_identifier("name").unwrap(),
"`name`"
);
assert!(pg_validator.get_safe_field_identifier("select").is_err());
assert!(
mysql_validator
.get_safe_field_identifier("123name")
.is_err()
);
}
}