1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use chrono::Utc;
8use serde::Deserialize;
9use serde_json::json;
10use uuid::Uuid;
11
12use crate::auth::TenantScope;
13use crate::error::Error;
14use crate::llm::types::ToolDefinition;
15use crate::tool::{Tool, ToolOutput};
16
17use super::{Memory, MemoryEntry, MemoryQuery};
18
19pub fn shared_memory_tools(
25 memory: Arc<dyn Memory>,
26 agent_name: &str,
27 scope: TenantScope,
28 include_write: bool,
29) -> Vec<Arc<dyn Tool>> {
30 let mut tools: Vec<Arc<dyn Tool>> = vec![Arc::new(SharedMemoryReadTool {
31 memory: memory.clone(),
32 scope: scope.clone(),
33 })];
34 if include_write {
35 tools.push(Arc::new(SharedMemoryWriteTool {
36 memory,
37 agent_name: agent_name.into(),
38 scope,
39 }));
40 }
41 tools
42}
43
44struct SharedMemoryReadTool {
47 memory: Arc<dyn Memory>,
48 scope: TenantScope,
49}
50
51#[derive(Deserialize)]
52struct SharedReadInput {
53 #[serde(default)]
54 query: Option<String>,
55 #[serde(default)]
56 agent: Option<String>,
57 #[serde(default)]
58 category: Option<String>,
59 #[serde(default)]
60 tags: Vec<String>,
61 #[serde(default = "super::default_recall_limit")]
62 limit: usize,
63}
64
65impl Tool for SharedMemoryReadTool {
66 fn definition(&self) -> ToolDefinition {
67 ToolDefinition {
68 name: "shared_memory_read".into(),
69 description: "Read memories from any agent's namespace. Use this to access \
70 knowledge that other agents have stored."
71 .into(),
72 input_schema: json!({
73 "type": "object",
74 "properties": {
75 "query": {
76 "type": "string",
77 "description": "Text to search for"
78 },
79 "agent": {
80 "type": "string",
81 "description": "Filter by agent name (omit for all agents)"
82 },
83 "category": {
84 "type": "string",
85 "description": "Filter by category"
86 },
87 "tags": {
88 "type": "array",
89 "items": {"type": "string"},
90 "description": "Filter by tags"
91 },
92 "limit": {
93 "type": "integer",
94 "description": "Max results (default: 10)"
95 }
96 }
97 }),
98 }
99 }
100
101 fn execute(
102 &self,
103 _ctx: &crate::ExecutionContext,
104 input: serde_json::Value,
105 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
106 Box::pin(async move {
107 let input: SharedReadInput =
108 serde_json::from_value(input).map_err(|e| Error::Memory(e.to_string()))?;
109
110 let results = self
117 .memory
118 .recall(
119 &self.scope,
120 MemoryQuery {
121 text: input.query,
122 category: input.category,
123 tags: input.tags,
124 agent: input.agent, limit: input.limit,
126 max_confidentiality: Some(crate::memory::Confidentiality::Internal),
127 ..Default::default()
128 },
129 )
130 .await?;
131
132 if results.is_empty() {
133 return Ok(ToolOutput::success("No shared memories found."));
134 }
135
136 let formatted: Vec<String> = results
137 .iter()
138 .map(|e| {
139 let mt = match e.memory_type {
140 crate::memory::MemoryType::Episodic => "episodic",
141 crate::memory::MemoryType::Semantic => "semantic",
142 crate::memory::MemoryType::Reflection => "reflection",
143 };
144 format!(
145 "- [{}] @{} ({}, {}, importance:{}, strength:{:.2}) {}",
146 e.id, e.agent, e.category, mt, e.importance, e.strength, e.content,
147 )
148 })
149 .collect();
150
151 let count = results.len();
152 let noun = if count == 1 { "memory" } else { "memories" };
153 Ok(ToolOutput::success(format!(
154 "Found {count} shared {noun}:\n{}",
155 formatted.join("\n")
156 )))
157 })
158 }
159}
160
161struct SharedMemoryWriteTool {
164 memory: Arc<dyn Memory>,
165 agent_name: String,
166 scope: TenantScope,
167}
168
169#[derive(Deserialize)]
170struct SharedWriteInput {
171 content: String,
172 #[serde(default = "super::default_category")]
173 category: String,
174 #[serde(default)]
175 tags: Vec<String>,
176 #[serde(default = "super::default_importance")]
177 importance: u8,
178 #[serde(default)]
179 keywords: Vec<String>,
180 #[serde(default)]
181 summary: Option<String>,
182}
183
184impl Tool for SharedMemoryWriteTool {
185 fn definition(&self) -> ToolDefinition {
186 ToolDefinition {
187 name: "shared_memory_write".into(),
188 description: "Write a memory to the shared namespace, visible to all agents. \
189 Use this to share important findings with other agents."
190 .into(),
191 input_schema: json!({
192 "type": "object",
193 "properties": {
194 "content": {
195 "type": "string",
196 "description": "Content to share"
197 },
198 "category": {
199 "type": "string",
200 "enum": ["fact", "observation", "preference", "procedure"],
201 "description": "Category (default: fact)"
202 },
203 "tags": {
204 "type": "array",
205 "items": {"type": "string"},
206 "description": "Tags for organization"
207 },
208 "importance": {
209 "type": "integer",
210 "minimum": 1,
211 "maximum": 10,
212 "description": "Importance score 1-10 (default: 5)"
213 },
214 "keywords": {
215 "type": "array",
216 "items": {"type": "string"},
217 "description": "Keywords for improved retrieval (BM25 boost)"
218 },
219 "summary": {
220 "type": "string",
221 "description": "One-sentence summary for context"
222 }
223 },
224 "required": ["content"]
225 }),
226 }
227 }
228
229 fn execute(
230 &self,
231 _ctx: &crate::ExecutionContext,
232 input: serde_json::Value,
233 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
234 Box::pin(async move {
235 let input: SharedWriteInput =
236 serde_json::from_value(input).map_err(|e| Error::Memory(e.to_string()))?;
237
238 let id = format!("shared:{}", Uuid::new_v4());
239 let now = Utc::now();
240 let entry = MemoryEntry {
241 id: id.clone(),
242 agent: self.agent_name.clone(),
243 content: input.content,
244 category: input.category,
245 tags: input.tags,
246 created_at: now,
247 last_accessed: now,
248 access_count: 0,
249 importance: input.importance.clamp(1, 10),
250 memory_type: crate::memory::MemoryType::default(),
251 keywords: input.keywords,
252 summary: input.summary,
253 strength: 1.0,
254 related_ids: vec![],
255 source_ids: vec![],
256 embedding: None,
257 confidentiality: crate::memory::Confidentiality::default(),
258 author_user_id: None,
259 author_tenant_id: None,
260 };
261
262 self.memory.store(&self.scope, entry).await?;
263 Ok(ToolOutput::success(format!(
264 "Shared memory stored with id: {id}"
265 )))
266 })
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::memory::in_memory::InMemoryStore;
274
275 fn test_scope() -> TenantScope {
276 TenantScope::default()
277 }
278
279 fn setup() -> (Arc<dyn Memory>, Vec<Arc<dyn Tool>>) {
280 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
281 let tools = shared_memory_tools(store.clone(), "agent_a", test_scope(), true);
282 (store, tools)
283 }
284
285 fn find_tool<'a>(tools: &'a [Arc<dyn Tool>], name: &str) -> &'a Arc<dyn Tool> {
286 tools
287 .iter()
288 .find(|t| t.definition().name == name)
289 .unwrap_or_else(|| panic!("tool {name} not found"))
290 }
291
292 #[test]
293 fn creates_two_tools() {
294 let (_store, tools) = setup();
295 assert_eq!(tools.len(), 2);
296 let names: Vec<String> = tools.iter().map(|t| t.definition().name).collect();
297 assert!(names.contains(&"shared_memory_read".to_string()));
298 assert!(names.contains(&"shared_memory_write".to_string()));
299 }
300
301 #[tokio::test]
302 async fn write_and_read_shared_memory() {
303 let (_store, tools) = setup();
304 let write_tool = find_tool(&tools, "shared_memory_write");
305 let read_tool = find_tool(&tools, "shared_memory_read");
306
307 let result = write_tool
308 .execute(
309 &crate::ExecutionContext::default(),
310 json!({
311 "content": "Important finding",
312 "category": "fact",
313 "tags": ["important"]
314 }),
315 )
316 .await
317 .unwrap();
318 assert!(!result.is_error);
319
320 let result = read_tool
321 .execute(&crate::ExecutionContext::default(), json!({}))
322 .await
323 .unwrap();
324 assert!(!result.is_error);
325 assert!(result.content.contains("Important finding"));
326 assert!(result.content.contains("agent_a")); }
328
329 #[tokio::test]
330 async fn read_empty_shared_memory() {
331 let (_store, tools) = setup();
332 let read_tool = find_tool(&tools, "shared_memory_read");
333
334 let result = read_tool
335 .execute(&crate::ExecutionContext::default(), json!({}))
336 .await
337 .unwrap();
338 assert_eq!(result.content, "No shared memories found.");
339 }
340
341 #[tokio::test]
342 async fn shared_memory_visible_to_all_agents() {
343 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
344 let tools_a = shared_memory_tools(store.clone(), "agent_a", test_scope(), true);
345 let tools_b = shared_memory_tools(store.clone(), "agent_b", test_scope(), true);
346
347 let write_a = find_tool(&tools_a, "shared_memory_write");
349 write_a
350 .execute(
351 &crate::ExecutionContext::default(),
352 json!({"content": "shared from A"}),
353 )
354 .await
355 .unwrap();
356
357 let read_b = find_tool(&tools_b, "shared_memory_read");
359 let result = read_b
360 .execute(&crate::ExecutionContext::default(), json!({}))
361 .await
362 .unwrap();
363 assert!(result.content.contains("shared from A"));
364 }
365
366 #[tokio::test]
367 async fn filter_by_agent() {
368 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
369 let tools_a = shared_memory_tools(store.clone(), "agent_a", test_scope(), true);
370 let tools_b = shared_memory_tools(store.clone(), "agent_b", test_scope(), true);
371
372 let write_a = find_tool(&tools_a, "shared_memory_write");
373 let write_b = find_tool(&tools_b, "shared_memory_write");
374
375 write_a
376 .execute(
377 &crate::ExecutionContext::default(),
378 json!({"content": "data from A"}),
379 )
380 .await
381 .unwrap();
382 write_b
383 .execute(
384 &crate::ExecutionContext::default(),
385 json!({"content": "data from B"}),
386 )
387 .await
388 .unwrap();
389
390 let read_a = find_tool(&tools_a, "shared_memory_read");
392 let result = read_a
393 .execute(
394 &crate::ExecutionContext::default(),
395 json!({"agent": "agent_a"}),
396 )
397 .await
398 .unwrap();
399 assert!(result.content.contains("data from A"));
400 assert!(!result.content.contains("data from B"));
401 }
402
403 #[tokio::test]
404 async fn write_with_keywords_and_summary() {
405 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
406 let scope = test_scope();
407 let tools = shared_memory_tools(store.clone(), "agent_a", scope.clone(), true);
408 let write_tool = find_tool(&tools, "shared_memory_write");
409
410 write_tool
411 .execute(
412 &crate::ExecutionContext::default(),
413 json!({
414 "content": "Rust has zero-cost abstractions",
415 "keywords": ["rust", "performance", "abstractions"],
416 "summary": "Key Rust language feature"
417 }),
418 )
419 .await
420 .unwrap();
421
422 let entries = store
424 .recall(
425 &scope,
426 MemoryQuery {
427 limit: 10,
428 ..Default::default()
429 },
430 )
431 .await
432 .unwrap();
433 assert_eq!(entries.len(), 1);
434 assert_eq!(
435 entries[0].keywords,
436 vec!["rust", "performance", "abstractions"]
437 );
438 assert_eq!(
439 entries[0].summary.as_deref(),
440 Some("Key Rust language feature")
441 );
442 }
443
444 #[tokio::test]
448 async fn shared_memory_read_filters_confidential_and_restricted() {
449 use chrono::Utc;
450 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
451
452 let mut entry = MemoryEntry {
455 id: Uuid::new_v4().to_string(),
456 agent: "sensor".into(),
457 content: "secret-token=abc".into(),
458 category: "secret".into(),
459 tags: vec![],
460 created_at: Utc::now(),
461 last_accessed: Utc::now(),
462 access_count: 0,
463 importance: 5,
464 memory_type: crate::memory::MemoryType::default(),
465 keywords: vec![],
466 summary: None,
467 strength: 1.0,
468 related_ids: vec![],
469 source_ids: vec![],
470 embedding: None,
471 confidentiality: crate::memory::Confidentiality::Confidential,
472 author_user_id: None,
473 author_tenant_id: None,
474 };
475 store.store(&test_scope(), entry.clone()).await.unwrap();
476
477 entry.id = Uuid::new_v4().to_string();
479 entry.confidentiality = crate::memory::Confidentiality::Restricted;
480 store.store(&test_scope(), entry).await.unwrap();
481
482 let tools = shared_memory_tools(store.clone(), "agent_a", test_scope(), false);
483 let read_tool = find_tool(&tools, "shared_memory_read");
484
485 let result = read_tool
486 .execute(&crate::ExecutionContext::default(), json!({}))
487 .await
488 .unwrap();
489 assert!(
490 !result.content.contains("secret-token"),
491 "shared_memory_read must filter Confidential+Restricted; got: {}",
492 result.content
493 );
494 }
495}