1use once_cell::sync::Lazy;
17use regex::Regex;
18use serde::{Deserialize, Serialize};
19
20use crate::protocol::{contains_ci, starts_with_ci};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredicateRule {
25 pub table: String,
26 pub column: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AgentContract {
33 pub id: String,
35 #[serde(default = "default_true")]
37 pub read_only: bool,
38 #[serde(default)]
40 pub allowed_verbs: Option<Vec<String>>,
41 #[serde(default)]
43 pub allowed_tables: Option<Vec<String>>,
44 #[serde(default)]
46 pub denied_tables: Vec<String>,
47 #[serde(default)]
49 pub require_predicate_on: Vec<PredicateRule>,
50 #[serde(default)]
52 pub require_limit: bool,
53 #[serde(default)]
56 pub max_rows: Option<u64>,
57}
58
59fn default_true() -> bool {
60 true
61}
62
63#[derive(Debug, Clone, Serialize)]
65pub struct Violation {
66 pub violation: String,
69 pub detail: String,
71 pub offending: String,
73 #[serde(skip_serializing_if = "Option::is_none")]
76 pub suggested_rewrite: Option<String>,
77}
78
79impl Violation {
80 pub fn to_json(&self) -> String {
81 serde_json::to_string(self).unwrap_or_else(|_| self.detail.clone())
82 }
83}
84
85static TABLE_RE: Lazy<Regex> = Lazy::new(|| {
86 Regex::new(
88 r"(?i)\b(?:FROM|JOIN|INTO|UPDATE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)",
89 )
90 .expect("valid table regex")
91});
92
93fn verb_of(sql: &str) -> String {
95 sql.trim_start()
96 .split(|c: char| c.is_whitespace() || c == '(')
97 .next()
98 .unwrap_or("")
99 .to_ascii_uppercase()
100}
101
102fn tables_of(sql: &str) -> Vec<String> {
104 TABLE_RE
105 .captures_iter(sql)
106 .filter_map(|c| c.get(1))
107 .map(|m| {
108 let full = m.as_str().to_ascii_lowercase();
109 full.rsplit('.').next().unwrap_or(&full).to_string()
111 })
112 .collect()
113}
114
115fn is_write_verb(verb: &str) -> bool {
116 matches!(
117 verb,
118 "INSERT"
119 | "UPDATE"
120 | "DELETE"
121 | "CREATE"
122 | "DROP"
123 | "ALTER"
124 | "TRUNCATE"
125 | "GRANT"
126 | "REVOKE"
127 | "MERGE"
128 | "CALL"
129 | "DO"
130 | "COPY"
131 | "VACUUM"
132 | "REINDEX"
133 | "CLUSTER"
134 | "LOCK"
135 | "COMMENT"
136 )
137}
138
139pub fn validate(sql: &str, contract: &AgentContract) -> Result<(), Violation> {
142 let trimmed = sql.trim();
143 let verb = verb_of(trimmed);
144
145 if contract.read_only && is_write_verb(&verb) {
147 return Err(Violation {
148 violation: "write_forbidden".into(),
149 detail: format!(
150 "agent '{}' is read-only; '{}' statements are not permitted",
151 contract.id, verb
152 ),
153 offending: sql.to_string(),
154 suggested_rewrite: None,
155 });
156 }
157
158 if let Some(ref verbs) = contract.allowed_verbs {
160 if !verbs.iter().any(|v| v.eq_ignore_ascii_case(&verb)) {
161 return Err(Violation {
162 violation: "verb_forbidden".into(),
163 detail: format!(
164 "verb '{}' not in this agent's allowed set {:?}",
165 verb, verbs
166 ),
167 offending: sql.to_string(),
168 suggested_rewrite: None,
169 });
170 }
171 }
172
173 let tables = tables_of(trimmed);
174
175 for t in &tables {
177 if contract
178 .denied_tables
179 .iter()
180 .any(|d| d.eq_ignore_ascii_case(t))
181 {
182 return Err(Violation {
183 violation: "table_forbidden".into(),
184 detail: format!("table '{}' is denied to agent '{}'", t, contract.id),
185 offending: sql.to_string(),
186 suggested_rewrite: None,
187 });
188 }
189 }
190
191 if let Some(ref allowed) = contract.allowed_tables {
193 for t in &tables {
194 if !allowed.iter().any(|a| a.eq_ignore_ascii_case(t)) {
195 return Err(Violation {
196 violation: "table_not_allowed".into(),
197 detail: format!(
198 "table '{}' not in this agent's allowed set {:?}",
199 t, allowed
200 ),
201 offending: sql.to_string(),
202 suggested_rewrite: None,
203 });
204 }
205 }
206 }
207
208 for rule in &contract.require_predicate_on {
210 if tables.iter().any(|t| t.eq_ignore_ascii_case(&rule.table))
211 && !mentions_predicate(trimmed, &rule.column)
212 {
213 let rewrite = inject_predicate(trimmed, &rule.column);
214 return Err(Violation {
215 violation: "missing_predicate".into(),
216 detail: format!(
217 "queries on '{}' must filter by '{}'",
218 rule.table, rule.column
219 ),
220 offending: sql.to_string(),
221 suggested_rewrite: Some(rewrite),
222 });
223 }
224 }
225
226 if contract.require_limit
228 && verb == "SELECT"
229 && !contains_ci(trimmed, " LIMIT ")
230 && !ends_with_limit(trimmed)
231 {
232 let cap = contract.max_rows.unwrap_or(1000);
233 return Err(Violation {
234 violation: "missing_limit".into(),
235 detail: format!("SELECTs must be bounded; add LIMIT {} or fewer", cap),
236 offending: sql.to_string(),
237 suggested_rewrite: Some(format!(
238 "{} LIMIT {}",
239 trimmed.trim_end_matches(';').trim_end(),
240 cap
241 )),
242 });
243 }
244
245 Ok(())
246}
247
248fn mentions_predicate(sql: &str, column: &str) -> bool {
251 let upper = sql.to_ascii_uppercase();
252 if let Some(where_pos) = upper.find(" WHERE ") {
253 let after = &sql[where_pos..];
254 contains_ci(after, column)
255 } else {
256 false
257 }
258}
259
260fn ends_with_limit(sql: &str) -> bool {
261 let up = sql.trim_end_matches(';').trim_end().to_ascii_uppercase();
264 let words: Vec<&str> = up.split_whitespace().collect();
265 let n = words.len();
266 n >= 2 && words[n - 2] == "LIMIT"
267}
268
269fn inject_predicate(sql: &str, column: &str) -> String {
272 let trimmed = sql.trim().trim_end_matches(';').trim_end();
273 if starts_with_ci(trimmed, "SELECT")
274 || starts_with_ci(trimmed, "UPDATE")
275 || starts_with_ci(trimmed, "DELETE")
276 {
277 if contains_ci(trimmed, " WHERE ") {
278 format!("{} AND {} = $1", trimmed, column)
279 } else {
280 let up = trimmed.to_ascii_uppercase();
282 let cut = ["ORDER BY", "GROUP BY", "LIMIT", "HAVING"]
283 .iter()
284 .filter_map(|kw| up.find(kw))
285 .min();
286 match cut {
287 Some(pos) => format!(
288 "{} WHERE {} = $1 {}",
289 trimmed[..pos].trim_end(),
290 column,
291 &trimmed[pos..]
292 ),
293 None => format!("{} WHERE {} = $1", trimmed, column),
294 }
295 }
296 } else {
297 format!("{} /* add filter: {} = $1 */", trimmed, column)
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 fn contract() -> AgentContract {
306 AgentContract {
307 id: "analyst".into(),
308 read_only: true,
309 allowed_verbs: None,
310 allowed_tables: Some(vec!["users".into(), "orders".into()]),
311 denied_tables: vec!["secrets".into()],
312 require_predicate_on: vec![PredicateRule {
313 table: "orders".into(),
314 column: "tenant_id".into(),
315 }],
316 require_limit: true,
317 max_rows: Some(1000),
318 }
319 }
320
321 #[test]
322 fn allows_compliant_query() {
323 let c = contract();
324 assert!(validate("SELECT id FROM users WHERE id = 1 LIMIT 10", &c).is_ok());
325 }
326
327 #[test]
328 fn blocks_write_when_read_only() {
329 let v = validate("DELETE FROM users", &contract()).unwrap_err();
330 assert_eq!(v.violation, "write_forbidden");
331 }
332
333 #[test]
334 fn blocks_denied_table() {
335 let v = validate("SELECT * FROM secrets LIMIT 1", &contract()).unwrap_err();
336 assert_eq!(v.violation, "table_forbidden");
337 }
338
339 #[test]
340 fn blocks_table_not_in_allowlist() {
341 let v = validate("SELECT * FROM invoices LIMIT 1", &contract()).unwrap_err();
342 assert_eq!(v.violation, "table_not_allowed");
343 }
344
345 #[test]
346 fn requires_predicate_with_rewrite() {
347 let v = validate("SELECT * FROM orders LIMIT 5", &contract()).unwrap_err();
348 assert_eq!(v.violation, "missing_predicate");
349 let rw = v.suggested_rewrite.unwrap();
350 assert!(rw.to_lowercase().contains("tenant_id"));
351 }
352
353 #[test]
354 fn requires_limit_with_rewrite() {
355 let v = validate("SELECT id FROM users WHERE id = 1", &contract()).unwrap_err();
356 assert_eq!(v.violation, "missing_limit");
357 assert!(v
358 .suggested_rewrite
359 .unwrap()
360 .to_uppercase()
361 .contains("LIMIT 1000"));
362 }
363
364 #[test]
365 fn table_extraction_handles_schema_and_joins() {
366 let t = tables_of("SELECT * FROM public.users u JOIN orders o ON o.uid = u.id");
367 assert!(t.contains(&"users".to_string()));
368 assert!(t.contains(&"orders".to_string()));
369 }
370}