use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserialize, Serialize};
use crate::protocol::{contains_ci, starts_with_ci};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredicateRule {
pub table: String,
pub column: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentContract {
pub id: String,
#[serde(default = "default_true")]
pub read_only: bool,
#[serde(default)]
pub allowed_verbs: Option<Vec<String>>,
#[serde(default)]
pub allowed_tables: Option<Vec<String>>,
#[serde(default)]
pub denied_tables: Vec<String>,
#[serde(default)]
pub require_predicate_on: Vec<PredicateRule>,
#[serde(default)]
pub require_limit: bool,
#[serde(default)]
pub max_rows: Option<u64>,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize)]
pub struct Violation {
pub violation: String,
pub detail: String,
pub offending: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub suggested_rewrite: Option<String>,
}
impl Violation {
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap_or_else(|_| self.detail.clone())
}
}
static TABLE_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"(?i)\b(?:FROM|JOIN|INTO|UPDATE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)")
.expect("valid table regex")
});
fn verb_of(sql: &str) -> String {
sql.trim_start()
.split(|c: char| c.is_whitespace() || c == '(')
.next()
.unwrap_or("")
.to_ascii_uppercase()
}
fn tables_of(sql: &str) -> Vec<String> {
TABLE_RE
.captures_iter(sql)
.filter_map(|c| c.get(1))
.map(|m| {
let full = m.as_str().to_ascii_lowercase();
full.rsplit('.').next().unwrap_or(&full).to_string()
})
.collect()
}
fn is_write_verb(verb: &str) -> bool {
matches!(
verb,
"INSERT" | "UPDATE" | "DELETE" | "CREATE" | "DROP" | "ALTER" | "TRUNCATE" | "GRANT"
| "REVOKE" | "MERGE" | "CALL" | "DO" | "COPY" | "VACUUM" | "REINDEX" | "CLUSTER"
| "LOCK" | "COMMENT"
)
}
pub fn validate(sql: &str, contract: &AgentContract) -> Result<(), Violation> {
let trimmed = sql.trim();
let verb = verb_of(trimmed);
if contract.read_only && is_write_verb(&verb) {
return Err(Violation {
violation: "write_forbidden".into(),
detail: format!("agent '{}' is read-only; '{}' statements are not permitted", contract.id, verb),
offending: sql.to_string(),
suggested_rewrite: None,
});
}
if let Some(ref verbs) = contract.allowed_verbs {
if !verbs.iter().any(|v| v.eq_ignore_ascii_case(&verb)) {
return Err(Violation {
violation: "verb_forbidden".into(),
detail: format!("verb '{}' not in this agent's allowed set {:?}", verb, verbs),
offending: sql.to_string(),
suggested_rewrite: None,
});
}
}
let tables = tables_of(trimmed);
for t in &tables {
if contract.denied_tables.iter().any(|d| d.eq_ignore_ascii_case(t)) {
return Err(Violation {
violation: "table_forbidden".into(),
detail: format!("table '{}' is denied to agent '{}'", t, contract.id),
offending: sql.to_string(),
suggested_rewrite: None,
});
}
}
if let Some(ref allowed) = contract.allowed_tables {
for t in &tables {
if !allowed.iter().any(|a| a.eq_ignore_ascii_case(t)) {
return Err(Violation {
violation: "table_not_allowed".into(),
detail: format!("table '{}' not in this agent's allowed set {:?}", t, allowed),
offending: sql.to_string(),
suggested_rewrite: None,
});
}
}
}
for rule in &contract.require_predicate_on {
if tables.iter().any(|t| t.eq_ignore_ascii_case(&rule.table)) && !mentions_predicate(trimmed, &rule.column) {
let rewrite = inject_predicate(trimmed, &rule.column);
return Err(Violation {
violation: "missing_predicate".into(),
detail: format!(
"queries on '{}' must filter by '{}'",
rule.table, rule.column
),
offending: sql.to_string(),
suggested_rewrite: Some(rewrite),
});
}
}
if contract.require_limit && verb == "SELECT" && !contains_ci(trimmed, " LIMIT ") && !ends_with_limit(trimmed) {
let cap = contract.max_rows.unwrap_or(1000);
return Err(Violation {
violation: "missing_limit".into(),
detail: format!("SELECTs must be bounded; add LIMIT {} or fewer", cap),
offending: sql.to_string(),
suggested_rewrite: Some(format!("{} LIMIT {}", trimmed.trim_end_matches(';').trim_end(), cap)),
});
}
Ok(())
}
fn mentions_predicate(sql: &str, column: &str) -> bool {
let upper = sql.to_ascii_uppercase();
if let Some(where_pos) = upper.find(" WHERE ") {
let after = &sql[where_pos..];
contains_ci(after, column)
} else {
false
}
}
fn ends_with_limit(sql: &str) -> bool {
let up = sql.trim_end_matches(';').trim_end().to_ascii_uppercase();
let words: Vec<&str> = up.split_whitespace().collect();
let n = words.len();
n >= 2 && words[n - 2] == "LIMIT"
}
fn inject_predicate(sql: &str, column: &str) -> String {
let trimmed = sql.trim().trim_end_matches(';').trim_end();
if starts_with_ci(trimmed, "SELECT") || starts_with_ci(trimmed, "UPDATE") || starts_with_ci(trimmed, "DELETE") {
if contains_ci(trimmed, " WHERE ") {
format!("{} AND {} = $1", trimmed, column)
} else {
let up = trimmed.to_ascii_uppercase();
let cut = ["ORDER BY", "GROUP BY", "LIMIT", "HAVING"]
.iter()
.filter_map(|kw| up.find(kw))
.min();
match cut {
Some(pos) => format!("{} WHERE {} = $1 {}", trimmed[..pos].trim_end(), column, &trimmed[pos..]),
None => format!("{} WHERE {} = $1", trimmed, column),
}
}
} else {
format!("{} /* add filter: {} = $1 */", trimmed, column)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn contract() -> AgentContract {
AgentContract {
id: "analyst".into(),
read_only: true,
allowed_verbs: None,
allowed_tables: Some(vec!["users".into(), "orders".into()]),
denied_tables: vec!["secrets".into()],
require_predicate_on: vec![PredicateRule { table: "orders".into(), column: "tenant_id".into() }],
require_limit: true,
max_rows: Some(1000),
}
}
#[test]
fn allows_compliant_query() {
let c = contract();
assert!(validate("SELECT id FROM users WHERE id = 1 LIMIT 10", &c).is_ok());
}
#[test]
fn blocks_write_when_read_only() {
let v = validate("DELETE FROM users", &contract()).unwrap_err();
assert_eq!(v.violation, "write_forbidden");
}
#[test]
fn blocks_denied_table() {
let v = validate("SELECT * FROM secrets LIMIT 1", &contract()).unwrap_err();
assert_eq!(v.violation, "table_forbidden");
}
#[test]
fn blocks_table_not_in_allowlist() {
let v = validate("SELECT * FROM invoices LIMIT 1", &contract()).unwrap_err();
assert_eq!(v.violation, "table_not_allowed");
}
#[test]
fn requires_predicate_with_rewrite() {
let v = validate("SELECT * FROM orders LIMIT 5", &contract()).unwrap_err();
assert_eq!(v.violation, "missing_predicate");
let rw = v.suggested_rewrite.unwrap();
assert!(rw.to_lowercase().contains("tenant_id"));
}
#[test]
fn requires_limit_with_rewrite() {
let v = validate("SELECT id FROM users WHERE id = 1", &contract()).unwrap_err();
assert_eq!(v.violation, "missing_limit");
assert!(v.suggested_rewrite.unwrap().to_uppercase().contains("LIMIT 1000"));
}
#[test]
fn table_extraction_handles_schema_and_joins() {
let t = tables_of("SELECT * FROM public.users u JOIN orders o ON o.uid = u.id");
assert!(t.contains(&"users".to_string()));
assert!(t.contains(&"orders".to_string()));
}
}