use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SqlDialect {
#[default]
Generic,
Postgres,
MySql,
Sqlite,
MsSql,
Snowflake,
BigQuery,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SqlOperation {
Select,
Insert,
Update,
Delete,
Ddl,
Other,
}
impl SqlOperation {
pub fn as_str(self) -> &'static str {
match self {
Self::Select => "SELECT",
Self::Insert => "INSERT",
Self::Update => "UPDATE",
Self::Delete => "DELETE",
Self::Ddl => "DDL",
Self::Other => "OTHER",
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SqlGuardConfig {
#[serde(default)]
pub dialect: SqlDialect,
#[serde(default)]
pub operation_allowlist: Vec<SqlOperation>,
#[serde(default)]
pub table_allowlist: Vec<String>,
#[serde(default)]
pub column_allowlist: Option<HashMap<String, Vec<String>>>,
#[serde(default)]
pub denylisted_predicates: Vec<String>,
#[serde(default = "default_require_where_for_mutations")]
pub require_where_for_mutations: bool,
#[serde(default)]
pub allow_all: bool,
}
fn default_require_where_for_mutations() -> bool {
true
}
impl Default for SqlGuardConfig {
fn default() -> Self {
Self {
dialect: SqlDialect::default(),
operation_allowlist: Vec::new(),
table_allowlist: Vec::new(),
column_allowlist: None,
denylisted_predicates: Vec::new(),
require_where_for_mutations: default_require_where_for_mutations(),
allow_all: false,
}
}
}
impl SqlGuardConfig {
pub fn is_empty(&self) -> bool {
self.operation_allowlist.is_empty()
&& self.table_allowlist.is_empty()
&& self
.column_allowlist
.as_ref()
.map(|m| m.is_empty())
.unwrap_or(true)
&& self.denylisted_predicates.is_empty()
}
pub fn table_allowed(&self, name: &str) -> bool {
let lower = name.to_ascii_lowercase();
self.table_allowlist
.iter()
.any(|entry| entry.to_ascii_lowercase() == lower)
}
pub fn column_allowed(&self, table: &str, column: &str) -> Option<bool> {
let map = self.column_allowlist.as_ref()?;
let lower_table = table.to_ascii_lowercase();
let lower_column = column.to_ascii_lowercase();
for (tbl, cols) in map {
if tbl.to_ascii_lowercase() == lower_table {
let allowed = cols
.iter()
.any(|c| c.to_ascii_lowercase() == lower_column || c == "*");
return Some(allowed);
}
}
None
}
pub fn table_has_column_allowlist(&self, table: &str) -> bool {
let Some(map) = self.column_allowlist.as_ref() else {
return false;
};
let lower = table.to_ascii_lowercase();
map.keys().any(|k| k.to_ascii_lowercase() == lower)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_config_is_detected() {
let cfg = SqlGuardConfig::default();
assert!(cfg.is_empty());
}
#[test]
fn table_allowlist_is_case_insensitive() {
let cfg = SqlGuardConfig {
table_allowlist: vec!["Orders".to_string()],
..Default::default()
};
assert!(cfg.table_allowed("orders"));
assert!(cfg.table_allowed("ORDERS"));
assert!(!cfg.table_allowed("users"));
}
#[test]
fn column_allowlist_returns_none_when_unset() {
let cfg = SqlGuardConfig::default();
assert!(cfg.column_allowed("orders", "id").is_none());
}
#[test]
fn column_allowlist_hit_and_miss() {
let mut map = HashMap::new();
map.insert("orders".to_string(), vec!["id".to_string(), "total".into()]);
let cfg = SqlGuardConfig {
column_allowlist: Some(map),
..Default::default()
};
assert_eq!(cfg.column_allowed("orders", "id"), Some(true));
assert_eq!(cfg.column_allowed("ORDERS", "TOTAL"), Some(true));
assert_eq!(cfg.column_allowed("orders", "email"), Some(false));
assert!(cfg.column_allowed("other_table", "id").is_none());
}
#[test]
fn wildcard_column_allows_everything_on_that_table() {
let mut map = HashMap::new();
map.insert("orders".to_string(), vec!["*".to_string()]);
let cfg = SqlGuardConfig {
column_allowlist: Some(map),
..Default::default()
};
assert_eq!(cfg.column_allowed("orders", "anything"), Some(true));
}
#[test]
fn table_has_column_allowlist_checks_keys() {
let mut map = HashMap::new();
map.insert("Orders".to_string(), vec!["id".into()]);
let cfg = SqlGuardConfig {
column_allowlist: Some(map),
..Default::default()
};
assert!(cfg.table_has_column_allowlist("orders"));
assert!(!cfg.table_has_column_allowlist("users"));
}
}