use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt;
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum Capability {
AudioSpeak,
AudioPlay,
AudioRecord,
DisplayChart,
DisplayText,
DisplayImage,
DisplayVideo,
NetworkFetch(String), NetworkWebSocket(String),
FileRead(String), FileWrite(String),
SystemTime,
SystemRandom,
SystemEnvironment(String),
ComputeUnlimited,
ComputeLimited(u64),
MemoryUnlimited,
MemoryLimited(usize), }
impl fmt::Display for Capability {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Capability::AudioSpeak => write!(f, "audio.speak"),
Capability::AudioPlay => write!(f, "audio.play"),
Capability::AudioRecord => write!(f, "audio.record"),
Capability::DisplayChart => write!(f, "display.chart"),
Capability::DisplayText => write!(f, "display.text"),
Capability::DisplayImage => write!(f, "display.image"),
Capability::DisplayVideo => write!(f, "display.video"),
Capability::NetworkFetch(pattern) => write!(f, "network.fetch[{pattern}]"),
Capability::NetworkWebSocket(pattern) => write!(f, "network.websocket[{pattern}]"),
Capability::FileRead(pattern) => write!(f, "file.read[{pattern}]"),
Capability::FileWrite(pattern) => write!(f, "file.write[{pattern}]"),
Capability::SystemTime => write!(f, "system.time"),
Capability::SystemRandom => write!(f, "system.random"),
Capability::SystemEnvironment(var) => write!(f, "system.env[{var}]"),
Capability::ComputeUnlimited => write!(f, "compute.unlimited"),
Capability::ComputeLimited(ops) => write!(f, "compute.limited[{ops}]"),
Capability::MemoryUnlimited => write!(f, "memory.unlimited"),
Capability::MemoryLimited(bytes) => write!(f, "memory.limited[{bytes}]"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CapabilitySet {
capabilities: HashSet<Capability>,
}
impl CapabilitySet {
pub fn new() -> Self {
CapabilitySet {
capabilities: HashSet::new(),
}
}
pub fn from_capabilities<I>(caps: I) -> Self
where
I: IntoIterator<Item = Capability>,
{
CapabilitySet {
capabilities: caps.into_iter().collect(),
}
}
pub fn grant(&mut self, cap: Capability) {
self.capabilities.insert(cap);
}
pub fn revoke(&mut self, cap: &Capability) -> bool {
self.capabilities.remove(cap)
}
pub fn has(&self, cap: &Capability) -> bool {
self.capabilities.contains(cap)
}
pub fn has_by_name(&self, name: &str) -> bool {
self.capabilities.iter().any(|cap| {
cap.to_string() == name || cap.to_string().starts_with(&format!("{name}["))
})
}
pub fn has_matching(&self, requested: &Capability) -> bool {
match requested {
Capability::NetworkFetch(url) => self.capabilities.iter().any(|cap| match cap {
Capability::NetworkFetch(pattern) => url_matches_pattern(url, pattern),
_ => false,
}),
Capability::NetworkWebSocket(url) => self.capabilities.iter().any(|cap| match cap {
Capability::NetworkWebSocket(pattern) => url_matches_pattern(url, pattern),
_ => false,
}),
Capability::FileRead(path) => self.capabilities.iter().any(|cap| match cap {
Capability::FileRead(pattern) => path_matches_pattern(path, pattern),
_ => false,
}),
Capability::FileWrite(path) => self.capabilities.iter().any(|cap| match cap {
Capability::FileWrite(pattern) => path_matches_pattern(path, pattern),
_ => false,
}),
_ => self.has(requested),
}
}
pub fn intersection(&self, other: &CapabilitySet) -> CapabilitySet {
CapabilitySet {
capabilities: self
.capabilities
.intersection(&other.capabilities)
.cloned()
.collect(),
}
}
pub fn union(&self, other: &CapabilitySet) -> CapabilitySet {
CapabilitySet {
capabilities: self
.capabilities
.union(&other.capabilities)
.cloned()
.collect(),
}
}
pub fn is_subset(&self, other: &CapabilitySet) -> bool {
self.capabilities.is_subset(&other.capabilities)
}
pub fn iter(&self) -> impl Iterator<Item = &Capability> {
self.capabilities.iter()
}
pub fn len(&self) -> usize {
self.capabilities.len()
}
pub fn is_empty(&self) -> bool {
self.capabilities.is_empty()
}
}
fn url_matches_pattern(url: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix("*") {
url.starts_with(prefix)
} else {
url == pattern
}
}
fn path_matches_pattern(path: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix("/*") {
path.starts_with(prefix)
} else if let Some(prefix) = pattern.strip_suffix("*") {
path.starts_with(prefix)
} else {
path == pattern
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capability_set() {
let mut caps = CapabilitySet::new();
caps.grant(Capability::AudioSpeak);
caps.grant(Capability::DisplayText);
assert!(caps.has(&Capability::AudioSpeak));
assert!(caps.has(&Capability::DisplayText));
assert!(!caps.has(&Capability::AudioPlay));
assert!(caps.revoke(&Capability::AudioSpeak));
assert!(!caps.has(&Capability::AudioSpeak));
}
#[test]
fn test_pattern_matching() {
let mut caps = CapabilitySet::new();
caps.grant(Capability::NetworkFetch(
"https://api.example.com/*".to_string(),
));
caps.grant(Capability::FileRead("/home/user/data/*".to_string()));
assert!(caps.has_matching(&Capability::NetworkFetch(
"https://api.example.com/users".to_string()
)));
assert!(!caps.has_matching(&Capability::NetworkFetch("https://evil.com/".to_string())));
assert!(caps.has_matching(&Capability::FileRead(
"/home/user/data/file.txt".to_string()
)));
assert!(!caps.has_matching(&Capability::FileRead("/etc/passwd".to_string())));
}
#[test]
fn test_set_operations() {
let caps1 = CapabilitySet::from_capabilities(vec![Capability::AudioSpeak, Capability::DisplayText]);
let caps2 = CapabilitySet::from_capabilities(vec![Capability::DisplayText, Capability::SystemTime]);
let intersection = caps1.intersection(&caps2);
assert!(intersection.has(&Capability::DisplayText));
assert!(!intersection.has(&Capability::AudioSpeak));
assert!(!intersection.has(&Capability::SystemTime));
let union = caps1.union(&caps2);
assert!(union.has(&Capability::AudioSpeak));
assert!(union.has(&Capability::DisplayText));
assert!(union.has(&Capability::SystemTime));
}
}