mcp_protocol_sdk/utils/
uri.rs

1//! URI handling utilities
2//!
3//! This module provides utilities for parsing, validating, and manipulating URIs
4//! used in the MCP protocol for resources and other operations.
5
6use crate::core::error::{McpError, McpResult};
7use std::collections::HashMap;
8use url::Url;
9
10/// Parse a URI and extract query parameters
11pub fn parse_uri_with_params(uri: &str) -> McpResult<(String, HashMap<String, String>)> {
12    if uri.starts_with("file://") || uri.contains("://") {
13        // Full URI
14        let parsed = Url::parse(uri)
15            .map_err(|e| McpError::InvalidUri(format!("Invalid URI '{}': {}", uri, e)))?;
16
17        let base_uri = format!(
18            "{}://{}{}",
19            parsed.scheme(),
20            parsed.host_str().unwrap_or(""),
21            parsed.path()
22        );
23
24        let mut params = HashMap::new();
25        for (key, value) in parsed.query_pairs() {
26            params.insert(key.to_string(), value.to_string());
27        }
28
29        Ok((base_uri, params))
30    } else if uri.starts_with('/') {
31        // Absolute path
32        if let Some((path, query)) = uri.split_once('?') {
33            let params = parse_query_string(query)?;
34            Ok((path.to_string(), params))
35        } else {
36            Ok((uri.to_string(), HashMap::new()))
37        }
38    } else {
39        // Relative path or simple identifier
40        if let Some((path, query)) = uri.split_once('?') {
41            let params = parse_query_string(query)?;
42            Ok((path.to_string(), params))
43        } else {
44            Ok((uri.to_string(), HashMap::new()))
45        }
46    }
47}
48
49/// Parse a query string into parameters
50pub fn parse_query_string(query: &str) -> McpResult<HashMap<String, String>> {
51    let mut params = HashMap::new();
52
53    for pair in query.split('&') {
54        if pair.is_empty() {
55            continue;
56        }
57
58        if let Some((key, value)) = pair.split_once('=') {
59            let decoded_key = percent_decode(key)?;
60            let decoded_value = percent_decode(value)?;
61            params.insert(decoded_key, decoded_value);
62        } else {
63            let decoded_key = percent_decode(pair)?;
64            params.insert(decoded_key, String::new());
65        }
66    }
67
68    Ok(params)
69}
70
71/// Simple percent decoding for URI components
72pub fn percent_decode(s: &str) -> McpResult<String> {
73    let mut result = String::new();
74    let mut chars = s.chars().peekable();
75
76    while let Some(ch) = chars.next() {
77        if ch == '%' {
78            let hex1 = chars
79                .next()
80                .ok_or_else(|| McpError::InvalidUri("Incomplete percent encoding".to_string()))?;
81            let hex2 = chars
82                .next()
83                .ok_or_else(|| McpError::InvalidUri("Incomplete percent encoding".to_string()))?;
84
85            let hex_str = format!("{}{}", hex1, hex2);
86            let byte = u8::from_str_radix(&hex_str, 16).map_err(|_| {
87                McpError::InvalidUri(format!("Invalid hex in percent encoding: {}", hex_str))
88            })?;
89
90            result.push(byte as char);
91        } else if ch == '+' {
92            result.push(' ');
93        } else {
94            result.push(ch);
95        }
96    }
97
98    Ok(result)
99}
100
101/// Simple percent encoding for URI components
102pub fn percent_encode(s: &str) -> String {
103    let mut result = String::new();
104
105    for byte in s.bytes() {
106        match byte {
107            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
108                result.push(byte as char);
109            }
110            b' ' => {
111                result.push('+');
112            }
113            _ => {
114                result.push_str(&format!("%{:02X}", byte));
115            }
116        }
117    }
118
119    result
120}
121
122/// Validate that a string is a valid URI
123pub fn validate_uri(uri: &str) -> McpResult<()> {
124    if uri.is_empty() {
125        return Err(McpError::InvalidUri("URI cannot be empty".to_string()));
126    }
127
128    // Check for basic URI patterns
129    if uri.contains("://") {
130        // Full URI - try to parse with url crate
131        Url::parse(uri)
132            .map_err(|e| McpError::InvalidUri(format!("Invalid URI '{}': {}", uri, e)))?;
133    } else if uri.starts_with('/') {
134        // Absolute path - basic validation
135        if uri.contains('\0') || uri.contains('\n') || uri.contains('\r') {
136            return Err(McpError::InvalidUri(
137                "URI contains invalid characters".to_string(),
138            ));
139        }
140    } else {
141        // Relative path or identifier - allow most characters
142        if uri.contains('\0') || uri.contains('\n') || uri.contains('\r') {
143            return Err(McpError::InvalidUri(
144                "URI contains invalid characters".to_string(),
145            ));
146        }
147    }
148
149    Ok(())
150}
151
152/// Normalize a URI to a standard form
153pub fn normalize_uri(uri: &str) -> McpResult<String> {
154    validate_uri(uri)?;
155
156    if uri.contains("://") {
157        // Full URI - normalize with url crate
158        let parsed = Url::parse(uri)
159            .map_err(|e| McpError::InvalidUri(format!("Invalid URI '{}': {}", uri, e)))?;
160        let mut normalized = parsed.to_string();
161
162        // Remove duplicate slashes in path
163        if let Ok(mut url) = Url::parse(&normalized) {
164            let path = url.path();
165            let clean_path = path.replace("//", "/");
166            url.set_path(&clean_path);
167            normalized = url.to_string();
168        }
169
170        // Remove trailing slash unless it's the root
171        if normalized.ends_with('/') && !normalized.ends_with("://") {
172            let path_start = normalized.find("://").unwrap() + 3;
173            if let Some(path_start_slash) = normalized[path_start..].find('/') {
174                let full_path_start = path_start + path_start_slash;
175                if full_path_start + 1 < normalized.len() {
176                    normalized.pop();
177                }
178            }
179        }
180
181        Ok(normalized)
182    } else {
183        // Path - basic normalization
184        let mut normalized = uri.to_string();
185
186        // Remove duplicate slashes
187        while normalized.contains("//") {
188            normalized = normalized.replace("//", "/");
189        }
190
191        // Remove trailing slash unless it's the root
192        if normalized.len() > 1 && normalized.ends_with('/') {
193            normalized.pop();
194        }
195
196        Ok(normalized)
197    }
198}
199
200/// Join a base URI with a relative path
201pub fn join_uri(base: &str, relative: &str) -> McpResult<String> {
202    if relative.contains("://") {
203        // Relative is actually absolute
204        return Ok(relative.to_string());
205    }
206
207    if relative.starts_with('/') {
208        // Relative path is absolute, return it as-is
209        return Ok(relative.to_string());
210    }
211
212    if base.contains("://") {
213        // Full URI base
214        let base_url = Url::parse(base)
215            .map_err(|e| McpError::InvalidUri(format!("Invalid base URI '{}': {}", base, e)))?;
216        let joined = base_url.join(relative).map_err(|e| {
217            McpError::InvalidUri(format!("Cannot join '{}' to '{}': {}", relative, base, e))
218        })?;
219        Ok(joined.to_string())
220    } else {
221        // Path base
222        let mut result = base.to_string();
223        if !result.ends_with('/') && !relative.starts_with('/') {
224            result.push('/');
225        }
226        result.push_str(relative);
227        normalize_uri(&result)
228    }
229}
230
231/// Extract the file extension from a URI
232pub fn get_uri_extension(uri: &str) -> Option<String> {
233    let path = if uri.contains("://") {
234        Url::parse(uri).ok()?.path().to_string()
235    } else {
236        uri.to_string()
237    };
238
239    if let Some(dot_pos) = path.rfind('.') {
240        if let Some(slash_pos) = path.rfind('/') {
241            if dot_pos > slash_pos {
242                return Some(path[dot_pos + 1..].to_lowercase());
243            }
244        } else {
245            return Some(path[dot_pos + 1..].to_lowercase());
246        }
247    }
248
249    None
250}
251
252/// Guess MIME type from URI extension
253pub fn guess_mime_type(uri: &str) -> Option<String> {
254    match get_uri_extension(uri)?.as_str() {
255        "txt" => Some("text/plain".to_string()),
256        "html" | "htm" => Some("text/html".to_string()),
257        "css" => Some("text/css".to_string()),
258        "js" => Some("application/javascript".to_string()),
259        "json" => Some("application/json".to_string()),
260        "xml" => Some("application/xml".to_string()),
261        "pdf" => Some("application/pdf".to_string()),
262        "zip" => Some("application/zip".to_string()),
263        "png" => Some("image/png".to_string()),
264        "jpg" | "jpeg" => Some("image/jpeg".to_string()),
265        "gif" => Some("image/gif".to_string()),
266        "webp" => Some("image/webp".to_string()),
267        "svg" => Some("image/svg+xml".to_string()),
268        "mp3" => Some("audio/mpeg".to_string()),
269        "wav" => Some("audio/wav".to_string()),
270        "mp4" => Some("video/mp4".to_string()),
271        "webm" => Some("video/webm".to_string()),
272        "csv" => Some("text/csv".to_string()),
273        "md" => Some("text/markdown".to_string()),
274        "yaml" | "yml" => Some("application/x-yaml".to_string()),
275        "toml" => Some("application/toml".to_string()),
276        _ => None,
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_parse_uri_with_params() {
286        let (uri, params) =
287            parse_uri_with_params("https://example.com/path?key=value&foo=bar").unwrap();
288        assert_eq!(uri, "https://example.com/path");
289        assert_eq!(params.get("key"), Some(&"value".to_string()));
290        assert_eq!(params.get("foo"), Some(&"bar".to_string()));
291    }
292
293    #[test]
294    fn test_parse_query_string() {
295        let params = parse_query_string("key=value&foo=bar&empty=").unwrap();
296        assert_eq!(params.get("key"), Some(&"value".to_string()));
297        assert_eq!(params.get("foo"), Some(&"bar".to_string()));
298        assert_eq!(params.get("empty"), Some(&"".to_string()));
299    }
300
301    #[test]
302    fn test_percent_encode_decode() {
303        let original = "hello world!@#$%";
304        let encoded = percent_encode(original);
305        let decoded = percent_decode(&encoded).unwrap();
306        assert_eq!(decoded, original);
307    }
308
309    #[test]
310    fn test_validate_uri() {
311        assert!(validate_uri("https://example.com").is_ok());
312        assert!(validate_uri("/absolute/path").is_ok());
313        assert!(validate_uri("relative/path").is_ok());
314        assert!(validate_uri("").is_err());
315        assert!(validate_uri("invalid\0uri").is_err());
316    }
317
318    #[test]
319    fn test_normalize_uri() {
320        assert_eq!(
321            normalize_uri("https://example.com//path//").unwrap(),
322            "https://example.com/path"
323        );
324        assert_eq!(normalize_uri("/path//to//file/").unwrap(), "/path/to/file");
325        assert_eq!(normalize_uri("/").unwrap(), "/");
326    }
327
328    #[test]
329    fn test_join_uri() {
330        assert_eq!(
331            join_uri("https://example.com", "path/to/file").unwrap(),
332            "https://example.com/path/to/file"
333        );
334        assert_eq!(
335            join_uri("/base", "relative/path").unwrap(),
336            "/base/relative/path"
337        );
338        assert_eq!(join_uri("/base/", "/absolute").unwrap(), "/absolute");
339    }
340
341    #[test]
342    fn test_get_uri_extension() {
343        assert_eq!(get_uri_extension("file.txt"), Some("txt".to_string()));
344        assert_eq!(
345            get_uri_extension("https://example.com/file.JSON"),
346            Some("json".to_string())
347        );
348        assert_eq!(
349            get_uri_extension("/path/to/file.tar.gz"),
350            Some("gz".to_string())
351        );
352        assert_eq!(get_uri_extension("no-extension"), None);
353    }
354
355    #[test]
356    fn test_guess_mime_type() {
357        assert_eq!(
358            guess_mime_type("file.json"),
359            Some("application/json".to_string())
360        );
361        assert_eq!(guess_mime_type("image.PNG"), Some("image/png".to_string()));
362        assert_eq!(
363            guess_mime_type("document.pdf"),
364            Some("application/pdf".to_string())
365        );
366        assert_eq!(guess_mime_type("unknown.xyz"), None);
367    }
368}