Skip to main content

mcp_kit/server/
roots.rs

1//! Roots support for file system access.
2//!
3//! Clients can provide a list of root URIs that the server is allowed to access.
4//! This is useful for sandboxing file operations and providing context about
5//! the user's workspace.
6
7use std::collections::HashSet;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::types::Root;
12
13/// Manages the client's declared root URIs.
14///
15/// Roots are directories or files that the client has exposed to the server.
16/// The server should only access files within these roots.
17#[derive(Clone, Default)]
18pub struct RootsManager {
19    inner: Arc<RwLock<RootsState>>,
20}
21
22#[derive(Default)]
23struct RootsState {
24    roots: Vec<Root>,
25    uris: HashSet<String>,
26}
27
28impl RootsManager {
29    /// Create a new roots manager.
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Set the list of roots from the client.
35    pub async fn set_roots(&self, roots: Vec<Root>) {
36        let mut state = self.inner.write().await;
37        state.uris = roots.iter().map(|r| r.uri.clone()).collect();
38        state.roots = roots;
39    }
40
41    /// Get all declared roots.
42    pub async fn roots(&self) -> Vec<Root> {
43        let state = self.inner.read().await;
44        state.roots.clone()
45    }
46
47    /// Check if a URI is within any declared root.
48    ///
49    /// Returns `true` if the URI starts with any root URI prefix.
50    pub async fn is_within_roots(&self, uri: &str) -> bool {
51        let state = self.inner.read().await;
52        if state.roots.is_empty() {
53            // No roots declared = allow all (backwards compatibility)
54            return true;
55        }
56        state.uris.iter().any(|root| uri.starts_with(root))
57    }
58
59    /// Find the root that contains a given URI.
60    pub async fn find_root(&self, uri: &str) -> Option<Root> {
61        let state = self.inner.read().await;
62        state
63            .roots
64            .iter()
65            .find(|r| uri.starts_with(&r.uri))
66            .cloned()
67    }
68
69    /// Get the number of declared roots.
70    pub async fn count(&self) -> usize {
71        let state = self.inner.read().await;
72        state.roots.len()
73    }
74
75    /// Check if any roots are declared.
76    pub async fn has_roots(&self) -> bool {
77        let state = self.inner.read().await;
78        !state.roots.is_empty()
79    }
80
81    /// Clear all roots.
82    pub async fn clear(&self) {
83        let mut state = self.inner.write().await;
84        state.roots.clear();
85        state.uris.clear();
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[tokio::test]
94    async fn test_set_and_get_roots() {
95        let mgr = RootsManager::new();
96
97        let roots = vec![
98            Root {
99                uri: "file:///home/user/project".to_string(),
100                name: Some("Project".to_string()),
101            },
102            Root {
103                uri: "file:///tmp".to_string(),
104                name: None,
105            },
106        ];
107
108        mgr.set_roots(roots.clone()).await;
109
110        assert_eq!(mgr.count().await, 2);
111        assert!(mgr.has_roots().await);
112    }
113
114    #[tokio::test]
115    async fn test_is_within_roots() {
116        let mgr = RootsManager::new();
117
118        mgr.set_roots(vec![Root {
119            uri: "file:///home/user/project".to_string(),
120            name: None,
121        }])
122        .await;
123
124        assert!(
125            mgr.is_within_roots("file:///home/user/project/src/main.rs")
126                .await
127        );
128        assert!(!mgr.is_within_roots("file:///etc/passwd").await);
129    }
130
131    #[tokio::test]
132    async fn test_no_roots_allows_all() {
133        let mgr = RootsManager::new();
134        // No roots set = allow all
135        assert!(mgr.is_within_roots("file:///anywhere").await);
136    }
137}