use std::collections::HashSet;
use std::sync::LazyLock;
use regex::Regex;
static RE_VALID_IDENTIFIER: LazyLock<Regex> = LazyLock::new(|| {
Regex::new("^[a-zA-Z0-9_\u{4e00}-\u{9fa5}\u{ff08}\u{ff09}\u{ff0c}.\"]+$").unwrap()
});
static UNSAFE_KEYWORDS_SET: LazyLock<HashSet<&'static str>> =
LazyLock::new(|| UNSAFE_FIELD_KEYWORDS.iter().cloned().collect());
pub const UNSAFE_FIELD_KEYWORDS: &[&str] = &[
"SELECT",
"FROM",
"WHERE",
"GROUP",
"ORDER",
"BY",
"HAVING",
"INSERT",
"INTO",
"VALUES",
"UPDATE",
"SET",
"DELETE",
"JOIN",
"LEFT",
"RIGHT",
"INNER",
"OUTER",
"ON",
"AS",
"LIMIT",
"DISTINCT",
"UNION",
"INT",
"INTEGER",
"BIGINT",
"SMALLINT",
"TINYINT",
"MEDIUMINT",
"DECIMAL",
"NUMERIC",
"FLOAT",
"DOUBLE",
"CHAR",
"VARCHAR",
"TEXT",
"TINYTEXT",
"MEDIUMTEXT",
"LONGTEXT",
"BLOB",
"LONGBLOB",
"MEDIUMBLOB",
"VARBINARY",
"BIT",
"REAL",
"DATE",
"TIME",
"DATETIME",
"TIMESTAMP",
"YEAR",
"IF",
"ELSE",
"THEN",
"CASE",
"WHEN",
"WHILE",
"LOOP",
"REPEAT",
"RETURN",
"BEGIN",
"END",
"DECLARE",
"CURSOR",
"OPEN",
"CLOSE",
"FETCH",
"EXIT",
"ITERATE",
"LEAVE",
"AND",
"OR",
"NOT",
"IN",
"IS",
"LIKE",
"BETWEEN",
"NULL",
"TRUE",
"FALSE",
"EXISTS",
"ANY",
"ALL",
"KEY",
"PRIMARY",
"FOREIGN",
"INDEX",
"UNIQUE",
"CHECK",
"REFERENCES",
"TABLE",
"DATABASE",
"TRIGGER",
"PROCEDURE",
"FUNCTION",
"GRANT",
"REVOKE",
"LOCK",
"UNLOCK",
"DESCRIBE",
"USE",
"CALL",
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"RENAME",
"FLUSH",
"WITH",
"DEFAULT",
"AUTO_INCREMENT",
];
pub fn escape_sql_string(s: &str) -> String {
s.replace('\\', "\\\\") .replace('\'', "''") .replace('\n', "\\n") .replace('\r', "\\r") .replace('\t', "\\t") }
pub fn is_valid_identifier(name: &str, ident_name: &str) -> anyhow::Result<()> {
let name = name.trim_matches('`'); if name.is_empty() {
return Err(anyhow::anyhow!("{ident_name}不能为空"));
}
if name.chars().count() > 100 {
return Err(anyhow::anyhow!("{ident_name}长度不能超过32个字符"));
}
if name.chars().next().unwrap().is_ascii_digit() {
return Err(anyhow::anyhow!("{ident_name}不能以数字开头"));
}
if !RE_VALID_IDENTIFIER.is_match(name) {
return Err(anyhow::anyhow!("{ident_name}只能包含字母、中文、数字或下划线或全角(),: {name}"));
}
if name.chars().all(|c| c == '_') {
return Err(anyhow::anyhow!("{ident_name}不能全部由下划线组成"));
}
if name.contains("__") {
return Err(anyhow::anyhow!("{ident_name}不能包含连续两个及以上下划线"));
}
let upper_name = name.to_ascii_uppercase();
if UNSAFE_KEYWORDS_SET.contains(upper_name.as_str()) {
return Err(anyhow::anyhow!("{ident_name} 不能使用数据库的关键字 {}", upper_name));
}
Ok(())
}
pub fn is_valid_identifiers(identifiers: &[String], ident_name: &str) -> anyhow::Result<()> {
for i in identifiers {
is_valid_identifier(i, ident_name)?;
}
Ok(())
}
pub fn is_valid_op(op: &str) -> anyhow::Result<()> {
let res = matches!(
op,
"=" | "!=" | ">" | "<" | ">=" | "<=" | "LIKE" | "LIKE_ANY" | "IN" | "NOT_IN" | "BETWEEN" | "NOT_BETWEEN" | "IN_QUERY" | "NOT_IN_QUERY" | "NOT_NULL" | "IS_NULL" | "REGEXP"
);
if !res {
return Err(anyhow::anyhow!("无效的运算符 {op}"));
}
Ok(())
}