1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::{broadcast, RwLock};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Root {
15 pub uri: String,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub name: Option<String>,
20}
21
22#[derive(Debug, Clone)]
24pub struct RootInfo {
25 pub uri: String,
27 pub name: Option<String>,
29 pub exists: bool,
31 pub absolute_path: Option<PathBuf>,
33 pub permissions: Option<RootPermissions>,
35}
36
37#[derive(Debug, Clone, Copy)]
39pub struct RootPermissions {
40 pub read: bool,
42 pub write: bool,
44}
45
46#[derive(Debug, Clone)]
48pub struct RootsConfig {
49 pub roots: Vec<Root>,
51 pub allow_dynamic_roots: bool,
53 pub validate_paths: bool,
55}
56
57impl Default for RootsConfig {
58 fn default() -> Self {
59 Self {
60 roots: Vec::new(),
61 allow_dynamic_roots: true,
62 validate_paths: true,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub enum RootEvent {
70 RootAdded { root: RootInfo },
72 RootRemoved { root: RootInfo },
74 RootUpdated { root: RootInfo, previous: RootInfo },
76 RootsCleared { count: usize },
78 RootsRefreshed { count: usize },
80}
81
82pub struct McpRootsManager {
84 roots: Arc<RwLock<HashMap<String, RootInfo>>>,
85 allow_dynamic_roots: bool,
86 validate_paths: bool,
87 event_sender: broadcast::Sender<RootEvent>,
88}
89
90impl McpRootsManager {
91 pub fn new(config: RootsConfig) -> Self {
93 let (event_sender, _) = broadcast::channel(64);
94 let manager = Self {
95 roots: Arc::new(RwLock::new(HashMap::new())),
96 allow_dynamic_roots: config.allow_dynamic_roots,
97 validate_paths: config.validate_paths,
98 event_sender,
99 };
100
101 for root in config.roots {
103 let root_info = manager.parse_root_sync(&root);
104 manager
105 .roots
106 .blocking_write()
107 .insert(root.uri.clone(), root_info);
108 }
109
110 manager
111 }
112
113 pub fn subscribe(&self) -> broadcast::Receiver<RootEvent> {
115 self.event_sender.subscribe()
116 }
117
118 pub async fn add_root(&self, root: Root) -> RootInfo {
120 let root_info = self.parse_root(&root);
121 self.roots
122 .write()
123 .await
124 .insert(root.uri.clone(), root_info.clone());
125 let _ = self.event_sender.send(RootEvent::RootAdded {
126 root: root_info.clone(),
127 });
128 root_info
129 }
130
131 pub async fn remove_root(&self, uri: &str) -> Option<RootInfo> {
133 let root = self.roots.write().await.remove(uri);
134 if let Some(ref r) = root {
135 let _ = self
136 .event_sender
137 .send(RootEvent::RootRemoved { root: r.clone() });
138 }
139 root
140 }
141
142 pub async fn update_root(&self, uri: &str, updates: Root) -> Option<RootInfo> {
144 let mut roots = self.roots.write().await;
145 let existing = roots.get(uri)?.clone();
146 let updated = self.parse_root(&updates);
147 roots.insert(uri.to_string(), updated.clone());
148 let _ = self.event_sender.send(RootEvent::RootUpdated {
149 root: updated.clone(),
150 previous: existing,
151 });
152 Some(updated)
153 }
154
155 pub async fn get_root(&self, uri: &str) -> Option<RootInfo> {
157 self.roots.read().await.get(uri).cloned()
158 }
159
160 pub async fn get_roots(&self) -> Vec<RootInfo> {
162 self.roots.read().await.values().cloned().collect()
163 }
164
165 pub async fn get_roots_for_protocol(&self) -> Vec<Root> {
167 self.roots
168 .read()
169 .await
170 .values()
171 .map(|r| Root {
172 uri: r.uri.clone(),
173 name: r.name.clone(),
174 })
175 .collect()
176 }
177
178 pub async fn clear_roots(&self) {
180 let mut roots = self.roots.write().await;
181 let count = roots.len();
182 roots.clear();
183 let _ = self.event_sender.send(RootEvent::RootsCleared { count });
184 }
185
186 pub async fn has_root(&self, uri: &str) -> bool {
188 self.roots.read().await.contains_key(uri)
189 }
190
191 fn parse_root(&self, root: &Root) -> RootInfo {
193 self.parse_root_sync(root)
194 }
195
196 fn parse_root_sync(&self, root: &Root) -> RootInfo {
197 let mut absolute_path = None;
198 let mut exists = false;
199 let mut permissions = None;
200
201 if root.uri.starts_with("file://") {
202 if let Some(path) = self.uri_to_path(&root.uri) {
203 absolute_path = Some(path.clone());
204
205 if self.validate_paths {
206 exists = path.exists();
207 if exists {
208 let read = path
209 .metadata()
210 .map(|m| !m.permissions().readonly())
211 .unwrap_or(false);
212 let write = std::fs::OpenOptions::new().write(true).open(&path).is_ok();
213 permissions = Some(RootPermissions { read, write });
214 }
215 }
216 }
217 }
218
219 RootInfo {
220 uri: root.uri.clone(),
221 name: root.name.clone(),
222 exists,
223 absolute_path,
224 permissions,
225 }
226 }
227
228 fn uri_to_path(&self, uri: &str) -> Option<PathBuf> {
230 if !uri.starts_with("file://") {
231 return None;
232 }
233
234 let path_str = uri.get(7..)?; #[cfg(windows)]
237 let path_str = if path_str.starts_with('/') && path_str.chars().nth(2) == Some(':') {
238 path_str.get(1..)? } else {
240 path_str
241 };
242
243 let decoded = urlencoding::decode(path_str).ok()?;
244 Some(PathBuf::from(decoded.into_owned()))
245 }
246
247 fn path_to_uri(&self, path: &Path) -> String {
249 let absolute = if path.is_absolute() {
250 path.to_path_buf()
251 } else {
252 std::env::current_dir().unwrap_or_default().join(path)
253 };
254
255 let path_str = absolute.to_string_lossy();
256
257 #[cfg(windows)]
258 let uri = format!("file:///{}", path_str.replace('\\', "/"));
259
260 #[cfg(not(windows))]
261 let uri = format!("file://{}", path_str);
262
263 uri
264 }
265
266 pub async fn is_path_in_roots(&self, path: &Path) -> bool {
268 let absolute = if path.is_absolute() {
269 path.to_path_buf()
270 } else {
271 std::env::current_dir().unwrap_or_default().join(path)
272 };
273
274 for root in self.roots.read().await.values() {
275 if let Some(ref root_path) = root.absolute_path {
276 if self.is_path_in_root(&absolute, root_path) {
277 return true;
278 }
279 }
280 }
281 false
282 }
283
284 fn is_path_in_root(&self, path: &Path, root_path: &Path) -> bool {
286 let normalized_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
287 let normalized_root = root_path
288 .canonicalize()
289 .unwrap_or_else(|_| root_path.to_path_buf());
290 normalized_path.starts_with(&normalized_root)
291 }
292
293 pub async fn get_root_for_path(&self, path: &Path) -> Option<RootInfo> {
295 let absolute = if path.is_absolute() {
296 path.to_path_buf()
297 } else {
298 std::env::current_dir().unwrap_or_default().join(path)
299 };
300
301 for root in self.roots.read().await.values() {
302 if let Some(ref root_path) = root.absolute_path {
303 if self.is_path_in_root(&absolute, root_path) {
304 return Some(root.clone());
305 }
306 }
307 }
308 None
309 }
310
311 pub async fn add_root_from_path(
313 &self,
314 path: &Path,
315 name: Option<String>,
316 ) -> Result<RootInfo, &'static str> {
317 if !self.allow_dynamic_roots {
318 return Err("Dynamic roots are not allowed");
319 }
320
321 let uri = self.path_to_uri(path);
322 let root = Root { uri, name };
323 Ok(self.add_root(root).await)
324 }
325
326 pub async fn add_cwd_root(&self, name: Option<String>) -> Result<RootInfo, &'static str> {
328 let cwd = std::env::current_dir().map_err(|_| "Could not get current directory")?;
329 self.add_root_from_path(&cwd, name.or(Some("Current Directory".to_string())))
330 .await
331 }
332
333 pub async fn add_home_root(&self, name: Option<String>) -> Result<RootInfo, &'static str> {
335 let home = dirs::home_dir().ok_or("Could not determine home directory")?;
336 self.add_root_from_path(&home, name.or(Some("Home Directory".to_string())))
337 .await
338 }
339
340 pub async fn get_stats(&self) -> RootsStats {
342 let roots = self.get_roots().await;
343 RootsStats {
344 total_roots: roots.len(),
345 existing_roots: roots.iter().filter(|r| r.exists).count(),
346 readable_roots: roots
347 .iter()
348 .filter(|r| r.permissions.map(|p| p.read).unwrap_or(false))
349 .count(),
350 writable_roots: roots
351 .iter()
352 .filter(|r| r.permissions.map(|p| p.write).unwrap_or(false))
353 .count(),
354 allow_dynamic_roots: self.allow_dynamic_roots,
355 validate_paths: self.validate_paths,
356 }
357 }
358
359 pub async fn refresh_roots(&self) {
361 let roots: Vec<_> = self.roots.read().await.values().cloned().collect();
362 let count = roots.len();
363
364 for root in roots {
365 let refreshed = self.parse_root(&Root {
366 uri: root.uri.clone(),
367 name: root.name.clone(),
368 });
369 self.roots.write().await.insert(root.uri, refreshed);
370 }
371
372 let _ = self.event_sender.send(RootEvent::RootsRefreshed { count });
373 }
374}
375
376impl Default for McpRootsManager {
377 fn default() -> Self {
378 Self::new(RootsConfig::default())
379 }
380}
381
382#[derive(Debug, Clone)]
384pub struct RootsStats {
385 pub total_roots: usize,
387 pub existing_roots: usize,
389 pub readable_roots: usize,
391 pub writable_roots: usize,
393 pub allow_dynamic_roots: bool,
395 pub validate_paths: bool,
397}
398
399pub fn create_root_from_path(path: &Path, name: Option<String>) -> Root {
401 let absolute = if path.is_absolute() {
402 path.to_path_buf()
403 } else {
404 std::env::current_dir().unwrap_or_default().join(path)
405 };
406
407 #[cfg(windows)]
408 let uri = format!("file:///{}", absolute.to_string_lossy().replace('\\', "/"));
409
410 #[cfg(not(windows))]
411 let uri = format!("file://{}", absolute.to_string_lossy());
412
413 Root { uri, name }
414}
415
416pub fn get_default_roots_config() -> RootsConfig {
418 let cwd = std::env::current_dir().unwrap_or_default();
419 RootsConfig {
420 roots: vec![create_root_from_path(
421 &cwd,
422 Some("Current Directory".to_string()),
423 )],
424 allow_dynamic_roots: true,
425 validate_paths: true,
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_root_creation() {
435 let root = Root {
436 uri: "file:///tmp/test".to_string(),
437 name: Some("Test Root".to_string()),
438 };
439 assert_eq!(root.uri, "file:///tmp/test");
440 assert_eq!(root.name, Some("Test Root".to_string()));
441 }
442
443 #[test]
444 fn test_create_root_from_path() {
445 let path = PathBuf::from("/tmp/test");
446 let root = create_root_from_path(&path, Some("Test".to_string()));
447 assert!(root.uri.starts_with("file://"));
448 assert!(root.uri.contains("tmp"));
449 }
450
451 #[tokio::test]
452 async fn test_manager_add_root() {
453 let manager = McpRootsManager::default();
454 let root = Root {
455 uri: "file:///tmp/test".to_string(),
456 name: Some("Test".to_string()),
457 };
458
459 let info = manager.add_root(root).await;
460 assert_eq!(info.uri, "file:///tmp/test");
461 assert!(manager.has_root("file:///tmp/test").await);
462 }
463
464 #[tokio::test]
465 async fn test_manager_remove_root() {
466 let manager = McpRootsManager::default();
467 let root = Root {
468 uri: "file:///tmp/test".to_string(),
469 name: None,
470 };
471
472 manager.add_root(root).await;
473 assert!(manager.has_root("file:///tmp/test").await);
474
475 manager.remove_root("file:///tmp/test").await;
476 assert!(!manager.has_root("file:///tmp/test").await);
477 }
478
479 #[tokio::test]
480 async fn test_manager_get_roots() {
481 let manager = McpRootsManager::default();
482
483 manager
484 .add_root(Root {
485 uri: "file:///tmp/a".to_string(),
486 name: None,
487 })
488 .await;
489 manager
490 .add_root(Root {
491 uri: "file:///tmp/b".to_string(),
492 name: None,
493 })
494 .await;
495
496 let roots = manager.get_roots().await;
497 assert_eq!(roots.len(), 2);
498 }
499
500 #[tokio::test]
501 async fn test_manager_clear_roots() {
502 let manager = McpRootsManager::default();
503
504 manager
505 .add_root(Root {
506 uri: "file:///tmp/a".to_string(),
507 name: None,
508 })
509 .await;
510 manager
511 .add_root(Root {
512 uri: "file:///tmp/b".to_string(),
513 name: None,
514 })
515 .await;
516
517 manager.clear_roots().await;
518 assert!(manager.get_roots().await.is_empty());
519 }
520
521 #[tokio::test]
522 async fn test_get_roots_for_protocol() {
523 let manager = McpRootsManager::default();
524
525 manager
526 .add_root(Root {
527 uri: "file:///tmp/test".to_string(),
528 name: Some("Test".to_string()),
529 })
530 .await;
531
532 let roots = manager.get_roots_for_protocol().await;
533 assert_eq!(roots.len(), 1);
534 assert_eq!(roots[0].uri, "file:///tmp/test");
535 assert_eq!(roots[0].name, Some("Test".to_string()));
536 }
537
538 #[tokio::test]
539 async fn test_get_stats() {
540 let manager = McpRootsManager::default();
541
542 manager
543 .add_root(Root {
544 uri: "file:///tmp/a".to_string(),
545 name: None,
546 })
547 .await;
548 manager
549 .add_root(Root {
550 uri: "file:///tmp/b".to_string(),
551 name: None,
552 })
553 .await;
554
555 let stats = manager.get_stats().await;
556 assert_eq!(stats.total_roots, 2);
557 assert!(stats.allow_dynamic_roots);
558 assert!(stats.validate_paths);
559 }
560
561 #[test]
562 fn test_get_default_roots_config() {
563 let config = get_default_roots_config();
564 assert!(!config.roots.is_empty());
565 assert!(config.allow_dynamic_roots);
566 assert!(config.validate_paths);
567 }
568}