1use std::collections::HashMap;
35use std::sync::Mutex;
36
37use regex::Regex;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40
41use chio_core::capability::Constraint;
42use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
43
44use crate::action::{extract_action, ToolAction};
45
46#[derive(Debug, thiserror::Error)]
48pub enum MemoryGovernanceError {
49 #[error("invalid deny pattern `{pattern}`: {source}")]
51 InvalidPattern {
52 pattern: String,
53 #[source]
54 source: regex::Error,
55 },
56}
57
58#[derive(Clone, Debug, Deserialize, Serialize)]
60#[serde(deny_unknown_fields)]
61pub struct MemoryGovernanceConfig {
62 #[serde(default = "default_true")]
64 pub enabled: bool,
65 #[serde(default)]
69 pub store_allowlist: Vec<String>,
70 #[serde(default, skip_serializing_if = "Option::is_none")]
73 pub max_memory_entries: Option<u64>,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub max_retention_ttl_secs: Option<u64>,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub max_content_size_bytes: Option<u64>,
83 #[serde(default)]
85 pub deny_patterns: Vec<String>,
86}
87
88fn default_true() -> bool {
89 true
90}
91
92impl Default for MemoryGovernanceConfig {
93 fn default() -> Self {
94 Self {
95 enabled: true,
96 store_allowlist: Vec::new(),
97 max_memory_entries: None,
98 max_retention_ttl_secs: None,
99 max_content_size_bytes: None,
100 deny_patterns: Vec::new(),
101 }
102 }
103}
104
105type SessionKey = (String, String); pub struct MemoryGovernanceGuard {
110 enabled: bool,
111 store_allowlist: Vec<String>,
112 max_memory_entries: Option<u64>,
113 max_retention_ttl_secs: Option<u64>,
114 max_content_size_bytes: Option<u64>,
115 deny_patterns: Vec<Regex>,
116 counters: Mutex<HashMap<SessionKey, u64>>,
117}
118
119impl MemoryGovernanceGuard {
120 pub fn new() -> Self {
124 Self::with_config(MemoryGovernanceConfig::default()).unwrap_or_else(|_| Self {
125 enabled: true,
126 store_allowlist: Vec::new(),
127 max_memory_entries: None,
128 max_retention_ttl_secs: None,
129 max_content_size_bytes: None,
130 deny_patterns: Vec::new(),
131 counters: Mutex::new(HashMap::new()),
132 })
133 }
134
135 pub fn with_config(config: MemoryGovernanceConfig) -> Result<Self, MemoryGovernanceError> {
137 let mut deny_patterns = Vec::with_capacity(config.deny_patterns.len());
138 for pat in &config.deny_patterns {
139 let re = Regex::new(pat).map_err(|e| MemoryGovernanceError::InvalidPattern {
140 pattern: pat.clone(),
141 source: e,
142 })?;
143 deny_patterns.push(re);
144 }
145 Ok(Self {
146 enabled: config.enabled,
147 store_allowlist: config.store_allowlist,
148 max_memory_entries: config.max_memory_entries,
149 max_retention_ttl_secs: config.max_retention_ttl_secs,
150 max_content_size_bytes: config.max_content_size_bytes,
151 deny_patterns,
152 counters: Mutex::new(HashMap::new()),
153 })
154 }
155
156 pub fn session_count(&self, agent_id: &str, capability_id: &str) -> u64 {
158 self.counters
159 .lock()
160 .ok()
161 .and_then(|g| {
162 g.get(&(agent_id.to_string(), capability_id.to_string()))
163 .copied()
164 })
165 .unwrap_or(0)
166 }
167
168 fn effective_store_allowlist<'a>(&'a self, ctx: &'a GuardContext<'a>) -> Option<Vec<String>> {
172 let mut combined: Vec<String> = self.store_allowlist.clone();
173 if let Some(grant) = ctx
174 .matched_grant_index
175 .and_then(|i| ctx.scope.grants.get(i))
176 {
177 for c in &grant.constraints {
178 if let Constraint::MemoryStoreAllowlist(list) = c {
179 combined.extend(list.iter().cloned());
180 }
181 }
182 }
183 if combined.is_empty() {
184 None
185 } else {
186 Some(combined)
187 }
188 }
189
190 fn bump_counter(&self, key: SessionKey) -> Result<u64, KernelError> {
193 let mut guard = self.counters.lock().map_err(|_| {
194 KernelError::Internal("memory-governance guard counter mutex poisoned".to_string())
195 })?;
196 let entry = guard.entry(key).or_insert(0);
197 *entry = entry.saturating_add(1);
198 Ok(*entry)
199 }
200}
201
202impl Default for MemoryGovernanceGuard {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl Guard for MemoryGovernanceGuard {
209 fn name(&self) -> &str {
210 "memory-governance"
211 }
212
213 fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
214 if !self.enabled {
215 return Ok(Verdict::Allow);
216 }
217
218 let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
219
220 match action {
221 ToolAction::MemoryWrite { store, .. } => self.evaluate_write(ctx, &store),
222 ToolAction::MemoryRead { store, .. } => self.evaluate_read(ctx, &store),
223 _ => Ok(Verdict::Allow),
224 }
225 }
226}
227
228impl MemoryGovernanceGuard {
229 fn evaluate_write(&self, ctx: &GuardContext, store: &str) -> Result<Verdict, KernelError> {
230 if let Some(allow) = self.effective_store_allowlist(ctx) {
232 if !allow.iter().any(|s| store_matches(s, store)) {
233 return Ok(Verdict::Deny);
234 }
235 }
236
237 if let Some(max_ttl) = self.max_retention_ttl_secs {
239 let requested = extract_retention_ttl(&ctx.request.arguments);
240 match requested {
241 None => {
242 return Ok(Verdict::Deny);
245 }
246 Some(ttl) if ttl > max_ttl => {
247 return Ok(Verdict::Deny);
248 }
249 Some(_) => {}
250 }
251 }
252
253 if let Some(max_bytes) = self.max_content_size_bytes {
255 if let Some(size) = extract_content_size_bytes(&ctx.request.arguments) {
256 if size > max_bytes {
257 return Ok(Verdict::Deny);
258 }
259 }
260 }
261
262 if !self.deny_patterns.is_empty() {
264 if let Some(content) = extract_content_text(&ctx.request.arguments) {
265 for re in &self.deny_patterns {
266 if re.is_match(&content) {
267 return Ok(Verdict::Deny);
268 }
269 }
270 }
271 }
272
273 if let Some(max_entries) = self.max_memory_entries {
276 let key = (ctx.agent_id.to_string(), ctx.request.capability.id.clone());
277 let count = self.bump_counter(key)?;
278 if count > max_entries {
279 return Ok(Verdict::Deny);
280 }
281 }
282
283 Ok(Verdict::Allow)
284 }
285
286 fn evaluate_read(&self, ctx: &GuardContext, store: &str) -> Result<Verdict, KernelError> {
287 if let Some(allow) = self.effective_store_allowlist(ctx) {
290 if !allow.iter().any(|s| store_matches(s, store)) {
291 return Ok(Verdict::Deny);
292 }
293 }
294 Ok(Verdict::Allow)
295 }
296}
297
298fn store_matches(pattern: &str, store: &str) -> bool {
300 if pattern == "*" {
301 return true;
302 }
303 if let Some(prefix) = pattern.strip_suffix('*') {
304 return store.starts_with(prefix);
305 }
306 pattern == store
307}
308
309fn extract_retention_ttl(arguments: &Value) -> Option<u64> {
311 for key in [
312 "retention_ttl",
313 "retentionTtl",
314 "retention_ttl_secs",
315 "retentionTtlSecs",
316 "ttl",
317 "ttl_secs",
318 "expires_in",
319 "expiresIn",
320 ] {
321 if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
322 return Some(v);
323 }
324 }
325 None
326}
327
328fn extract_content_size_bytes(arguments: &Value) -> Option<u64> {
331 for key in ["content_size", "contentSize", "content_bytes", "size"] {
332 if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
333 return Some(v);
334 }
335 }
336 extract_content_text(arguments).map(|s| s.len() as u64)
337}
338
339fn extract_content_text(arguments: &Value) -> Option<String> {
341 for key in ["content", "text", "value", "vector_text", "payload"] {
342 if let Some(v) = arguments.get(key).and_then(|v| v.as_str()) {
343 if !v.is_empty() {
344 return Some(v.to_string());
345 }
346 }
347 }
348 None
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn store_matches_wildcards() {
357 assert!(store_matches("*", "anything"));
358 assert!(store_matches("agent-*", "agent-notes"));
359 assert!(!store_matches("agent-*", "other"));
360 assert!(store_matches("agent-notes", "agent-notes"));
361 }
362
363 #[test]
364 fn extract_retention_ttl_reads_common_keys() {
365 let args = serde_json::json!({"ttl": 600});
366 assert_eq!(extract_retention_ttl(&args), Some(600));
367 let camel = serde_json::json!({"retentionTtl": 120});
368 assert_eq!(extract_retention_ttl(&camel), Some(120));
369 let none = serde_json::json!({});
370 assert_eq!(extract_retention_ttl(&none), None);
371 }
372
373 #[test]
374 fn content_size_falls_back_to_text_length() {
375 let args = serde_json::json!({"content": "hello"});
376 assert_eq!(extract_content_size_bytes(&args), Some(5));
377 }
378}