ai_agents_runtime/spawner/
storage.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use ai_agents_core::{AgentSnapshot, AgentStorage, Result};
8
9pub struct NamespacedStorage {
11 inner: Arc<dyn AgentStorage>,
12 prefix: String,
13}
14
15impl NamespacedStorage {
16 pub fn new(inner: Arc<dyn AgentStorage>, prefix: impl Into<String>) -> Self {
17 Self {
18 inner,
19 prefix: prefix.into(),
20 }
21 }
22
23 fn namespaced_key(&self, session_id: &str) -> String {
25 format!("{}/{}", self.prefix, session_id)
26 }
27}
28
29#[async_trait]
30impl AgentStorage for NamespacedStorage {
31 async fn save(&self, session_id: &str, snapshot: &AgentSnapshot) -> Result<()> {
32 self.inner
33 .save(&self.namespaced_key(session_id), snapshot)
34 .await
35 }
36
37 async fn load(&self, session_id: &str) -> Result<Option<AgentSnapshot>> {
38 self.inner.load(&self.namespaced_key(session_id)).await
39 }
40
41 async fn delete(&self, session_id: &str) -> Result<()> {
42 self.inner.delete(&self.namespaced_key(session_id)).await
43 }
44
45 async fn list_sessions(&self) -> Result<Vec<String>> {
46 let all = self.inner.list_sessions().await?;
47 let prefix_slash = format!("{}/", self.prefix);
48 Ok(all
49 .into_iter()
50 .filter_map(|s| s.strip_prefix(&prefix_slash).map(|rest| rest.to_string()))
51 .collect())
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use ai_agents_core::AgentSnapshot;
59 use parking_lot::RwLock;
60 use std::collections::HashMap;
61
62 struct MemStorage {
64 data: RwLock<HashMap<String, AgentSnapshot>>,
65 }
66
67 impl MemStorage {
68 fn new() -> Self {
69 Self {
70 data: RwLock::new(HashMap::new()),
71 }
72 }
73 }
74
75 #[async_trait]
76 impl AgentStorage for MemStorage {
77 async fn save(&self, session_id: &str, snapshot: &AgentSnapshot) -> Result<()> {
78 self.data
79 .write()
80 .insert(session_id.to_string(), snapshot.clone());
81 Ok(())
82 }
83
84 async fn load(&self, session_id: &str) -> Result<Option<AgentSnapshot>> {
85 Ok(self.data.read().get(session_id).cloned())
86 }
87
88 async fn delete(&self, session_id: &str) -> Result<()> {
89 self.data.write().remove(session_id);
90 Ok(())
91 }
92
93 async fn list_sessions(&self) -> Result<Vec<String>> {
94 Ok(self.data.read().keys().cloned().collect())
95 }
96 }
97
98 #[tokio::test]
99 async fn test_namespaced_save_load() {
100 let inner = Arc::new(MemStorage::new());
101 let ns = NamespacedStorage::new(inner.clone(), "agent_1");
102
103 let snapshot = AgentSnapshot::new("agent_1".to_string());
104 ns.save("session_a", &snapshot).await.unwrap();
105
106 assert!(inner.load("agent_1/session_a").await.unwrap().is_some());
108
109 assert!(ns.load("session_a").await.unwrap().is_some());
111 assert!(ns.load("session_b").await.unwrap().is_none());
112 }
113
114 #[tokio::test]
115 async fn test_namespaced_list_sessions() {
116 let inner = Arc::new(MemStorage::new());
117
118 let ns1 = NamespacedStorage::new(inner.clone(), "npc_a");
119 let ns2 = NamespacedStorage::new(inner.clone(), "npc_b");
120
121 ns1.save("s1", &AgentSnapshot::new("npc_a".into()))
122 .await
123 .unwrap();
124 ns1.save("s2", &AgentSnapshot::new("npc_a".into()))
125 .await
126 .unwrap();
127 ns2.save("s1", &AgentSnapshot::new("npc_b".into()))
128 .await
129 .unwrap();
130
131 let mut sessions1 = ns1.list_sessions().await.unwrap();
132 sessions1.sort();
133 assert_eq!(sessions1, vec!["s1", "s2"]);
134
135 let sessions2 = ns2.list_sessions().await.unwrap();
136 assert_eq!(sessions2, vec!["s1"]);
137 }
138
139 #[tokio::test]
140 async fn test_namespaced_delete() {
141 let inner = Arc::new(MemStorage::new());
142 let ns = NamespacedStorage::new(inner.clone(), "agent_x");
143
144 ns.save("sess", &AgentSnapshot::new("agent_x".into()))
145 .await
146 .unwrap();
147 assert!(ns.load("sess").await.unwrap().is_some());
148
149 ns.delete("sess").await.unwrap();
150 assert!(ns.load("sess").await.unwrap().is_none());
151 }
152}