agentic_memory_mcp/session/
workspace.rs1use std::collections::HashMap;
4use std::path::Path;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use agentic_memory::{AmemReader, MemoryGraph, QueryEngine, TextSearchParams};
8
9use crate::types::{McpError, McpResult};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ContextRole {
14 Primary,
15 Secondary,
16 Reference,
17 Archive,
18}
19
20impl ContextRole {
21 pub fn parse_str(s: &str) -> Option<Self> {
22 match s.to_lowercase().as_str() {
23 "primary" => Some(Self::Primary),
24 "secondary" => Some(Self::Secondary),
25 "reference" => Some(Self::Reference),
26 "archive" => Some(Self::Archive),
27 _ => None,
28 }
29 }
30
31 pub fn label(&self) -> &'static str {
32 match self {
33 Self::Primary => "primary",
34 Self::Secondary => "secondary",
35 Self::Reference => "reference",
36 Self::Archive => "archive",
37 }
38 }
39}
40
41pub struct MemoryContext {
43 pub id: String,
44 pub role: ContextRole,
45 pub path: String,
46 pub label: Option<String>,
47 pub graph: MemoryGraph,
48}
49
50pub struct MemoryWorkspace {
52 pub id: String,
53 pub name: String,
54 pub contexts: Vec<MemoryContext>,
55 pub created_at: u64,
56}
57
58#[derive(Debug)]
60pub struct CrossContextResult {
61 pub context_id: String,
62 pub context_role: ContextRole,
63 pub matches: Vec<CrossContextMatch>,
64}
65
66#[derive(Debug)]
68pub struct CrossContextMatch {
69 pub node_id: u64,
70 pub content: String,
71 pub event_type: String,
72 pub confidence: f32,
73 pub score: f32,
74}
75
76#[derive(Debug)]
78pub struct Comparison {
79 pub item: String,
80 pub found_in: Vec<String>,
81 pub missing_from: Vec<String>,
82 pub matches_per_context: Vec<(String, Vec<CrossContextMatch>)>,
83}
84
85#[derive(Debug)]
87pub struct CrossReference {
88 pub item: String,
89 pub present_in: Vec<String>,
90 pub absent_from: Vec<String>,
91}
92
93#[derive(Default)]
95pub struct WorkspaceManager {
96 workspaces: HashMap<String, MemoryWorkspace>,
97 next_id: u64,
98}
99
100impl WorkspaceManager {
101 pub fn new() -> Self {
102 Self {
103 workspaces: HashMap::new(),
104 next_id: 1,
105 }
106 }
107
108 pub fn create(&mut self, name: &str) -> String {
110 let id = format!("ws_{}", self.next_id);
111 self.next_id += 1;
112
113 let now = SystemTime::now()
114 .duration_since(UNIX_EPOCH)
115 .unwrap_or_default()
116 .as_micros() as u64;
117
118 let workspace = MemoryWorkspace {
119 id: id.clone(),
120 name: name.to_string(),
121 contexts: Vec::new(),
122 created_at: now,
123 };
124
125 self.workspaces.insert(id.clone(), workspace);
126 id
127 }
128
129 pub fn add_context(
131 &mut self,
132 workspace_id: &str,
133 path: &str,
134 role: ContextRole,
135 label: Option<String>,
136 ) -> McpResult<String> {
137 let workspace = self.workspaces.get_mut(workspace_id).ok_or_else(|| {
138 McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
139 })?;
140
141 let file_path = Path::new(path);
143 if !file_path.exists() {
144 return Err(McpError::InvalidParams(format!("File not found: {path}")));
145 }
146
147 let graph = AmemReader::read_from_file(file_path)
148 .map_err(|e| McpError::AgenticMemory(format!("Failed to parse {path}: {e}")))?;
149
150 let ctx_id = format!("ctx_{}_{}", workspace.contexts.len() + 1, workspace_id);
151
152 let context = MemoryContext {
153 id: ctx_id.clone(),
154 role,
155 path: path.to_string(),
156 label: label.or_else(|| {
157 file_path
158 .file_stem()
159 .and_then(|s| s.to_str())
160 .map(|s| s.to_string())
161 }),
162 graph,
163 };
164
165 workspace.contexts.push(context);
166 Ok(ctx_id)
167 }
168
169 pub fn list(&self, workspace_id: &str) -> McpResult<&[MemoryContext]> {
171 let workspace = self.workspaces.get(workspace_id).ok_or_else(|| {
172 McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
173 })?;
174 Ok(&workspace.contexts)
175 }
176
177 pub fn get(&self, workspace_id: &str) -> Option<&MemoryWorkspace> {
179 self.workspaces.get(workspace_id)
180 }
181
182 pub fn query_all(
184 &self,
185 workspace_id: &str,
186 query: &str,
187 max_per_context: usize,
188 ) -> McpResult<Vec<CrossContextResult>> {
189 let workspace = self.workspaces.get(workspace_id).ok_or_else(|| {
190 McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
191 })?;
192
193 let engine = QueryEngine::new();
194 let mut results = Vec::new();
195
196 for ctx in &workspace.contexts {
197 let text_matches = engine
198 .text_search(
199 &ctx.graph,
200 ctx.graph.term_index.as_ref(),
201 ctx.graph.doc_lengths.as_ref(),
202 TextSearchParams {
203 query: query.to_string(),
204 max_results: max_per_context,
205 event_types: Vec::new(),
206 session_ids: Vec::new(),
207 min_score: 0.0,
208 },
209 )
210 .unwrap_or_default();
211
212 let matches: Vec<CrossContextMatch> = text_matches
213 .iter()
214 .filter_map(|m| {
215 ctx.graph.get_node(m.node_id).map(|node| CrossContextMatch {
216 node_id: node.id,
217 content: node.content.clone(),
218 event_type: node.event_type.name().to_string(),
219 confidence: node.confidence,
220 score: m.score,
221 })
222 })
223 .collect();
224
225 results.push(CrossContextResult {
226 context_id: ctx.id.clone(),
227 context_role: ctx.role,
228 matches,
229 });
230 }
231
232 Ok(results)
233 }
234
235 pub fn compare(
237 &self,
238 workspace_id: &str,
239 item: &str,
240 max_per_context: usize,
241 ) -> McpResult<Comparison> {
242 let results = self.query_all(workspace_id, item, max_per_context)?;
243 let workspace = self.workspaces.get(workspace_id).unwrap();
244
245 let mut found_in = Vec::new();
246 let mut missing_from = Vec::new();
247 let mut matches_per_context = Vec::new();
248
249 for (i, ctx_result) in results.into_iter().enumerate() {
250 let label = workspace.contexts[i]
251 .label
252 .clone()
253 .unwrap_or_else(|| ctx_result.context_id.clone());
254
255 if ctx_result.matches.is_empty() {
256 missing_from.push(label);
257 } else {
258 found_in.push(label.clone());
259 matches_per_context.push((label, ctx_result.matches));
260 }
261 }
262
263 Ok(Comparison {
264 item: item.to_string(),
265 found_in,
266 missing_from,
267 matches_per_context,
268 })
269 }
270
271 pub fn cross_reference(&self, workspace_id: &str, item: &str) -> McpResult<CrossReference> {
273 let comparison = self.compare(workspace_id, item, 5)?;
274 Ok(CrossReference {
275 item: comparison.item,
276 present_in: comparison.found_in,
277 absent_from: comparison.missing_from,
278 })
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_context_role_roundtrip() {
288 assert_eq!(
289 ContextRole::parse_str("primary"),
290 Some(ContextRole::Primary)
291 );
292 assert_eq!(
293 ContextRole::parse_str("SECONDARY"),
294 Some(ContextRole::Secondary)
295 );
296 assert_eq!(
297 ContextRole::parse_str("reference"),
298 Some(ContextRole::Reference)
299 );
300 assert_eq!(
301 ContextRole::parse_str("archive"),
302 Some(ContextRole::Archive)
303 );
304 assert_eq!(ContextRole::parse_str("unknown"), None);
305 }
306
307 #[test]
308 fn test_workspace_create() {
309 let mut mgr = WorkspaceManager::new();
310 let id = mgr.create("test");
311 assert!(id.starts_with("ws_"));
312 assert!(mgr.get(&id).is_some());
313 assert_eq!(mgr.get(&id).unwrap().name, "test");
314 }
315
316 #[test]
317 fn test_workspace_not_found() {
318 let mgr = WorkspaceManager::new();
319 assert!(mgr.list("nonexistent").is_err());
320 }
321
322 #[test]
323 fn test_workspace_file_not_found() {
324 let mut mgr = WorkspaceManager::new();
325 let id = mgr.create("test");
326 let result = mgr.add_context(&id, "/nonexistent.amem", ContextRole::Primary, None);
327 assert!(result.is_err());
328 }
329}