Skip to main content

agentic_vision_mcp/session/
workspace.rs

1//! Multi-context workspace manager for loading and querying multiple .avis files.
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use agentic_vision::{AvisReader, VisualMemoryStore};
8
9use crate::types::{McpError, McpResult};
10
11/// Role of a context within a workspace.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ContextRole {
14    Primary,
15    Secondary,
16    Reference,
17    Archive,
18}
19
20impl ContextRole {
21    pub fn parse_str(s: &str) -> Option<Self> {
22        match s.to_lowercase().as_str() {
23            "primary" => Some(Self::Primary),
24            "secondary" => Some(Self::Secondary),
25            "reference" => Some(Self::Reference),
26            "archive" => Some(Self::Archive),
27            _ => None,
28        }
29    }
30
31    pub fn label(&self) -> &'static str {
32        match self {
33            Self::Primary => "primary",
34            Self::Secondary => "secondary",
35            Self::Reference => "reference",
36            Self::Archive => "archive",
37        }
38    }
39}
40
41/// A loaded vision context within a workspace.
42pub struct VisionContext {
43    pub id: String,
44    pub role: ContextRole,
45    pub path: String,
46    pub label: Option<String>,
47    pub store: VisualMemoryStore,
48}
49
50/// A multi-vision workspace.
51pub struct VisionWorkspace {
52    pub id: String,
53    pub name: String,
54    pub contexts: Vec<VisionContext>,
55    pub created_at: u64,
56}
57
58/// Cross-context match.
59#[derive(Debug)]
60pub struct CrossContextMatch {
61    pub observation_id: u64,
62    pub description: Option<String>,
63    pub labels: Vec<String>,
64    pub score: f32,
65}
66
67/// Cross-context result.
68#[derive(Debug)]
69pub struct CrossContextResult {
70    pub context_id: String,
71    pub context_role: ContextRole,
72    pub matches: Vec<CrossContextMatch>,
73}
74
75/// Comparison.
76#[derive(Debug)]
77pub struct Comparison {
78    pub item: String,
79    pub found_in: Vec<String>,
80    pub missing_from: Vec<String>,
81    pub matches_per_context: Vec<(String, Vec<CrossContextMatch>)>,
82}
83
84/// Cross-reference.
85#[derive(Debug)]
86pub struct CrossReference {
87    pub item: String,
88    pub present_in: Vec<String>,
89    pub absent_from: Vec<String>,
90}
91
92/// Manages multiple vision workspaces.
93#[derive(Default)]
94pub struct VisionWorkspaceManager {
95    workspaces: HashMap<String, VisionWorkspace>,
96    next_id: u64,
97}
98
99impl VisionWorkspaceManager {
100    pub fn new() -> Self {
101        Self {
102            workspaces: HashMap::new(),
103            next_id: 1,
104        }
105    }
106
107    pub fn create(&mut self, name: &str) -> String {
108        let id = format!("vws_{}", self.next_id);
109        self.next_id += 1;
110        let now = SystemTime::now()
111            .duration_since(UNIX_EPOCH)
112            .unwrap_or_default()
113            .as_micros() as u64;
114        self.workspaces.insert(
115            id.clone(),
116            VisionWorkspace {
117                id: id.clone(),
118                name: name.to_string(),
119                contexts: Vec::new(),
120                created_at: now,
121            },
122        );
123        id
124    }
125
126    pub fn add_context(
127        &mut self,
128        workspace_id: &str,
129        path: &str,
130        role: ContextRole,
131        label: Option<String>,
132    ) -> McpResult<String> {
133        let workspace = self.workspaces.get_mut(workspace_id).ok_or_else(|| {
134            McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
135        })?;
136
137        let file_path = Path::new(path);
138        if !file_path.exists() {
139            return Err(McpError::InvalidParams(format!("File not found: {path}")));
140        }
141
142        let store = AvisReader::read_from_file(file_path)
143            .map_err(|e| McpError::VisionError(format!("Failed to parse {path}: {e}")))?;
144
145        let ctx_id = format!("vctx_{}_{}", workspace.contexts.len() + 1, workspace_id);
146
147        workspace.contexts.push(VisionContext {
148            id: ctx_id.clone(),
149            role,
150            path: path.to_string(),
151            label: label.or_else(|| {
152                file_path
153                    .file_stem()
154                    .and_then(|s| s.to_str())
155                    .map(|s| s.to_string())
156            }),
157            store,
158        });
159
160        Ok(ctx_id)
161    }
162
163    pub fn list(&self, workspace_id: &str) -> McpResult<&[VisionContext]> {
164        let workspace = self.workspaces.get(workspace_id).ok_or_else(|| {
165            McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
166        })?;
167        Ok(&workspace.contexts)
168    }
169
170    pub fn get(&self, workspace_id: &str) -> Option<&VisionWorkspace> {
171        self.workspaces.get(workspace_id)
172    }
173
174    pub fn query_all(
175        &self,
176        workspace_id: &str,
177        query: &str,
178        max_per_context: usize,
179    ) -> McpResult<Vec<CrossContextResult>> {
180        let workspace = self.workspaces.get(workspace_id).ok_or_else(|| {
181            McpError::InvalidParams(format!("Workspace not found: {workspace_id}"))
182        })?;
183
184        let query_lower = query.to_lowercase();
185        let query_words: Vec<&str> = query_lower.split_whitespace().collect();
186        let mut results = Vec::new();
187
188        for ctx in &workspace.contexts {
189            let mut matches = Vec::new();
190            for obs in &ctx.store.observations {
191                let mut score = 0.0f32;
192
193                if let Some(ref desc) = obs.metadata.description {
194                    let desc_lower = desc.to_lowercase();
195                    let overlap = query_words
196                        .iter()
197                        .filter(|w| desc_lower.contains(**w))
198                        .count();
199                    score += overlap as f32 / query_words.len().max(1) as f32;
200                }
201
202                for label in &obs.metadata.labels {
203                    if query_lower.contains(&label.to_lowercase()) {
204                        score += 0.3;
205                    }
206                }
207
208                if score > 0.0 {
209                    matches.push(CrossContextMatch {
210                        observation_id: obs.id,
211                        description: obs.metadata.description.clone(),
212                        labels: obs.metadata.labels.clone(),
213                        score,
214                    });
215                }
216            }
217
218            matches.sort_by(|a, b| {
219                b.score
220                    .partial_cmp(&a.score)
221                    .unwrap_or(std::cmp::Ordering::Equal)
222            });
223            matches.truncate(max_per_context);
224
225            results.push(CrossContextResult {
226                context_id: ctx.id.clone(),
227                context_role: ctx.role,
228                matches,
229            });
230        }
231
232        Ok(results)
233    }
234
235    pub fn compare(
236        &self,
237        workspace_id: &str,
238        item: &str,
239        max_per_context: usize,
240    ) -> McpResult<Comparison> {
241        let results = self.query_all(workspace_id, item, max_per_context)?;
242        let workspace = self.workspaces.get(workspace_id).unwrap();
243
244        let mut found_in = Vec::new();
245        let mut missing_from = Vec::new();
246        let mut matches_per_context = Vec::new();
247
248        for (i, cr) in results.into_iter().enumerate() {
249            let label = workspace.contexts[i]
250                .label
251                .clone()
252                .unwrap_or_else(|| cr.context_id.clone());
253            if cr.matches.is_empty() {
254                missing_from.push(label);
255            } else {
256                found_in.push(label.clone());
257                matches_per_context.push((label, cr.matches));
258            }
259        }
260
261        Ok(Comparison {
262            item: item.to_string(),
263            found_in,
264            missing_from,
265            matches_per_context,
266        })
267    }
268
269    pub fn cross_reference(&self, workspace_id: &str, item: &str) -> McpResult<CrossReference> {
270        let c = self.compare(workspace_id, item, 5)?;
271        Ok(CrossReference {
272            item: c.item,
273            present_in: c.found_in,
274            absent_from: c.missing_from,
275        })
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_context_role_roundtrip() {
285        assert_eq!(
286            ContextRole::parse_str("primary"),
287            Some(ContextRole::Primary)
288        );
289        assert_eq!(
290            ContextRole::parse_str("ARCHIVE"),
291            Some(ContextRole::Archive)
292        );
293        assert_eq!(ContextRole::parse_str("unknown"), None);
294    }
295
296    #[test]
297    fn test_workspace_create() {
298        let mut mgr = VisionWorkspaceManager::new();
299        let id = mgr.create("test");
300        assert!(id.starts_with("vws_"));
301        assert!(mgr.get(&id).is_some());
302    }
303
304    #[test]
305    fn test_workspace_not_found() {
306        let mgr = VisionWorkspaceManager::new();
307        assert!(mgr.list("nonexistent").is_err());
308    }
309
310    #[test]
311    fn test_workspace_file_not_found() {
312        let mut mgr = VisionWorkspaceManager::new();
313        let id = mgr.create("test");
314        assert!(mgr
315            .add_context(&id, "/nonexistent.avis", ContextRole::Primary, None)
316            .is_err());
317    }
318}