Skip to main content

agentic_memory_mcp/session/
workspace.rs

1//! Multi-context workspace manager for loading and querying multiple .amem files.
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use agentic_memory::{AmemReader, MemoryGraph, QueryEngine, TextSearchParams};
8
9use crate::types::{McpError, McpResult};
10
11/// Role of a context within a workspace.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ContextRole {
14    Primary,
15    Secondary,
16    Reference,
17    Archive,
18}
19
20impl ContextRole {
21    pub fn from_str(s: &str) -> Option<Self> {
22        match s.to_lowercase().as_str() {
23            "primary" => Some(Self::Primary),
24            "secondary" => Some(Self::Secondary),
25            "reference" => Some(Self::Reference),
26            "archive" => Some(Self::Archive),
27            _ => None,
28        }
29    }
30
31    pub fn label(&self) -> &'static str {
32        match self {
33            Self::Primary => "primary",
34            Self::Secondary => "secondary",
35            Self::Reference => "reference",
36            Self::Archive => "archive",
37        }
38    }
39}
40
41/// A loaded memory context within a workspace.
42pub struct MemoryContext {
43    pub id: String,
44    pub role: ContextRole,
45    pub path: String,
46    pub label: Option<String>,
47    pub graph: MemoryGraph,
48}
49
50/// A multi-memory workspace.
51pub struct MemoryWorkspace {
52    pub id: String,
53    pub name: String,
54    pub contexts: Vec<MemoryContext>,
55    pub created_at: u64,
56}
57
58/// Result from querying across contexts.
59#[derive(Debug)]
60pub struct CrossContextResult {
61    pub context_id: String,
62    pub context_role: ContextRole,
63    pub matches: Vec<CrossContextMatch>,
64}
65
66/// A single match from cross-context querying.
67#[derive(Debug)]
68pub struct CrossContextMatch {
69    pub node_id: u64,
70    pub content: String,
71    pub event_type: String,
72    pub confidence: f32,
73    pub score: f32,
74}
75
76/// Comparison result across contexts.
77#[derive(Debug)]
78pub struct Comparison {
79    pub item: String,
80    pub found_in: Vec<String>,
81    pub missing_from: Vec<String>,
82    pub matches_per_context: Vec<(String, Vec<CrossContextMatch>)>,
83}
84
85/// Cross-reference result.
86#[derive(Debug)]
87pub struct CrossReference {
88    pub item: String,
89    pub present_in: Vec<String>,
90    pub absent_from: Vec<String>,
91}
92
93/// Manages multiple memory workspaces.
94#[derive(Default)]
95pub struct WorkspaceManager {
96    workspaces: HashMap<String, MemoryWorkspace>,
97    next_id: u64,
98}
99
100impl WorkspaceManager {
101    pub fn new() -> Self {
102        Self {
103            workspaces: HashMap::new(),
104            next_id: 1,
105        }
106    }
107
108    /// Create a new workspace.
109    pub fn create(&mut self, name: &str) -> String {
110        let id = format!("ws_{}", self.next_id);
111        self.next_id += 1;
112
113        let now = SystemTime::now()
114            .duration_since(UNIX_EPOCH)
115            .unwrap_or_default()
116            .as_micros() as u64;
117
118        let workspace = MemoryWorkspace {
119            id: id.clone(),
120            name: name.to_string(),
121            contexts: Vec::new(),
122            created_at: now,
123        };
124
125        self.workspaces.insert(id.clone(), workspace);
126        id
127    }
128
129    /// Add a context to a workspace by loading an .amem file.
130    pub fn add_context(
131        &mut self,
132        workspace_id: &str,
133        path: &str,
134        role: ContextRole,
135        label: Option<String>,
136    ) -> McpResult<String> {
137        let workspace = self
138            .workspaces
139            .get_mut(workspace_id)
140            .ok_or_else(|| McpError::InvalidParams(format!("Workspace not found: {workspace_id}")))?;
141
142        // Load the .amem file
143        let file_path = Path::new(path);
144        if !file_path.exists() {
145            return Err(McpError::InvalidParams(format!("File not found: {path}")));
146        }
147
148        let graph = AmemReader::read_from_file(file_path)
149            .map_err(|e| McpError::AgenticMemory(format!("Failed to parse {path}: {e}")))?;
150
151        let ctx_id = format!("ctx_{}_{}", workspace.contexts.len() + 1, workspace_id);
152
153        let context = MemoryContext {
154            id: ctx_id.clone(),
155            role,
156            path: path.to_string(),
157            label: label.or_else(|| {
158                file_path
159                    .file_stem()
160                    .and_then(|s| s.to_str())
161                    .map(|s| s.to_string())
162            }),
163            graph,
164        };
165
166        workspace.contexts.push(context);
167        Ok(ctx_id)
168    }
169
170    /// List contexts in a workspace.
171    pub fn list(&self, workspace_id: &str) -> McpResult<&[MemoryContext]> {
172        let workspace = self
173            .workspaces
174            .get(workspace_id)
175            .ok_or_else(|| McpError::InvalidParams(format!("Workspace not found: {workspace_id}")))?;
176        Ok(&workspace.contexts)
177    }
178
179    /// Get a workspace reference.
180    pub fn get(&self, workspace_id: &str) -> Option<&MemoryWorkspace> {
181        self.workspaces.get(workspace_id)
182    }
183
184    /// Query across all contexts in a workspace.
185    pub fn query_all(
186        &self,
187        workspace_id: &str,
188        query: &str,
189        max_per_context: usize,
190    ) -> McpResult<Vec<CrossContextResult>> {
191        let workspace = self
192            .workspaces
193            .get(workspace_id)
194            .ok_or_else(|| McpError::InvalidParams(format!("Workspace not found: {workspace_id}")))?;
195
196        let engine = QueryEngine::new();
197        let mut results = Vec::new();
198
199        for ctx in &workspace.contexts {
200            let text_matches = engine
201                .text_search(
202                    &ctx.graph,
203                    ctx.graph.term_index.as_ref(),
204                    ctx.graph.doc_lengths.as_ref(),
205                    TextSearchParams {
206                        query: query.to_string(),
207                        max_results: max_per_context,
208                        event_types: Vec::new(),
209                        session_ids: Vec::new(),
210                        min_score: 0.0,
211                    },
212                )
213                .unwrap_or_default();
214
215            let matches: Vec<CrossContextMatch> = text_matches
216                .iter()
217                .filter_map(|m| {
218                    ctx.graph.get_node(m.node_id).map(|node| CrossContextMatch {
219                        node_id: node.id,
220                        content: node.content.clone(),
221                        event_type: node.event_type.name().to_string(),
222                        confidence: node.confidence,
223                        score: m.score,
224                    })
225                })
226                .collect();
227
228            results.push(CrossContextResult {
229                context_id: ctx.id.clone(),
230                context_role: ctx.role,
231                matches,
232            });
233        }
234
235        Ok(results)
236    }
237
238    /// Compare a topic across all contexts.
239    pub fn compare(
240        &self,
241        workspace_id: &str,
242        item: &str,
243        max_per_context: usize,
244    ) -> McpResult<Comparison> {
245        let results = self.query_all(workspace_id, item, max_per_context)?;
246        let workspace = self.workspaces.get(workspace_id).unwrap();
247
248        let mut found_in = Vec::new();
249        let mut missing_from = Vec::new();
250        let mut matches_per_context = Vec::new();
251
252        for (i, ctx_result) in results.into_iter().enumerate() {
253            let label = workspace.contexts[i]
254                .label
255                .clone()
256                .unwrap_or_else(|| ctx_result.context_id.clone());
257
258            if ctx_result.matches.is_empty() {
259                missing_from.push(label);
260            } else {
261                found_in.push(label.clone());
262                matches_per_context.push((label, ctx_result.matches));
263            }
264        }
265
266        Ok(Comparison {
267            item: item.to_string(),
268            found_in,
269            missing_from,
270            matches_per_context,
271        })
272    }
273
274    /// Cross-reference: find which contexts have/lack a topic.
275    pub fn cross_reference(
276        &self,
277        workspace_id: &str,
278        item: &str,
279    ) -> McpResult<CrossReference> {
280        let comparison = self.compare(workspace_id, item, 5)?;
281        Ok(CrossReference {
282            item: comparison.item,
283            present_in: comparison.found_in,
284            absent_from: comparison.missing_from,
285        })
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_context_role_roundtrip() {
295        assert_eq!(ContextRole::from_str("primary"), Some(ContextRole::Primary));
296        assert_eq!(ContextRole::from_str("SECONDARY"), Some(ContextRole::Secondary));
297        assert_eq!(ContextRole::from_str("reference"), Some(ContextRole::Reference));
298        assert_eq!(ContextRole::from_str("archive"), Some(ContextRole::Archive));
299        assert_eq!(ContextRole::from_str("unknown"), None);
300    }
301
302    #[test]
303    fn test_workspace_create() {
304        let mut mgr = WorkspaceManager::new();
305        let id = mgr.create("test");
306        assert!(id.starts_with("ws_"));
307        assert!(mgr.get(&id).is_some());
308        assert_eq!(mgr.get(&id).unwrap().name, "test");
309    }
310
311    #[test]
312    fn test_workspace_not_found() {
313        let mgr = WorkspaceManager::new();
314        assert!(mgr.list("nonexistent").is_err());
315    }
316
317    #[test]
318    fn test_workspace_file_not_found() {
319        let mut mgr = WorkspaceManager::new();
320        let id = mgr.create("test");
321        let result = mgr.add_context(&id, "/nonexistent.amem", ContextRole::Primary, None);
322        assert!(result.is_err());
323    }
324}