Skip to main content

ai_agents_runtime/spawner/
storage.rs

1//! Storage adapter that namespaces session keys with an agent-specific prefix.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use ai_agents_core::{AgentSnapshot, AgentStorage, Result};
8
9/// Wraps a shared `AgentStorage` backend and transparently prepends `{prefix}/` to every session key, isolating each agent's data.
10pub struct NamespacedStorage {
11    inner: Arc<dyn AgentStorage>,
12    prefix: String,
13}
14
15impl NamespacedStorage {
16    pub fn new(inner: Arc<dyn AgentStorage>, prefix: impl Into<String>) -> Self {
17        Self {
18            inner,
19            prefix: prefix.into(),
20        }
21    }
22
23    /// Build the namespaced key: `"{prefix}/{session_id}"`.
24    fn namespaced_key(&self, session_id: &str) -> String {
25        format!("{}/{}", self.prefix, session_id)
26    }
27}
28
29#[async_trait]
30impl AgentStorage for NamespacedStorage {
31    async fn save(&self, session_id: &str, snapshot: &AgentSnapshot) -> Result<()> {
32        self.inner
33            .save(&self.namespaced_key(session_id), snapshot)
34            .await
35    }
36
37    async fn load(&self, session_id: &str) -> Result<Option<AgentSnapshot>> {
38        self.inner.load(&self.namespaced_key(session_id)).await
39    }
40
41    async fn delete(&self, session_id: &str) -> Result<()> {
42        self.inner.delete(&self.namespaced_key(session_id)).await
43    }
44
45    async fn list_sessions(&self) -> Result<Vec<String>> {
46        let all = self.inner.list_sessions().await?;
47        let prefix_slash = format!("{}/", self.prefix);
48        Ok(all
49            .into_iter()
50            .filter_map(|s| s.strip_prefix(&prefix_slash).map(|rest| rest.to_string()))
51            .collect())
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use ai_agents_core::AgentSnapshot;
59    use parking_lot::RwLock;
60    use std::collections::HashMap;
61
62    /// Minimal in-memory storage for testing.
63    struct MemStorage {
64        data: RwLock<HashMap<String, AgentSnapshot>>,
65    }
66
67    impl MemStorage {
68        fn new() -> Self {
69            Self {
70                data: RwLock::new(HashMap::new()),
71            }
72        }
73    }
74
75    #[async_trait]
76    impl AgentStorage for MemStorage {
77        async fn save(&self, session_id: &str, snapshot: &AgentSnapshot) -> Result<()> {
78            self.data
79                .write()
80                .insert(session_id.to_string(), snapshot.clone());
81            Ok(())
82        }
83
84        async fn load(&self, session_id: &str) -> Result<Option<AgentSnapshot>> {
85            Ok(self.data.read().get(session_id).cloned())
86        }
87
88        async fn delete(&self, session_id: &str) -> Result<()> {
89            self.data.write().remove(session_id);
90            Ok(())
91        }
92
93        async fn list_sessions(&self) -> Result<Vec<String>> {
94            Ok(self.data.read().keys().cloned().collect())
95        }
96    }
97
98    #[tokio::test]
99    async fn test_namespaced_save_load() {
100        let inner = Arc::new(MemStorage::new());
101        let ns = NamespacedStorage::new(inner.clone(), "agent_1");
102
103        let snapshot = AgentSnapshot::new("agent_1".to_string());
104        ns.save("session_a", &snapshot).await.unwrap();
105
106        // Underlying storage should have the prefixed key
107        assert!(inner.load("agent_1/session_a").await.unwrap().is_some());
108
109        // Namespaced load should work with the unprefixed key
110        assert!(ns.load("session_a").await.unwrap().is_some());
111        assert!(ns.load("session_b").await.unwrap().is_none());
112    }
113
114    #[tokio::test]
115    async fn test_namespaced_list_sessions() {
116        let inner = Arc::new(MemStorage::new());
117
118        let ns1 = NamespacedStorage::new(inner.clone(), "npc_a");
119        let ns2 = NamespacedStorage::new(inner.clone(), "npc_b");
120
121        ns1.save("s1", &AgentSnapshot::new("npc_a".into()))
122            .await
123            .unwrap();
124        ns1.save("s2", &AgentSnapshot::new("npc_a".into()))
125            .await
126            .unwrap();
127        ns2.save("s1", &AgentSnapshot::new("npc_b".into()))
128            .await
129            .unwrap();
130
131        let mut sessions1 = ns1.list_sessions().await.unwrap();
132        sessions1.sort();
133        assert_eq!(sessions1, vec!["s1", "s2"]);
134
135        let sessions2 = ns2.list_sessions().await.unwrap();
136        assert_eq!(sessions2, vec!["s1"]);
137    }
138
139    #[tokio::test]
140    async fn test_namespaced_delete() {
141        let inner = Arc::new(MemStorage::new());
142        let ns = NamespacedStorage::new(inner.clone(), "agent_x");
143
144        ns.save("sess", &AgentSnapshot::new("agent_x".into()))
145            .await
146            .unwrap();
147        assert!(ns.load("sess").await.unwrap().is_some());
148
149        ns.delete("sess").await.unwrap();
150        assert!(ns.load("sess").await.unwrap().is_none());
151    }
152}