rexis_rag/agent/memory/
shared.rs

1//! Shared knowledge base - cross-agent memory
2//!
3//! Shared knowledge allows multiple agents to read and write to a common memory space.
4//! It's global-scoped and enables agent collaboration and information sharing.
5
6use crate::error::RragResult;
7use crate::storage::{Memory, MemoryValue};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11/// A shared knowledge entry
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct KnowledgeEntry {
14    /// Unique identifier
15    pub id: String,
16
17    /// Key for the knowledge
18    pub key: String,
19
20    /// The value/content
21    pub value: MemoryValue,
22
23    /// Agent that created this entry
24    pub created_by: String,
25
26    /// When it was created
27    pub created_at: chrono::DateTime<chrono::Utc>,
28
29    /// Agent that last updated this entry
30    pub updated_by: String,
31
32    /// When it was last updated
33    pub updated_at: chrono::DateTime<chrono::Utc>,
34
35    /// Tags for categorization
36    pub tags: Vec<String>,
37
38    /// Access control list (agent IDs that can access)
39    pub acl: Option<Vec<String>>,
40
41    /// Optional metadata
42    pub metadata: std::collections::HashMap<String, String>,
43}
44
45impl KnowledgeEntry {
46    /// Create a new knowledge entry
47    pub fn new(
48        key: impl Into<String>,
49        value: impl Into<MemoryValue>,
50        created_by: impl Into<String>,
51    ) -> Self {
52        let now = chrono::Utc::now();
53        let created_by = created_by.into();
54
55        Self {
56            id: uuid::Uuid::new_v4().to_string(),
57            key: key.into(),
58            value: value.into(),
59            created_by: created_by.clone(),
60            created_at: now,
61            updated_by: created_by,
62            updated_at: now,
63            tags: Vec::new(),
64            acl: None,
65            metadata: std::collections::HashMap::new(),
66        }
67    }
68
69    /// Set tags
70    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
71        self.tags = tags;
72        self
73    }
74
75    /// Set access control list
76    pub fn with_acl(mut self, acl: Vec<String>) -> Self {
77        self.acl = Some(acl);
78        self
79    }
80
81    /// Add metadata
82    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
83        self.metadata.insert(key.into(), value.into());
84        self
85    }
86
87    /// Check if an agent has access
88    pub fn has_access(&self, agent_id: &str) -> bool {
89        match &self.acl {
90            None => true, // No ACL means public access
91            Some(acl) => acl.contains(&agent_id.to_string()) || agent_id == self.created_by,
92        }
93    }
94}
95
96/// Shared knowledge base for cross-agent memory
97pub struct SharedKnowledgeBase {
98    /// Storage backend
99    storage: Arc<dyn Memory>,
100
101    /// Agent ID (for tracking who creates/updates entries)
102    agent_id: String,
103
104    /// Namespace (global::knowledge)
105    namespace: String,
106}
107
108impl SharedKnowledgeBase {
109    /// Create a new shared knowledge base
110    pub fn new(storage: Arc<dyn Memory>, agent_id: String) -> Self {
111        Self {
112            storage,
113            agent_id,
114            namespace: "global::knowledge".to_string(),
115        }
116    }
117
118    /// Store a knowledge entry
119    pub async fn store(
120        &self,
121        key: impl Into<String>,
122        value: impl Into<MemoryValue>,
123    ) -> RragResult<KnowledgeEntry> {
124        let entry = KnowledgeEntry::new(key, value, self.agent_id.clone());
125        self.store_entry(entry.clone()).await?;
126        Ok(entry)
127    }
128
129    /// Store a knowledge entry with tags
130    pub async fn store_with_tags(
131        &self,
132        key: impl Into<String>,
133        value: impl Into<MemoryValue>,
134        tags: Vec<String>,
135    ) -> RragResult<KnowledgeEntry> {
136        let entry = KnowledgeEntry::new(key, value, self.agent_id.clone()).with_tags(tags);
137        self.store_entry(entry.clone()).await?;
138        Ok(entry)
139    }
140
141    /// Store a full knowledge entry
142    pub async fn store_entry(&self, mut entry: KnowledgeEntry) -> RragResult<()> {
143        // Update metadata
144        entry.updated_by = self.agent_id.clone();
145        entry.updated_at = chrono::Utc::now();
146
147        let storage_key = self.entry_key(&entry.key);
148        let value = serde_json::to_value(&entry).map_err(|e| {
149            crate::error::RragError::storage(
150                "serialize_entry",
151                std::io::Error::new(std::io::ErrorKind::Other, e),
152            )
153        })?;
154
155        self.storage
156            .set(&storage_key, MemoryValue::Json(value))
157            .await
158    }
159
160    /// Get a knowledge entry
161    pub async fn get(&self, key: &str) -> RragResult<Option<KnowledgeEntry>> {
162        let storage_key = self.entry_key(key);
163        if let Some(value) = self.storage.get(&storage_key).await? {
164            if let Some(json) = value.as_json() {
165                let entry: KnowledgeEntry = serde_json::from_value(json.clone()).map_err(|e| {
166                    crate::error::RragError::storage(
167                        "deserialize_entry",
168                        std::io::Error::new(std::io::ErrorKind::Other, e),
169                    )
170                })?;
171
172                // Check ACL
173                if entry.has_access(&self.agent_id) {
174                    return Ok(Some(entry));
175                }
176            }
177        }
178        Ok(None)
179    }
180
181    /// Get just the value (without metadata)
182    pub async fn get_value(&self, key: &str) -> RragResult<Option<MemoryValue>> {
183        if let Some(entry) = self.get(key).await? {
184            Ok(Some(entry.value))
185        } else {
186            Ok(None)
187        }
188    }
189
190    /// Delete a knowledge entry
191    pub async fn delete(&self, key: &str) -> RragResult<bool> {
192        // Check if the current agent has permission to delete
193        if let Some(entry) = self.get(key).await? {
194            if entry.created_by != self.agent_id {
195                // Only creator can delete (or implement more sophisticated permissions)
196                return Ok(false);
197            }
198        }
199
200        let storage_key = self.entry_key(key);
201        self.storage.delete(&storage_key).await
202    }
203
204    /// Check if a key exists and is accessible
205    pub async fn exists(&self, key: &str) -> RragResult<bool> {
206        Ok(self.get(key).await?.is_some())
207    }
208
209    /// Find entries by tag
210    pub async fn find_by_tag(&self, tag: &str) -> RragResult<Vec<KnowledgeEntry>> {
211        let all_entries = self.get_all_entries().await?;
212
213        let matching = all_entries
214            .into_iter()
215            .filter(|e| e.has_access(&self.agent_id) && e.tags.contains(&tag.to_string()))
216            .collect();
217
218        Ok(matching)
219    }
220
221    /// Find entries created by a specific agent
222    pub async fn find_by_creator(&self, creator_agent_id: &str) -> RragResult<Vec<KnowledgeEntry>> {
223        let all_entries = self.get_all_entries().await?;
224
225        let matching = all_entries
226            .into_iter()
227            .filter(|e| e.has_access(&self.agent_id) && e.created_by == creator_agent_id)
228            .collect();
229
230        Ok(matching)
231    }
232
233    /// Get all accessible entries
234    pub async fn get_all_entries(&self) -> RragResult<Vec<KnowledgeEntry>> {
235        let all_keys = self.list_entry_keys().await?;
236        let mut entries = Vec::new();
237
238        for key in all_keys {
239            if let Some(entry) = self.get(&key).await? {
240                entries.push(entry);
241            }
242        }
243
244        Ok(entries)
245    }
246
247    /// Count accessible entries
248    pub async fn count(&self) -> RragResult<usize> {
249        // This counts all entries; for accurate count, filter by ACL
250        self.storage.count(Some(&self.namespace)).await
251    }
252
253    /// Clear all entries (requires appropriate permissions)
254    pub async fn clear(&self) -> RragResult<()> {
255        self.storage.clear(Some(&self.namespace)).await
256    }
257
258    /// Generate entry key
259    fn entry_key(&self, key: &str) -> String {
260        format!("{}::{}", self.namespace, key)
261    }
262
263    /// List all entry keys
264    async fn list_entry_keys(&self) -> RragResult<Vec<String>> {
265        use crate::storage::MemoryQuery;
266
267        let query = MemoryQuery::new().with_namespace(self.namespace.clone());
268        let all_keys = self.storage.keys(&query).await?;
269
270        // Extract entry keys
271        let prefix = format!("{}::", self.namespace);
272        let keys = all_keys
273            .into_iter()
274            .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
275            .collect();
276
277        Ok(keys)
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::storage::InMemoryStorage;
285
286    #[tokio::test]
287    async fn test_shared_knowledge_store_and_retrieve() {
288        let storage = Arc::new(InMemoryStorage::new());
289        let kb = SharedKnowledgeBase::new(storage, "agent1".to_string());
290
291        // Store and retrieve
292        kb.store("api_key", MemoryValue::from("secret123"))
293            .await
294            .unwrap();
295
296        let value = kb.get_value("api_key").await.unwrap().unwrap();
297        assert_eq!(value.as_string(), Some("secret123"));
298    }
299
300    #[tokio::test]
301    async fn test_shared_knowledge_cross_agent_access() {
302        let storage = Arc::new(InMemoryStorage::new());
303        let kb1 = SharedKnowledgeBase::new(storage.clone(), "agent1".to_string());
304        let kb2 = SharedKnowledgeBase::new(storage.clone(), "agent2".to_string());
305
306        // Agent1 stores data
307        kb1.store("shared_config", MemoryValue::from("config_value"))
308            .await
309            .unwrap();
310
311        // Agent2 can access it
312        let value = kb2.get_value("shared_config").await.unwrap().unwrap();
313        assert_eq!(value.as_string(), Some("config_value"));
314    }
315
316    #[tokio::test]
317    async fn test_shared_knowledge_with_acl() {
318        let storage = Arc::new(InMemoryStorage::new());
319        let kb1 = SharedKnowledgeBase::new(storage.clone(), "agent1".to_string());
320        let kb2 = SharedKnowledgeBase::new(storage.clone(), "agent2".to_string());
321        let kb3 = SharedKnowledgeBase::new(storage.clone(), "agent3".to_string());
322
323        // Agent1 stores with ACL (only agent1 and agent2)
324        let entry = KnowledgeEntry::new("private_data", MemoryValue::from("secret"), "agent1")
325            .with_acl(vec!["agent1".to_string(), "agent2".to_string()]);
326
327        kb1.store_entry(entry).await.unwrap();
328
329        // Agent2 can access
330        assert!(kb2.get("private_data").await.unwrap().is_some());
331
332        // Agent3 cannot access
333        assert!(kb3.get("private_data").await.unwrap().is_none());
334    }
335
336    #[tokio::test]
337    async fn test_shared_knowledge_with_tags() {
338        let storage = Arc::new(InMemoryStorage::new());
339        let kb = SharedKnowledgeBase::new(storage, "agent1".to_string());
340
341        // Store with tags
342        kb.store_with_tags(
343            "config1",
344            MemoryValue::from("value1"),
345            vec!["config".to_string(), "production".to_string()],
346        )
347        .await
348        .unwrap();
349
350        kb.store_with_tags(
351            "config2",
352            MemoryValue::from("value2"),
353            vec!["config".to_string(), "development".to_string()],
354        )
355        .await
356        .unwrap();
357
358        // Find by tag
359        let config_entries = kb.find_by_tag("config").await.unwrap();
360        assert_eq!(config_entries.len(), 2);
361
362        let prod_entries = kb.find_by_tag("production").await.unwrap();
363        assert_eq!(prod_entries.len(), 1);
364    }
365
366    #[tokio::test]
367    async fn test_shared_knowledge_delete_permissions() {
368        let storage = Arc::new(InMemoryStorage::new());
369        let kb1 = SharedKnowledgeBase::new(storage.clone(), "agent1".to_string());
370        let kb2 = SharedKnowledgeBase::new(storage.clone(), "agent2".to_string());
371
372        // Agent1 creates entry
373        kb1.store("data", MemoryValue::from("value")).await.unwrap();
374
375        // Agent2 cannot delete (not creator)
376        let deleted = kb2.delete("data").await.unwrap();
377        assert!(!deleted);
378
379        // Agent1 can delete
380        let deleted = kb1.delete("data").await.unwrap();
381        assert!(deleted);
382    }
383}