use std::collections::BTreeSet;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "kebab-case")]
#[non_exhaustive]
pub enum Capability {
Network {
#[serde(default)]
allow: Vec<SmolStr>,
},
Filesystem {
#[serde(default)]
read: Vec<SmolStr>,
#[serde(default)]
write: Vec<SmolStr>,
},
HostQuery {
#[serde(default)]
read_only: bool,
#[serde(default)]
scopes: Vec<SmolStr>,
},
Kms {
#[serde(default)]
key_ids: Vec<SmolStr>,
},
Secret {
#[serde(default)]
ids: Vec<SmolStr>,
},
Lock {
granularity: LockGranularity,
},
Config {
#[serde(default)]
keys: Vec<SmolStr>,
},
PluginStorage,
ScalarFn,
AggregateFn,
WindowFn,
Procedure,
ProcedureWrites,
ProcedureSchema,
ProcedureDbms,
LocyAggregate,
LocyPredicate,
Operator,
Index,
Storage,
Algorithm,
Crdt,
Hook,
Trigger,
BackgroundJob {
max_concurrent: u32,
},
Type,
Auth,
Authz,
Connector,
Collation,
Cdc,
Catalog,
PluginDeclare,
MemoryBytes(u64),
FuelPerCall(u64),
WallClockMillisPerCall(u64),
ConcurrentInstances(u32),
TotalMemoryBytes(u64),
MaxResultRows(u64),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub enum LockGranularity {
Nodes,
Edges,
Both,
Global,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct CapabilitySet {
set: BTreeSet<Capability>,
}
impl CapabilitySet {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_iter_of(caps: impl IntoIterator<Item = Capability>) -> Self {
Self {
set: caps.into_iter().collect(),
}
}
#[must_use]
pub fn from_manifest(caps: impl IntoIterator<Item = ManifestCapability>) -> Self {
Self::from_iter_of(caps.into_iter().map(|m| m.0))
}
pub fn insert(&mut self, cap: Capability) -> bool {
self.set.insert(cap)
}
#[must_use]
pub fn contains(&self, cap: &Capability) -> bool {
self.set.contains(cap)
}
#[must_use]
pub fn contains_variant(&self, target: &Capability) -> bool {
self.set.iter().any(|c| variant_matches(c, target))
}
#[must_use]
pub fn intersect(&self, other: &Self) -> Self {
let mut out = Self::new();
for c in &self.set {
if other.contains_variant(c) {
out.insert(c.clone());
}
}
out
}
pub fn iter(&self) -> impl Iterator<Item = &Capability> {
self.set.iter()
}
#[must_use]
pub fn len(&self) -> usize {
self.set.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.set.is_empty()
}
}
fn variant_matches(a: &Capability, b: &Capability) -> bool {
std::mem::discriminant(a) == std::mem::discriminant(b)
}
impl Capability {
#[must_use]
pub fn network_allows(&self, url: &str) -> bool {
matches!(self, Capability::Network { allow } if allow.iter().any(|p| wildcard_match(p, url)))
}
#[must_use]
pub fn kms_allows(&self, key_id: &str) -> bool {
matches!(self, Capability::Kms { key_ids } if key_ids.iter().any(|p| wildcard_match(p, key_id)))
}
#[must_use]
pub fn secret_allows(&self, id: &str) -> bool {
matches!(self, Capability::Secret { ids } if ids.iter().any(|p| wildcard_match(p, id)))
}
#[must_use]
pub fn filesystem_read_allows(&self, path: &str) -> bool {
matches!(self, Capability::Filesystem { read, .. } if read.iter().any(|p| wildcard_match(p, path)))
}
#[must_use]
pub fn filesystem_write_allows(&self, path: &str) -> bool {
matches!(self, Capability::Filesystem { write, .. } if write.iter().any(|p| wildcard_match(p, path)))
}
}
#[derive(Clone, Debug)]
pub struct ManifestCapability(pub Capability);
impl<'de> Deserialize<'de> for ManifestCapability {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum Repr {
Bare(String),
Full(Capability),
}
let cap = match Repr::deserialize(deserializer)? {
Repr::Full(c) => c,
Repr::Bare(name) => {
let tagged = serde_json::json!({ "kind": name });
Capability::deserialize(tagged).map_err(serde::de::Error::custom)?
}
};
Ok(ManifestCapability(cap))
}
}
fn wildcard_match(pattern: &str, text: &str) -> bool {
let p = pattern.as_bytes();
let t = text.as_bytes();
let (mut pi, mut ti) = (0usize, 0usize);
let mut star: Option<usize> = None;
let mut mark = 0usize;
while ti < t.len() {
if pi < p.len() && p[pi] == b'*' {
while pi < p.len() && p[pi] == b'*' {
pi += 1;
}
if pi == p.len() {
return true;
}
star = Some(pi);
mark = ti;
} else if pi < p.len() && p[pi] == t[ti] {
pi += 1;
ti += 1;
} else if let Some(s) = star {
pi = s;
mark += 1;
ti = mark;
} else {
return false;
}
}
while pi < p.len() && p[pi] == b'*' {
pi += 1;
}
pi == p.len()
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum Determinism {
Pure,
SessionScoped,
#[default]
Nondeterministic,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum SideEffects {
#[default]
ReadOnly,
Writes,
ExternalIo,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum Scope {
#[default]
Instance,
Session,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn capability_set_default_empty() {
let s = CapabilitySet::new();
assert!(s.is_empty());
assert_eq!(s.len(), 0);
}
#[test]
fn capability_set_insert_dedup() {
let mut s = CapabilitySet::new();
assert!(s.insert(Capability::ScalarFn));
assert!(!s.insert(Capability::ScalarFn));
assert_eq!(s.len(), 1);
}
#[test]
fn intersect_keeps_matching_variants() {
let a = CapabilitySet::from_iter_of([
Capability::ScalarFn,
Capability::Storage,
Capability::Network {
allow: vec![SmolStr::new("https://api.example/**")],
},
]);
let b = CapabilitySet::from_iter_of([
Capability::ScalarFn,
Capability::Network {
allow: vec![SmolStr::new("https://api.example/**")],
},
]);
let inter = a.intersect(&b);
assert!(inter.contains(&Capability::ScalarFn));
assert!(!inter.contains_variant(&Capability::Storage));
assert!(inter.contains_variant(&Capability::Network { allow: vec![] }));
}
#[test]
fn contains_variant_ignores_attenuation() {
let s = CapabilitySet::from_iter_of([Capability::Network {
allow: vec![SmolStr::new("https://x.example/*")],
}]);
assert!(s.contains_variant(&Capability::Network { allow: vec![] }));
assert!(!s.contains(&Capability::Network { allow: vec![] }));
}
#[test]
fn determinism_default_is_nondeterministic() {
assert_eq!(Determinism::default(), Determinism::Nondeterministic);
}
#[test]
fn wildcard_match_basics() {
assert!(wildcard_match("*", "anything"));
assert!(wildcard_match("**", "any/thing"));
assert!(wildcard_match(
"https://api.example/**",
"https://api.example/v1/x"
));
assert!(wildcard_match("exact", "exact"));
assert!(!wildcard_match("exact", "other"));
assert!(!wildcard_match(
"https://api.example/**",
"https://evil.example/x"
));
assert!(wildcard_match("a*c", "abbbc"));
assert!(!wildcard_match("a*c", "abbb"));
}
#[test]
fn network_allows_matches_only_network_variant() {
let net = Capability::Network {
allow: vec![SmolStr::new("https://api.example/**")],
};
assert!(net.network_allows("https://api.example/v1/data"));
assert!(!net.network_allows("https://evil.example/x"));
assert!(!Capability::ScalarFn.network_allows("https://api.example/x"));
}
#[test]
fn kms_and_secret_allow_wildcard_and_exact() {
let kms = Capability::Kms {
key_ids: vec![SmolStr::new("*")],
};
assert!(kms.kms_allows("signing-key-1"));
let secret = Capability::Secret {
ids: vec![SmolStr::new("db-password")],
};
assert!(secret.secret_allows("db-password"));
assert!(!secret.secret_allows("other"));
}
#[test]
fn manifest_capability_parses_bare_and_structured() {
let bare: ManifestCapability = serde_json::from_str("\"network\"").unwrap();
assert!(matches!(&bare.0, Capability::Network { allow } if allow.is_empty()));
assert!(!bare.0.network_allows("https://api.example/x"));
let scalar: ManifestCapability = serde_json::from_str("\"scalar-fn\"").unwrap();
assert_eq!(scalar.0, Capability::ScalarFn);
let structured: ManifestCapability =
serde_json::from_str(r#"{"kind":"network","allow":["https://api.example/**"]}"#)
.unwrap();
assert!(structured.0.network_allows("https://api.example/v1/x"));
assert!(!structured.0.network_allows("https://evil.example/x"));
let set = CapabilitySet::from_manifest([bare, scalar, structured]);
assert!(set.contains_variant(&Capability::Network { allow: vec![] }));
assert!(set.contains(&Capability::ScalarFn));
}
#[test]
fn filesystem_allows_read_and_write_separately() {
let fs = Capability::Filesystem {
read: vec![SmolStr::new("/data/**")],
write: vec![SmolStr::new("/tmp/out/**")],
};
assert!(fs.filesystem_read_allows("/data/x/y.txt"));
assert!(!fs.filesystem_read_allows("/etc/passwd"));
assert!(fs.filesystem_write_allows("/tmp/out/log"));
assert!(!fs.filesystem_write_allows("/data/x/y.txt"));
assert!(!Capability::ScalarFn.filesystem_read_allows("/data/x"));
}
}