1use std::collections::HashSet;
4use std::path::{Path, PathBuf};
5use std::time::Duration;
6
7use crate::error::{Error, Result};
8
9#[derive(Debug, Clone, Default)]
11pub struct PathAllowlist {
12 pub read: HashSet<PathBuf>,
14 pub write: HashSet<PathBuf>,
16 pub deny: HashSet<PathBuf>,
18}
19
20impl PathAllowlist {
21 pub fn none() -> Self {
23 Self::default()
24 }
25
26 pub fn all() -> Self {
28 Self {
29 read: [PathBuf::from("/")].into_iter().collect(),
30 write: [PathBuf::from("/")].into_iter().collect(),
31 deny: HashSet::new(),
32 }
33 }
34
35 pub fn allow_read(mut self, path: impl Into<PathBuf>) -> Self {
37 self.read.insert(path.into());
38 self
39 }
40
41 pub fn allow_write(mut self, path: impl Into<PathBuf>) -> Self {
43 self.write.insert(path.into());
44 self
45 }
46
47 pub fn allow_rw(self, path: impl Into<PathBuf>) -> Self {
49 let path = path.into();
50 self.allow_read(path.clone()).allow_write(path)
51 }
52
53 pub fn deny(mut self, path: impl Into<PathBuf>) -> Self {
55 self.deny.insert(path.into());
56 self
57 }
58
59 pub fn can_read(&self, path: &Path) -> bool {
61 if self.is_denied(path) {
62 return false;
63 }
64 self.read.iter().any(|allowed| path.starts_with(allowed))
65 }
66
67 pub fn can_write(&self, path: &Path) -> bool {
69 if self.is_denied(path) {
70 return false;
71 }
72 self.write.iter().any(|allowed| path.starts_with(allowed))
73 }
74
75 fn is_denied(&self, path: &Path) -> bool {
77 self.deny.iter().any(|denied| path.starts_with(denied))
78 }
79
80 pub fn check_read(&self, path: &Path) -> Result<()> {
82 if self.can_read(path) {
83 Ok(())
84 } else {
85 Err(Error::path_not_allowed(path.display().to_string()))
86 }
87 }
88
89 pub fn check_write(&self, path: &Path) -> Result<()> {
91 if self.can_write(path) {
92 Ok(())
93 } else {
94 Err(Error::path_not_allowed(path.display().to_string()))
95 }
96 }
97}
98
99#[derive(Debug, Clone, Default)]
101pub struct HostAllowlist {
102 pub allowed: HashSet<String>,
104 pub denied: HashSet<String>,
106}
107
108impl HostAllowlist {
109 pub fn none() -> Self {
111 Self::default()
112 }
113
114 pub fn all() -> Self {
116 Self {
117 allowed: ["*".to_string()].into_iter().collect(),
118 denied: HashSet::new(),
119 }
120 }
121
122 pub fn allow(mut self, host: impl Into<String>) -> Self {
124 self.allowed.insert(host.into());
125 self
126 }
127
128 pub fn deny(mut self, host: impl Into<String>) -> Self {
130 self.denied.insert(host.into());
131 self
132 }
133
134 pub fn can_access(&self, host: &str) -> bool {
136 let host = host.to_lowercase();
137
138 for denied in &self.denied {
140 if Self::host_matches(&host, denied) {
141 return false;
142 }
143 }
144
145 for allowed in &self.allowed {
147 if Self::host_matches(&host, allowed) {
148 return true;
149 }
150 }
151
152 false
153 }
154
155 fn host_matches(host: &str, pattern: &str) -> bool {
156 let pattern = pattern.to_lowercase();
157
158 if pattern == "*" {
159 return true;
160 }
161
162 if pattern.starts_with("*.") {
163 let suffix = &pattern[1..];
164 host.ends_with(suffix) || host == &pattern[2..]
165 } else {
166 host == pattern
167 }
168 }
169
170 pub fn check(&self, host: &str) -> Result<()> {
172 if self.can_access(host) {
173 Ok(())
174 } else {
175 Err(Error::host_not_allowed(host))
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct SafetyConfig {
183 pub paths: PathAllowlist,
185 pub hosts: HostAllowlist,
187 pub env_vars: Option<HashSet<String>>,
189 pub allow_process: bool,
191 pub allowed_commands: Option<HashSet<String>>,
193 pub default_timeout: Duration,
195 pub max_timeout: Duration,
197}
198
199impl Default for SafetyConfig {
200 fn default() -> Self {
201 Self {
202 paths: PathAllowlist::none(),
203 hosts: HostAllowlist::none(),
204 env_vars: Some(HashSet::new()),
205 allow_process: false,
206 allowed_commands: None,
207 default_timeout: Duration::from_secs(30),
208 max_timeout: Duration::from_secs(300),
209 }
210 }
211}
212
213impl SafetyConfig {
214 pub fn new() -> Self {
216 Self::default()
217 }
218
219 pub fn permissive() -> Self {
221 Self {
222 paths: PathAllowlist::all(),
223 hosts: HostAllowlist::all(),
224 env_vars: None,
225 allow_process: true,
226 allowed_commands: None,
227 default_timeout: Duration::from_secs(60),
228 max_timeout: Duration::from_secs(3600),
229 }
230 }
231
232 pub fn strict() -> Self {
234 Self {
235 paths: PathAllowlist::none(),
236 hosts: HostAllowlist::none(),
237 env_vars: Some(HashSet::new()),
238 allow_process: false,
239 allowed_commands: Some(HashSet::new()),
240 default_timeout: Duration::from_secs(10),
241 max_timeout: Duration::from_secs(30),
242 }
243 }
244
245 pub fn with_paths(mut self, paths: PathAllowlist) -> Self {
247 self.paths = paths;
248 self
249 }
250
251 pub fn with_hosts(mut self, hosts: HostAllowlist) -> Self {
253 self.hosts = hosts;
254 self
255 }
256
257 pub fn with_env_vars<I, S>(mut self, vars: I) -> Self
259 where
260 I: IntoIterator<Item = S>,
261 S: Into<String>,
262 {
263 self.env_vars = Some(vars.into_iter().map(Into::into).collect());
264 self
265 }
266
267 pub fn allow_all_env(mut self) -> Self {
269 self.env_vars = None;
270 self
271 }
272
273 pub fn with_allow_process(mut self, allow: bool) -> Self {
275 self.allow_process = allow;
276 self
277 }
278
279 pub fn with_allowed_commands<I, S>(mut self, commands: I) -> Self
281 where
282 I: IntoIterator<Item = S>,
283 S: Into<String>,
284 {
285 self.allowed_commands = Some(commands.into_iter().map(Into::into).collect());
286 self
287 }
288
289 pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
291 self.default_timeout = timeout;
292 self
293 }
294
295 pub fn with_max_timeout(mut self, timeout: Duration) -> Self {
297 self.max_timeout = timeout;
298 self
299 }
300
301 pub fn can_access_env(&self, name: &str) -> bool {
303 match &self.env_vars {
304 None => true,
305 Some(allowed) => allowed.contains(name),
306 }
307 }
308
309 pub fn check_env(&self, name: &str) -> Result<()> {
311 if self.can_access_env(name) {
312 Ok(())
313 } else {
314 Err(Error::not_permitted(format!(
315 "environment variable access denied: {}",
316 name
317 )))
318 }
319 }
320
321 pub fn can_execute(&self, command: &str) -> bool {
323 if !self.allow_process {
324 return false;
325 }
326
327 match &self.allowed_commands {
328 None => true,
329 Some(allowed) => allowed.contains(command),
330 }
331 }
332
333 pub fn check_execute(&self, command: &str) -> Result<()> {
335 if !self.allow_process {
336 return Err(Error::not_permitted("process execution not allowed"));
337 }
338
339 if let Some(ref allowed) = self.allowed_commands {
340 if !allowed.contains(command) {
341 return Err(Error::not_permitted(format!(
342 "command not allowed: {}",
343 command
344 )));
345 }
346 }
347
348 Ok(())
349 }
350
351 pub fn clamp_timeout(&self, timeout: Duration) -> Duration {
353 timeout.min(self.max_timeout)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_path_allowlist() {
363 let paths = PathAllowlist::none()
364 .allow_read("/tmp")
365 .allow_rw("/home/user/data")
366 .deny("/home/user/data/secret");
367
368 assert!(paths.can_read(Path::new("/tmp/file.txt")));
369 assert!(!paths.can_write(Path::new("/tmp/file.txt")));
370
371 assert!(paths.can_read(Path::new("/home/user/data/file.txt")));
372 assert!(paths.can_write(Path::new("/home/user/data/file.txt")));
373
374 assert!(!paths.can_read(Path::new("/home/user/data/secret/key")));
375 assert!(!paths.can_write(Path::new("/home/user/data/secret/key")));
376
377 assert!(!paths.can_read(Path::new("/etc/passwd")));
378 }
379
380 #[test]
381 fn test_host_allowlist() {
382 let hosts = HostAllowlist::none()
383 .allow("api.example.com")
384 .allow("*.trusted.org")
385 .deny("evil.trusted.org");
386
387 assert!(hosts.can_access("api.example.com"));
388 assert!(hosts.can_access("sub.trusted.org"));
389 assert!(hosts.can_access("trusted.org"));
390 assert!(!hosts.can_access("evil.trusted.org"));
391 assert!(!hosts.can_access("other.com"));
392 }
393
394 #[test]
395 fn test_safety_config() {
396 let config = SafetyConfig::new()
397 .with_env_vars(["PATH", "HOME"])
398 .with_allow_process(true)
399 .with_allowed_commands(["ls", "cat"]);
400
401 assert!(config.can_access_env("PATH"));
402 assert!(!config.can_access_env("SECRET"));
403
404 assert!(config.can_execute("ls"));
405 assert!(!config.can_execute("rm"));
406 }
407
408 #[test]
409 fn test_timeout_clamping() {
410 let config = SafetyConfig::new().with_max_timeout(Duration::from_secs(60));
411
412 assert_eq!(
413 config.clamp_timeout(Duration::from_secs(30)),
414 Duration::from_secs(30)
415 );
416 assert_eq!(
417 config.clamp_timeout(Duration::from_secs(120)),
418 Duration::from_secs(60)
419 );
420 }
421}