Skip to main content

agentic_memory_mcp/session/
workspace.rs

1//! Multi-context workspace manager for loading and querying multiple .amem files.
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use agentic_memory::{AmemReader, MemoryGraph, QueryEngine, TextSearchParams};
8
9use crate::types::{McpError, McpResult};
10
11/// Role of a context within a workspace.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ContextRole {
14    Primary,
15    Secondary,
16    Reference,
17    Archive,
18}
19
20impl ContextRole {
21    pub fn parse_str(s: &str) -> Option<Self> {
22        match s.to_lowercase().as_str() {
23            "primary" => Some(Self::Primary),
24            "secondary" => Some(Self::Secondary),
25            "reference" => Some(Self::Reference),
26            "archive" => Some(Self::Archive),
27            _ => None,
28        }
29    }
30
31    pub fn label(&self) -> &'static str {
32        match self {
33            Self::Primary => "primary",
34            Self::Secondary => "secondary",
35            Self::Reference => "reference",
36            Self::Archive => "archive",
37        }
38    }
39}
40
41/// A loaded memory context within a workspace.
42pub struct MemoryContext {
43    pub id: String,
44    pub role: ContextRole,
45    pub path: String,
46    pub label: Option<String>,
47    pub graph: MemoryGraph,
48}
49
50/// A multi-memory workspace.
51pub struct MemoryWorkspace {
52    pub id: String,
53    pub name: String,
54    pub contexts: Vec<MemoryContext>,
55    pub created_at: u64,
56}
57
58/// Result from querying across contexts.
59#[derive(Debug)]
60pub struct CrossContextResult {
61    pub context_id: String,
62    pub context_role: ContextRole,
63    pub matches: Vec<CrossContextMatch>,
64}
65
66/// A single match from cross-context querying.
67#[derive(Debug)]
68pub struct CrossContextMatch {
69    pub node_id: u64,
70    pub content: String,
71    pub event_type: String,
72    pub confidence: f32,
73    pub score: f32,
74}
75
76/// Comparison result across contexts.
77#[derive(Debug)]
78pub struct Comparison {
79    pub item: String,
80    pub found_in: Vec<String>,
81    pub missing_from: Vec<String>,
82    pub matches_per_context: Vec<(String, Vec<CrossContextMatch>)>,
83}
84
85/// Cross-reference result.
86#[derive(Debug)]
87pub struct CrossReference {
88    pub item: String,
89    pub present_in: Vec<String>,
90    pub absent_from: Vec<String>,
91}
92
93/// Manages multiple memory workspaces.
94#[derive(Default)]
95pub struct WorkspaceManager {
96    workspaces: HashMap<String, MemoryWorkspace>,
97    next_id: u64,
98}
99
100impl WorkspaceManager {
101    pub fn new() -> Self {
102        Self {
103            workspaces: HashMap::new(),
104            next_id: 1,
105        }
106    }
107
108    /// Create a new workspace.
109    pub fn create(&mut self, name: &str) -> String {
110        let id = format!("ws_{}", self.next_id);
111        self.next_id += 1;
112
113        let now = SystemTime::now()
114            .duration_since(UNIX_EPOCH)
115            .unwrap_or_default()
116            .as_micros() as u64;
117
118        let workspace = MemoryWorkspace {
119            id: id.clone(),
120            name: name.to_string(),
121            contexts: Vec::new(),
122            created_at: now,
123        };
124
125        self.workspaces.insert(id.clone(), workspace);
126        id
127    }
128
129    /// Add a context to a workspace by loading an .amem file.
130    pub fn add_context(
131        &mut self,
132        workspace_id: &str,
133        path: &str,
134        role: ContextRole,
135        label: Option<String>,
136    ) -> McpResult<String> {
137        let workspace = self.workspaces.get_mut(workspace_id).ok_or_else(|| {
138            McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
139        })?;
140
141        // Load the .amem file
142        let file_path = Path::new(path);
143        if !file_path.exists() {
144            return Err(McpError::InvalidParams(format!("File not found: {path}")));
145        }
146
147        let graph = AmemReader::read_from_file(file_path)
148            .map_err(|e| McpError::AgenticMemory(format!("Failed to parse {path}: {e}")))?;
149
150        let ctx_id = format!("ctx_{}_{}", workspace.contexts.len() + 1, workspace_id);
151
152        let context = MemoryContext {
153            id: ctx_id.clone(),
154            role,
155            path: path.to_string(),
156            label: label.or_else(|| {
157                file_path
158                    .file_stem()
159                    .and_then(|s| s.to_str())
160                    .map(|s| s.to_string())
161            }),
162            graph,
163        };
164
165        workspace.contexts.push(context);
166        Ok(ctx_id)
167    }
168
169    /// List contexts in a workspace.
170    pub fn list(&self, workspace_id: &str) -> McpResult<&[MemoryContext]> {
171        let workspace = self.workspaces.get(workspace_id).ok_or_else(|| {
172            McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
173        })?;
174        Ok(&workspace.contexts)
175    }
176
177    /// Get a workspace reference.
178    pub fn get(&self, workspace_id: &str) -> Option<&MemoryWorkspace> {
179        self.workspaces.get(workspace_id)
180    }
181
182    /// Query across all contexts in a workspace.
183    pub fn query_all(
184        &self,
185        workspace_id: &str,
186        query: &str,
187        max_per_context: usize,
188    ) -> McpResult<Vec<CrossContextResult>> {
189        let workspace = self.workspaces.get(workspace_id).ok_or_else(|| {
190            McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
191        })?;
192
193        let engine = QueryEngine::new();
194        let mut results = Vec::new();
195
196        for ctx in &workspace.contexts {
197            let text_matches = engine
198                .text_search(
199                    &ctx.graph,
200                    ctx.graph.term_index.as_ref(),
201                    ctx.graph.doc_lengths.as_ref(),
202                    TextSearchParams {
203                        query: query.to_string(),
204                        max_results: max_per_context,
205                        event_types: Vec::new(),
206                        session_ids: Vec::new(),
207                        min_score: 0.0,
208                    },
209                )
210                .unwrap_or_default();
211
212            let matches: Vec<CrossContextMatch> = text_matches
213                .iter()
214                .filter_map(|m| {
215                    ctx.graph.get_node(m.node_id).map(|node| CrossContextMatch {
216                        node_id: node.id,
217                        content: node.content.clone(),
218                        event_type: node.event_type.name().to_string(),
219                        confidence: node.confidence,
220                        score: m.score,
221                    })
222                })
223                .collect();
224
225            results.push(CrossContextResult {
226                context_id: ctx.id.clone(),
227                context_role: ctx.role,
228                matches,
229            });
230        }
231
232        Ok(results)
233    }
234
235    /// Compare a topic across all contexts.
236    pub fn compare(
237        &self,
238        workspace_id: &str,
239        item: &str,
240        max_per_context: usize,
241    ) -> McpResult<Comparison> {
242        let results = self.query_all(workspace_id, item, max_per_context)?;
243        let workspace = self.workspaces.get(workspace_id).unwrap();
244
245        let mut found_in = Vec::new();
246        let mut missing_from = Vec::new();
247        let mut matches_per_context = Vec::new();
248
249        for (i, ctx_result) in results.into_iter().enumerate() {
250            let label = workspace.contexts[i]
251                .label
252                .clone()
253                .unwrap_or_else(|| ctx_result.context_id.clone());
254
255            if ctx_result.matches.is_empty() {
256                missing_from.push(label);
257            } else {
258                found_in.push(label.clone());
259                matches_per_context.push((label, ctx_result.matches));
260            }
261        }
262
263        Ok(Comparison {
264            item: item.to_string(),
265            found_in,
266            missing_from,
267            matches_per_context,
268        })
269    }
270
271    /// Cross-reference: find which contexts have/lack a topic.
272    pub fn cross_reference(&self, workspace_id: &str, item: &str) -> McpResult<CrossReference> {
273        let comparison = self.compare(workspace_id, item, 5)?;
274        Ok(CrossReference {
275            item: comparison.item,
276            present_in: comparison.found_in,
277            absent_from: comparison.missing_from,
278        })
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_context_role_roundtrip() {
288        assert_eq!(
289            ContextRole::parse_str("primary"),
290            Some(ContextRole::Primary)
291        );
292        assert_eq!(
293            ContextRole::parse_str("SECONDARY"),
294            Some(ContextRole::Secondary)
295        );
296        assert_eq!(
297            ContextRole::parse_str("reference"),
298            Some(ContextRole::Reference)
299        );
300        assert_eq!(
301            ContextRole::parse_str("archive"),
302            Some(ContextRole::Archive)
303        );
304        assert_eq!(ContextRole::parse_str("unknown"), None);
305    }
306
307    #[test]
308    fn test_workspace_create() {
309        let mut mgr = WorkspaceManager::new();
310        let id = mgr.create("test");
311        assert!(id.starts_with("ws_"));
312        assert!(mgr.get(&id).is_some());
313        assert_eq!(mgr.get(&id).unwrap().name, "test");
314    }
315
316    #[test]
317    fn test_workspace_not_found() {
318        let mgr = WorkspaceManager::new();
319        assert!(mgr.list("nonexistent").is_err());
320    }
321
322    #[test]
323    fn test_workspace_file_not_found() {
324        let mut mgr = WorkspaceManager::new();
325        let id = mgr.create("test");
326        let result = mgr.add_context(&id, "/nonexistent.amem", ContextRole::Primary, None);
327        assert!(result.is_err());
328    }
329}