use std::fmt;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct UserId(pub String);
impl fmt::Display for UserId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl From<&str> for UserId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for UserId {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct AgentId(pub String);
impl fmt::Display for AgentId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl From<&str> for AgentId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for AgentId {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[serde(rename_all = "snake_case")]
pub enum Tier {
Read,
Write,
Execute,
Destructive,
External,
}
impl Tier {
pub fn requires_confirmation(self) -> bool {
matches!(self, Tier::Destructive | Tier::External)
}
pub fn default_timeout(self) -> std::time::Duration {
match self {
Tier::Read => std::time::Duration::from_secs(30),
Tier::Write => std::time::Duration::from_secs(45),
Tier::Execute => std::time::Duration::from_secs(60),
Tier::Destructive => std::time::Duration::from_secs(90),
Tier::External => std::time::Duration::from_secs(60),
}
}
}
impl fmt::Display for Tier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Tier::Read => "read",
Tier::Write => "write",
Tier::Execute => "execute",
Tier::Destructive => "destructive",
Tier::External => "external",
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Principal {
pub user_id: UserId,
pub agent_id: AgentId,
pub scopes: Vec<String>,
pub tier: Tier,
}
impl Principal {
pub fn has_scope(&self, verb_ns: &str, verb_action: &str) -> bool {
let target = format!("{verb_ns}.{verb_action}");
for scope in &self.scopes {
if scope == &target {
return true;
}
if let Some(prefix) = scope.strip_suffix(".*") {
if verb_ns == prefix {
return true;
}
}
if scope == "*" {
return true;
}
}
false
}
}
#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MatchKind {
#[default]
Exact,
Prefix,
HostSuffix,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModifierConstraint {
pub verb: String,
pub modifier: String,
#[serde(default)]
pub match_kind: MatchKind,
#[serde(default)]
pub allow: Vec<String>,
}
impl ModifierConstraint {
pub fn applies_to(&self, verb_ns: &str, verb_action: &str) -> bool {
if self.verb == "*" {
return true;
}
if let Some(prefix) = self.verb.strip_suffix(".*") {
return verb_ns == prefix;
}
let mut target = String::with_capacity(verb_ns.len() + 1 + verb_action.len());
target.push_str(verb_ns);
target.push('.');
target.push_str(verb_action);
self.verb == target
}
pub fn permits(&self, value: &str) -> bool {
self.allow.iter().any(|entry| match self.match_kind {
MatchKind::Exact => entry == value,
MatchKind::Prefix => value.starts_with(entry.as_str()),
MatchKind::HostSuffix => host_suffix_match(entry, value),
})
}
}
fn host_suffix_match(entry: &str, value: &str) -> bool {
let base = entry.trim_start_matches("*.").to_ascii_lowercase();
if base.is_empty() {
return false;
}
let value = value.to_ascii_lowercase();
value == base || value.ends_with(&format!(".{base}"))
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AgentHint {
AgentId(AgentId),
Anonymous,
}
#[derive(Clone, Debug)]
pub struct AuthorizationRequest {
pub verb_ns: String,
pub verb_action: String,
pub modifiers: serde_json::Value,
}
impl AuthorizationRequest {
pub fn new(verb_ns: impl Into<String>, verb_action: impl Into<String>) -> Self {
Self {
verb_ns: verb_ns.into(),
verb_action: verb_action.into(),
modifiers: serde_json::Value::Null,
}
}
pub fn with_modifiers(mut self, modifiers: serde_json::Value) -> Self {
self.modifiers = modifiers;
self
}
pub fn modifier_str(&self, key: &str) -> Option<&str> {
self.modifiers.get(key)?.as_str()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CheckOutcome {
Allow,
EscalateToUser { reason: String },
Deny { reason: String },
}
#[derive(Debug, Error)]
pub enum IdentityError {
#[error("unknown agent: {0}")]
UnknownAgent(String),
#[error("config error: {0}")]
Config(String),
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("yaml: {0}")]
Yaml(#[from] serde_yaml::Error),
}
#[async_trait]
pub trait IdentityStore: Send + Sync {
async fn principal_for(&self, agent_hint: &AgentHint) -> Result<Principal, IdentityError>;
async fn check(
&self,
p: &Principal,
req: &AuthorizationRequest,
required: Tier,
) -> CheckOutcome;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tier_is_strictly_ordered() {
assert!(Tier::Read < Tier::Write);
assert!(Tier::Write < Tier::Execute);
assert!(Tier::Execute < Tier::Destructive);
assert!(Tier::Destructive < Tier::External);
}
#[test]
fn tier_satisfies_via_ge() {
assert!(Tier::Execute >= Tier::Write);
assert!(Tier::Read < Tier::Execute);
}
#[test]
fn requires_confirmation_only_destructive_external() {
assert!(!Tier::Read.requires_confirmation());
assert!(!Tier::Write.requires_confirmation());
assert!(!Tier::Execute.requires_confirmation());
assert!(Tier::Destructive.requires_confirmation());
assert!(Tier::External.requires_confirmation());
}
#[test]
fn default_timeout_increases_with_tier() {
assert!(Tier::Read.default_timeout() < Tier::Execute.default_timeout());
assert!(Tier::Execute.default_timeout() < Tier::Destructive.default_timeout());
}
#[test]
fn external_timeout_is_shorter_than_destructive() {
assert!(Tier::External.default_timeout() < Tier::Destructive.default_timeout());
}
#[test]
fn display_matches_serde_repr() {
assert_eq!(Tier::Destructive.to_string(), "destructive");
let json = serde_json::to_string(&Tier::Destructive).unwrap();
assert_eq!(json, "\"destructive\"");
}
#[test]
fn has_scope_exact_match() {
let p = Principal {
user_id: "k".into(),
agent_id: "claude-code".into(),
scopes: vec!["shell.exec".into()],
tier: Tier::Execute,
};
assert!(p.has_scope("shell", "exec"));
assert!(!p.has_scope("shell", "kill"));
assert!(!p.has_scope("fs", "read"));
}
#[test]
fn has_scope_namespace_wildcard() {
let p = Principal {
user_id: "k".into(),
agent_id: "claude-code".into(),
scopes: vec!["fs.*".into()],
tier: Tier::Write,
};
assert!(p.has_scope("fs", "read"));
assert!(p.has_scope("fs", "write"));
assert!(!p.has_scope("shell", "exec"));
}
#[test]
fn has_scope_global_wildcard() {
let p = Principal {
user_id: "k".into(),
agent_id: "root".into(),
scopes: vec!["*".into()],
tier: Tier::External,
};
assert!(p.has_scope("shell", "exec"));
assert!(p.has_scope("fs", "read"));
assert!(p.has_scope("net", "http"));
}
#[test]
fn auth_request_modifier_str_reads_paths() {
let req = AuthorizationRequest::new("fs", "read")
.with_modifiers(serde_json::json!({ "path": "/tmp/x", "limit": 10 }));
assert_eq!(req.modifier_str("path"), Some("/tmp/x"));
assert_eq!(req.modifier_str("limit"), None); assert_eq!(req.modifier_str("missing"), None);
}
fn constraint(
verb: &str,
modifier: &str,
match_kind: MatchKind,
allow: &[&str],
) -> ModifierConstraint {
ModifierConstraint {
verb: verb.into(),
modifier: modifier.into(),
match_kind,
allow: allow.iter().map(|s| s.to_string()).collect(),
}
}
#[test]
fn constraint_applies_to_exact_namespace_and_global() {
let exact = constraint("net.http", "host", MatchKind::HostSuffix, &[]);
assert!(exact.applies_to("net", "http"));
assert!(!exact.applies_to("net", "connect"));
assert!(!exact.applies_to("fs", "read"));
let ns = constraint("net.*", "host", MatchKind::HostSuffix, &[]);
assert!(ns.applies_to("net", "http"));
assert!(ns.applies_to("net", "connect"));
assert!(!ns.applies_to("fs", "read"));
let global = constraint("*", "host", MatchKind::HostSuffix, &[]);
assert!(global.applies_to("anything", "at-all"));
}
#[test]
fn constraint_exact_match() {
let c = constraint(
"mcp.mount",
"name",
MatchKind::Exact,
&["github", "filesystem"],
);
assert!(c.permits("github"));
assert!(!c.permits("git")); assert!(!c.permits("evil"));
}
#[test]
fn constraint_prefix_match() {
let c = constraint(
"shell.exec",
"command",
MatchKind::Prefix,
&["git ", "cargo "],
);
assert!(c.permits("git status"));
assert!(c.permits("cargo build"));
assert!(!c.permits("rm -rf /"));
}
#[test]
fn constraint_host_suffix_match() {
let c = constraint(
"net.http",
"host",
MatchKind::HostSuffix,
&["github.com", "*.internal.example"],
);
assert!(c.permits("github.com")); assert!(c.permits("api.github.com")); assert!(c.permits("svc.internal.example")); assert!(c.permits("API.GitHub.com")); assert!(!c.permits("github.com.evil.com")); assert!(!c.permits("evil.com"));
}
#[test]
fn constraint_empty_allow_denies_everything() {
let c = constraint("net.http", "host", MatchKind::HostSuffix, &[]);
assert!(!c.permits("github.com"));
assert!(!c.permits(""));
}
use proptest::prelude::*;
fn principal_with(scopes: Vec<String>) -> Principal {
Principal {
user_id: "u".into(),
agent_id: "a".into(),
scopes,
tier: Tier::Execute,
}
}
proptest! {
#![proptest_config(ProptestConfig { cases: 512, .. ProptestConfig::default() })]
#[test]
fn global_scope_covers_everything(ns in "[a-z]{1,8}", action in "[a-z]{1,8}") {
prop_assert!(principal_with(vec!["*".into()]).has_scope(&ns, &action));
}
#[test]
fn namespace_wildcard_matches_only_its_namespace(
scope_ns in "[a-z]{1,8}",
ns in "[a-z]{1,8}",
action in "[a-z]{1,8}",
) {
let p = principal_with(vec![format!("{scope_ns}.*")]);
prop_assert_eq!(p.has_scope(&ns, &action), ns == scope_ns);
}
#[test]
fn exact_scope_matches_only_exact_verb(
sn in "[a-z]{1,8}", sa in "[a-z]{1,8}",
ns in "[a-z]{1,8}", action in "[a-z]{1,8}",
) {
let p = principal_with(vec![format!("{sn}.{sa}")]);
prop_assert_eq!(p.has_scope(&ns, &action), ns == sn && action == sa);
}
#[test]
fn constraint_namespace_wildcard_is_exact(
verb_ns in "[a-z]{1,8}",
ns in "[a-z]{1,8}",
action in "[a-z]{1,8}",
) {
let c = constraint(&format!("{verb_ns}.*"), "host", MatchKind::Exact, &[]);
prop_assert_eq!(c.applies_to(&ns, &action), ns == verb_ns);
}
#[test]
fn host_suffix_respects_label_boundary(
base in "[a-z]{1,6}\\.[a-z]{2,4}",
label in "[a-z]{1,5}",
) {
let c = constraint("net.http", "host", MatchKind::HostSuffix, &[&base]);
let subdomain = format!("{label}.{base}");
let sibling = format!("{label}{base}"); let upper = base.to_uppercase();
prop_assert!(c.permits(&base));
prop_assert!(c.permits(&subdomain));
prop_assert!(c.permits(&upper));
prop_assert!(!c.permits(&sibling));
let starred_entry = format!("*.{base}");
let starred = constraint("net.http", "host", MatchKind::HostSuffix, &[&starred_entry]);
for v in [base.clone(), subdomain.clone(), sibling.clone()] {
prop_assert_eq!(c.permits(&v), starred.permits(&v));
}
}
}
fn any_tier() -> impl Strategy<Value = Tier> {
prop_oneof![
Just(Tier::Read),
Just(Tier::Write),
Just(Tier::Execute),
Just(Tier::Destructive),
Just(Tier::External),
]
}
#[test]
fn tier_ladder_is_canonical() {
let mut all = [
Tier::External,
Tier::Read,
Tier::Destructive,
Tier::Write,
Tier::Execute,
];
all.sort();
assert_eq!(
all,
[
Tier::Read,
Tier::Write,
Tier::Execute,
Tier::Destructive,
Tier::External,
]
);
}
proptest! {
#![proptest_config(ProptestConfig { cases: 256, .. ProptestConfig::default() })]
#[test]
fn confirmation_gate_is_the_destructive_cut(t in any_tier()) {
prop_assert_eq!(t.requires_confirmation(), t >= Tier::Destructive);
}
#[test]
fn confirmation_requirement_is_upward_closed(a in any_tier(), b in any_tier()) {
let (lo, hi) = if a <= b { (a, b) } else { (b, a) };
if lo.requires_confirmation() {
prop_assert!(hi.requires_confirmation());
}
}
#[test]
fn authorization_is_upward_closed(
required in any_tier(),
p in any_tier(),
q in any_tier(),
) {
let (lo, hi) = if p <= q { (p, q) } else { (q, p) };
if lo >= required {
prop_assert!(hi >= required);
}
}
#[test]
fn principal_satisfies_iff_required_at_or_below(
principal in any_tier(),
required in any_tier(),
) {
prop_assert_eq!(principal >= required, required <= principal);
}
#[test]
fn default_timeout_is_positive_and_bounded(t in any_tier()) {
let secs = t.default_timeout().as_secs();
prop_assert!((1..=300).contains(&secs));
}
}
}