agentic_memory_mcp/session/
workspace.rs1use std::collections::HashMap;
4use std::path::Path;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use agentic_memory::{AmemReader, MemoryGraph, QueryEngine, TextSearchParams};
8
9use crate::types::{McpError, McpResult};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ContextRole {
14 Primary,
15 Secondary,
16 Reference,
17 Archive,
18}
19
20impl ContextRole {
21 pub fn from_str(s: &str) -> Option<Self> {
22 match s.to_lowercase().as_str() {
23 "primary" => Some(Self::Primary),
24 "secondary" => Some(Self::Secondary),
25 "reference" => Some(Self::Reference),
26 "archive" => Some(Self::Archive),
27 _ => None,
28 }
29 }
30
31 pub fn label(&self) -> &'static str {
32 match self {
33 Self::Primary => "primary",
34 Self::Secondary => "secondary",
35 Self::Reference => "reference",
36 Self::Archive => "archive",
37 }
38 }
39}
40
41pub struct MemoryContext {
43 pub id: String,
44 pub role: ContextRole,
45 pub path: String,
46 pub label: Option<String>,
47 pub graph: MemoryGraph,
48}
49
50pub struct MemoryWorkspace {
52 pub id: String,
53 pub name: String,
54 pub contexts: Vec<MemoryContext>,
55 pub created_at: u64,
56}
57
58#[derive(Debug)]
60pub struct CrossContextResult {
61 pub context_id: String,
62 pub context_role: ContextRole,
63 pub matches: Vec<CrossContextMatch>,
64}
65
66#[derive(Debug)]
68pub struct CrossContextMatch {
69 pub node_id: u64,
70 pub content: String,
71 pub event_type: String,
72 pub confidence: f32,
73 pub score: f32,
74}
75
76#[derive(Debug)]
78pub struct Comparison {
79 pub item: String,
80 pub found_in: Vec<String>,
81 pub missing_from: Vec<String>,
82 pub matches_per_context: Vec<(String, Vec<CrossContextMatch>)>,
83}
84
85#[derive(Debug)]
87pub struct CrossReference {
88 pub item: String,
89 pub present_in: Vec<String>,
90 pub absent_from: Vec<String>,
91}
92
93#[derive(Default)]
95pub struct WorkspaceManager {
96 workspaces: HashMap<String, MemoryWorkspace>,
97 next_id: u64,
98}
99
100impl WorkspaceManager {
101 pub fn new() -> Self {
102 Self {
103 workspaces: HashMap::new(),
104 next_id: 1,
105 }
106 }
107
108 pub fn create(&mut self, name: &str) -> String {
110 let id = format!("ws_{}", self.next_id);
111 self.next_id += 1;
112
113 let now = SystemTime::now()
114 .duration_since(UNIX_EPOCH)
115 .unwrap_or_default()
116 .as_micros() as u64;
117
118 let workspace = MemoryWorkspace {
119 id: id.clone(),
120 name: name.to_string(),
121 contexts: Vec::new(),
122 created_at: now,
123 };
124
125 self.workspaces.insert(id.clone(), workspace);
126 id
127 }
128
129 pub fn add_context(
131 &mut self,
132 workspace_id: &str,
133 path: &str,
134 role: ContextRole,
135 label: Option<String>,
136 ) -> McpResult<String> {
137 let workspace = self
138 .workspaces
139 .get_mut(workspace_id)
140 .ok_or_else(|| McpError::InvalidParams(format!("Workspace not found: {workspace_id}")))?;
141
142 let file_path = Path::new(path);
144 if !file_path.exists() {
145 return Err(McpError::InvalidParams(format!("File not found: {path}")));
146 }
147
148 let graph = AmemReader::read_from_file(file_path)
149 .map_err(|e| McpError::AgenticMemory(format!("Failed to parse {path}: {e}")))?;
150
151 let ctx_id = format!("ctx_{}_{}", workspace.contexts.len() + 1, workspace_id);
152
153 let context = MemoryContext {
154 id: ctx_id.clone(),
155 role,
156 path: path.to_string(),
157 label: label.or_else(|| {
158 file_path
159 .file_stem()
160 .and_then(|s| s.to_str())
161 .map(|s| s.to_string())
162 }),
163 graph,
164 };
165
166 workspace.contexts.push(context);
167 Ok(ctx_id)
168 }
169
170 pub fn list(&self, workspace_id: &str) -> McpResult<&[MemoryContext]> {
172 let workspace = self
173 .workspaces
174 .get(workspace_id)
175 .ok_or_else(|| McpError::InvalidParams(format!("Workspace not found: {workspace_id}")))?;
176 Ok(&workspace.contexts)
177 }
178
179 pub fn get(&self, workspace_id: &str) -> Option<&MemoryWorkspace> {
181 self.workspaces.get(workspace_id)
182 }
183
184 pub fn query_all(
186 &self,
187 workspace_id: &str,
188 query: &str,
189 max_per_context: usize,
190 ) -> McpResult<Vec<CrossContextResult>> {
191 let workspace = self
192 .workspaces
193 .get(workspace_id)
194 .ok_or_else(|| McpError::InvalidParams(format!("Workspace not found: {workspace_id}")))?;
195
196 let engine = QueryEngine::new();
197 let mut results = Vec::new();
198
199 for ctx in &workspace.contexts {
200 let text_matches = engine
201 .text_search(
202 &ctx.graph,
203 ctx.graph.term_index.as_ref(),
204 ctx.graph.doc_lengths.as_ref(),
205 TextSearchParams {
206 query: query.to_string(),
207 max_results: max_per_context,
208 event_types: Vec::new(),
209 session_ids: Vec::new(),
210 min_score: 0.0,
211 },
212 )
213 .unwrap_or_default();
214
215 let matches: Vec<CrossContextMatch> = text_matches
216 .iter()
217 .filter_map(|m| {
218 ctx.graph.get_node(m.node_id).map(|node| CrossContextMatch {
219 node_id: node.id,
220 content: node.content.clone(),
221 event_type: node.event_type.name().to_string(),
222 confidence: node.confidence,
223 score: m.score,
224 })
225 })
226 .collect();
227
228 results.push(CrossContextResult {
229 context_id: ctx.id.clone(),
230 context_role: ctx.role,
231 matches,
232 });
233 }
234
235 Ok(results)
236 }
237
238 pub fn compare(
240 &self,
241 workspace_id: &str,
242 item: &str,
243 max_per_context: usize,
244 ) -> McpResult<Comparison> {
245 let results = self.query_all(workspace_id, item, max_per_context)?;
246 let workspace = self.workspaces.get(workspace_id).unwrap();
247
248 let mut found_in = Vec::new();
249 let mut missing_from = Vec::new();
250 let mut matches_per_context = Vec::new();
251
252 for (i, ctx_result) in results.into_iter().enumerate() {
253 let label = workspace.contexts[i]
254 .label
255 .clone()
256 .unwrap_or_else(|| ctx_result.context_id.clone());
257
258 if ctx_result.matches.is_empty() {
259 missing_from.push(label);
260 } else {
261 found_in.push(label.clone());
262 matches_per_context.push((label, ctx_result.matches));
263 }
264 }
265
266 Ok(Comparison {
267 item: item.to_string(),
268 found_in,
269 missing_from,
270 matches_per_context,
271 })
272 }
273
274 pub fn cross_reference(
276 &self,
277 workspace_id: &str,
278 item: &str,
279 ) -> McpResult<CrossReference> {
280 let comparison = self.compare(workspace_id, item, 5)?;
281 Ok(CrossReference {
282 item: comparison.item,
283 present_in: comparison.found_in,
284 absent_from: comparison.missing_from,
285 })
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_context_role_roundtrip() {
295 assert_eq!(ContextRole::from_str("primary"), Some(ContextRole::Primary));
296 assert_eq!(ContextRole::from_str("SECONDARY"), Some(ContextRole::Secondary));
297 assert_eq!(ContextRole::from_str("reference"), Some(ContextRole::Reference));
298 assert_eq!(ContextRole::from_str("archive"), Some(ContextRole::Archive));
299 assert_eq!(ContextRole::from_str("unknown"), None);
300 }
301
302 #[test]
303 fn test_workspace_create() {
304 let mut mgr = WorkspaceManager::new();
305 let id = mgr.create("test");
306 assert!(id.starts_with("ws_"));
307 assert!(mgr.get(&id).is_some());
308 assert_eq!(mgr.get(&id).unwrap().name, "test");
309 }
310
311 #[test]
312 fn test_workspace_not_found() {
313 let mgr = WorkspaceManager::new();
314 assert!(mgr.list("nonexistent").is_err());
315 }
316
317 #[test]
318 fn test_workspace_file_not_found() {
319 let mut mgr = WorkspaceManager::new();
320 let id = mgr.create("test");
321 let result = mgr.add_context(&id, "/nonexistent.amem", ContextRole::Primary, None);
322 assert!(result.is_err());
323 }
324}