Skip to main content

heartbit_core/memory/
shared_tools.rs

1//! Shared-memory tool definitions for inter-agent institutional memory access.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use chrono::Utc;
8use serde::Deserialize;
9use serde_json::json;
10use uuid::Uuid;
11
12use crate::auth::TenantScope;
13use crate::error::Error;
14use crate::llm::types::ToolDefinition;
15use crate::tool::{Tool, ToolOutput};
16
17use super::{Memory, MemoryEntry, MemoryQuery};
18
19/// Create shared memory tools for cross-agent memory access within the caller's tenant.
20///
21/// - `shared_memory_read`: read memories from any agent's namespace in this tenant
22/// - `shared_memory_write`: write to a shared namespace visible to all agents in this tenant
23///   (only included when `include_write` is `true`)
24pub fn shared_memory_tools(
25    memory: Arc<dyn Memory>,
26    agent_name: &str,
27    scope: TenantScope,
28    include_write: bool,
29) -> Vec<Arc<dyn Tool>> {
30    let mut tools: Vec<Arc<dyn Tool>> = vec![Arc::new(SharedMemoryReadTool {
31        memory: memory.clone(),
32        scope: scope.clone(),
33    })];
34    if include_write {
35        tools.push(Arc::new(SharedMemoryWriteTool {
36            memory,
37            agent_name: agent_name.into(),
38            scope,
39        }));
40    }
41    tools
42}
43
44// --- shared_memory_read ---
45
46struct SharedMemoryReadTool {
47    memory: Arc<dyn Memory>,
48    scope: TenantScope,
49}
50
51#[derive(Deserialize)]
52struct SharedReadInput {
53    #[serde(default)]
54    query: Option<String>,
55    #[serde(default)]
56    agent: Option<String>,
57    #[serde(default)]
58    category: Option<String>,
59    #[serde(default)]
60    tags: Vec<String>,
61    #[serde(default = "super::default_recall_limit")]
62    limit: usize,
63}
64
65impl Tool for SharedMemoryReadTool {
66    fn definition(&self) -> ToolDefinition {
67        ToolDefinition {
68            name: "shared_memory_read".into(),
69            description: "Read memories from any agent's namespace. Use this to access \
70                          knowledge that other agents have stored."
71                .into(),
72            input_schema: json!({
73                "type": "object",
74                "properties": {
75                    "query": {
76                        "type": "string",
77                        "description": "Text to search for"
78                    },
79                    "agent": {
80                        "type": "string",
81                        "description": "Filter by agent name (omit for all agents)"
82                    },
83                    "category": {
84                        "type": "string",
85                        "description": "Filter by category"
86                    },
87                    "tags": {
88                        "type": "array",
89                        "items": {"type": "string"},
90                        "description": "Filter by tags"
91                    },
92                    "limit": {
93                        "type": "integer",
94                        "description": "Max results (default: 10)"
95                    }
96                }
97            }),
98        }
99    }
100
101    fn execute(
102        &self,
103        _ctx: &crate::ExecutionContext,
104        input: serde_json::Value,
105    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
106        Box::pin(async move {
107            let input: SharedReadInput =
108                serde_json::from_value(input).map_err(|e| Error::Memory(e.to_string()))?;
109
110            // SECURITY (F-MEM-2): cap recalled confidentiality at `Internal`.
111            // `Confidentiality::Restricted` is documented as "never in LLM
112            // context" — this tool short-circuits NamespacedMemory's per-agent
113            // cap and would leak Restricted entries cross-agent if any code
114            // path stored them. Cap defensively here so even a malformed
115            // store does not become a leak.
116            let results = self
117                .memory
118                .recall(
119                    &self.scope,
120                    MemoryQuery {
121                        text: input.query,
122                        category: input.category,
123                        tags: input.tags,
124                        agent: input.agent, // None = all agents within this tenant
125                        limit: input.limit,
126                        max_confidentiality: Some(crate::memory::Confidentiality::Internal),
127                        ..Default::default()
128                    },
129                )
130                .await?;
131
132            if results.is_empty() {
133                return Ok(ToolOutput::success("No shared memories found."));
134            }
135
136            let formatted: Vec<String> = results
137                .iter()
138                .map(|e| {
139                    let mt = match e.memory_type {
140                        crate::memory::MemoryType::Episodic => "episodic",
141                        crate::memory::MemoryType::Semantic => "semantic",
142                        crate::memory::MemoryType::Reflection => "reflection",
143                    };
144                    format!(
145                        "- [{}] @{} ({}, {}, importance:{}, strength:{:.2}) {}",
146                        e.id, e.agent, e.category, mt, e.importance, e.strength, e.content,
147                    )
148                })
149                .collect();
150
151            let count = results.len();
152            let noun = if count == 1 { "memory" } else { "memories" };
153            Ok(ToolOutput::success(format!(
154                "Found {count} shared {noun}:\n{}",
155                formatted.join("\n")
156            )))
157        })
158    }
159}
160
161// --- shared_memory_write ---
162
163struct SharedMemoryWriteTool {
164    memory: Arc<dyn Memory>,
165    agent_name: String,
166    scope: TenantScope,
167}
168
169#[derive(Deserialize)]
170struct SharedWriteInput {
171    content: String,
172    #[serde(default = "super::default_category")]
173    category: String,
174    #[serde(default)]
175    tags: Vec<String>,
176    #[serde(default = "super::default_importance")]
177    importance: u8,
178    #[serde(default)]
179    keywords: Vec<String>,
180    #[serde(default)]
181    summary: Option<String>,
182}
183
184impl Tool for SharedMemoryWriteTool {
185    fn definition(&self) -> ToolDefinition {
186        ToolDefinition {
187            name: "shared_memory_write".into(),
188            description: "Write a memory to the shared namespace, visible to all agents. \
189                          Use this to share important findings with other agents."
190                .into(),
191            input_schema: json!({
192                "type": "object",
193                "properties": {
194                    "content": {
195                        "type": "string",
196                        "description": "Content to share"
197                    },
198                    "category": {
199                        "type": "string",
200                        "enum": ["fact", "observation", "preference", "procedure"],
201                        "description": "Category (default: fact)"
202                    },
203                    "tags": {
204                        "type": "array",
205                        "items": {"type": "string"},
206                        "description": "Tags for organization"
207                    },
208                    "importance": {
209                        "type": "integer",
210                        "minimum": 1,
211                        "maximum": 10,
212                        "description": "Importance score 1-10 (default: 5)"
213                    },
214                    "keywords": {
215                        "type": "array",
216                        "items": {"type": "string"},
217                        "description": "Keywords for improved retrieval (BM25 boost)"
218                    },
219                    "summary": {
220                        "type": "string",
221                        "description": "One-sentence summary for context"
222                    }
223                },
224                "required": ["content"]
225            }),
226        }
227    }
228
229    fn execute(
230        &self,
231        _ctx: &crate::ExecutionContext,
232        input: serde_json::Value,
233    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
234        Box::pin(async move {
235            let input: SharedWriteInput =
236                serde_json::from_value(input).map_err(|e| Error::Memory(e.to_string()))?;
237
238            let id = format!("shared:{}", Uuid::new_v4());
239            let now = Utc::now();
240            let entry = MemoryEntry {
241                id: id.clone(),
242                agent: self.agent_name.clone(),
243                content: input.content,
244                category: input.category,
245                tags: input.tags,
246                created_at: now,
247                last_accessed: now,
248                access_count: 0,
249                importance: input.importance.clamp(1, 10),
250                memory_type: crate::memory::MemoryType::default(),
251                keywords: input.keywords,
252                summary: input.summary,
253                strength: 1.0,
254                related_ids: vec![],
255                source_ids: vec![],
256                embedding: None,
257                confidentiality: crate::memory::Confidentiality::default(),
258                author_user_id: None,
259                author_tenant_id: None,
260            };
261
262            self.memory.store(&self.scope, entry).await?;
263            Ok(ToolOutput::success(format!(
264                "Shared memory stored with id: {id}"
265            )))
266        })
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::memory::in_memory::InMemoryStore;
274
275    fn test_scope() -> TenantScope {
276        TenantScope::default()
277    }
278
279    fn setup() -> (Arc<dyn Memory>, Vec<Arc<dyn Tool>>) {
280        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
281        let tools = shared_memory_tools(store.clone(), "agent_a", test_scope(), true);
282        (store, tools)
283    }
284
285    fn find_tool<'a>(tools: &'a [Arc<dyn Tool>], name: &str) -> &'a Arc<dyn Tool> {
286        tools
287            .iter()
288            .find(|t| t.definition().name == name)
289            .unwrap_or_else(|| panic!("tool {name} not found"))
290    }
291
292    #[test]
293    fn creates_two_tools() {
294        let (_store, tools) = setup();
295        assert_eq!(tools.len(), 2);
296        let names: Vec<String> = tools.iter().map(|t| t.definition().name).collect();
297        assert!(names.contains(&"shared_memory_read".to_string()));
298        assert!(names.contains(&"shared_memory_write".to_string()));
299    }
300
301    #[tokio::test]
302    async fn write_and_read_shared_memory() {
303        let (_store, tools) = setup();
304        let write_tool = find_tool(&tools, "shared_memory_write");
305        let read_tool = find_tool(&tools, "shared_memory_read");
306
307        let result = write_tool
308            .execute(
309                &crate::ExecutionContext::default(),
310                json!({
311                    "content": "Important finding",
312                    "category": "fact",
313                    "tags": ["important"]
314                }),
315            )
316            .await
317            .unwrap();
318        assert!(!result.is_error);
319
320        let result = read_tool
321            .execute(&crate::ExecutionContext::default(), json!({}))
322            .await
323            .unwrap();
324        assert!(!result.is_error);
325        assert!(result.content.contains("Important finding"));
326        assert!(result.content.contains("agent_a")); // provenance
327    }
328
329    #[tokio::test]
330    async fn read_empty_shared_memory() {
331        let (_store, tools) = setup();
332        let read_tool = find_tool(&tools, "shared_memory_read");
333
334        let result = read_tool
335            .execute(&crate::ExecutionContext::default(), json!({}))
336            .await
337            .unwrap();
338        assert_eq!(result.content, "No shared memories found.");
339    }
340
341    #[tokio::test]
342    async fn shared_memory_visible_to_all_agents() {
343        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
344        let tools_a = shared_memory_tools(store.clone(), "agent_a", test_scope(), true);
345        let tools_b = shared_memory_tools(store.clone(), "agent_b", test_scope(), true);
346
347        // Agent A writes
348        let write_a = find_tool(&tools_a, "shared_memory_write");
349        write_a
350            .execute(
351                &crate::ExecutionContext::default(),
352                json!({"content": "shared from A"}),
353            )
354            .await
355            .unwrap();
356
357        // Agent B can read it
358        let read_b = find_tool(&tools_b, "shared_memory_read");
359        let result = read_b
360            .execute(&crate::ExecutionContext::default(), json!({}))
361            .await
362            .unwrap();
363        assert!(result.content.contains("shared from A"));
364    }
365
366    #[tokio::test]
367    async fn filter_by_agent() {
368        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
369        let tools_a = shared_memory_tools(store.clone(), "agent_a", test_scope(), true);
370        let tools_b = shared_memory_tools(store.clone(), "agent_b", test_scope(), true);
371
372        let write_a = find_tool(&tools_a, "shared_memory_write");
373        let write_b = find_tool(&tools_b, "shared_memory_write");
374
375        write_a
376            .execute(
377                &crate::ExecutionContext::default(),
378                json!({"content": "data from A"}),
379            )
380            .await
381            .unwrap();
382        write_b
383            .execute(
384                &crate::ExecutionContext::default(),
385                json!({"content": "data from B"}),
386            )
387            .await
388            .unwrap();
389
390        // Filter by agent_a only
391        let read_a = find_tool(&tools_a, "shared_memory_read");
392        let result = read_a
393            .execute(
394                &crate::ExecutionContext::default(),
395                json!({"agent": "agent_a"}),
396            )
397            .await
398            .unwrap();
399        assert!(result.content.contains("data from A"));
400        assert!(!result.content.contains("data from B"));
401    }
402
403    #[tokio::test]
404    async fn write_with_keywords_and_summary() {
405        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
406        let scope = test_scope();
407        let tools = shared_memory_tools(store.clone(), "agent_a", scope.clone(), true);
408        let write_tool = find_tool(&tools, "shared_memory_write");
409
410        write_tool
411            .execute(
412                &crate::ExecutionContext::default(),
413                json!({
414                    "content": "Rust has zero-cost abstractions",
415                    "keywords": ["rust", "performance", "abstractions"],
416                    "summary": "Key Rust language feature"
417                }),
418            )
419            .await
420            .unwrap();
421
422        // Verify keywords and summary were stored
423        let entries = store
424            .recall(
425                &scope,
426                MemoryQuery {
427                    limit: 10,
428                    ..Default::default()
429                },
430            )
431            .await
432            .unwrap();
433        assert_eq!(entries.len(), 1);
434        assert_eq!(
435            entries[0].keywords,
436            vec!["rust", "performance", "abstractions"]
437        );
438        assert_eq!(
439            entries[0].summary.as_deref(),
440            Some("Key Rust language feature")
441        );
442    }
443
444    /// SECURITY (F-MEM-2): `shared_memory_read` MUST cap recall confidentiality
445    /// at `Internal`. Even if a `Restricted` (or `Confidential`) entry was
446    /// somehow stored cross-namespace, the LLM-facing tool must not surface it.
447    #[tokio::test]
448    async fn shared_memory_read_filters_confidential_and_restricted() {
449        use chrono::Utc;
450        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
451
452        // Stash a Confidential entry directly via the store API (bypassing
453        // the tool's own cap, since the read-cap is the boundary we test).
454        let mut entry = MemoryEntry {
455            id: Uuid::new_v4().to_string(),
456            agent: "sensor".into(),
457            content: "secret-token=abc".into(),
458            category: "secret".into(),
459            tags: vec![],
460            created_at: Utc::now(),
461            last_accessed: Utc::now(),
462            access_count: 0,
463            importance: 5,
464            memory_type: crate::memory::MemoryType::default(),
465            keywords: vec![],
466            summary: None,
467            strength: 1.0,
468            related_ids: vec![],
469            source_ids: vec![],
470            embedding: None,
471            confidentiality: crate::memory::Confidentiality::Confidential,
472            author_user_id: None,
473            author_tenant_id: None,
474        };
475        store.store(&test_scope(), entry.clone()).await.unwrap();
476
477        // Also stash a Restricted entry for completeness.
478        entry.id = Uuid::new_v4().to_string();
479        entry.confidentiality = crate::memory::Confidentiality::Restricted;
480        store.store(&test_scope(), entry).await.unwrap();
481
482        let tools = shared_memory_tools(store.clone(), "agent_a", test_scope(), false);
483        let read_tool = find_tool(&tools, "shared_memory_read");
484
485        let result = read_tool
486            .execute(&crate::ExecutionContext::default(), json!({}))
487            .await
488            .unwrap();
489        assert!(
490            !result.content.contains("secret-token"),
491            "shared_memory_read must filter Confidential+Restricted; got: {}",
492            result.content
493        );
494    }
495}