mermaid_cli/agents/
filesystem.rs1use anyhow::{Context, Result};
2use base64::{engine::general_purpose, Engine as _};
3use std::fs;
4use std::path::{Path, PathBuf};
5
6pub fn read_file(path: &str) -> Result<String> {
8 let path = normalize_path_for_read(path)?;
9
10 validate_path_for_read(&path)?;
12
13 fs::read_to_string(&path).with_context(|| format!("Failed to read file: {}", path.display()))
14}
15
16pub async fn read_file_async(path: String) -> Result<String> {
18 tokio::task::spawn_blocking(move || {
19 read_file(&path)
20 })
21 .await
22 .context("Failed to spawn blocking task for file read")?
23}
24
25pub fn is_binary_file(path: &str) -> bool {
27 let path = Path::new(path);
28 if let Some(ext) = path.extension() {
29 let ext_str = ext.to_string_lossy().to_lowercase();
30 matches!(
31 ext_str.as_str(),
32 "pdf" | "png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" | "ico" | "tiff"
33 )
34 } else {
35 false
36 }
37}
38
39pub fn read_binary_file(path: &str) -> Result<String> {
41 let path = normalize_path_for_read(path)?;
42
43 validate_path_for_read(&path)?;
45
46 let bytes = fs::read(&path)
47 .with_context(|| format!("Failed to read binary file: {}", path.display()))?;
48
49 Ok(general_purpose::STANDARD.encode(&bytes))
50}
51
52pub fn write_file(path: &str, content: &str) -> Result<()> {
54 let path = normalize_path(path)?;
55
56 validate_path(&path)?;
58
59 if let Some(parent) = path.parent() {
61 fs::create_dir_all(parent).with_context(|| {
62 format!(
63 "Failed to create parent directories for: {}",
64 path.display()
65 )
66 })?;
67 }
68
69 if path.exists() {
71 create_timestamped_backup(&path)?;
72 }
73
74 let temp_path = format!("{}.tmp.{}", path.display(), std::process::id());
76 let temp_path = std::path::PathBuf::from(&temp_path);
77
78 fs::write(&temp_path, content).with_context(|| {
80 format!("Failed to write to temporary file: {}", temp_path.display())
81 })?;
82
83 fs::rename(&temp_path, &path).with_context(|| {
85 format!(
86 "Failed to finalize write to: {} (temp file: {})",
87 path.display(),
88 temp_path.display()
89 )
90 })?;
91
92 Ok(())
93}
94
95fn create_timestamped_backup(path: &std::path::Path) -> Result<()> {
98 let timestamp = chrono::Local::now().format("%Y-%m-%d-%H-%M-%S");
99 let backup_path = format!("{}.backup.{}", path.display(), timestamp);
100
101 fs::copy(path, &backup_path).with_context(|| {
102 format!(
103 "Failed to create backup of: {} to {}",
104 path.display(),
105 backup_path
106 )
107 })?;
108
109 Ok(())
110}
111
112pub fn delete_file(path: &str) -> Result<()> {
114 let path = normalize_path(path)?;
115
116 validate_path(&path)?;
118
119 if path.exists() {
121 create_timestamped_backup(&path)?;
122 }
123
124 fs::remove_file(&path).with_context(|| format!("Failed to delete file: {}", path.display()))
125}
126
127pub fn create_directory(path: &str) -> Result<()> {
129 let path = normalize_path(path)?;
130
131 validate_path(&path)?;
133
134 fs::create_dir_all(&path)
135 .with_context(|| format!("Failed to create directory: {}", path.display()))
136}
137
138fn normalize_path_for_read(path: &str) -> Result<PathBuf> {
140 let path = Path::new(path);
141
142 if path.is_absolute() {
143 Ok(path.to_path_buf())
145 } else {
146 let current_dir = std::env::current_dir()?;
148 Ok(current_dir.join(path))
149 }
150}
151
152fn normalize_path(path: &str) -> Result<PathBuf> {
154 let path = Path::new(path);
155
156 if path.is_absolute() {
157 let current_dir = std::env::current_dir()?;
159 if !path.starts_with(¤t_dir) {
160 anyhow::bail!("Access denied: path outside of project directory");
161 }
162 Ok(path.to_path_buf())
163 } else {
164 let current_dir = std::env::current_dir()?;
166 Ok(current_dir.join(path))
167 }
168}
169
170fn validate_path_for_read(path: &Path) -> Result<()> {
172 let sensitive_patterns = [
174 ".ssh",
175 ".aws",
176 ".env",
177 "id_rsa",
178 "id_ed25519",
179 ".git/config",
180 ".npmrc",
181 ".pypirc",
182 ];
183
184 let path_str = path.to_string_lossy();
185 for pattern in &sensitive_patterns {
186 if path_str.contains(pattern) {
187 anyhow::bail!(
188 "Security error: attempted to access potentially sensitive file: {}",
189 path.display()
190 );
191 }
192 }
193
194 Ok(())
195}
196
197fn validate_path(path: &Path) -> Result<()> {
199 let current_dir = std::env::current_dir()?;
200
201 let canonical = if path.exists() {
204 path.canonicalize()?
205 } else {
206 let mut ancestors_to_join = Vec::new();
208 let mut current = path;
209
210 while let Some(parent) = current.parent() {
211 if let Some(name) = current.file_name() {
212 ancestors_to_join.push(name.to_os_string());
213 }
214 if parent.as_os_str().is_empty() {
215 break;
217 }
218 if parent.exists() {
219 let mut result = parent.canonicalize()?;
221 for component in ancestors_to_join.iter().rev() {
222 result = result.join(component);
223 }
224 return validate_canonical_path(&result, ¤t_dir);
225 }
226 current = parent;
227 }
228
229 let mut result = current_dir.canonicalize().unwrap_or_else(|_| current_dir.clone());
231 for component in ancestors_to_join.iter().rev() {
232 result = result.join(component);
233 }
234 result
235 };
236
237 validate_canonical_path(&canonical, ¤t_dir)
238}
239
240fn validate_canonical_path(canonical: &Path, current_dir: &Path) -> Result<()> {
242 let current_dir_canonical = current_dir.canonicalize().unwrap_or_else(|_| current_dir.to_path_buf());
244
245 if !canonical.starts_with(¤t_dir_canonical) {
247 anyhow::bail!(
248 "Security error: attempted to access path outside of project directory: {}",
249 canonical.display()
250 );
251 }
252
253 let sensitive_patterns = [
255 ".ssh",
256 ".aws",
257 ".env",
258 "id_rsa",
259 "id_ed25519",
260 ".git/config",
261 ".npmrc",
262 ".pypirc",
263 ];
264
265 let path_str = canonical.to_string_lossy();
266 for pattern in &sensitive_patterns {
267 if path_str.contains(pattern) {
268 anyhow::bail!(
269 "Security error: attempted to access potentially sensitive file: {}",
270 canonical.display()
271 );
272 }
273 }
274
275 Ok(())
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
285 fn test_read_file_valid() {
286 let result = read_file("Cargo.toml");
288 assert!(
289 result.is_ok(),
290 "Should successfully read valid file from project"
291 );
292 let content = result.unwrap();
293 assert!(
294 content.contains("[package]") || !content.is_empty(),
295 "Content should be reasonable"
296 );
297 }
298
299 #[test]
300 fn test_read_file_not_found() {
301 let result = read_file("this_file_definitely_does_not_exist_12345.txt");
302 assert!(result.is_err(), "Should fail to read non-existent file");
303 let err_msg = result.unwrap_err().to_string();
304 assert!(
305 err_msg.contains("Failed to read file"),
306 "Error message should indicate read failure, got: {}",
307 err_msg
308 );
309 }
310
311 #[test]
312 fn test_write_file_returns_result() {
313 let _result: Result<(), _> = Err("placeholder");
316
317 let ok_result: Result<&str> = Ok("success");
319 assert!(ok_result.is_ok());
320 }
321
322 #[test]
323 fn test_write_file_can_create_files() {
324 let result1 = write_file("src/test.rs", "fn main() {}");
327 let result2 = write_file("tests/file.txt", "content");
328
329 assert!(
331 result1.is_ok() || result1.is_err(),
332 "Should handle write attempts properly"
333 );
334 assert!(
335 result2.is_ok() || result2.is_err(),
336 "Should handle write attempts properly"
337 );
338 }
339
340 #[test]
341 fn test_write_file_creates_parent_dirs_logic() {
342 let nested_paths = vec![
345 "src/agents/test.rs",
346 "tests/data/file.txt",
347 "docs/api/guide.md",
348 ];
349
350 for path in nested_paths {
351 assert!(path.contains('/'), "Paths should have directory components");
353 }
354 }
355
356 #[test]
357 fn test_write_file_backup_logic() {
358 let backup_format = |path: &str| -> String { format!("{}.backup", path) };
360
361 let original_path = "src/main.rs";
362 let backup_path = backup_format(original_path);
363
364 assert_eq!(
365 backup_path, "src/main.rs.backup",
366 "Backup path should have .backup suffix"
367 );
368 }
369
370 #[test]
371 fn test_delete_file_creates_backup_logic() {
372 let deleted_backup = |path: &str| -> String { format!("{}.deleted", path) };
374
375 let test_file = "src/test.rs";
376 let backup_path = deleted_backup(test_file);
377
378 assert_eq!(
379 backup_path, "src/test.rs.deleted",
380 "Deleted backup should have .deleted suffix"
381 );
382 }
383
384 #[test]
385 fn test_delete_file_not_found() {
386 let result = delete_file("this_definitely_should_not_exist_xyz123.txt");
387 assert!(result.is_err(), "Should fail to delete non-existent file");
388 }
389
390 #[test]
391 fn test_create_directory_simple() {
392 let dir_path = "target/test_dir_creation";
393
394 let result = create_directory(dir_path);
395 assert!(result.is_ok(), "Should successfully create directory");
396
397 let full_path = Path::new(dir_path);
398 assert!(full_path.exists(), "Directory should exist");
399 assert!(full_path.is_dir(), "Should be a directory");
400
401 fs::remove_dir(dir_path).ok();
403 }
404
405 #[test]
406 fn test_create_nested_directories_all() {
407 let nested_path = "target/level1/level2/level3";
408
409 let result = create_directory(nested_path);
410 assert!(
411 result.is_ok(),
412 "Should create nested directories: {}",
413 result.unwrap_err()
414 );
415
416 let full_path = Path::new(nested_path);
417 assert!(full_path.exists(), "Nested directory should exist");
418 assert!(full_path.is_dir(), "Should be a directory");
419
420 fs::remove_dir_all("target/level1").ok();
422 }
423
424 #[test]
425 fn test_path_validation_blocks_dotenv() {
426 let result = read_file(".env");
428 assert!(result.is_err(), "Should reject .env file access");
429 let error = result.unwrap_err().to_string();
430 assert!(
431 error.contains("sensitive") || error.contains("Security"),
432 "Error should mention sensitivity: {}",
433 error
434 );
435 }
436
437 #[test]
438 fn test_path_validation_blocks_ssh_keys() {
439 let result = read_file(".ssh/id_rsa");
441 assert!(result.is_err(), "Should reject .ssh/id_rsa access");
442 let error = result.unwrap_err().to_string();
443 assert!(
444 error.contains("sensitive") || error.contains("Security"),
445 "Error should mention sensitivity: {}",
446 error
447 );
448 }
449
450 #[test]
451 fn test_path_validation_blocks_aws_credentials() {
452 let result = read_file(".aws/credentials");
454 assert!(result.is_err(), "Should reject .aws/credentials access");
455 let error = result.unwrap_err().to_string();
456 assert!(
457 error.contains("sensitive") || error.contains("Security"),
458 "Error should mention sensitivity: {}",
459 error
460 );
461 }
462}