mcp_host/
utils.rs

1//! Utility functions for file handling, path security, and encoding
2//!
3//! Provides common utilities needed by MCP servers: file collection with
4//! gitignore support, path security validation, byte offset conversion,
5//! and base64 encoding for pagination cursors.
6
7use std::path::{Path, PathBuf};
8
9/// Collect files recursively from a path, respecting .gitignore
10///
11/// Uses `ignore` crate to respect:
12/// - `.gitignore` files
13/// - `.git/info/exclude`
14/// - Global git ignore configuration
15///
16/// # Arguments
17///
18/// * `path` - File or directory to collect from
19///
20/// # Returns
21///
22/// Vector of file paths (directories are excluded)
23///
24/// # Example
25///
26/// ```rust,ignore
27/// let files = collect_files(Path::new("src"));
28/// for file in files {
29///     println!("Found: {:?}", file);
30/// }
31/// ```
32pub fn collect_files(path: &Path) -> Vec<PathBuf> {
33    use ignore::WalkBuilder;
34
35    let mut files = Vec::new();
36    if path.is_file() {
37        files.push(path.to_path_buf());
38    } else if path.is_dir() {
39        let walker = WalkBuilder::new(path)
40            .standard_filters(true) // enables .gitignore, .git/info/exclude, global config
41            .build();
42
43        for result in walker.flatten() {
44            if result.file_type().is_some_and(|ft| ft.is_file()) {
45                files.push(result.path().to_path_buf());
46            }
47        }
48    }
49    files
50}
51
52/// Convert byte offset to line and column numbers
53///
54/// Useful for converting byte-based spans from parsers to human-readable
55/// line/column positions.
56///
57/// # Arguments
58///
59/// * `src` - Source text
60/// * `byte_idx` - Byte offset into source
61///
62/// # Returns
63///
64/// Tuple of (line, column), both 1-indexed
65///
66/// # Example
67///
68/// ```rust
69/// use mcp_host::utils::byte_to_line_col;
70///
71/// let src = "hello\nworld";
72/// let (line, col) = byte_to_line_col(src, 6);
73/// assert_eq!((line, col), (2, 1)); // First char of "world"
74/// ```
75pub fn byte_to_line_col(src: &str, byte_idx: usize) -> (usize, usize) {
76    let mut line = 1;
77    let mut col = 1;
78    for (i, ch) in src.char_indices() {
79        if i == byte_idx {
80            return (line, col);
81        }
82        if ch == '\n' {
83            line += 1;
84            col = 1;
85        } else {
86            col += 1;
87        }
88    }
89    (line, col)
90}
91
92/// Check if a path is within a base directory (security boundary)
93///
94/// Prevents path traversal attacks by ensuring paths don't escape the
95/// base directory using `..` or absolute paths.
96///
97/// # Arguments
98///
99/// * `path` - Path to validate
100/// * `base_dir` - Base directory (security boundary)
101///
102/// # Returns
103///
104/// `true` if path is within base_dir, `false` otherwise
105///
106/// # Security
107///
108/// This function prevents:
109/// - Absolute paths outside base_dir
110/// - Relative paths with `..` that escape base_dir
111/// - Symlink escapes (via canonicalization)
112///
113/// # Example
114///
115/// ```rust,ignore
116/// let cwd = std::env::current_dir()?;
117/// assert!(is_safe_path(Path::new("foo/bar.txt"), &cwd));
118/// assert!(!is_safe_path(Path::new("../etc/passwd"), &cwd));
119/// ```
120pub fn is_safe_path(path: &Path, base_dir: &Path) -> bool {
121    let resolved = if path.is_absolute() {
122        path.to_path_buf()
123    } else {
124        base_dir.join(path)
125    };
126
127    // Try canonicalize first (handles symlinks and ..)
128    if let (Ok(resolved_canon), Ok(base_canon)) = (resolved.canonicalize(), base_dir.canonicalize())
129    {
130        return resolved_canon.starts_with(&base_canon);
131    }
132
133    // For non-existent paths, do basic check:
134    // - Absolute paths outside base are rejected
135    // - Relative paths with .. that escape are rejected
136    if path.is_absolute() {
137        if let Ok(base_canon) = base_dir.canonicalize() {
138            return resolved.starts_with(&base_canon);
139        }
140        return false;
141    }
142
143    // Relative path - check it doesn't start with .. that would escape
144    let mut depth: i32 = 0;
145    for component in path.components() {
146        match component {
147            std::path::Component::ParentDir => depth -= 1,
148            std::path::Component::Normal(_) => depth += 1,
149            _ => {}
150        }
151        if depth < 0 {
152            return false; // Escapes base with ..
153        }
154    }
155    true // Relative path stays within base
156}
157
158/// Base64 encode bytes
159///
160/// Useful for creating pagination cursors or encoding binary data
161/// for JSON transmission.
162///
163/// # Arguments
164///
165/// * `data` - Bytes to encode
166///
167/// # Returns
168///
169/// Base64-encoded string
170pub fn base64_encode(data: &[u8]) -> String {
171    use base64::{engine::general_purpose::STANDARD, Engine};
172    STANDARD.encode(data)
173}
174
175/// Base64 decode string
176///
177/// Decodes base64-encoded data. Returns empty vector on invalid input.
178///
179/// # Arguments
180///
181/// * `s` - Base64-encoded string
182///
183/// # Returns
184///
185/// Decoded bytes, or empty vector if invalid
186pub fn base64_decode(s: &str) -> Vec<u8> {
187    use base64::{engine::general_purpose::STANDARD, Engine};
188    STANDARD.decode(s).unwrap_or_default()
189}
190
191// ============================================================================
192// Pagination Utilities
193// ============================================================================
194
195/// Default page size for list operations
196pub const DEFAULT_PAGE_SIZE: usize = 100;
197
198/// Result of a paginated list operation
199#[derive(Debug, Clone)]
200pub struct PaginatedResult<T> {
201    /// Items in the current page
202    pub items: Vec<T>,
203    /// Cursor for the next page, if more items exist
204    pub next_cursor: Option<String>,
205}
206
207impl<T> PaginatedResult<T> {
208    /// Create a new paginated result
209    pub fn new(items: Vec<T>, next_cursor: Option<String>) -> Self {
210        Self { items, next_cursor }
211    }
212
213    /// Create an empty result
214    pub fn empty() -> Self {
215        Self {
216            items: Vec::new(),
217            next_cursor: None,
218        }
219    }
220}
221
222/// Encode a pagination cursor
223///
224/// Creates an opaque base64-encoded cursor from an offset.
225/// The cursor format is implementation-specific and clients
226/// should treat it as opaque.
227pub fn encode_cursor(offset: usize) -> String {
228    base64_encode(offset.to_string().as_bytes())
229}
230
231/// Decode a pagination cursor
232///
233/// Decodes an opaque cursor back to an offset.
234/// Returns None for invalid cursors.
235pub fn decode_cursor(cursor: &str) -> Option<usize> {
236    let bytes = base64_decode(cursor);
237    let s = String::from_utf8(bytes).ok()?;
238    s.parse().ok()
239}
240
241/// Paginate a list of items
242///
243/// Generic pagination helper that works with any list of items.
244/// Uses offset-based pagination with opaque cursors.
245///
246/// # Arguments
247///
248/// * `items` - Full list of items to paginate
249/// * `cursor` - Optional cursor from previous page
250/// * `page_size` - Maximum items per page
251///
252/// # Returns
253///
254/// PaginatedResult with the current page and optional next cursor
255pub fn paginate<T: Clone>(
256    items: &[T],
257    cursor: Option<&str>,
258    page_size: usize,
259) -> PaginatedResult<T> {
260    if items.is_empty() {
261        return PaginatedResult::empty();
262    }
263
264    // Decode cursor to get starting offset
265    let start = cursor.and_then(decode_cursor).unwrap_or(0).min(items.len());
266
267    let end = (start + page_size).min(items.len());
268    let page_items: Vec<T> = items[start..end].to_vec();
269
270    // Generate next cursor if there are more items
271    let next_cursor = if end < items.len() {
272        Some(encode_cursor(end))
273    } else {
274        None
275    };
276
277    PaginatedResult::new(page_items, next_cursor)
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    // ========== is_safe_path tests ==========
285
286    #[test]
287    fn test_is_safe_path_relative() {
288        let base = std::env::current_dir().unwrap();
289        assert!(is_safe_path(Path::new("foo"), &base));
290        assert!(is_safe_path(Path::new("foo/bar"), &base));
291        assert!(is_safe_path(Path::new("./foo"), &base));
292    }
293
294    #[test]
295    fn test_is_safe_path_parent_escape() {
296        let base = std::env::current_dir().unwrap();
297        assert!(!is_safe_path(Path::new("../foo"), &base));
298        assert!(!is_safe_path(Path::new("foo/../../bar"), &base));
299        assert!(!is_safe_path(Path::new(".."), &base));
300    }
301
302    #[test]
303    fn test_is_safe_path_absolute_outside() {
304        let base = std::env::current_dir().unwrap();
305        assert!(!is_safe_path(Path::new("/tmp"), &base));
306        assert!(!is_safe_path(Path::new("/etc/passwd"), &base));
307        assert!(!is_safe_path(Path::new("/home"), &base));
308    }
309
310    // ========== Base64 encoding tests ==========
311
312    #[test]
313    fn test_base64_roundtrip() {
314        let original = "42";
315        let encoded = base64_encode(original.as_bytes());
316        let decoded = String::from_utf8(base64_decode(&encoded)).unwrap();
317        assert_eq!(original, decoded);
318    }
319
320    #[test]
321    fn test_base64_offset_encoding() {
322        let offsets = [0, 10, 100, 12345];
323        for offset in offsets {
324            let encoded = base64_encode(offset.to_string().as_bytes());
325            let decoded: usize = String::from_utf8(base64_decode(&encoded))
326                .unwrap()
327                .parse()
328                .unwrap();
329            assert_eq!(offset, decoded);
330        }
331    }
332
333    #[test]
334    fn test_base64_invalid_input() {
335        // Invalid base64 should return empty vec
336        let result = base64_decode("!!!invalid!!!");
337        assert!(result.is_empty());
338    }
339
340    // ========== byte_to_line_col tests ==========
341
342    #[test]
343    fn test_byte_to_line_col_start() {
344        let src = "hello\nworld\n";
345        let (line, col) = byte_to_line_col(src, 0);
346        assert_eq!((line, col), (1, 1));
347    }
348
349    #[test]
350    fn test_byte_to_line_col_middle_first_line() {
351        let src = "hello\nworld\n";
352        let (line, col) = byte_to_line_col(src, 2);
353        assert_eq!((line, col), (1, 3)); // 'l' at position 2
354    }
355
356    #[test]
357    fn test_byte_to_line_col_second_line() {
358        let src = "hello\nworld\n";
359        let (line, col) = byte_to_line_col(src, 6);
360        assert_eq!((line, col), (2, 1)); // 'w' at start of line 2
361    }
362
363    #[test]
364    fn test_byte_to_line_col_end() {
365        let src = "hello\nworld\n";
366        let (line, col) = byte_to_line_col(src, 11);
367        assert_eq!((line, col), (2, 6)); // newline at end of 'world'
368    }
369
370    #[test]
371    fn test_byte_to_line_col_beyond_end() {
372        let src = "hi";
373        let (line, col) = byte_to_line_col(src, 100);
374        // Should return last position
375        assert_eq!((line, col), (1, 3));
376    }
377
378    // ========== pagination tests ==========
379
380    #[test]
381    fn test_encode_decode_cursor() {
382        let cursor = encode_cursor(42);
383        assert_eq!(decode_cursor(&cursor), Some(42));
384    }
385
386    #[test]
387    fn test_decode_invalid_cursor() {
388        assert_eq!(decode_cursor("invalid"), None);
389        assert_eq!(decode_cursor("!!!"), None);
390    }
391
392    #[test]
393    fn test_paginate_empty() {
394        let items: Vec<i32> = vec![];
395        let result = paginate(&items, None, 10);
396        assert!(result.items.is_empty());
397        assert!(result.next_cursor.is_none());
398    }
399
400    #[test]
401    fn test_paginate_no_cursor_within_limit() {
402        let items: Vec<i32> = vec![1, 2, 3, 4, 5];
403        let result = paginate(&items, None, 10);
404        assert_eq!(result.items, vec![1, 2, 3, 4, 5]);
405        assert!(result.next_cursor.is_none());
406    }
407
408    #[test]
409    fn test_paginate_no_cursor_exceeds_limit() {
410        let items: Vec<i32> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
411        let result = paginate(&items, None, 3);
412        assert_eq!(result.items, vec![1, 2, 3]);
413        assert!(result.next_cursor.is_some());
414    }
415
416    #[test]
417    fn test_paginate_with_cursor() {
418        let items: Vec<i32> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
419
420        // First page
421        let result1 = paginate(&items, None, 3);
422        assert_eq!(result1.items, vec![1, 2, 3]);
423
424        // Second page
425        let result2 = paginate(&items, result1.next_cursor.as_deref(), 3);
426        assert_eq!(result2.items, vec![4, 5, 6]);
427
428        // Third page
429        let result3 = paginate(&items, result2.next_cursor.as_deref(), 3);
430        assert_eq!(result3.items, vec![7, 8, 9]);
431
432        // Last page
433        let result4 = paginate(&items, result3.next_cursor.as_deref(), 3);
434        assert_eq!(result4.items, vec![10]);
435        assert!(result4.next_cursor.is_none());
436    }
437
438    #[test]
439    fn test_paginate_invalid_cursor() {
440        let items: Vec<i32> = vec![1, 2, 3, 4, 5];
441        // Invalid cursor should start from beginning
442        let result = paginate(&items, Some("invalid"), 10);
443        assert_eq!(result.items, vec![1, 2, 3, 4, 5]);
444    }
445}