use std::collections::HashMap;
use std::sync::Mutex;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use chio_core::capability::Constraint;
use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
use crate::action::{extract_action, ToolAction};
#[derive(Debug, thiserror::Error)]
pub enum MemoryGovernanceError {
#[error("invalid deny pattern `{pattern}`: {source}")]
InvalidPattern {
pattern: String,
#[source]
source: regex::Error,
},
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct MemoryGovernanceConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub store_allowlist: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_memory_entries: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_retention_ttl_secs: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_content_size_bytes: Option<u64>,
#[serde(default)]
pub deny_patterns: Vec<String>,
}
fn default_true() -> bool {
true
}
impl Default for MemoryGovernanceConfig {
fn default() -> Self {
Self {
enabled: true,
store_allowlist: Vec::new(),
max_memory_entries: None,
max_retention_ttl_secs: None,
max_content_size_bytes: None,
deny_patterns: Vec::new(),
}
}
}
type SessionKey = (String, String);
pub struct MemoryGovernanceGuard {
enabled: bool,
store_allowlist: Vec<String>,
max_memory_entries: Option<u64>,
max_retention_ttl_secs: Option<u64>,
max_content_size_bytes: Option<u64>,
deny_patterns: Vec<Regex>,
counters: Mutex<HashMap<SessionKey, u64>>,
}
impl MemoryGovernanceGuard {
pub fn new() -> Self {
Self::with_config(MemoryGovernanceConfig::default()).unwrap_or_else(|_| Self {
enabled: true,
store_allowlist: Vec::new(),
max_memory_entries: None,
max_retention_ttl_secs: None,
max_content_size_bytes: None,
deny_patterns: Vec::new(),
counters: Mutex::new(HashMap::new()),
})
}
pub fn with_config(config: MemoryGovernanceConfig) -> Result<Self, MemoryGovernanceError> {
let mut deny_patterns = Vec::with_capacity(config.deny_patterns.len());
for pat in &config.deny_patterns {
let re = Regex::new(pat).map_err(|e| MemoryGovernanceError::InvalidPattern {
pattern: pat.clone(),
source: e,
})?;
deny_patterns.push(re);
}
Ok(Self {
enabled: config.enabled,
store_allowlist: config.store_allowlist,
max_memory_entries: config.max_memory_entries,
max_retention_ttl_secs: config.max_retention_ttl_secs,
max_content_size_bytes: config.max_content_size_bytes,
deny_patterns,
counters: Mutex::new(HashMap::new()),
})
}
pub fn session_count(&self, agent_id: &str, capability_id: &str) -> u64 {
self.counters
.lock()
.ok()
.and_then(|g| {
g.get(&(agent_id.to_string(), capability_id.to_string()))
.copied()
})
.unwrap_or(0)
}
fn effective_store_allowlist<'a>(&'a self, ctx: &'a GuardContext<'a>) -> Option<Vec<String>> {
let mut combined: Vec<String> = self.store_allowlist.clone();
if let Some(grant) = ctx
.matched_grant_index
.and_then(|i| ctx.scope.grants.get(i))
{
for c in &grant.constraints {
if let Constraint::MemoryStoreAllowlist(list) = c {
combined.extend(list.iter().cloned());
}
}
}
if combined.is_empty() {
None
} else {
Some(combined)
}
}
fn bump_counter(&self, key: SessionKey) -> Result<u64, KernelError> {
let mut guard = self.counters.lock().map_err(|_| {
KernelError::Internal("memory-governance guard counter mutex poisoned".to_string())
})?;
let entry = guard.entry(key).or_insert(0);
*entry = entry.saturating_add(1);
Ok(*entry)
}
}
impl Default for MemoryGovernanceGuard {
fn default() -> Self {
Self::new()
}
}
impl Guard for MemoryGovernanceGuard {
fn name(&self) -> &str {
"memory-governance"
}
fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
if !self.enabled {
return Ok(Verdict::Allow);
}
let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
match action {
ToolAction::MemoryWrite { store, .. } => self.evaluate_write(ctx, &store),
ToolAction::MemoryRead { store, .. } => self.evaluate_read(ctx, &store),
_ => Ok(Verdict::Allow),
}
}
}
impl MemoryGovernanceGuard {
fn evaluate_write(&self, ctx: &GuardContext, store: &str) -> Result<Verdict, KernelError> {
if let Some(allow) = self.effective_store_allowlist(ctx) {
if !allow.iter().any(|s| store_matches(s, store)) {
return Ok(Verdict::Deny);
}
}
if let Some(max_ttl) = self.max_retention_ttl_secs {
let requested = extract_retention_ttl(&ctx.request.arguments);
match requested {
None => {
return Ok(Verdict::Deny);
}
Some(ttl) if ttl > max_ttl => {
return Ok(Verdict::Deny);
}
Some(_) => {}
}
}
if let Some(max_bytes) = self.max_content_size_bytes {
if let Some(size) = extract_content_size_bytes(&ctx.request.arguments) {
if size > max_bytes {
return Ok(Verdict::Deny);
}
}
}
if !self.deny_patterns.is_empty() {
if let Some(content) = extract_content_text(&ctx.request.arguments) {
for re in &self.deny_patterns {
if re.is_match(&content) {
return Ok(Verdict::Deny);
}
}
}
}
if let Some(max_entries) = self.max_memory_entries {
let key = (ctx.agent_id.to_string(), ctx.request.capability.id.clone());
let count = self.bump_counter(key)?;
if count > max_entries {
return Ok(Verdict::Deny);
}
}
Ok(Verdict::Allow)
}
fn evaluate_read(&self, ctx: &GuardContext, store: &str) -> Result<Verdict, KernelError> {
if let Some(allow) = self.effective_store_allowlist(ctx) {
if !allow.iter().any(|s| store_matches(s, store)) {
return Ok(Verdict::Deny);
}
}
Ok(Verdict::Allow)
}
}
fn store_matches(pattern: &str, store: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
return store.starts_with(prefix);
}
pattern == store
}
fn extract_retention_ttl(arguments: &Value) -> Option<u64> {
for key in [
"retention_ttl",
"retentionTtl",
"retention_ttl_secs",
"retentionTtlSecs",
"ttl",
"ttl_secs",
"expires_in",
"expiresIn",
] {
if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
return Some(v);
}
}
None
}
fn extract_content_size_bytes(arguments: &Value) -> Option<u64> {
for key in ["content_size", "contentSize", "content_bytes", "size"] {
if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
return Some(v);
}
}
extract_content_text(arguments).map(|s| s.len() as u64)
}
fn extract_content_text(arguments: &Value) -> Option<String> {
for key in ["content", "text", "value", "vector_text", "payload"] {
if let Some(v) = arguments.get(key).and_then(|v| v.as_str()) {
if !v.is_empty() {
return Some(v.to_string());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn store_matches_wildcards() {
assert!(store_matches("*", "anything"));
assert!(store_matches("agent-*", "agent-notes"));
assert!(!store_matches("agent-*", "other"));
assert!(store_matches("agent-notes", "agent-notes"));
}
#[test]
fn extract_retention_ttl_reads_common_keys() {
let args = serde_json::json!({"ttl": 600});
assert_eq!(extract_retention_ttl(&args), Some(600));
let camel = serde_json::json!({"retentionTtl": 120});
assert_eq!(extract_retention_ttl(&camel), Some(120));
let none = serde_json::json!({});
assert_eq!(extract_retention_ttl(&none), None);
}
#[test]
fn content_size_falls_back_to_text_length() {
let args = serde_json::json!({"content": "hello"});
assert_eq!(extract_content_size_bytes(&args), Some(5));
}
}