Skip to main content

rucora_tools/file/
config.rs

1//! 文件工具共享配置
2//!
3//! 用于提取 FileReadTool、FileWriteTool 和 FileEditTool 的公共逻辑
4
5use rucora_core::error::ToolError;
6use std::path::{Path, PathBuf};
7
8/// 允许的文件扩展名白名单
9const ALLOWED_EXTENSIONS: &[&str] = &[
10    "txt", "md", "rst", "rs", "py", "js", "ts", "jsx", "tsx", "json", "yaml", "yml", "toml", "cfg",
11    "ini", "sh", "bash", "zsh", "html", "css", "scss", "less", "xml", "csv",
12];
13
14/// 禁止访问的路径前缀
15const FORBIDDEN_PATH_PREFIXES: &[&str] = &[
16    "/etc/",
17    "/proc/",
18    "/sys/",
19    "/dev/",
20    "/boot/",
21    "/bin/",
22    "/sbin/",
23    "/usr/bin/",
24    "/usr/sbin/",
25    "C:\\Windows\\",
26    "C:\\Program Files\\",
27    "C:\\Program Files (x86)\\",
28];
29
30/// 文件工具的共享配置
31#[derive(Clone)]
32pub struct FileToolConfig {
33    /// 允许的工作目录(可选,限制文件访问范围)
34    pub allowed_dirs: Option<Vec<PathBuf>>,
35    /// 最大文件大小(字节)
36    pub max_file_size: u64,
37}
38
39impl FileToolConfig {
40    /// 创建默认配置
41    pub fn new() -> Self {
42        Self {
43            allowed_dirs: None,
44            max_file_size: 1024 * 1024, // 1MB
45        }
46    }
47
48    /// 设置允许的工作目录
49    pub fn with_allowed_dirs(mut self, dirs: Vec<PathBuf>) -> Self {
50        self.allowed_dirs = Some(
51            dirs.into_iter()
52                .map(|dir| dir.canonicalize().unwrap_or(dir))
53                .collect(),
54        );
55        self
56    }
57
58    /// 设置最大文件大小
59    pub fn with_max_file_size(mut self, size: u64) -> Self {
60        self.max_file_size = size;
61        self
62    }
63
64    /// 验证路径是否安全(用于读取)
65    pub fn validate_path_for_read(&self, path: &str) -> Result<PathBuf, ToolError> {
66        self.validate_path(path, false)
67    }
68
69    /// 验证路径是否安全(用于写入)
70    pub fn validate_path_for_write(&self, path: &str) -> Result<PathBuf, ToolError> {
71        self.validate_path(path, true)
72    }
73
74    /// 验证路径的共享逻辑
75    fn validate_path(&self, path: &str, is_write: bool) -> Result<PathBuf, ToolError> {
76        let path = Path::new(path);
77
78        // 检查是否为绝对路径且包含禁止前缀
79        if let Some(path_str) = path.to_str() {
80            let path_lower = path_str.to_lowercase();
81            for prefix in FORBIDDEN_PATH_PREFIXES {
82                if path_lower.starts_with(&prefix.to_lowercase()) {
83                    return Err(ToolError::Message(format!(
84                        "禁止访问系统敏感路径:{path_str}"
85                    )));
86                }
87            }
88        }
89
90        // 检查扩展名
91        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
92            let ext_lower = ext.to_lowercase();
93            if !ALLOWED_EXTENSIONS.contains(&ext_lower.as_str()) {
94                return Err(ToolError::Message(format!(
95                    "不支持的文件类型:{ext}(允许的类型:{ALLOWED_EXTENSIONS:?})"
96                )));
97            }
98        } else {
99            return Err(ToolError::Message("文件必须包含扩展名".to_string()));
100        }
101
102        // 如果配置了允许的目录,检查路径是否在其中
103        if let Some(allowed_dirs) = &self.allowed_dirs {
104            if is_write {
105                // 写入时检查父目录
106                let parent = path.parent().unwrap_or(path);
107                let canonical_path = parent
108                    .canonicalize()
109                    .unwrap_or_else(|_| parent.to_path_buf());
110                let is_allowed = allowed_dirs
111                    .iter()
112                    .any(|dir| canonical_path.starts_with(dir));
113                if !is_allowed {
114                    return Err(ToolError::Message(format!(
115                        "文件路径不在允许的工作目录内(允许的目录:{allowed_dirs:?})"
116                    )));
117                }
118            } else {
119                // 读取时检查文件本身
120                let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
121                let is_allowed = allowed_dirs
122                    .iter()
123                    .any(|dir| canonical_path.starts_with(dir));
124                if !is_allowed {
125                    return Err(ToolError::Message(format!(
126                        "文件路径不在允许的工作目录内(允许的目录:{allowed_dirs:?})"
127                    )));
128                }
129            }
130        }
131
132        Ok(path.to_path_buf())
133    }
134
135    /// 检查文件大小
136    pub fn check_file_size(&self, size: u64, operation: &str) -> Result<(), ToolError> {
137        if size > self.max_file_size {
138            return Err(ToolError::Message(format!(
139                "{}过大({} 字节),超过限制({} 字节)",
140                operation, size, self.max_file_size
141            )));
142        }
143        Ok(())
144    }
145}
146
147impl Default for FileToolConfig {
148    fn default() -> Self {
149        Self::new()
150    }
151}