mcp_host/
utils.rs

1//! Utility functions for file handling, path security, and encoding
2//!
3//! Provides common utilities needed by MCP servers: file collection with
4//! gitignore support, path security validation, byte offset conversion,
5//! and base64 encoding for pagination cursors.
6
7use std::path::{Path, PathBuf};
8
9/// Collect files recursively from a path, respecting .gitignore
10///
11/// Uses `ignore` crate to respect:
12/// - `.gitignore` files
13/// - `.git/info/exclude`
14/// - Global git ignore configuration
15///
16/// # Arguments
17///
18/// * `path` - File or directory to collect from
19///
20/// # Returns
21///
22/// Vector of file paths (directories are excluded)
23///
24/// # Example
25///
26/// ```rust,ignore
27/// let files = collect_files(Path::new("src"));
28/// for file in files {
29///     println!("Found: {:?}", file);
30/// }
31/// ```
32pub fn collect_files(path: &Path) -> Vec<PathBuf> {
33    use ignore::WalkBuilder;
34
35    let mut files = Vec::new();
36    if path.is_file() {
37        files.push(path.to_path_buf());
38    } else if path.is_dir() {
39        let walker = WalkBuilder::new(path)
40            .standard_filters(true) // enables .gitignore, .git/info/exclude, global config
41            .build();
42
43        for result in walker.flatten() {
44            if result.file_type().is_some_and(|ft| ft.is_file()) {
45                files.push(result.path().to_path_buf());
46            }
47        }
48    }
49    files
50}
51
52/// Convert byte offset to line and column numbers
53///
54/// Useful for converting byte-based spans from parsers to human-readable
55/// line/column positions.
56///
57/// # Arguments
58///
59/// * `src` - Source text
60/// * `byte_idx` - Byte offset into source
61///
62/// # Returns
63///
64/// Tuple of (line, column), both 1-indexed
65///
66/// # Example
67///
68/// ```rust
69/// use mcp_host::utils::byte_to_line_col;
70///
71/// let src = "hello\nworld";
72/// let (line, col) = byte_to_line_col(src, 6);
73/// assert_eq!((line, col), (2, 1)); // First char of "world"
74/// ```
75pub fn byte_to_line_col(src: &str, byte_idx: usize) -> (usize, usize) {
76    let mut line = 1;
77    let mut col = 1;
78    for (i, ch) in src.char_indices() {
79        if i == byte_idx {
80            return (line, col);
81        }
82        if ch == '\n' {
83            line += 1;
84            col = 1;
85        } else {
86            col += 1;
87        }
88    }
89    (line, col)
90}
91
92/// Check if a path is within a base directory (security boundary)
93///
94/// Prevents path traversal attacks by ensuring paths don't escape the
95/// base directory using `..` or absolute paths.
96///
97/// # Arguments
98///
99/// * `path` - Path to validate
100/// * `base_dir` - Base directory (security boundary)
101///
102/// # Returns
103///
104/// `true` if path is within base_dir, `false` otherwise
105///
106/// # Security
107///
108/// This function prevents:
109/// - Absolute paths outside base_dir
110/// - Relative paths with `..` that escape base_dir
111/// - Symlink escapes (via canonicalization)
112///
113/// # Example
114///
115/// ```rust,ignore
116/// let cwd = std::env::current_dir()?;
117/// assert!(is_safe_path(Path::new("foo/bar.txt"), &cwd));
118/// assert!(!is_safe_path(Path::new("../etc/passwd"), &cwd));
119/// ```
120pub fn is_safe_path(path: &Path, base_dir: &Path) -> bool {
121    let resolved = if path.is_absolute() {
122        path.to_path_buf()
123    } else {
124        base_dir.join(path)
125    };
126
127    // Try canonicalize first (handles symlinks and ..)
128    if let (Ok(resolved_canon), Ok(base_canon)) = (resolved.canonicalize(), base_dir.canonicalize())
129    {
130        return resolved_canon.starts_with(&base_canon);
131    }
132
133    // For non-existent paths, do basic check:
134    // - Absolute paths outside base are rejected
135    // - Relative paths with .. that escape are rejected
136    if path.is_absolute() {
137        if let Ok(base_canon) = base_dir.canonicalize() {
138            return resolved.starts_with(&base_canon);
139        }
140        return false;
141    }
142
143    // Relative path - check it doesn't start with .. that would escape
144    let mut depth: i32 = 0;
145    for component in path.components() {
146        match component {
147            std::path::Component::ParentDir => depth -= 1,
148            std::path::Component::Normal(_) => depth += 1,
149            _ => {}
150        }
151        if depth < 0 {
152            return false; // Escapes base with ..
153        }
154    }
155    true // Relative path stays within base
156}
157
158/// Base64 encode bytes
159///
160/// Useful for creating pagination cursors or encoding binary data
161/// for JSON transmission.
162///
163/// # Arguments
164///
165/// * `data` - Bytes to encode
166///
167/// # Returns
168///
169/// Base64-encoded string
170pub fn base64_encode(data: &[u8]) -> String {
171    use base64::{Engine, engine::general_purpose::STANDARD};
172    STANDARD.encode(data)
173}
174
175/// Base64 decode string
176///
177/// Decodes base64-encoded data. Returns empty vector on invalid input.
178///
179/// # Arguments
180///
181/// * `s` - Base64-encoded string
182///
183/// # Returns
184///
185/// Decoded bytes, or empty vector if invalid
186pub fn base64_decode(s: &str) -> Vec<u8> {
187    use base64::{Engine, engine::general_purpose::STANDARD};
188    STANDARD.decode(s).unwrap_or_default()
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    // ========== is_safe_path tests ==========
196
197    #[test]
198    fn test_is_safe_path_relative() {
199        let base = std::env::current_dir().unwrap();
200        assert!(is_safe_path(Path::new("foo"), &base));
201        assert!(is_safe_path(Path::new("foo/bar"), &base));
202        assert!(is_safe_path(Path::new("./foo"), &base));
203    }
204
205    #[test]
206    fn test_is_safe_path_parent_escape() {
207        let base = std::env::current_dir().unwrap();
208        assert!(!is_safe_path(Path::new("../foo"), &base));
209        assert!(!is_safe_path(Path::new("foo/../../bar"), &base));
210        assert!(!is_safe_path(Path::new(".."), &base));
211    }
212
213    #[test]
214    fn test_is_safe_path_absolute_outside() {
215        let base = std::env::current_dir().unwrap();
216        assert!(!is_safe_path(Path::new("/tmp"), &base));
217        assert!(!is_safe_path(Path::new("/etc/passwd"), &base));
218        assert!(!is_safe_path(Path::new("/home"), &base));
219    }
220
221    // ========== Base64 encoding tests ==========
222
223    #[test]
224    fn test_base64_roundtrip() {
225        let original = "42";
226        let encoded = base64_encode(original.as_bytes());
227        let decoded = String::from_utf8(base64_decode(&encoded)).unwrap();
228        assert_eq!(original, decoded);
229    }
230
231    #[test]
232    fn test_base64_offset_encoding() {
233        let offsets = [0, 10, 100, 12345];
234        for offset in offsets {
235            let encoded = base64_encode(offset.to_string().as_bytes());
236            let decoded: usize = String::from_utf8(base64_decode(&encoded))
237                .unwrap()
238                .parse()
239                .unwrap();
240            assert_eq!(offset, decoded);
241        }
242    }
243
244    #[test]
245    fn test_base64_invalid_input() {
246        // Invalid base64 should return empty vec
247        let result = base64_decode("!!!invalid!!!");
248        assert!(result.is_empty());
249    }
250
251    // ========== byte_to_line_col tests ==========
252
253    #[test]
254    fn test_byte_to_line_col_start() {
255        let src = "hello\nworld\n";
256        let (line, col) = byte_to_line_col(src, 0);
257        assert_eq!((line, col), (1, 1));
258    }
259
260    #[test]
261    fn test_byte_to_line_col_middle_first_line() {
262        let src = "hello\nworld\n";
263        let (line, col) = byte_to_line_col(src, 2);
264        assert_eq!((line, col), (1, 3)); // 'l' at position 2
265    }
266
267    #[test]
268    fn test_byte_to_line_col_second_line() {
269        let src = "hello\nworld\n";
270        let (line, col) = byte_to_line_col(src, 6);
271        assert_eq!((line, col), (2, 1)); // 'w' at start of line 2
272    }
273
274    #[test]
275    fn test_byte_to_line_col_end() {
276        let src = "hello\nworld\n";
277        let (line, col) = byte_to_line_col(src, 11);
278        assert_eq!((line, col), (2, 6)); // newline at end of 'world'
279    }
280
281    #[test]
282    fn test_byte_to_line_col_beyond_end() {
283        let src = "hi";
284        let (line, col) = byte_to_line_col(src, 100);
285        // Should return last position
286        assert_eq!((line, col), (1, 3));
287    }
288}