modelcontextprotocol_server/resources/
mod.rs

1// mcp-server/src/resources/mod.rs
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4use anyhow::{anyhow, Result};
5use tokio::sync::{RwLock, broadcast};
6use mcp_protocol::types::resource::{
7    Resource, ResourceContent, ResourceTemplate
8};
9use mcp_protocol::types::completion::CompletionItem;
10
11const DEFAULT_PAGE_SIZE: usize = 50;
12
13/// Resource content provider function type
14pub type ResourceContentProvider = Arc<dyn Fn() -> Result<Vec<ResourceContent>> + Send + Sync>;
15
16/// Template completion provider function type
17pub type TemplateCompletionProvider = Arc<dyn Fn(String, String, Option<String>) -> Result<Vec<CompletionItem>> + Send + Sync>;
18
19/// Template expansion function type
20pub type TemplateExpanderFn = Arc<dyn Fn(String, HashMap<String, String>) -> Result<String> + Send + Sync>;
21
22/// Resource manager for registering and accessing resources
23pub struct ResourceManager {
24    resources: Arc<RwLock<HashMap<String, (Resource, ResourceContentProvider)>>>,
25    templates: Arc<RwLock<HashMap<String, (ResourceTemplate, TemplateExpanderFn)>>>,
26    subscriptions: Arc<RwLock<HashMap<String, HashSet<String>>>>, // Maps resource URI to set of client IDs
27    update_tx: broadcast::Sender<String>, // Channel for notifying resource updates
28    completion_providers: Arc<RwLock<HashMap<String, TemplateCompletionProvider>>>,
29}
30
31impl ResourceManager {
32    /// Create a new resource manager
33    pub fn new() -> Self {
34        let (update_tx, _) = broadcast::channel(100);
35        Self {
36            resources: Arc::new(RwLock::new(HashMap::new())),
37            templates: Arc::new(RwLock::new(HashMap::new())),
38            subscriptions: Arc::new(RwLock::new(HashMap::new())),
39            update_tx,
40            completion_providers: Arc::new(RwLock::new(HashMap::new())),
41        }
42    }
43    
44    /// Register a new resource
45    pub fn register_resource(
46        &self, 
47        resource: Resource, 
48        content_provider: impl Fn() -> Result<Vec<ResourceContent>> + Send + Sync + 'static
49    ) {
50        let resources = self.resources.clone();
51        let content_provider = Arc::new(content_provider);
52        
53        tokio::spawn(async move {
54            let mut resources = resources.write().await;
55            resources.insert(resource.uri.clone(), (resource, content_provider));
56        });
57    }
58    
59    /// Get registered resources with pagination
60    pub async fn list_resources(&self, cursor: Option<String>) -> (Vec<Resource>, Option<String>) {
61        let resources = self.resources.read().await;
62        let all_resources: Vec<Resource> = resources.values().map(|(resource, _)| resource.clone()).collect();
63        
64        // If we have a cursor, find its position
65        let start_pos = match cursor {
66            Some(cursor) => {
67                // Find the index of the resource after the cursor
68                let pos = all_resources.iter().position(|r| r.uri == cursor);
69                pos.map(|p| p + 1).unwrap_or(0)
70            },
71            None => 0,
72        };
73        
74        // Get a page of resources
75        let end_pos = std::cmp::min(start_pos + DEFAULT_PAGE_SIZE, all_resources.len());
76        let page = all_resources[start_pos..end_pos].to_vec();
77        
78        // Set the next cursor if there are more resources
79        let next_cursor = if end_pos < all_resources.len() {
80            Some(all_resources[end_pos - 1].uri.clone())
81        } else {
82            None
83        };
84        
85        (page, next_cursor)
86    }
87    
88    /// Get a specific resource's content
89    pub async fn get_resource_content(&self, uri: &str) -> Result<Vec<ResourceContent>> {
90        // First check if this is a direct resource
91        let resources = self.resources.read().await;
92        if let Some((_, content_provider)) = resources.get(uri) {
93            return content_provider();
94        }
95        
96        // If not a direct resource, check if it matches a template
97        let templates = self.templates.read().await;
98        for (template_uri, (_, _expander)) in templates.iter() {
99            // Check if the URI could be from this template (simple prefix check)
100            // In a real implementation, you'd want a more sophisticated matching algorithm
101            if uri.starts_with(template_uri.split('{').next().unwrap_or("")) {
102                // Try to find a resource provider for the expanded URI
103                if let Some((_, content_provider)) = resources.get(uri) {
104                    return content_provider();
105                }
106            }
107        }
108        
109        Err(anyhow!("Resource not found: {}", uri))
110    }
111    
112    /// Register a template
113    pub fn register_template(
114        &self,
115        template: ResourceTemplate,
116        expander: impl Fn(String, HashMap<String, String>) -> Result<String> + Send + Sync + 'static,
117    ) {
118        let templates = self.templates.clone();
119        let expander = Arc::new(expander);
120        
121        tokio::spawn(async move {
122            let mut templates = templates.write().await;
123            templates.insert(template.uri_template.clone(), (template, expander));
124        });
125    }
126    
127    /// Register a completion provider for a template parameter
128    pub fn register_completion_provider(
129        &self,
130        template_uri: &str,
131        provider: impl Fn(String, String, Option<String>) -> Result<Vec<CompletionItem>> + Send + Sync + 'static,
132    ) {
133        let providers = self.completion_providers.clone();
134        let template_uri = template_uri.to_string();
135        let provider = Arc::new(provider);
136        
137        tokio::spawn(async move {
138            let mut providers = providers.write().await;
139            providers.insert(template_uri, provider);
140        });
141    }
142    
143    /// Get completion items for a template parameter
144    pub async fn get_completions(
145        &self,
146        template_uri: &str,
147        parameter: &str,
148        value: Option<String>,
149    ) -> Result<Vec<CompletionItem>> {
150        let providers = self.completion_providers.read().await;
151        
152        if let Some(provider) = providers.get(template_uri) {
153            return provider(template_uri.to_string(), parameter.to_string(), value);
154        }
155        
156        // Return empty results if no provider is registered
157        Ok(Vec::new())
158    }
159    
160    /// Get all registered templates with pagination
161    pub async fn list_templates(&self, cursor: Option<String>) -> (Vec<ResourceTemplate>, Option<String>) {
162        let templates = self.templates.read().await;
163        let all_templates: Vec<ResourceTemplate> = templates.values().map(|(template, _)| template.clone()).collect();
164        
165        // If we have a cursor, find its position
166        let start_pos = match cursor {
167            Some(cursor) => {
168                // Find the index of the template after the cursor
169                let pos = all_templates.iter().position(|t| t.uri_template == cursor);
170                pos.map(|p| p + 1).unwrap_or(0)
171            },
172            None => 0,
173        };
174        
175        // Get a page of templates
176        let end_pos = std::cmp::min(start_pos + DEFAULT_PAGE_SIZE, all_templates.len());
177        let page = all_templates[start_pos..end_pos].to_vec();
178        
179        // Set the next cursor if there are more templates
180        let next_cursor = if end_pos < all_templates.len() {
181            Some(all_templates[end_pos - 1].uri_template.clone())
182        } else {
183            None
184        };
185        
186        (page, next_cursor)
187    }
188    
189    /// Subscribe to resource updates
190    pub async fn subscribe(&self, client_id: &str, uri: &str) -> Result<()> {
191        // Check if resource exists
192        {
193            let resources = self.resources.read().await;
194            if !resources.contains_key(uri) {
195                return Err(anyhow::anyhow!("Resource not found: {}", uri));
196            }
197        }
198        
199        // Add subscription
200        let mut subscriptions = self.subscriptions.write().await;
201        let subscribers = subscriptions.entry(uri.to_string()).or_insert_with(HashSet::new);
202        subscribers.insert(client_id.to_string());
203        
204        Ok(())
205    }
206    
207    /// Unsubscribe from resource updates
208    pub async fn unsubscribe(&self, client_id: &str, uri: &str) -> Result<()> {
209        let mut subscriptions = self.subscriptions.write().await;
210        if let Some(subscribers) = subscriptions.get_mut(uri) {
211            subscribers.remove(client_id);
212            if subscribers.is_empty() {
213                subscriptions.remove(uri);
214            }
215        }
216        
217        Ok(())
218    }
219    
220    /// Update a resource and notify subscribers
221    pub async fn update_resource(
222        &self, 
223        resource: Resource, 
224        content_provider: impl Fn() -> Result<Vec<ResourceContent>> + Send + Sync + 'static
225    ) -> Result<()> {
226        // Update resource
227        {
228            let mut resources = self.resources.write().await;
229            resources.insert(resource.uri.clone(), (resource.clone(), Arc::new(content_provider)));
230        }
231        
232        // Notify subscribers
233        let _ = self.update_tx.send(resource.uri.clone());
234        
235        Ok(())
236    }
237    
238    /// Get a channel for subscribing to resource updates
239    pub fn subscribe_to_updates(&self) -> broadcast::Receiver<String> {
240        self.update_tx.subscribe()
241    }
242    
243    /// Parse template parameters from a URI
244    /// This is a simple implementation - a production version would need more robust parsing
245    pub fn parse_template_parameters(&self, template: &str, uri: &str) -> HashMap<String, String> {
246        let mut params = HashMap::new();
247        
248        // Extract template parts - this is a very simple implementation
249        // A real implementation would parse RFC 6570 URI templates properly
250        let template_parts: Vec<&str> = template.split('{')
251            .flat_map(|part| part.split('}')).collect();
252        
253        let mut uri_cursor = uri;
254        
255        for (i, part) in template_parts.iter().enumerate() {
256            if i % 2 == 0 {
257                // This is a literal part
258                if uri_cursor.starts_with(part) {
259                    uri_cursor = &uri_cursor[part.len()..];
260                }
261            } else {
262                // This is a parameter name
263                let param_name = *part;
264                
265                // Find the next literal part, if any
266                let next_literal = if i + 1 < template_parts.len() {
267                    template_parts[i + 1]
268                } else {
269                    ""
270                };
271                
272                // Extract the parameter value
273                let param_value = if next_literal.is_empty() {
274                    uri_cursor.to_string()
275                } else if let Some(pos) = uri_cursor.find(next_literal) {
276                    let value = &uri_cursor[..pos];
277                    uri_cursor = &uri_cursor[pos + next_literal.len()..];
278                    value.to_string()
279                } else {
280                    uri_cursor.to_string()
281                };
282                
283                params.insert(param_name.to_string(), param_value);
284            }
285        }
286        
287        params
288    }
289    
290    /// Expand a template with parameters
291    pub async fn expand_template(&self, template_uri: &str, params: HashMap<String, String>) -> Result<String> {
292        let templates = self.templates.read().await;
293        
294        if let Some((_, expander)) = templates.get(template_uri) {
295            return expander(template_uri.to_string(), params);
296        }
297        
298        // Fallback to simple expansion if no custom expander is registered
299        let mut result = template_uri.to_string();
300        
301        for (name, value) in params {
302            result = result.replace(&format!("{{{}}}", name), &value);
303        }
304        
305        Ok(result)
306    }
307}
308
309impl Default for ResourceManager {
310    fn default() -> Self {
311        Self::new()
312    }
313}