Skip to main content

chio_guards/
memory_governance.rs

1//! MemoryGovernanceGuard -- enforce memory store allowlist, retention
2//! TTL ceilings, and per-session memory-entry counts on
3//! [`ToolAction::MemoryWrite`] and [`ToolAction::MemoryRead`] actions.
4//!
5//! Roadmap phase 18.1 (see `docs/protocols/STRUCTURAL-SECURITY-FIXES.md`
6//! section 3).  The guard sources its policy from two places:
7//!
8//! 1. **Capability constraints** on the matched grant
9//!    ([`Constraint::MemoryStoreAllowlist`]): when present, writes and
10//!    reads targeting a store outside the allowlist are denied.
11//! 2. **Guard configuration** ([`MemoryGovernanceConfig`]): provides
12//!    deployment-wide defaults for `max_memory_entries`,
13//!    `max_retention_ttl_secs`, and per-store overrides.  Operators can
14//!    use these even when the current capability grammar does not
15//!    surface the equivalent constraints (see ADR-TYPE-EVOLUTION for
16//!    future expansion to first-class constraints).
17//!
18//! The guard keeps an in-memory per-session counter of memory writes so
19//! it can enforce [`MemoryGovernanceConfig::max_memory_entries`]
20//! deterministically without touching shared kernel state.
21//!
22//! # Fail-closed semantics
23//!
24//! - memory writes without a parseable store key are denied when the
25//!   matched grant carries a non-empty `MemoryStoreAllowlist`;
26//! - malformed deny-pattern regex input causes
27//!   [`MemoryGovernanceGuard::with_config`] to return
28//!   [`MemoryGovernanceError::InvalidPattern`];
29//! - writes with an explicit retention TTL above `max_retention_ttl_secs`
30//!   are denied;
31//! - writes whose total matches / exceeds `max_memory_entries` are
32//!   denied (fail-closed on counter mutex poisoning).
33
34use std::collections::HashMap;
35use std::sync::Mutex;
36
37use regex::Regex;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40
41use chio_core::capability::Constraint;
42use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
43
44use crate::action::{extract_action, ToolAction};
45
46/// Errors produced when building a [`MemoryGovernanceGuard`].
47#[derive(Debug, thiserror::Error)]
48pub enum MemoryGovernanceError {
49    /// A `deny_patterns` entry was not a valid regex.
50    #[error("invalid deny pattern `{pattern}`: {source}")]
51    InvalidPattern {
52        pattern: String,
53        #[source]
54        source: regex::Error,
55    },
56}
57
58/// Configuration for [`MemoryGovernanceGuard`].
59#[derive(Clone, Debug, Deserialize, Serialize)]
60#[serde(deny_unknown_fields)]
61pub struct MemoryGovernanceConfig {
62    /// Enable/disable the guard entirely.
63    #[serde(default = "default_true")]
64    pub enabled: bool,
65    /// Hard-coded store allowlist applied on top of the capability-level
66    /// [`Constraint::MemoryStoreAllowlist`].  Empty means "no additional
67    /// allowlist" (capability-level list still applies).
68    #[serde(default)]
69    pub store_allowlist: Vec<String>,
70    /// Maximum memory-entry count per agent + session combination.  When
71    /// `Some(n)`, the `n`-th write is denied.
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    pub max_memory_entries: Option<u64>,
74    /// Maximum retention TTL (seconds) allowed on a single write.  When
75    /// `Some(ttl)`, writes requesting a larger TTL -- or indefinite
76    /// retention (missing TTL) -- are denied.
77    #[serde(default, skip_serializing_if = "Option::is_none")]
78    pub max_retention_ttl_secs: Option<u64>,
79    /// Maximum content size (bytes) for a single memory write.  `None`
80    /// disables the check.
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub max_content_size_bytes: Option<u64>,
83    /// Extra regex patterns that deny a write when the content matches.
84    #[serde(default)]
85    pub deny_patterns: Vec<String>,
86}
87
88fn default_true() -> bool {
89    true
90}
91
92impl Default for MemoryGovernanceConfig {
93    fn default() -> Self {
94        Self {
95            enabled: true,
96            store_allowlist: Vec::new(),
97            max_memory_entries: None,
98            max_retention_ttl_secs: None,
99            max_content_size_bytes: None,
100            deny_patterns: Vec::new(),
101        }
102    }
103}
104
105/// Session key used for per-session memory-entry counting.
106type SessionKey = (String, String); // (agent_id, capability_id)
107
108/// Guard implementing memory governance (phase 18.1).
109pub struct MemoryGovernanceGuard {
110    enabled: bool,
111    store_allowlist: Vec<String>,
112    max_memory_entries: Option<u64>,
113    max_retention_ttl_secs: Option<u64>,
114    max_content_size_bytes: Option<u64>,
115    deny_patterns: Vec<Regex>,
116    counters: Mutex<HashMap<SessionKey, u64>>,
117}
118
119impl MemoryGovernanceGuard {
120    /// Build a guard with default configuration (no limits).  Non-guard
121    /// code paths remain fully permissive until a capability constraint
122    /// or config field is supplied.
123    pub fn new() -> Self {
124        Self::with_config(MemoryGovernanceConfig::default()).unwrap_or_else(|_| Self {
125            enabled: true,
126            store_allowlist: Vec::new(),
127            max_memory_entries: None,
128            max_retention_ttl_secs: None,
129            max_content_size_bytes: None,
130            deny_patterns: Vec::new(),
131            counters: Mutex::new(HashMap::new()),
132        })
133    }
134
135    /// Build a guard with explicit configuration.
136    pub fn with_config(config: MemoryGovernanceConfig) -> Result<Self, MemoryGovernanceError> {
137        let mut deny_patterns = Vec::with_capacity(config.deny_patterns.len());
138        for pat in &config.deny_patterns {
139            let re = Regex::new(pat).map_err(|e| MemoryGovernanceError::InvalidPattern {
140                pattern: pat.clone(),
141                source: e,
142            })?;
143            deny_patterns.push(re);
144        }
145        Ok(Self {
146            enabled: config.enabled,
147            store_allowlist: config.store_allowlist,
148            max_memory_entries: config.max_memory_entries,
149            max_retention_ttl_secs: config.max_retention_ttl_secs,
150            max_content_size_bytes: config.max_content_size_bytes,
151            deny_patterns,
152            counters: Mutex::new(HashMap::new()),
153        })
154    }
155
156    /// Current counter value for a session (test / observability helper).
157    pub fn session_count(&self, agent_id: &str, capability_id: &str) -> u64 {
158        self.counters
159            .lock()
160            .ok()
161            .and_then(|g| {
162                g.get(&(agent_id.to_string(), capability_id.to_string()))
163                    .copied()
164            })
165            .unwrap_or(0)
166    }
167
168    /// Gather the effective store allowlist from the matched grant plus
169    /// the guard-level config.  Returns `None` if neither source supplies
170    /// a non-empty allowlist.
171    fn effective_store_allowlist<'a>(&'a self, ctx: &'a GuardContext<'a>) -> Option<Vec<String>> {
172        let mut combined: Vec<String> = self.store_allowlist.clone();
173        if let Some(grant) = ctx
174            .matched_grant_index
175            .and_then(|i| ctx.scope.grants.get(i))
176        {
177            for c in &grant.constraints {
178                if let Constraint::MemoryStoreAllowlist(list) = c {
179                    combined.extend(list.iter().cloned());
180                }
181            }
182        }
183        if combined.is_empty() {
184            None
185        } else {
186            Some(combined)
187        }
188    }
189
190    /// Increment the per-session write counter and return the new value.
191    /// Fails closed (treats poisoning as "over limit") on mutex poisoning.
192    fn bump_counter(&self, key: SessionKey) -> Result<u64, KernelError> {
193        let mut guard = self.counters.lock().map_err(|_| {
194            KernelError::Internal("memory-governance guard counter mutex poisoned".to_string())
195        })?;
196        let entry = guard.entry(key).or_insert(0);
197        *entry = entry.saturating_add(1);
198        Ok(*entry)
199    }
200}
201
202impl Default for MemoryGovernanceGuard {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208impl Guard for MemoryGovernanceGuard {
209    fn name(&self) -> &str {
210        "memory-governance"
211    }
212
213    fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
214        if !self.enabled {
215            return Ok(Verdict::Allow);
216        }
217
218        let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
219
220        match action {
221            ToolAction::MemoryWrite { store, .. } => self.evaluate_write(ctx, &store),
222            ToolAction::MemoryRead { store, .. } => self.evaluate_read(ctx, &store),
223            _ => Ok(Verdict::Allow),
224        }
225    }
226}
227
228impl MemoryGovernanceGuard {
229    fn evaluate_write(&self, ctx: &GuardContext, store: &str) -> Result<Verdict, KernelError> {
230        // 1. Store allowlist (capability + guard config).
231        if let Some(allow) = self.effective_store_allowlist(ctx) {
232            if !allow.iter().any(|s| store_matches(s, store)) {
233                return Ok(Verdict::Deny);
234            }
235        }
236
237        // 2. Retention TTL ceiling.
238        if let Some(max_ttl) = self.max_retention_ttl_secs {
239            let requested = extract_retention_ttl(&ctx.request.arguments);
240            match requested {
241                None => {
242                    // Missing TTL with a configured ceiling is treated
243                    // as a request for indefinite retention and denied.
244                    return Ok(Verdict::Deny);
245                }
246                Some(ttl) if ttl > max_ttl => {
247                    return Ok(Verdict::Deny);
248                }
249                Some(_) => {}
250            }
251        }
252
253        // 3. Content size.
254        if let Some(max_bytes) = self.max_content_size_bytes {
255            if let Some(size) = extract_content_size_bytes(&ctx.request.arguments) {
256                if size > max_bytes {
257                    return Ok(Verdict::Deny);
258                }
259            }
260        }
261
262        // 4. Deny patterns on content.
263        if !self.deny_patterns.is_empty() {
264            if let Some(content) = extract_content_text(&ctx.request.arguments) {
265                for re in &self.deny_patterns {
266                    if re.is_match(&content) {
267                        return Ok(Verdict::Deny);
268                    }
269                }
270            }
271        }
272
273        // 5. Per-session entry limit.  We bump the counter only after
274        //    the previous gates pass; denials do not consume quota.
275        if let Some(max_entries) = self.max_memory_entries {
276            let key = (ctx.agent_id.to_string(), ctx.request.capability.id.clone());
277            let count = self.bump_counter(key)?;
278            if count > max_entries {
279                return Ok(Verdict::Deny);
280            }
281        }
282
283        Ok(Verdict::Allow)
284    }
285
286    fn evaluate_read(&self, ctx: &GuardContext, store: &str) -> Result<Verdict, KernelError> {
287        // Reads respect the store allowlist so an agent cannot read from
288        // a forbidden store even when the write path is blocked.
289        if let Some(allow) = self.effective_store_allowlist(ctx) {
290            if !allow.iter().any(|s| store_matches(s, store)) {
291                return Ok(Verdict::Deny);
292            }
293        }
294        Ok(Verdict::Allow)
295    }
296}
297
298/// Store allowlist match: supports exact match and `*` wildcard.
299fn store_matches(pattern: &str, store: &str) -> bool {
300    if pattern == "*" {
301        return true;
302    }
303    if let Some(prefix) = pattern.strip_suffix('*') {
304        return store.starts_with(prefix);
305    }
306    pattern == store
307}
308
309/// Read an explicit retention TTL (seconds) from the arguments.
310fn extract_retention_ttl(arguments: &Value) -> Option<u64> {
311    for key in [
312        "retention_ttl",
313        "retentionTtl",
314        "retention_ttl_secs",
315        "retentionTtlSecs",
316        "ttl",
317        "ttl_secs",
318        "expires_in",
319        "expiresIn",
320    ] {
321        if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
322            return Some(v);
323        }
324    }
325    None
326}
327
328/// Read an explicit content byte size from the arguments, falling back
329/// to the length of the `content` / `text` string when present.
330fn extract_content_size_bytes(arguments: &Value) -> Option<u64> {
331    for key in ["content_size", "contentSize", "content_bytes", "size"] {
332        if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
333            return Some(v);
334        }
335    }
336    extract_content_text(arguments).map(|s| s.len() as u64)
337}
338
339/// Extract the text body of a memory write for regex / size checks.
340fn extract_content_text(arguments: &Value) -> Option<String> {
341    for key in ["content", "text", "value", "vector_text", "payload"] {
342        if let Some(v) = arguments.get(key).and_then(|v| v.as_str()) {
343            if !v.is_empty() {
344                return Some(v.to_string());
345            }
346        }
347    }
348    None
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn store_matches_wildcards() {
357        assert!(store_matches("*", "anything"));
358        assert!(store_matches("agent-*", "agent-notes"));
359        assert!(!store_matches("agent-*", "other"));
360        assert!(store_matches("agent-notes", "agent-notes"));
361    }
362
363    #[test]
364    fn extract_retention_ttl_reads_common_keys() {
365        let args = serde_json::json!({"ttl": 600});
366        assert_eq!(extract_retention_ttl(&args), Some(600));
367        let camel = serde_json::json!({"retentionTtl": 120});
368        assert_eq!(extract_retention_ttl(&camel), Some(120));
369        let none = serde_json::json!({});
370        assert_eq!(extract_retention_ttl(&none), None);
371    }
372
373    #[test]
374    fn content_size_falls_back_to_text_length() {
375        let args = serde_json::json!({"content": "hello"});
376        assert_eq!(extract_content_size_bytes(&args), Some(5));
377    }
378}