1use std::collections::HashSet;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7#[non_exhaustive]
8pub enum Capability {
9 FsRead,
11 FsWrite,
13 FsExecute,
15 NetRequest,
17 NetListen,
19 ProcessExec,
21 EnvRead,
23 EnvWrite,
25 TimeRead,
27 Random,
29 StdinRead,
31 StdoutWrite,
33 StderrWrite,
35 Metrics,
37 Logging,
39 AsyncSpawn,
41 Crypto,
43 Serialize,
45}
46
47impl Capability {
48 pub fn name(&self) -> &'static str {
50 match self {
51 Capability::FsRead => "fs:read",
52 Capability::FsWrite => "fs:write",
53 Capability::FsExecute => "fs:execute",
54 Capability::NetRequest => "net:request",
55 Capability::NetListen => "net:listen",
56 Capability::ProcessExec => "process:exec",
57 Capability::EnvRead => "env:read",
58 Capability::EnvWrite => "env:write",
59 Capability::TimeRead => "time:read",
60 Capability::Random => "random",
61 Capability::StdinRead => "stdin:read",
62 Capability::StdoutWrite => "stdout:write",
63 Capability::StderrWrite => "stderr:write",
64 Capability::Metrics => "metrics",
65 Capability::Logging => "logging",
66 Capability::AsyncSpawn => "async:spawn",
67 Capability::Crypto => "crypto",
68 Capability::Serialize => "serialize",
69 }
70 }
71
72 pub fn from_name(name: &str) -> Option<Self> {
74 match name {
75 "fs:read" => Some(Capability::FsRead),
76 "fs:write" => Some(Capability::FsWrite),
77 "fs:execute" => Some(Capability::FsExecute),
78 "net:request" => Some(Capability::NetRequest),
79 "net:listen" => Some(Capability::NetListen),
80 "process:exec" => Some(Capability::ProcessExec),
81 "env:read" => Some(Capability::EnvRead),
82 "env:write" => Some(Capability::EnvWrite),
83 "time:read" => Some(Capability::TimeRead),
84 "random" => Some(Capability::Random),
85 "stdin:read" => Some(Capability::StdinRead),
86 "stdout:write" => Some(Capability::StdoutWrite),
87 "stderr:write" => Some(Capability::StderrWrite),
88 "metrics" => Some(Capability::Metrics),
89 "logging" => Some(Capability::Logging),
90 "async:spawn" => Some(Capability::AsyncSpawn),
91 "crypto" => Some(Capability::Crypto),
92 "serialize" => Some(Capability::Serialize),
93 _ => None,
94 }
95 }
96
97 pub fn is_dangerous(&self) -> bool {
99 matches!(
100 self,
101 Capability::FsWrite
102 | Capability::ProcessExec
103 | Capability::NetListen
104 | Capability::EnvWrite
105 )
106 }
107
108 pub fn all() -> &'static [Capability] {
110 &[
111 Capability::FsRead,
112 Capability::FsWrite,
113 Capability::FsExecute,
114 Capability::NetRequest,
115 Capability::NetListen,
116 Capability::ProcessExec,
117 Capability::EnvRead,
118 Capability::EnvWrite,
119 Capability::TimeRead,
120 Capability::Random,
121 Capability::StdinRead,
122 Capability::StdoutWrite,
123 Capability::StderrWrite,
124 Capability::Metrics,
125 Capability::Logging,
126 Capability::AsyncSpawn,
127 Capability::Crypto,
128 Capability::Serialize,
129 ]
130 }
131}
132
133#[derive(Debug, Clone, Default)]
135pub struct Capabilities {
136 granted: HashSet<Capability>,
137}
138
139impl Capabilities {
140 pub fn none() -> Self {
142 Self::default()
143 }
144
145 pub fn all() -> Self {
147 Self {
148 granted: Capability::all().iter().copied().collect(),
149 }
150 }
151
152 pub fn safe_defaults() -> Self {
156 Self::none()
157 .with(Capability::TimeRead)
158 .with(Capability::Random)
159 .with(Capability::StdoutWrite)
160 .with(Capability::StderrWrite)
161 .with(Capability::Logging)
162 .with(Capability::Serialize)
163 }
164
165 pub fn with(mut self, cap: Capability) -> Self {
167 self.granted.insert(cap);
168 self
169 }
170
171 pub fn with_all<I: IntoIterator<Item = Capability>>(mut self, caps: I) -> Self {
173 self.granted.extend(caps);
174 self
175 }
176
177 pub fn without(mut self, cap: Capability) -> Self {
179 self.granted.remove(&cap);
180 self
181 }
182
183 pub fn grant(&mut self, cap: Capability) {
185 self.granted.insert(cap);
186 }
187
188 pub fn revoke(&mut self, cap: Capability) {
190 self.granted.remove(&cap);
191 }
192
193 pub fn has(&self, cap: Capability) -> bool {
195 self.granted.contains(&cap)
196 }
197
198 pub fn require(&self, cap: Capability) -> crate::Result<()> {
200 if self.has(cap) {
201 Ok(())
202 } else {
203 Err(crate::Error::capability_denied(cap.name()))
204 }
205 }
206
207 pub fn granted(&self) -> impl Iterator<Item = &Capability> {
209 self.granted.iter()
210 }
211
212 pub fn len(&self) -> usize {
214 self.granted.len()
215 }
216
217 pub fn is_empty(&self) -> bool {
219 self.granted.is_empty()
220 }
221
222 pub fn has_dangerous(&self) -> bool {
224 self.granted.iter().any(|c| c.is_dangerous())
225 }
226
227 pub fn from_names<'a, I: IntoIterator<Item = &'a str>>(names: I) -> Self {
229 let granted = names
230 .into_iter()
231 .filter_map(Capability::from_name)
232 .collect();
233 Self { granted }
234 }
235
236 pub fn to_names(&self) -> Vec<&'static str> {
238 self.granted.iter().map(|c| c.name()).collect()
239 }
240
241 pub fn merge(&self, other: &Capabilities) -> Capabilities {
243 let granted = self.granted.union(&other.granted).copied().collect();
244 Capabilities { granted }
245 }
246
247 pub fn intersect(&self, other: &Capabilities) -> Capabilities {
249 let granted = self.granted.intersection(&other.granted).copied().collect();
250 Capabilities { granted }
251 }
252}
253
254impl FromIterator<Capability> for Capabilities {
255 fn from_iter<I: IntoIterator<Item = Capability>>(iter: I) -> Self {
256 Self {
257 granted: iter.into_iter().collect(),
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_capability_name_roundtrip() {
268 for cap in Capability::all() {
269 let name = cap.name();
270 let parsed = Capability::from_name(name);
271 assert_eq!(parsed, Some(*cap), "Failed roundtrip for {:?}", cap);
272 }
273 }
274
275 #[test]
276 fn test_capabilities_none() {
277 let caps = Capabilities::none();
278 assert!(caps.is_empty());
279 assert!(!caps.has(Capability::FsRead));
280 }
281
282 #[test]
283 fn test_capabilities_all() {
284 let caps = Capabilities::all();
285 assert_eq!(caps.len(), Capability::all().len());
286 assert!(caps.has(Capability::FsRead));
287 assert!(caps.has(Capability::NetRequest));
288 }
289
290 #[test]
291 fn test_capabilities_safe_defaults() {
292 let caps = Capabilities::safe_defaults();
293 assert!(caps.has(Capability::TimeRead));
294 assert!(caps.has(Capability::Logging));
295 assert!(!caps.has(Capability::FsWrite));
296 assert!(!caps.has(Capability::ProcessExec));
297 }
298
299 #[test]
300 fn test_capabilities_builder() {
301 let caps = Capabilities::none()
302 .with(Capability::FsRead)
303 .with(Capability::NetRequest)
304 .without(Capability::FsRead);
305
306 assert!(!caps.has(Capability::FsRead));
307 assert!(caps.has(Capability::NetRequest));
308 }
309
310 #[test]
311 fn test_capabilities_require() {
312 let caps = Capabilities::none().with(Capability::FsRead);
313
314 assert!(caps.require(Capability::FsRead).is_ok());
315 assert!(caps.require(Capability::FsWrite).is_err());
316 }
317
318 #[test]
319 fn test_capabilities_from_names() {
320 let caps = Capabilities::from_names(["fs:read", "net:request", "invalid"]);
321 assert!(caps.has(Capability::FsRead));
322 assert!(caps.has(Capability::NetRequest));
323 assert_eq!(caps.len(), 2);
324 }
325
326 #[test]
327 fn test_dangerous_capabilities() {
328 assert!(Capability::FsWrite.is_dangerous());
329 assert!(Capability::ProcessExec.is_dangerous());
330 assert!(!Capability::FsRead.is_dangerous());
331 assert!(!Capability::TimeRead.is_dangerous());
332
333 let caps = Capabilities::none().with(Capability::FsWrite);
334 assert!(caps.has_dangerous());
335
336 let safe = Capabilities::safe_defaults();
337 assert!(!safe.has_dangerous());
338 }
339
340 #[test]
341 fn test_capabilities_merge() {
342 let a = Capabilities::none().with(Capability::FsRead);
343 let b = Capabilities::none().with(Capability::NetRequest);
344 let merged = a.merge(&b);
345
346 assert!(merged.has(Capability::FsRead));
347 assert!(merged.has(Capability::NetRequest));
348 }
349
350 #[test]
351 fn test_capabilities_intersect() {
352 let a = Capabilities::none()
353 .with(Capability::FsRead)
354 .with(Capability::NetRequest);
355 let b = Capabilities::none()
356 .with(Capability::NetRequest)
357 .with(Capability::TimeRead);
358 let intersected = a.intersect(&b);
359
360 assert!(!intersected.has(Capability::FsRead));
361 assert!(intersected.has(Capability::NetRequest));
362 assert!(!intersected.has(Capability::TimeRead));
363 }
364}