mcp_protocol_sdk/utils/
uri.rs1use crate::core::error::{McpError, McpResult};
7use std::collections::HashMap;
8use url::Url;
9
10pub fn parse_uri_with_params(uri: &str) -> McpResult<(String, HashMap<String, String>)> {
12 if uri.starts_with("file://") || uri.contains("://") {
13 let parsed = Url::parse(uri)
15 .map_err(|e| McpError::InvalidUri(format!("Invalid URI '{uri}': {e}")))?;
16
17 let base_uri = format!(
18 "{}://{}{}",
19 parsed.scheme(),
20 parsed.host_str().unwrap_or(""),
21 parsed.path()
22 );
23
24 let mut params = HashMap::new();
25 for (key, value) in parsed.query_pairs() {
26 params.insert(key.to_string(), value.to_string());
27 }
28
29 Ok((base_uri, params))
30 } else if uri.starts_with('/') {
31 if let Some((path, query)) = uri.split_once('?') {
33 let params = parse_query_string(query)?;
34 Ok((path.to_string(), params))
35 } else {
36 Ok((uri.to_string(), HashMap::new()))
37 }
38 } else {
39 if let Some((path, query)) = uri.split_once('?') {
41 let params = parse_query_string(query)?;
42 Ok((path.to_string(), params))
43 } else {
44 Ok((uri.to_string(), HashMap::new()))
45 }
46 }
47}
48
49pub fn parse_query_string(query: &str) -> McpResult<HashMap<String, String>> {
51 let mut params = HashMap::new();
52
53 for pair in query.split('&') {
54 if pair.is_empty() {
55 continue;
56 }
57
58 if let Some((key, value)) = pair.split_once('=') {
59 let decoded_key = percent_decode(key)?;
60 let decoded_value = percent_decode(value)?;
61 params.insert(decoded_key, decoded_value);
62 } else {
63 let decoded_key = percent_decode(pair)?;
64 params.insert(decoded_key, String::new());
65 }
66 }
67
68 Ok(params)
69}
70
71pub fn percent_decode(s: &str) -> McpResult<String> {
73 let mut result = String::new();
74 let mut chars = s.chars().peekable();
75
76 while let Some(ch) = chars.next() {
77 if ch == '%' {
78 let hex1 = chars
79 .next()
80 .ok_or_else(|| McpError::InvalidUri("Incomplete percent encoding".to_string()))?;
81 let hex2 = chars
82 .next()
83 .ok_or_else(|| McpError::InvalidUri("Incomplete percent encoding".to_string()))?;
84
85 let hex_str = format!("{hex1}{hex2}");
86 let byte = u8::from_str_radix(&hex_str, 16).map_err(|_| {
87 McpError::InvalidUri(format!("Invalid hex in percent encoding: {hex_str}"))
88 })?;
89
90 result.push(byte as char);
91 } else if ch == '+' {
92 result.push(' ');
93 } else {
94 result.push(ch);
95 }
96 }
97
98 Ok(result)
99}
100
101pub fn percent_encode(s: &str) -> String {
103 let mut result = String::new();
104
105 for byte in s.bytes() {
106 match byte {
107 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
108 result.push(byte as char);
109 }
110 b' ' => {
111 result.push('+');
112 }
113 _ => {
114 result.push_str(&format!("%{byte:02X}"));
115 }
116 }
117 }
118
119 result
120}
121
122pub fn validate_uri(uri: &str) -> McpResult<()> {
124 if uri.is_empty() {
125 return Err(McpError::InvalidUri("URI cannot be empty".to_string()));
126 }
127
128 if uri.contains("://") {
130 Url::parse(uri).map_err(|e| McpError::InvalidUri(format!("Invalid URI '{uri}': {e}")))?;
132 } else if uri.starts_with('/') {
133 if uri.contains('\0') || uri.contains('\n') || uri.contains('\r') {
135 return Err(McpError::InvalidUri(
136 "URI contains invalid characters".to_string(),
137 ));
138 }
139 } else {
140 if uri.contains('\0') || uri.contains('\n') || uri.contains('\r') {
142 return Err(McpError::InvalidUri(
143 "URI contains invalid characters".to_string(),
144 ));
145 }
146 }
147
148 Ok(())
149}
150
151pub fn normalize_uri(uri: &str) -> McpResult<String> {
153 validate_uri(uri)?;
154
155 if uri.contains("://") {
156 let parsed = Url::parse(uri)
158 .map_err(|e| McpError::InvalidUri(format!("Invalid URI '{uri}': {e}")))?;
159 let mut normalized = parsed.to_string();
160
161 if let Ok(mut url) = Url::parse(&normalized) {
163 let path = url.path();
164 let clean_path = path.replace("//", "/");
165 url.set_path(&clean_path);
166 normalized = url.to_string();
167 }
168
169 if normalized.ends_with('/') && !normalized.ends_with("://") {
171 let path_start = normalized.find("://").unwrap() + 3;
172 if let Some(path_start_slash) = normalized[path_start..].find('/') {
173 let full_path_start = path_start + path_start_slash;
174 if full_path_start + 1 < normalized.len() {
175 normalized.pop();
176 }
177 }
178 }
179
180 Ok(normalized)
181 } else {
182 let mut normalized = uri.to_string();
184
185 while normalized.contains("//") {
187 normalized = normalized.replace("//", "/");
188 }
189
190 if normalized.len() > 1 && normalized.ends_with('/') {
192 normalized.pop();
193 }
194
195 Ok(normalized)
196 }
197}
198
199pub fn join_uri(base: &str, relative: &str) -> McpResult<String> {
201 if relative.contains("://") {
202 return Ok(relative.to_string());
204 }
205
206 if relative.starts_with('/') {
207 return Ok(relative.to_string());
209 }
210
211 if base.contains("://") {
212 let base_url = Url::parse(base)
214 .map_err(|e| McpError::InvalidUri(format!("Invalid base URI '{base}': {e}")))?;
215 let joined = base_url.join(relative).map_err(|e| {
216 McpError::InvalidUri(format!("Cannot join '{relative}' to '{base}': {e}"))
217 })?;
218 Ok(joined.to_string())
219 } else {
220 let mut result = base.to_string();
222 if !result.ends_with('/') && !relative.starts_with('/') {
223 result.push('/');
224 }
225 result.push_str(relative);
226 normalize_uri(&result)
227 }
228}
229
230pub fn get_uri_extension(uri: &str) -> Option<String> {
232 let path = if uri.contains("://") {
233 Url::parse(uri).ok()?.path().to_string()
234 } else {
235 uri.to_string()
236 };
237
238 if let Some(dot_pos) = path.rfind('.') {
239 if let Some(slash_pos) = path.rfind('/') {
240 if dot_pos > slash_pos {
241 return Some(path[dot_pos + 1..].to_lowercase());
242 }
243 } else {
244 return Some(path[dot_pos + 1..].to_lowercase());
245 }
246 }
247
248 None
249}
250
251pub fn guess_mime_type(uri: &str) -> Option<String> {
253 match get_uri_extension(uri)?.as_str() {
254 "txt" => Some("text/plain".to_string()),
255 "html" | "htm" => Some("text/html".to_string()),
256 "css" => Some("text/css".to_string()),
257 "js" => Some("application/javascript".to_string()),
258 "json" => Some("application/json".to_string()),
259 "xml" => Some("application/xml".to_string()),
260 "pdf" => Some("application/pdf".to_string()),
261 "zip" => Some("application/zip".to_string()),
262 "png" => Some("image/png".to_string()),
263 "jpg" | "jpeg" => Some("image/jpeg".to_string()),
264 "gif" => Some("image/gif".to_string()),
265 "webp" => Some("image/webp".to_string()),
266 "svg" => Some("image/svg+xml".to_string()),
267 "mp3" => Some("audio/mpeg".to_string()),
268 "wav" => Some("audio/wav".to_string()),
269 "mp4" => Some("video/mp4".to_string()),
270 "webm" => Some("video/webm".to_string()),
271 "csv" => Some("text/csv".to_string()),
272 "md" => Some("text/markdown".to_string()),
273 "yaml" | "yml" => Some("application/x-yaml".to_string()),
274 "toml" => Some("application/toml".to_string()),
275 _ => None,
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_parse_uri_with_params() {
285 let (uri, params) =
286 parse_uri_with_params("https://example.com/path?key=value&foo=bar").unwrap();
287 assert_eq!(uri, "https://example.com/path");
288 assert_eq!(params.get("key"), Some(&"value".to_string()));
289 assert_eq!(params.get("foo"), Some(&"bar".to_string()));
290 }
291
292 #[test]
293 fn test_parse_query_string() {
294 let params = parse_query_string("key=value&foo=bar&empty=").unwrap();
295 assert_eq!(params.get("key"), Some(&"value".to_string()));
296 assert_eq!(params.get("foo"), Some(&"bar".to_string()));
297 assert_eq!(params.get("empty"), Some(&"".to_string()));
298 }
299
300 #[test]
301 fn test_percent_encode_decode() {
302 let original = "hello world!@#$%";
303 let encoded = percent_encode(original);
304 let decoded = percent_decode(&encoded).unwrap();
305 assert_eq!(decoded, original);
306 }
307
308 #[test]
309 fn test_validate_uri() {
310 assert!(validate_uri("https://example.com").is_ok());
311 assert!(validate_uri("/absolute/path").is_ok());
312 assert!(validate_uri("relative/path").is_ok());
313 assert!(validate_uri("").is_err());
314 assert!(validate_uri("invalid\0uri").is_err());
315 }
316
317 #[test]
318 fn test_normalize_uri() {
319 assert_eq!(
320 normalize_uri("https://example.com//path//").unwrap(),
321 "https://example.com/path"
322 );
323 assert_eq!(normalize_uri("/path//to//file/").unwrap(), "/path/to/file");
324 assert_eq!(normalize_uri("/").unwrap(), "/");
325 }
326
327 #[test]
328 fn test_join_uri() {
329 assert_eq!(
330 join_uri("https://example.com", "path/to/file").unwrap(),
331 "https://example.com/path/to/file"
332 );
333 assert_eq!(
334 join_uri("/base", "relative/path").unwrap(),
335 "/base/relative/path"
336 );
337 assert_eq!(join_uri("/base/", "/absolute").unwrap(), "/absolute");
338 }
339
340 #[test]
341 fn test_get_uri_extension() {
342 assert_eq!(get_uri_extension("file.txt"), Some("txt".to_string()));
343 assert_eq!(
344 get_uri_extension("https://example.com/file.JSON"),
345 Some("json".to_string())
346 );
347 assert_eq!(
348 get_uri_extension("/path/to/file.tar.gz"),
349 Some("gz".to_string())
350 );
351 assert_eq!(get_uri_extension("no-extension"), None);
352 }
353
354 #[test]
355 fn test_guess_mime_type() {
356 assert_eq!(
357 guess_mime_type("file.json"),
358 Some("application/json".to_string())
359 );
360 assert_eq!(guess_mime_type("image.PNG"), Some("image/png".to_string()));
361 assert_eq!(
362 guess_mime_type("document.pdf"),
363 Some("application/pdf".to_string())
364 );
365 assert_eq!(guess_mime_type("unknown.xyz"), None);
366 }
367}