Skip to main content

heliosdb_proxy/
agent_contract.rs

1//! Per-agent SQL contracts + a contract validator with machine-readable
2//! repair hints.
3//!
4//! An agent contract is a scoped grant: which SQL verbs and tables an agent
5//! may touch, which predicates a query must carry (e.g. a tenant filter), and
6//! whether reads must be bounded by a LIMIT. Queries are validated against
7//! the contract BEFORE execution; a violation is returned as a structured
8//! [`Violation`] — a violation class, the offending fragment, and a suggested
9//! rewrite — so an LLM agent can read it and self-correct in one round trip
10//! instead of flailing against an opaque error.
11//!
12//! Validation is intentionally a lightweight static inspection (verb +
13//! table + predicate + LIMIT detection), the same altitude as a pg_hba /
14//! pgcat-style guard; it is a policy gate, not a full SQL parser.
15
16use once_cell::sync::Lazy;
17use regex::Regex;
18use serde::{Deserialize, Serialize};
19
20use crate::protocol::{contains_ci, starts_with_ci};
21
22/// A predicate an agent's queries must carry when they touch `table`.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredicateRule {
25    pub table: String,
26    /// Column that must appear in the query's WHERE clause.
27    pub column: String,
28}
29
30/// A scoped grant for one agent identity.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AgentContract {
33    /// Identifier matched against the connecting agent.
34    pub id: String,
35    /// Reject write/DDL statements.
36    #[serde(default = "default_true")]
37    pub read_only: bool,
38    /// If set, only these SQL verbs are allowed (upper-case, e.g. "SELECT").
39    #[serde(default)]
40    pub allowed_verbs: Option<Vec<String>>,
41    /// If set, only these tables may be referenced.
42    #[serde(default)]
43    pub allowed_tables: Option<Vec<String>>,
44    /// Tables that may never be referenced (takes precedence over allow).
45    #[serde(default)]
46    pub denied_tables: Vec<String>,
47    /// Predicates that must be present when the named table is touched.
48    #[serde(default)]
49    pub require_predicate_on: Vec<PredicateRule>,
50    /// Require a LIMIT on SELECTs.
51    #[serde(default)]
52    pub require_limit: bool,
53    /// Suggested/enforced row cap (used in repair hints and to back
54    /// `require_limit`).
55    #[serde(default)]
56    pub max_rows: Option<u64>,
57}
58
59fn default_true() -> bool {
60    true
61}
62
63/// A contract violation, serialized to the agent as a machine-readable hint.
64#[derive(Debug, Clone, Serialize)]
65pub struct Violation {
66    /// Stable class, e.g. "write_forbidden", "table_forbidden",
67    /// "missing_predicate", "missing_limit", "verb_forbidden".
68    pub violation: String,
69    /// Human/agent-readable explanation.
70    pub detail: String,
71    /// The offending input (the SQL).
72    pub offending: String,
73    /// A concrete corrected statement the agent can retry, when one can be
74    /// synthesised.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub suggested_rewrite: Option<String>,
77}
78
79impl Violation {
80    pub fn to_json(&self) -> String {
81        serde_json::to_string(self).unwrap_or_else(|_| self.detail.clone())
82    }
83}
84
85static TABLE_RE: Lazy<Regex> = Lazy::new(|| {
86    // FROM/JOIN/INTO/UPDATE <table> — captures schema-qualified identifiers.
87    Regex::new(r"(?i)\b(?:FROM|JOIN|INTO|UPDATE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)")
88        .expect("valid table regex")
89});
90
91/// Extract the leading SQL verb (upper-case), e.g. "SELECT".
92fn verb_of(sql: &str) -> String {
93    sql.trim_start()
94        .split(|c: char| c.is_whitespace() || c == '(')
95        .next()
96        .unwrap_or("")
97        .to_ascii_uppercase()
98}
99
100/// Extract referenced table names (lower-cased, bare name without schema).
101fn tables_of(sql: &str) -> Vec<String> {
102    TABLE_RE
103        .captures_iter(sql)
104        .filter_map(|c| c.get(1))
105        .map(|m| {
106            let full = m.as_str().to_ascii_lowercase();
107            // bare table name (strip schema prefix) for matching
108            full.rsplit('.').next().unwrap_or(&full).to_string()
109        })
110        .collect()
111}
112
113fn is_write_verb(verb: &str) -> bool {
114    matches!(
115        verb,
116        "INSERT" | "UPDATE" | "DELETE" | "CREATE" | "DROP" | "ALTER" | "TRUNCATE" | "GRANT"
117            | "REVOKE" | "MERGE" | "CALL" | "DO" | "COPY" | "VACUUM" | "REINDEX" | "CLUSTER"
118            | "LOCK" | "COMMENT"
119    )
120}
121
122/// Validate `sql` against `contract`. `Ok(())` admits the query; `Err`
123/// carries a structured repair hint.
124pub fn validate(sql: &str, contract: &AgentContract) -> Result<(), Violation> {
125    let trimmed = sql.trim();
126    let verb = verb_of(trimmed);
127
128    // 1. read-only
129    if contract.read_only && is_write_verb(&verb) {
130        return Err(Violation {
131            violation: "write_forbidden".into(),
132            detail: format!("agent '{}' is read-only; '{}' statements are not permitted", contract.id, verb),
133            offending: sql.to_string(),
134            suggested_rewrite: None,
135        });
136    }
137
138    // 2. allowed verbs
139    if let Some(ref verbs) = contract.allowed_verbs {
140        if !verbs.iter().any(|v| v.eq_ignore_ascii_case(&verb)) {
141            return Err(Violation {
142                violation: "verb_forbidden".into(),
143                detail: format!("verb '{}' not in this agent's allowed set {:?}", verb, verbs),
144                offending: sql.to_string(),
145                suggested_rewrite: None,
146            });
147        }
148    }
149
150    let tables = tables_of(trimmed);
151
152    // 3. denied tables (highest precedence)
153    for t in &tables {
154        if contract.denied_tables.iter().any(|d| d.eq_ignore_ascii_case(t)) {
155            return Err(Violation {
156                violation: "table_forbidden".into(),
157                detail: format!("table '{}' is denied to agent '{}'", t, contract.id),
158                offending: sql.to_string(),
159                suggested_rewrite: None,
160            });
161        }
162    }
163
164    // 4. allowed-tables allowlist
165    if let Some(ref allowed) = contract.allowed_tables {
166        for t in &tables {
167            if !allowed.iter().any(|a| a.eq_ignore_ascii_case(t)) {
168                return Err(Violation {
169                    violation: "table_not_allowed".into(),
170                    detail: format!("table '{}' not in this agent's allowed set {:?}", t, allowed),
171                    offending: sql.to_string(),
172                    suggested_rewrite: None,
173                });
174            }
175        }
176    }
177
178    // 5. required predicates per touched table
179    for rule in &contract.require_predicate_on {
180        if tables.iter().any(|t| t.eq_ignore_ascii_case(&rule.table)) && !mentions_predicate(trimmed, &rule.column) {
181            let rewrite = inject_predicate(trimmed, &rule.column);
182            return Err(Violation {
183                violation: "missing_predicate".into(),
184                detail: format!(
185                    "queries on '{}' must filter by '{}'",
186                    rule.table, rule.column
187                ),
188                offending: sql.to_string(),
189                suggested_rewrite: Some(rewrite),
190            });
191        }
192    }
193
194    // 6. require LIMIT on SELECT
195    if contract.require_limit && verb == "SELECT" && !contains_ci(trimmed, " LIMIT ") && !ends_with_limit(trimmed) {
196        let cap = contract.max_rows.unwrap_or(1000);
197        return Err(Violation {
198            violation: "missing_limit".into(),
199            detail: format!("SELECTs must be bounded; add LIMIT {} or fewer", cap),
200            offending: sql.to_string(),
201            suggested_rewrite: Some(format!("{} LIMIT {}", trimmed.trim_end_matches(';').trim_end(), cap)),
202        });
203    }
204
205    Ok(())
206}
207
208/// Does the statement reference `column` in a WHERE-ish position? Heuristic:
209/// there is a WHERE clause and the column name appears after it.
210fn mentions_predicate(sql: &str, column: &str) -> bool {
211    let upper = sql.to_ascii_uppercase();
212    if let Some(where_pos) = upper.find(" WHERE ") {
213        let after = &sql[where_pos..];
214        contains_ci(after, column)
215    } else {
216        false
217    }
218}
219
220fn ends_with_limit(sql: &str) -> bool {
221    // Catch "... LIMIT <n>" at the tail (the leading-space `contains_ci`
222    // check in `validate` misses a LIMIT that is the final clause).
223    let up = sql.trim_end_matches(';').trim_end().to_ascii_uppercase();
224    let words: Vec<&str> = up.split_whitespace().collect();
225    let n = words.len();
226    n >= 2 && words[n - 2] == "LIMIT"
227}
228
229/// Best-effort: add `WHERE <col> = $1` (or ` AND <col> = $1` when a WHERE
230/// already exists) so the agent has a concrete statement to retry.
231fn inject_predicate(sql: &str, column: &str) -> String {
232    let trimmed = sql.trim().trim_end_matches(';').trim_end();
233    if starts_with_ci(trimmed, "SELECT") || starts_with_ci(trimmed, "UPDATE") || starts_with_ci(trimmed, "DELETE") {
234        if contains_ci(trimmed, " WHERE ") {
235            format!("{} AND {} = $1", trimmed, column)
236        } else {
237            // Insert before ORDER BY / GROUP BY / LIMIT if present, else append.
238            let up = trimmed.to_ascii_uppercase();
239            let cut = ["ORDER BY", "GROUP BY", "LIMIT", "HAVING"]
240                .iter()
241                .filter_map(|kw| up.find(kw))
242                .min();
243            match cut {
244                Some(pos) => format!("{} WHERE {} = $1 {}", trimmed[..pos].trim_end(), column, &trimmed[pos..]),
245                None => format!("{} WHERE {} = $1", trimmed, column),
246            }
247        }
248    } else {
249        format!("{} /* add filter: {} = $1 */", trimmed, column)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    fn contract() -> AgentContract {
258        AgentContract {
259            id: "analyst".into(),
260            read_only: true,
261            allowed_verbs: None,
262            allowed_tables: Some(vec!["users".into(), "orders".into()]),
263            denied_tables: vec!["secrets".into()],
264            require_predicate_on: vec![PredicateRule { table: "orders".into(), column: "tenant_id".into() }],
265            require_limit: true,
266            max_rows: Some(1000),
267        }
268    }
269
270    #[test]
271    fn allows_compliant_query() {
272        let c = contract();
273        assert!(validate("SELECT id FROM users WHERE id = 1 LIMIT 10", &c).is_ok());
274    }
275
276    #[test]
277    fn blocks_write_when_read_only() {
278        let v = validate("DELETE FROM users", &contract()).unwrap_err();
279        assert_eq!(v.violation, "write_forbidden");
280    }
281
282    #[test]
283    fn blocks_denied_table() {
284        let v = validate("SELECT * FROM secrets LIMIT 1", &contract()).unwrap_err();
285        assert_eq!(v.violation, "table_forbidden");
286    }
287
288    #[test]
289    fn blocks_table_not_in_allowlist() {
290        let v = validate("SELECT * FROM invoices LIMIT 1", &contract()).unwrap_err();
291        assert_eq!(v.violation, "table_not_allowed");
292    }
293
294    #[test]
295    fn requires_predicate_with_rewrite() {
296        let v = validate("SELECT * FROM orders LIMIT 5", &contract()).unwrap_err();
297        assert_eq!(v.violation, "missing_predicate");
298        let rw = v.suggested_rewrite.unwrap();
299        assert!(rw.to_lowercase().contains("tenant_id"));
300    }
301
302    #[test]
303    fn requires_limit_with_rewrite() {
304        let v = validate("SELECT id FROM users WHERE id = 1", &contract()).unwrap_err();
305        assert_eq!(v.violation, "missing_limit");
306        assert!(v.suggested_rewrite.unwrap().to_uppercase().contains("LIMIT 1000"));
307    }
308
309    #[test]
310    fn table_extraction_handles_schema_and_joins() {
311        let t = tables_of("SELECT * FROM public.users u JOIN orders o ON o.uid = u.id");
312        assert!(t.contains(&"users".to_string()));
313        assert!(t.contains(&"orders".to_string()));
314    }
315}