Skip to main content

aster/mcp/
roots.rs

1//! MCP Roots Module
2//!
3//! Manages root directories for MCP servers. Roots define the base directories
4//! that servers can access, providing a sandboxing mechanism for file operations.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::{broadcast, RwLock};
11
12/// Root directory for MCP protocol
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Root {
15    /// URI of the root (file:// format)
16    pub uri: String,
17    /// Optional human-readable name
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub name: Option<String>,
20}
21
22/// Root directory with metadata
23#[derive(Debug, Clone)]
24pub struct RootInfo {
25    /// URI of the root
26    pub uri: String,
27    /// Optional name
28    pub name: Option<String>,
29    /// Whether the path exists
30    pub exists: bool,
31    /// Absolute path (if file:// URI)
32    pub absolute_path: Option<PathBuf>,
33    /// Permissions
34    pub permissions: Option<RootPermissions>,
35}
36
37/// Root permissions
38#[derive(Debug, Clone, Copy)]
39pub struct RootPermissions {
40    /// Read permission
41    pub read: bool,
42    /// Write permission
43    pub write: bool,
44}
45
46/// Roots configuration
47#[derive(Debug, Clone)]
48pub struct RootsConfig {
49    /// Initial roots
50    pub roots: Vec<Root>,
51    /// Allow dynamic root addition
52    pub allow_dynamic_roots: bool,
53    /// Validate paths exist
54    pub validate_paths: bool,
55}
56
57impl Default for RootsConfig {
58    fn default() -> Self {
59        Self {
60            roots: Vec::new(),
61            allow_dynamic_roots: true,
62            validate_paths: true,
63        }
64    }
65}
66
67/// Root event for broadcasting
68#[derive(Debug, Clone)]
69pub enum RootEvent {
70    /// Root added
71    RootAdded { root: RootInfo },
72    /// Root removed
73    RootRemoved { root: RootInfo },
74    /// Root updated
75    RootUpdated { root: RootInfo, previous: RootInfo },
76    /// Roots cleared
77    RootsCleared { count: usize },
78    /// Roots refreshed
79    RootsRefreshed { count: usize },
80}
81
82/// Manages root directories for MCP servers
83pub struct McpRootsManager {
84    roots: Arc<RwLock<HashMap<String, RootInfo>>>,
85    allow_dynamic_roots: bool,
86    validate_paths: bool,
87    event_sender: broadcast::Sender<RootEvent>,
88}
89
90impl McpRootsManager {
91    /// Create a new roots manager
92    pub fn new(config: RootsConfig) -> Self {
93        let (event_sender, _) = broadcast::channel(64);
94        let manager = Self {
95            roots: Arc::new(RwLock::new(HashMap::new())),
96            allow_dynamic_roots: config.allow_dynamic_roots,
97            validate_paths: config.validate_paths,
98            event_sender,
99        };
100
101        // Initialize with provided roots (blocking for simplicity)
102        for root in config.roots {
103            let root_info = manager.parse_root_sync(&root);
104            manager
105                .roots
106                .blocking_write()
107                .insert(root.uri.clone(), root_info);
108        }
109
110        manager
111    }
112
113    /// Subscribe to root events
114    pub fn subscribe(&self) -> broadcast::Receiver<RootEvent> {
115        self.event_sender.subscribe()
116    }
117
118    /// Add a root directory
119    pub async fn add_root(&self, root: Root) -> RootInfo {
120        let root_info = self.parse_root(&root);
121        self.roots
122            .write()
123            .await
124            .insert(root.uri.clone(), root_info.clone());
125        let _ = self.event_sender.send(RootEvent::RootAdded {
126            root: root_info.clone(),
127        });
128        root_info
129    }
130
131    /// Remove a root directory
132    pub async fn remove_root(&self, uri: &str) -> Option<RootInfo> {
133        let root = self.roots.write().await.remove(uri);
134        if let Some(ref r) = root {
135            let _ = self
136                .event_sender
137                .send(RootEvent::RootRemoved { root: r.clone() });
138        }
139        root
140    }
141
142    /// Update a root directory
143    pub async fn update_root(&self, uri: &str, updates: Root) -> Option<RootInfo> {
144        let mut roots = self.roots.write().await;
145        let existing = roots.get(uri)?.clone();
146        let updated = self.parse_root(&updates);
147        roots.insert(uri.to_string(), updated.clone());
148        let _ = self.event_sender.send(RootEvent::RootUpdated {
149            root: updated.clone(),
150            previous: existing,
151        });
152        Some(updated)
153    }
154
155    /// Get a root by URI
156    pub async fn get_root(&self, uri: &str) -> Option<RootInfo> {
157        self.roots.read().await.get(uri).cloned()
158    }
159
160    /// Get all roots
161    pub async fn get_roots(&self) -> Vec<RootInfo> {
162        self.roots.read().await.values().cloned().collect()
163    }
164
165    /// Get all roots as plain Root objects (for MCP protocol)
166    pub async fn get_roots_for_protocol(&self) -> Vec<Root> {
167        self.roots
168            .read()
169            .await
170            .values()
171            .map(|r| Root {
172                uri: r.uri.clone(),
173                name: r.name.clone(),
174            })
175            .collect()
176    }
177
178    /// Clear all roots
179    pub async fn clear_roots(&self) {
180        let mut roots = self.roots.write().await;
181        let count = roots.len();
182        roots.clear();
183        let _ = self.event_sender.send(RootEvent::RootsCleared { count });
184    }
185
186    /// Check if a URI is registered as a root
187    pub async fn has_root(&self, uri: &str) -> bool {
188        self.roots.read().await.contains_key(uri)
189    }
190
191    /// Parse a root and extract information
192    fn parse_root(&self, root: &Root) -> RootInfo {
193        self.parse_root_sync(root)
194    }
195
196    fn parse_root_sync(&self, root: &Root) -> RootInfo {
197        let mut absolute_path = None;
198        let mut exists = false;
199        let mut permissions = None;
200
201        if root.uri.starts_with("file://") {
202            if let Some(path) = self.uri_to_path(&root.uri) {
203                absolute_path = Some(path.clone());
204
205                if self.validate_paths {
206                    exists = path.exists();
207                    if exists {
208                        let read = path
209                            .metadata()
210                            .map(|m| !m.permissions().readonly())
211                            .unwrap_or(false);
212                        let write = std::fs::OpenOptions::new().write(true).open(&path).is_ok();
213                        permissions = Some(RootPermissions { read, write });
214                    }
215                }
216            }
217        }
218
219        RootInfo {
220            uri: root.uri.clone(),
221            name: root.name.clone(),
222            exists,
223            absolute_path,
224            permissions,
225        }
226    }
227
228    /// Convert file:// URI to local path
229    fn uri_to_path(&self, uri: &str) -> Option<PathBuf> {
230        if !uri.starts_with("file://") {
231            return None;
232        }
233
234        let path_str = uri.get(7..)?; // Remove "file://"
235
236        #[cfg(windows)]
237        let path_str = if path_str.starts_with('/') && path_str.chars().nth(2) == Some(':') {
238            path_str.get(1..)? // Remove leading / for Windows paths like /C:/
239        } else {
240            path_str
241        };
242
243        let decoded = urlencoding::decode(path_str).ok()?;
244        Some(PathBuf::from(decoded.into_owned()))
245    }
246
247    /// Convert local path to file:// URI
248    fn path_to_uri(&self, path: &Path) -> String {
249        let absolute = if path.is_absolute() {
250            path.to_path_buf()
251        } else {
252            std::env::current_dir().unwrap_or_default().join(path)
253        };
254
255        let path_str = absolute.to_string_lossy();
256
257        #[cfg(windows)]
258        let uri = format!("file:///{}", path_str.replace('\\', "/"));
259
260        #[cfg(not(windows))]
261        let uri = format!("file://{}", path_str);
262
263        uri
264    }
265
266    /// Check if a path is within any root
267    pub async fn is_path_in_roots(&self, path: &Path) -> bool {
268        let absolute = if path.is_absolute() {
269            path.to_path_buf()
270        } else {
271            std::env::current_dir().unwrap_or_default().join(path)
272        };
273
274        for root in self.roots.read().await.values() {
275            if let Some(ref root_path) = root.absolute_path {
276                if self.is_path_in_root(&absolute, root_path) {
277                    return true;
278                }
279            }
280        }
281        false
282    }
283
284    /// Check if a path is within a specific root
285    fn is_path_in_root(&self, path: &Path, root_path: &Path) -> bool {
286        let normalized_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
287        let normalized_root = root_path
288            .canonicalize()
289            .unwrap_or_else(|_| root_path.to_path_buf());
290        normalized_path.starts_with(&normalized_root)
291    }
292
293    /// Get the root that contains a path
294    pub async fn get_root_for_path(&self, path: &Path) -> Option<RootInfo> {
295        let absolute = if path.is_absolute() {
296            path.to_path_buf()
297        } else {
298            std::env::current_dir().unwrap_or_default().join(path)
299        };
300
301        for root in self.roots.read().await.values() {
302            if let Some(ref root_path) = root.absolute_path {
303                if self.is_path_in_root(&absolute, root_path) {
304                    return Some(root.clone());
305                }
306            }
307        }
308        None
309    }
310
311    /// Add a root from a local path
312    pub async fn add_root_from_path(
313        &self,
314        path: &Path,
315        name: Option<String>,
316    ) -> Result<RootInfo, &'static str> {
317        if !self.allow_dynamic_roots {
318            return Err("Dynamic roots are not allowed");
319        }
320
321        let uri = self.path_to_uri(path);
322        let root = Root { uri, name };
323        Ok(self.add_root(root).await)
324    }
325
326    /// Add the current working directory as a root
327    pub async fn add_cwd_root(&self, name: Option<String>) -> Result<RootInfo, &'static str> {
328        let cwd = std::env::current_dir().map_err(|_| "Could not get current directory")?;
329        self.add_root_from_path(&cwd, name.or(Some("Current Directory".to_string())))
330            .await
331    }
332
333    /// Add home directory as a root
334    pub async fn add_home_root(&self, name: Option<String>) -> Result<RootInfo, &'static str> {
335        let home = dirs::home_dir().ok_or("Could not determine home directory")?;
336        self.add_root_from_path(&home, name.or(Some("Home Directory".to_string())))
337            .await
338    }
339
340    /// Get statistics about roots
341    pub async fn get_stats(&self) -> RootsStats {
342        let roots = self.get_roots().await;
343        RootsStats {
344            total_roots: roots.len(),
345            existing_roots: roots.iter().filter(|r| r.exists).count(),
346            readable_roots: roots
347                .iter()
348                .filter(|r| r.permissions.map(|p| p.read).unwrap_or(false))
349                .count(),
350            writable_roots: roots
351                .iter()
352                .filter(|r| r.permissions.map(|p| p.write).unwrap_or(false))
353                .count(),
354            allow_dynamic_roots: self.allow_dynamic_roots,
355            validate_paths: self.validate_paths,
356        }
357    }
358
359    /// Refresh root information
360    pub async fn refresh_roots(&self) {
361        let roots: Vec<_> = self.roots.read().await.values().cloned().collect();
362        let count = roots.len();
363
364        for root in roots {
365            let refreshed = self.parse_root(&Root {
366                uri: root.uri.clone(),
367                name: root.name.clone(),
368            });
369            self.roots.write().await.insert(root.uri, refreshed);
370        }
371
372        let _ = self.event_sender.send(RootEvent::RootsRefreshed { count });
373    }
374}
375
376impl Default for McpRootsManager {
377    fn default() -> Self {
378        Self::new(RootsConfig::default())
379    }
380}
381
382/// Roots statistics
383#[derive(Debug, Clone)]
384pub struct RootsStats {
385    /// Total number of roots
386    pub total_roots: usize,
387    /// Roots that exist
388    pub existing_roots: usize,
389    /// Roots that are readable
390    pub readable_roots: usize,
391    /// Roots that are writable
392    pub writable_roots: usize,
393    /// Whether dynamic roots are allowed
394    pub allow_dynamic_roots: bool,
395    /// Whether paths are validated
396    pub validate_paths: bool,
397}
398
399/// Create a root from a file path
400pub fn create_root_from_path(path: &Path, name: Option<String>) -> Root {
401    let absolute = if path.is_absolute() {
402        path.to_path_buf()
403    } else {
404        std::env::current_dir().unwrap_or_default().join(path)
405    };
406
407    #[cfg(windows)]
408    let uri = format!("file:///{}", absolute.to_string_lossy().replace('\\', "/"));
409
410    #[cfg(not(windows))]
411    let uri = format!("file://{}", absolute.to_string_lossy());
412
413    Root { uri, name }
414}
415
416/// Get default roots configuration
417pub fn get_default_roots_config() -> RootsConfig {
418    let cwd = std::env::current_dir().unwrap_or_default();
419    RootsConfig {
420        roots: vec![create_root_from_path(
421            &cwd,
422            Some("Current Directory".to_string()),
423        )],
424        allow_dynamic_roots: true,
425        validate_paths: true,
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_root_creation() {
435        let root = Root {
436            uri: "file:///tmp/test".to_string(),
437            name: Some("Test Root".to_string()),
438        };
439        assert_eq!(root.uri, "file:///tmp/test");
440        assert_eq!(root.name, Some("Test Root".to_string()));
441    }
442
443    #[test]
444    fn test_create_root_from_path() {
445        let path = PathBuf::from("/tmp/test");
446        let root = create_root_from_path(&path, Some("Test".to_string()));
447        assert!(root.uri.starts_with("file://"));
448        assert!(root.uri.contains("tmp"));
449    }
450
451    #[tokio::test]
452    async fn test_manager_add_root() {
453        let manager = McpRootsManager::default();
454        let root = Root {
455            uri: "file:///tmp/test".to_string(),
456            name: Some("Test".to_string()),
457        };
458
459        let info = manager.add_root(root).await;
460        assert_eq!(info.uri, "file:///tmp/test");
461        assert!(manager.has_root("file:///tmp/test").await);
462    }
463
464    #[tokio::test]
465    async fn test_manager_remove_root() {
466        let manager = McpRootsManager::default();
467        let root = Root {
468            uri: "file:///tmp/test".to_string(),
469            name: None,
470        };
471
472        manager.add_root(root).await;
473        assert!(manager.has_root("file:///tmp/test").await);
474
475        manager.remove_root("file:///tmp/test").await;
476        assert!(!manager.has_root("file:///tmp/test").await);
477    }
478
479    #[tokio::test]
480    async fn test_manager_get_roots() {
481        let manager = McpRootsManager::default();
482
483        manager
484            .add_root(Root {
485                uri: "file:///tmp/a".to_string(),
486                name: None,
487            })
488            .await;
489        manager
490            .add_root(Root {
491                uri: "file:///tmp/b".to_string(),
492                name: None,
493            })
494            .await;
495
496        let roots = manager.get_roots().await;
497        assert_eq!(roots.len(), 2);
498    }
499
500    #[tokio::test]
501    async fn test_manager_clear_roots() {
502        let manager = McpRootsManager::default();
503
504        manager
505            .add_root(Root {
506                uri: "file:///tmp/a".to_string(),
507                name: None,
508            })
509            .await;
510        manager
511            .add_root(Root {
512                uri: "file:///tmp/b".to_string(),
513                name: None,
514            })
515            .await;
516
517        manager.clear_roots().await;
518        assert!(manager.get_roots().await.is_empty());
519    }
520
521    #[tokio::test]
522    async fn test_get_roots_for_protocol() {
523        let manager = McpRootsManager::default();
524
525        manager
526            .add_root(Root {
527                uri: "file:///tmp/test".to_string(),
528                name: Some("Test".to_string()),
529            })
530            .await;
531
532        let roots = manager.get_roots_for_protocol().await;
533        assert_eq!(roots.len(), 1);
534        assert_eq!(roots[0].uri, "file:///tmp/test");
535        assert_eq!(roots[0].name, Some("Test".to_string()));
536    }
537
538    #[tokio::test]
539    async fn test_get_stats() {
540        let manager = McpRootsManager::default();
541
542        manager
543            .add_root(Root {
544                uri: "file:///tmp/a".to_string(),
545                name: None,
546            })
547            .await;
548        manager
549            .add_root(Root {
550                uri: "file:///tmp/b".to_string(),
551                name: None,
552            })
553            .await;
554
555        let stats = manager.get_stats().await;
556        assert_eq!(stats.total_roots, 2);
557        assert!(stats.allow_dynamic_roots);
558        assert!(stats.validate_paths);
559    }
560
561    #[test]
562    fn test_get_default_roots_config() {
563        let config = get_default_roots_config();
564        assert!(!config.roots.is_empty());
565        assert!(config.allow_dynamic_roots);
566        assert!(config.validate_paths);
567    }
568}