1use once_cell::sync::Lazy;
17use regex::Regex;
18use serde::{Deserialize, Serialize};
19
20use crate::protocol::{contains_ci, starts_with_ci};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredicateRule {
25 pub table: String,
26 pub column: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AgentContract {
33 pub id: String,
35 #[serde(default = "default_true")]
37 pub read_only: bool,
38 #[serde(default)]
40 pub allowed_verbs: Option<Vec<String>>,
41 #[serde(default)]
43 pub allowed_tables: Option<Vec<String>>,
44 #[serde(default)]
46 pub denied_tables: Vec<String>,
47 #[serde(default)]
49 pub require_predicate_on: Vec<PredicateRule>,
50 #[serde(default)]
52 pub require_limit: bool,
53 #[serde(default)]
56 pub max_rows: Option<u64>,
57}
58
59fn default_true() -> bool {
60 true
61}
62
63#[derive(Debug, Clone, Serialize)]
65pub struct Violation {
66 pub violation: String,
69 pub detail: String,
71 pub offending: String,
73 #[serde(skip_serializing_if = "Option::is_none")]
76 pub suggested_rewrite: Option<String>,
77}
78
79impl Violation {
80 pub fn to_json(&self) -> String {
81 serde_json::to_string(self).unwrap_or_else(|_| self.detail.clone())
82 }
83}
84
85static TABLE_RE: Lazy<Regex> = Lazy::new(|| {
86 Regex::new(r"(?i)\b(?:FROM|JOIN|INTO|UPDATE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)")
88 .expect("valid table regex")
89});
90
91fn verb_of(sql: &str) -> String {
93 sql.trim_start()
94 .split(|c: char| c.is_whitespace() || c == '(')
95 .next()
96 .unwrap_or("")
97 .to_ascii_uppercase()
98}
99
100fn tables_of(sql: &str) -> Vec<String> {
102 TABLE_RE
103 .captures_iter(sql)
104 .filter_map(|c| c.get(1))
105 .map(|m| {
106 let full = m.as_str().to_ascii_lowercase();
107 full.rsplit('.').next().unwrap_or(&full).to_string()
109 })
110 .collect()
111}
112
113fn is_write_verb(verb: &str) -> bool {
114 matches!(
115 verb,
116 "INSERT" | "UPDATE" | "DELETE" | "CREATE" | "DROP" | "ALTER" | "TRUNCATE" | "GRANT"
117 | "REVOKE" | "MERGE" | "CALL" | "DO" | "COPY" | "VACUUM" | "REINDEX" | "CLUSTER"
118 | "LOCK" | "COMMENT"
119 )
120}
121
122pub fn validate(sql: &str, contract: &AgentContract) -> Result<(), Violation> {
125 let trimmed = sql.trim();
126 let verb = verb_of(trimmed);
127
128 if contract.read_only && is_write_verb(&verb) {
130 return Err(Violation {
131 violation: "write_forbidden".into(),
132 detail: format!("agent '{}' is read-only; '{}' statements are not permitted", contract.id, verb),
133 offending: sql.to_string(),
134 suggested_rewrite: None,
135 });
136 }
137
138 if let Some(ref verbs) = contract.allowed_verbs {
140 if !verbs.iter().any(|v| v.eq_ignore_ascii_case(&verb)) {
141 return Err(Violation {
142 violation: "verb_forbidden".into(),
143 detail: format!("verb '{}' not in this agent's allowed set {:?}", verb, verbs),
144 offending: sql.to_string(),
145 suggested_rewrite: None,
146 });
147 }
148 }
149
150 let tables = tables_of(trimmed);
151
152 for t in &tables {
154 if contract.denied_tables.iter().any(|d| d.eq_ignore_ascii_case(t)) {
155 return Err(Violation {
156 violation: "table_forbidden".into(),
157 detail: format!("table '{}' is denied to agent '{}'", t, contract.id),
158 offending: sql.to_string(),
159 suggested_rewrite: None,
160 });
161 }
162 }
163
164 if let Some(ref allowed) = contract.allowed_tables {
166 for t in &tables {
167 if !allowed.iter().any(|a| a.eq_ignore_ascii_case(t)) {
168 return Err(Violation {
169 violation: "table_not_allowed".into(),
170 detail: format!("table '{}' not in this agent's allowed set {:?}", t, allowed),
171 offending: sql.to_string(),
172 suggested_rewrite: None,
173 });
174 }
175 }
176 }
177
178 for rule in &contract.require_predicate_on {
180 if tables.iter().any(|t| t.eq_ignore_ascii_case(&rule.table)) && !mentions_predicate(trimmed, &rule.column) {
181 let rewrite = inject_predicate(trimmed, &rule.column);
182 return Err(Violation {
183 violation: "missing_predicate".into(),
184 detail: format!(
185 "queries on '{}' must filter by '{}'",
186 rule.table, rule.column
187 ),
188 offending: sql.to_string(),
189 suggested_rewrite: Some(rewrite),
190 });
191 }
192 }
193
194 if contract.require_limit && verb == "SELECT" && !contains_ci(trimmed, " LIMIT ") && !ends_with_limit(trimmed) {
196 let cap = contract.max_rows.unwrap_or(1000);
197 return Err(Violation {
198 violation: "missing_limit".into(),
199 detail: format!("SELECTs must be bounded; add LIMIT {} or fewer", cap),
200 offending: sql.to_string(),
201 suggested_rewrite: Some(format!("{} LIMIT {}", trimmed.trim_end_matches(';').trim_end(), cap)),
202 });
203 }
204
205 Ok(())
206}
207
208fn mentions_predicate(sql: &str, column: &str) -> bool {
211 let upper = sql.to_ascii_uppercase();
212 if let Some(where_pos) = upper.find(" WHERE ") {
213 let after = &sql[where_pos..];
214 contains_ci(after, column)
215 } else {
216 false
217 }
218}
219
220fn ends_with_limit(sql: &str) -> bool {
221 let up = sql.trim_end_matches(';').trim_end().to_ascii_uppercase();
224 let words: Vec<&str> = up.split_whitespace().collect();
225 let n = words.len();
226 n >= 2 && words[n - 2] == "LIMIT"
227}
228
229fn inject_predicate(sql: &str, column: &str) -> String {
232 let trimmed = sql.trim().trim_end_matches(';').trim_end();
233 if starts_with_ci(trimmed, "SELECT") || starts_with_ci(trimmed, "UPDATE") || starts_with_ci(trimmed, "DELETE") {
234 if contains_ci(trimmed, " WHERE ") {
235 format!("{} AND {} = $1", trimmed, column)
236 } else {
237 let up = trimmed.to_ascii_uppercase();
239 let cut = ["ORDER BY", "GROUP BY", "LIMIT", "HAVING"]
240 .iter()
241 .filter_map(|kw| up.find(kw))
242 .min();
243 match cut {
244 Some(pos) => format!("{} WHERE {} = $1 {}", trimmed[..pos].trim_end(), column, &trimmed[pos..]),
245 None => format!("{} WHERE {} = $1", trimmed, column),
246 }
247 }
248 } else {
249 format!("{} /* add filter: {} = $1 */", trimmed, column)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 fn contract() -> AgentContract {
258 AgentContract {
259 id: "analyst".into(),
260 read_only: true,
261 allowed_verbs: None,
262 allowed_tables: Some(vec!["users".into(), "orders".into()]),
263 denied_tables: vec!["secrets".into()],
264 require_predicate_on: vec![PredicateRule { table: "orders".into(), column: "tenant_id".into() }],
265 require_limit: true,
266 max_rows: Some(1000),
267 }
268 }
269
270 #[test]
271 fn allows_compliant_query() {
272 let c = contract();
273 assert!(validate("SELECT id FROM users WHERE id = 1 LIMIT 10", &c).is_ok());
274 }
275
276 #[test]
277 fn blocks_write_when_read_only() {
278 let v = validate("DELETE FROM users", &contract()).unwrap_err();
279 assert_eq!(v.violation, "write_forbidden");
280 }
281
282 #[test]
283 fn blocks_denied_table() {
284 let v = validate("SELECT * FROM secrets LIMIT 1", &contract()).unwrap_err();
285 assert_eq!(v.violation, "table_forbidden");
286 }
287
288 #[test]
289 fn blocks_table_not_in_allowlist() {
290 let v = validate("SELECT * FROM invoices LIMIT 1", &contract()).unwrap_err();
291 assert_eq!(v.violation, "table_not_allowed");
292 }
293
294 #[test]
295 fn requires_predicate_with_rewrite() {
296 let v = validate("SELECT * FROM orders LIMIT 5", &contract()).unwrap_err();
297 assert_eq!(v.violation, "missing_predicate");
298 let rw = v.suggested_rewrite.unwrap();
299 assert!(rw.to_lowercase().contains("tenant_id"));
300 }
301
302 #[test]
303 fn requires_limit_with_rewrite() {
304 let v = validate("SELECT id FROM users WHERE id = 1", &contract()).unwrap_err();
305 assert_eq!(v.violation, "missing_limit");
306 assert!(v.suggested_rewrite.unwrap().to_uppercase().contains("LIMIT 1000"));
307 }
308
309 #[test]
310 fn table_extraction_handles_schema_and_joins() {
311 let t = tables_of("SELECT * FROM public.users u JOIN orders o ON o.uid = u.id");
312 assert!(t.contains(&"users".to_string()));
313 assert!(t.contains(&"orders".to_string()));
314 }
315}