mcp_host/utils.rs
1//! Utility functions for file handling, path security, and encoding
2//!
3//! Provides common utilities needed by MCP servers: file collection with
4//! gitignore support, path security validation, byte offset conversion,
5//! and base64 encoding for pagination cursors.
6
7use std::path::{Path, PathBuf};
8
9/// Collect files recursively from a path, respecting .gitignore
10///
11/// Uses `ignore` crate to respect:
12/// - `.gitignore` files
13/// - `.git/info/exclude`
14/// - Global git ignore configuration
15///
16/// # Arguments
17///
18/// * `path` - File or directory to collect from
19///
20/// # Returns
21///
22/// Vector of file paths (directories are excluded)
23///
24/// # Example
25///
26/// ```rust,ignore
27/// let files = collect_files(Path::new("src"));
28/// for file in files {
29/// println!("Found: {:?}", file);
30/// }
31/// ```
32pub fn collect_files(path: &Path) -> Vec<PathBuf> {
33 use ignore::WalkBuilder;
34
35 let mut files = Vec::new();
36 if path.is_file() {
37 files.push(path.to_path_buf());
38 } else if path.is_dir() {
39 let walker = WalkBuilder::new(path)
40 .standard_filters(true) // enables .gitignore, .git/info/exclude, global config
41 .build();
42
43 for result in walker.flatten() {
44 if result.file_type().is_some_and(|ft| ft.is_file()) {
45 files.push(result.path().to_path_buf());
46 }
47 }
48 }
49 files
50}
51
52/// Convert byte offset to line and column numbers
53///
54/// Useful for converting byte-based spans from parsers to human-readable
55/// line/column positions.
56///
57/// # Arguments
58///
59/// * `src` - Source text
60/// * `byte_idx` - Byte offset into source
61///
62/// # Returns
63///
64/// Tuple of (line, column), both 1-indexed
65///
66/// # Example
67///
68/// ```rust
69/// use mcp_host::utils::byte_to_line_col;
70///
71/// let src = "hello\nworld";
72/// let (line, col) = byte_to_line_col(src, 6);
73/// assert_eq!((line, col), (2, 1)); // First char of "world"
74/// ```
75pub fn byte_to_line_col(src: &str, byte_idx: usize) -> (usize, usize) {
76 let mut line = 1;
77 let mut col = 1;
78 for (i, ch) in src.char_indices() {
79 if i == byte_idx {
80 return (line, col);
81 }
82 if ch == '\n' {
83 line += 1;
84 col = 1;
85 } else {
86 col += 1;
87 }
88 }
89 (line, col)
90}
91
92/// Check if a path is within a base directory (security boundary)
93///
94/// Prevents path traversal attacks by ensuring paths don't escape the
95/// base directory using `..` or absolute paths.
96///
97/// # Arguments
98///
99/// * `path` - Path to validate
100/// * `base_dir` - Base directory (security boundary)
101///
102/// # Returns
103///
104/// `true` if path is within base_dir, `false` otherwise
105///
106/// # Security
107///
108/// This function prevents:
109/// - Absolute paths outside base_dir
110/// - Relative paths with `..` that escape base_dir
111/// - Symlink escapes (via canonicalization)
112///
113/// # Example
114///
115/// ```rust,ignore
116/// let cwd = std::env::current_dir()?;
117/// assert!(is_safe_path(Path::new("foo/bar.txt"), &cwd));
118/// assert!(!is_safe_path(Path::new("../etc/passwd"), &cwd));
119/// ```
120pub fn is_safe_path(path: &Path, base_dir: &Path) -> bool {
121 let resolved = if path.is_absolute() {
122 path.to_path_buf()
123 } else {
124 base_dir.join(path)
125 };
126
127 // Try canonicalize first (handles symlinks and ..)
128 if let (Ok(resolved_canon), Ok(base_canon)) = (resolved.canonicalize(), base_dir.canonicalize())
129 {
130 return resolved_canon.starts_with(&base_canon);
131 }
132
133 // For non-existent paths, do basic check:
134 // - Absolute paths outside base are rejected
135 // - Relative paths with .. that escape are rejected
136 if path.is_absolute() {
137 if let Ok(base_canon) = base_dir.canonicalize() {
138 return resolved.starts_with(&base_canon);
139 }
140 return false;
141 }
142
143 // Relative path - check it doesn't start with .. that would escape
144 let mut depth: i32 = 0;
145 for component in path.components() {
146 match component {
147 std::path::Component::ParentDir => depth -= 1,
148 std::path::Component::Normal(_) => depth += 1,
149 _ => {}
150 }
151 if depth < 0 {
152 return false; // Escapes base with ..
153 }
154 }
155 true // Relative path stays within base
156}
157
158/// Base64 encode bytes
159///
160/// Useful for creating pagination cursors or encoding binary data
161/// for JSON transmission.
162///
163/// # Arguments
164///
165/// * `data` - Bytes to encode
166///
167/// # Returns
168///
169/// Base64-encoded string
170pub fn base64_encode(data: &[u8]) -> String {
171 use base64::{Engine, engine::general_purpose::STANDARD};
172 STANDARD.encode(data)
173}
174
175/// Base64 decode string
176///
177/// Decodes base64-encoded data. Returns empty vector on invalid input.
178///
179/// # Arguments
180///
181/// * `s` - Base64-encoded string
182///
183/// # Returns
184///
185/// Decoded bytes, or empty vector if invalid
186pub fn base64_decode(s: &str) -> Vec<u8> {
187 use base64::{Engine, engine::general_purpose::STANDARD};
188 STANDARD.decode(s).unwrap_or_default()
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 // ========== is_safe_path tests ==========
196
197 #[test]
198 fn test_is_safe_path_relative() {
199 let base = std::env::current_dir().unwrap();
200 assert!(is_safe_path(Path::new("foo"), &base));
201 assert!(is_safe_path(Path::new("foo/bar"), &base));
202 assert!(is_safe_path(Path::new("./foo"), &base));
203 }
204
205 #[test]
206 fn test_is_safe_path_parent_escape() {
207 let base = std::env::current_dir().unwrap();
208 assert!(!is_safe_path(Path::new("../foo"), &base));
209 assert!(!is_safe_path(Path::new("foo/../../bar"), &base));
210 assert!(!is_safe_path(Path::new(".."), &base));
211 }
212
213 #[test]
214 fn test_is_safe_path_absolute_outside() {
215 let base = std::env::current_dir().unwrap();
216 assert!(!is_safe_path(Path::new("/tmp"), &base));
217 assert!(!is_safe_path(Path::new("/etc/passwd"), &base));
218 assert!(!is_safe_path(Path::new("/home"), &base));
219 }
220
221 // ========== Base64 encoding tests ==========
222
223 #[test]
224 fn test_base64_roundtrip() {
225 let original = "42";
226 let encoded = base64_encode(original.as_bytes());
227 let decoded = String::from_utf8(base64_decode(&encoded)).unwrap();
228 assert_eq!(original, decoded);
229 }
230
231 #[test]
232 fn test_base64_offset_encoding() {
233 let offsets = [0, 10, 100, 12345];
234 for offset in offsets {
235 let encoded = base64_encode(offset.to_string().as_bytes());
236 let decoded: usize = String::from_utf8(base64_decode(&encoded))
237 .unwrap()
238 .parse()
239 .unwrap();
240 assert_eq!(offset, decoded);
241 }
242 }
243
244 #[test]
245 fn test_base64_invalid_input() {
246 // Invalid base64 should return empty vec
247 let result = base64_decode("!!!invalid!!!");
248 assert!(result.is_empty());
249 }
250
251 // ========== byte_to_line_col tests ==========
252
253 #[test]
254 fn test_byte_to_line_col_start() {
255 let src = "hello\nworld\n";
256 let (line, col) = byte_to_line_col(src, 0);
257 assert_eq!((line, col), (1, 1));
258 }
259
260 #[test]
261 fn test_byte_to_line_col_middle_first_line() {
262 let src = "hello\nworld\n";
263 let (line, col) = byte_to_line_col(src, 2);
264 assert_eq!((line, col), (1, 3)); // 'l' at position 2
265 }
266
267 #[test]
268 fn test_byte_to_line_col_second_line() {
269 let src = "hello\nworld\n";
270 let (line, col) = byte_to_line_col(src, 6);
271 assert_eq!((line, col), (2, 1)); // 'w' at start of line 2
272 }
273
274 #[test]
275 fn test_byte_to_line_col_end() {
276 let src = "hello\nworld\n";
277 let (line, col) = byte_to_line_col(src, 11);
278 assert_eq!((line, col), (2, 6)); // newline at end of 'world'
279 }
280
281 #[test]
282 fn test_byte_to_line_col_beyond_end() {
283 let src = "hi";
284 let (line, col) = byte_to_line_col(src, 100);
285 // Should return last position
286 assert_eq!((line, col), (1, 3));
287 }
288}