Skip to main content

heliosdb_proxy/
agent_contract.rs

1//! Per-agent SQL contracts + a contract validator with machine-readable
2//! repair hints.
3//!
4//! An agent contract is a scoped grant: which SQL verbs and tables an agent
5//! may touch, which predicates a query must carry (e.g. a tenant filter), and
6//! whether reads must be bounded by a LIMIT. Queries are validated against
7//! the contract BEFORE execution; a violation is returned as a structured
8//! [`Violation`] — a violation class, the offending fragment, and a suggested
9//! rewrite — so an LLM agent can read it and self-correct in one round trip
10//! instead of flailing against an opaque error.
11//!
12//! Validation is intentionally a lightweight static inspection (verb +
13//! table + predicate + LIMIT detection), the same altitude as a pg_hba /
14//! pgcat-style guard; it is a policy gate, not a full SQL parser.
15
16use once_cell::sync::Lazy;
17use regex::Regex;
18use serde::{Deserialize, Serialize};
19
20use crate::protocol::{contains_ci, starts_with_ci};
21
22/// A predicate an agent's queries must carry when they touch `table`.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredicateRule {
25    pub table: String,
26    /// Column that must appear in the query's WHERE clause.
27    pub column: String,
28}
29
30/// A scoped grant for one agent identity.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AgentContract {
33    /// Identifier matched against the connecting agent.
34    pub id: String,
35    /// Reject write/DDL statements.
36    #[serde(default = "default_true")]
37    pub read_only: bool,
38    /// If set, only these SQL verbs are allowed (upper-case, e.g. "SELECT").
39    #[serde(default)]
40    pub allowed_verbs: Option<Vec<String>>,
41    /// If set, only these tables may be referenced.
42    #[serde(default)]
43    pub allowed_tables: Option<Vec<String>>,
44    /// Tables that may never be referenced (takes precedence over allow).
45    #[serde(default)]
46    pub denied_tables: Vec<String>,
47    /// Predicates that must be present when the named table is touched.
48    #[serde(default)]
49    pub require_predicate_on: Vec<PredicateRule>,
50    /// Require a LIMIT on SELECTs.
51    #[serde(default)]
52    pub require_limit: bool,
53    /// Suggested/enforced row cap (used in repair hints and to back
54    /// `require_limit`).
55    #[serde(default)]
56    pub max_rows: Option<u64>,
57}
58
59fn default_true() -> bool {
60    true
61}
62
63/// A contract violation, serialized to the agent as a machine-readable hint.
64#[derive(Debug, Clone, Serialize)]
65pub struct Violation {
66    /// Stable class, e.g. "write_forbidden", "table_forbidden",
67    /// "missing_predicate", "missing_limit", "verb_forbidden".
68    pub violation: String,
69    /// Human/agent-readable explanation.
70    pub detail: String,
71    /// The offending input (the SQL).
72    pub offending: String,
73    /// A concrete corrected statement the agent can retry, when one can be
74    /// synthesised.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub suggested_rewrite: Option<String>,
77}
78
79impl Violation {
80    pub fn to_json(&self) -> String {
81        serde_json::to_string(self).unwrap_or_else(|_| self.detail.clone())
82    }
83}
84
85static TABLE_RE: Lazy<Regex> = Lazy::new(|| {
86    // FROM/JOIN/INTO/UPDATE <table> — captures schema-qualified identifiers.
87    Regex::new(
88        r"(?i)\b(?:FROM|JOIN|INTO|UPDATE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)",
89    )
90    .expect("valid table regex")
91});
92
93/// Extract the leading SQL verb (upper-case), e.g. "SELECT".
94fn verb_of(sql: &str) -> String {
95    sql.trim_start()
96        .split(|c: char| c.is_whitespace() || c == '(')
97        .next()
98        .unwrap_or("")
99        .to_ascii_uppercase()
100}
101
102/// Extract referenced table names (lower-cased, bare name without schema).
103fn tables_of(sql: &str) -> Vec<String> {
104    TABLE_RE
105        .captures_iter(sql)
106        .filter_map(|c| c.get(1))
107        .map(|m| {
108            let full = m.as_str().to_ascii_lowercase();
109            // bare table name (strip schema prefix) for matching
110            full.rsplit('.').next().unwrap_or(&full).to_string()
111        })
112        .collect()
113}
114
115fn is_write_verb(verb: &str) -> bool {
116    matches!(
117        verb,
118        "INSERT"
119            | "UPDATE"
120            | "DELETE"
121            | "CREATE"
122            | "DROP"
123            | "ALTER"
124            | "TRUNCATE"
125            | "GRANT"
126            | "REVOKE"
127            | "MERGE"
128            | "CALL"
129            | "DO"
130            | "COPY"
131            | "VACUUM"
132            | "REINDEX"
133            | "CLUSTER"
134            | "LOCK"
135            | "COMMENT"
136    )
137}
138
139/// Validate `sql` against `contract`. `Ok(())` admits the query; `Err`
140/// carries a structured repair hint.
141pub fn validate(sql: &str, contract: &AgentContract) -> Result<(), Violation> {
142    let trimmed = sql.trim();
143    let verb = verb_of(trimmed);
144
145    // 1. read-only
146    if contract.read_only && is_write_verb(&verb) {
147        return Err(Violation {
148            violation: "write_forbidden".into(),
149            detail: format!(
150                "agent '{}' is read-only; '{}' statements are not permitted",
151                contract.id, verb
152            ),
153            offending: sql.to_string(),
154            suggested_rewrite: None,
155        });
156    }
157
158    // 2. allowed verbs
159    if let Some(ref verbs) = contract.allowed_verbs {
160        if !verbs.iter().any(|v| v.eq_ignore_ascii_case(&verb)) {
161            return Err(Violation {
162                violation: "verb_forbidden".into(),
163                detail: format!(
164                    "verb '{}' not in this agent's allowed set {:?}",
165                    verb, verbs
166                ),
167                offending: sql.to_string(),
168                suggested_rewrite: None,
169            });
170        }
171    }
172
173    let tables = tables_of(trimmed);
174
175    // 3. denied tables (highest precedence)
176    for t in &tables {
177        if contract
178            .denied_tables
179            .iter()
180            .any(|d| d.eq_ignore_ascii_case(t))
181        {
182            return Err(Violation {
183                violation: "table_forbidden".into(),
184                detail: format!("table '{}' is denied to agent '{}'", t, contract.id),
185                offending: sql.to_string(),
186                suggested_rewrite: None,
187            });
188        }
189    }
190
191    // 4. allowed-tables allowlist
192    if let Some(ref allowed) = contract.allowed_tables {
193        for t in &tables {
194            if !allowed.iter().any(|a| a.eq_ignore_ascii_case(t)) {
195                return Err(Violation {
196                    violation: "table_not_allowed".into(),
197                    detail: format!(
198                        "table '{}' not in this agent's allowed set {:?}",
199                        t, allowed
200                    ),
201                    offending: sql.to_string(),
202                    suggested_rewrite: None,
203                });
204            }
205        }
206    }
207
208    // 5. required predicates per touched table
209    for rule in &contract.require_predicate_on {
210        if tables.iter().any(|t| t.eq_ignore_ascii_case(&rule.table))
211            && !mentions_predicate(trimmed, &rule.column)
212        {
213            let rewrite = inject_predicate(trimmed, &rule.column);
214            return Err(Violation {
215                violation: "missing_predicate".into(),
216                detail: format!(
217                    "queries on '{}' must filter by '{}'",
218                    rule.table, rule.column
219                ),
220                offending: sql.to_string(),
221                suggested_rewrite: Some(rewrite),
222            });
223        }
224    }
225
226    // 6. require LIMIT on SELECT
227    if contract.require_limit
228        && verb == "SELECT"
229        && !contains_ci(trimmed, " LIMIT ")
230        && !ends_with_limit(trimmed)
231    {
232        let cap = contract.max_rows.unwrap_or(1000);
233        return Err(Violation {
234            violation: "missing_limit".into(),
235            detail: format!("SELECTs must be bounded; add LIMIT {} or fewer", cap),
236            offending: sql.to_string(),
237            suggested_rewrite: Some(format!(
238                "{} LIMIT {}",
239                trimmed.trim_end_matches(';').trim_end(),
240                cap
241            )),
242        });
243    }
244
245    Ok(())
246}
247
248/// Does the statement reference `column` in a WHERE-ish position? Heuristic:
249/// there is a WHERE clause and the column name appears after it.
250fn mentions_predicate(sql: &str, column: &str) -> bool {
251    let upper = sql.to_ascii_uppercase();
252    if let Some(where_pos) = upper.find(" WHERE ") {
253        let after = &sql[where_pos..];
254        contains_ci(after, column)
255    } else {
256        false
257    }
258}
259
260fn ends_with_limit(sql: &str) -> bool {
261    // Catch "... LIMIT <n>" at the tail (the leading-space `contains_ci`
262    // check in `validate` misses a LIMIT that is the final clause).
263    let up = sql.trim_end_matches(';').trim_end().to_ascii_uppercase();
264    let words: Vec<&str> = up.split_whitespace().collect();
265    let n = words.len();
266    n >= 2 && words[n - 2] == "LIMIT"
267}
268
269/// Best-effort: add `WHERE <col> = $1` (or ` AND <col> = $1` when a WHERE
270/// already exists) so the agent has a concrete statement to retry.
271fn inject_predicate(sql: &str, column: &str) -> String {
272    let trimmed = sql.trim().trim_end_matches(';').trim_end();
273    if starts_with_ci(trimmed, "SELECT")
274        || starts_with_ci(trimmed, "UPDATE")
275        || starts_with_ci(trimmed, "DELETE")
276    {
277        if contains_ci(trimmed, " WHERE ") {
278            format!("{} AND {} = $1", trimmed, column)
279        } else {
280            // Insert before ORDER BY / GROUP BY / LIMIT if present, else append.
281            let up = trimmed.to_ascii_uppercase();
282            let cut = ["ORDER BY", "GROUP BY", "LIMIT", "HAVING"]
283                .iter()
284                .filter_map(|kw| up.find(kw))
285                .min();
286            match cut {
287                Some(pos) => format!(
288                    "{} WHERE {} = $1 {}",
289                    trimmed[..pos].trim_end(),
290                    column,
291                    &trimmed[pos..]
292                ),
293                None => format!("{} WHERE {} = $1", trimmed, column),
294            }
295        }
296    } else {
297        format!("{} /* add filter: {} = $1 */", trimmed, column)
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    fn contract() -> AgentContract {
306        AgentContract {
307            id: "analyst".into(),
308            read_only: true,
309            allowed_verbs: None,
310            allowed_tables: Some(vec!["users".into(), "orders".into()]),
311            denied_tables: vec!["secrets".into()],
312            require_predicate_on: vec![PredicateRule {
313                table: "orders".into(),
314                column: "tenant_id".into(),
315            }],
316            require_limit: true,
317            max_rows: Some(1000),
318        }
319    }
320
321    #[test]
322    fn allows_compliant_query() {
323        let c = contract();
324        assert!(validate("SELECT id FROM users WHERE id = 1 LIMIT 10", &c).is_ok());
325    }
326
327    #[test]
328    fn blocks_write_when_read_only() {
329        let v = validate("DELETE FROM users", &contract()).unwrap_err();
330        assert_eq!(v.violation, "write_forbidden");
331    }
332
333    #[test]
334    fn blocks_denied_table() {
335        let v = validate("SELECT * FROM secrets LIMIT 1", &contract()).unwrap_err();
336        assert_eq!(v.violation, "table_forbidden");
337    }
338
339    #[test]
340    fn blocks_table_not_in_allowlist() {
341        let v = validate("SELECT * FROM invoices LIMIT 1", &contract()).unwrap_err();
342        assert_eq!(v.violation, "table_not_allowed");
343    }
344
345    #[test]
346    fn requires_predicate_with_rewrite() {
347        let v = validate("SELECT * FROM orders LIMIT 5", &contract()).unwrap_err();
348        assert_eq!(v.violation, "missing_predicate");
349        let rw = v.suggested_rewrite.unwrap();
350        assert!(rw.to_lowercase().contains("tenant_id"));
351    }
352
353    #[test]
354    fn requires_limit_with_rewrite() {
355        let v = validate("SELECT id FROM users WHERE id = 1", &contract()).unwrap_err();
356        assert_eq!(v.violation, "missing_limit");
357        assert!(v
358            .suggested_rewrite
359            .unwrap()
360            .to_uppercase()
361            .contains("LIMIT 1000"));
362    }
363
364    #[test]
365    fn table_extraction_handles_schema_and_joins() {
366        let t = tables_of("SELECT * FROM public.users u JOIN orders o ON o.uid = u.id");
367        assert!(t.contains(&"users".to_string()));
368        assert!(t.contains(&"orders".to_string()));
369    }
370}