use std::collections::HashSet;
use std::marker::PhantomData;
use std::sync::Arc;
pub trait Capability: Send + Sync + 'static {
const NAME: &'static str;
}
pub struct Proof<C: Capability> {
_marker: PhantomData<C>,
}
impl<C: Capability> Clone for Proof<C> {
fn clone(&self) -> Self {
*self
}
}
impl<C: Capability> Copy for Proof<C> {}
impl<C: Capability> std::fmt::Debug for Proof<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Proof<{}>", C::NAME)
}
}
#[derive(Clone, Debug)]
pub struct AuthContext {
capabilities: Arc<HashSet<String>>,
}
impl AuthContext {
pub fn new(caps: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
capabilities: Arc::new(caps.into_iter().map(Into::into).collect()),
}
}
pub fn check<C: Capability>(&self) -> Option<Proof<C>> {
if self.capabilities.contains(C::NAME) {
Some(Proof {
_marker: PhantomData,
})
} else {
None
}
}
pub fn require<C: Capability>(&self) -> Result<Proof<C>, rmcp::ErrorData> {
self.check::<C>().ok_or_else(|| {
rmcp::ErrorData::invalid_params(
format!("missing required capability: {}", C::NAME),
None,
)
})
}
pub fn has(&self, name: &str) -> bool {
self.capabilities.contains(name)
}
pub fn capability_names(&self) -> &HashSet<String> {
&self.capabilities
}
}
#[cfg(test)]
mod tests {
use super::*;
struct ManageWorkflows;
impl Capability for ManageWorkflows {
const NAME: &'static str = "manage_workflows";
}
struct BackwardRouting;
impl Capability for BackwardRouting {
const NAME: &'static str = "backward_routing";
}
struct Admin;
impl Capability for Admin {
const NAME: &'static str = "admin";
}
#[test]
fn proof_is_zero_sized() {
assert_eq!(std::mem::size_of::<Proof<ManageWorkflows>>(), 0);
assert_eq!(std::mem::size_of::<Proof<BackwardRouting>>(), 0);
}
#[test]
fn check_returns_proof_when_capable() {
let auth = AuthContext::new(vec!["manage_workflows", "backward_routing"]);
assert!(auth.check::<ManageWorkflows>().is_some());
assert!(auth.check::<BackwardRouting>().is_some());
}
#[test]
fn check_returns_none_when_not_capable() {
let auth = AuthContext::new(vec!["manage_workflows"]);
assert!(auth.check::<BackwardRouting>().is_none());
}
#[test]
fn require_returns_error_when_not_capable() {
let auth = AuthContext::new(Vec::<String>::new());
let err = auth.require::<Admin>().unwrap_err();
assert!(err.message.contains("admin"));
}
#[test]
fn has_checks_by_string_name() {
let auth = AuthContext::new(vec!["manage_workflows"]);
assert!(auth.has("manage_workflows"));
assert!(!auth.has("admin"));
}
#[test]
fn empty_context_has_no_capabilities() {
let auth = AuthContext::new(Vec::<String>::new());
assert!(!auth.has("anything"));
assert!(auth.check::<Admin>().is_none());
}
#[test]
fn proof_is_copyable() {
let auth = AuthContext::new(vec!["manage_workflows"]);
let proof = auth.check::<ManageWorkflows>().unwrap();
let _copy = proof;
let _another = proof; }
}