Skip to main content

mcp_authorization/
capability.rs

1use std::collections::HashSet;
2use std::marker::PhantomData;
3use std::sync::Arc;
4
5/// Marker trait for authorization capabilities.
6///
7/// Implement this on zero-sized types (ZSTs) to define capabilities:
8///
9/// ```
10/// use mcp_authorization::Capability;
11///
12/// struct ManageWorkflows;
13/// impl Capability for ManageWorkflows {
14///     const NAME: &'static str = "manage_workflows";
15/// }
16/// ```
17pub trait Capability: Send + Sync + 'static {
18    /// The wire name for this capability (matches JWT claims, DB flags, etc.)
19    const NAME: &'static str;
20}
21
22/// A zero-sized, compile-time proof that a capability was verified.
23///
24/// `Proof<C>` cannot be constructed outside this crate — the only way to
25/// obtain one is through [`AuthContext::check`] or [`AuthContext::require`].
26/// This means any function that demands `Proof<C>` in its signature is
27/// *statically guaranteed* to have been preceded by an authorization check.
28///
29/// ```compile_fail
30/// use mcp_authorization::{Capability, Proof};
31///
32/// struct Admin;
33/// impl Capability for Admin { const NAME: &'static str = "admin"; }
34///
35/// // This will not compile — Proof's fields are private:
36/// let fake = Proof::<Admin>::new();
37/// ```
38///
39/// At runtime, `Proof<C>` compiles away entirely:
40/// ```
41/// use mcp_authorization::{Capability, Proof};
42/// assert_eq!(std::mem::size_of::<Proof<Admin>>(), 0);
43///
44/// struct Admin;
45/// impl Capability for Admin { const NAME: &'static str = "admin"; }
46/// ```
47pub struct Proof<C: Capability> {
48    _marker: PhantomData<C>,
49}
50
51// Proof is Copy — it's zero-sized, so this is free.
52impl<C: Capability> Clone for Proof<C> {
53    fn clone(&self) -> Self {
54        *self
55    }
56}
57
58impl<C: Capability> Copy for Proof<C> {}
59
60impl<C: Capability> std::fmt::Debug for Proof<C> {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "Proof<{}>", C::NAME)
63    }
64}
65
66/// Per-request authorization context.
67///
68/// Built by middleware (e.g. from JWT claims) and stored in rmcp's
69/// `RequestContext::extensions`. Use rmcp's `Extension<AuthContext>`
70/// extractor to access it in tool handlers.
71///
72/// ```
73/// use mcp_authorization::{AuthContext, Capability, Proof};
74///
75/// struct BackwardRouting;
76/// impl Capability for BackwardRouting {
77///     const NAME: &'static str = "backward_routing";
78/// }
79///
80/// let auth = AuthContext::new(vec!["backward_routing"]);
81/// let proof: Proof<BackwardRouting> = auth.require::<BackwardRouting>().unwrap();
82/// ```
83#[derive(Clone, Debug)]
84pub struct AuthContext {
85    capabilities: Arc<HashSet<String>>,
86}
87
88impl AuthContext {
89    /// Create a new `AuthContext` from an iterable of capability names.
90    pub fn new(caps: impl IntoIterator<Item = impl Into<String>>) -> Self {
91        Self {
92            capabilities: Arc::new(caps.into_iter().map(Into::into).collect()),
93        }
94    }
95
96    /// Try to obtain a `Proof<C>`. Returns `Some(Proof)` if the user has
97    /// the capability, `None` otherwise.
98    ///
99    /// This is the **only way** to construct a `Proof<C>`.
100    pub fn check<C: Capability>(&self) -> Option<Proof<C>> {
101        if self.capabilities.contains(C::NAME) {
102            Some(Proof {
103                _marker: PhantomData,
104            })
105        } else {
106            None
107        }
108    }
109
110    /// Like [`check`](Self::check), but returns an `McpError` on failure.
111    pub fn require<C: Capability>(&self) -> Result<Proof<C>, rmcp::ErrorData> {
112        self.check::<C>().ok_or_else(|| {
113            rmcp::ErrorData::invalid_params(
114                format!("missing required capability: {}", C::NAME),
115                None,
116            )
117        })
118    }
119
120    /// String-based capability query for runtime schema shaping.
121    pub fn has(&self, name: &str) -> bool {
122        self.capabilities.contains(name)
123    }
124
125    /// Returns the set of capability names this context holds.
126    pub fn capability_names(&self) -> &HashSet<String> {
127        &self.capabilities
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    struct ManageWorkflows;
136    impl Capability for ManageWorkflows {
137        const NAME: &'static str = "manage_workflows";
138    }
139
140    struct BackwardRouting;
141    impl Capability for BackwardRouting {
142        const NAME: &'static str = "backward_routing";
143    }
144
145    struct Admin;
146    impl Capability for Admin {
147        const NAME: &'static str = "admin";
148    }
149
150    #[test]
151    fn proof_is_zero_sized() {
152        assert_eq!(std::mem::size_of::<Proof<ManageWorkflows>>(), 0);
153        assert_eq!(std::mem::size_of::<Proof<BackwardRouting>>(), 0);
154    }
155
156    #[test]
157    fn check_returns_proof_when_capable() {
158        let auth = AuthContext::new(vec!["manage_workflows", "backward_routing"]);
159        assert!(auth.check::<ManageWorkflows>().is_some());
160        assert!(auth.check::<BackwardRouting>().is_some());
161    }
162
163    #[test]
164    fn check_returns_none_when_not_capable() {
165        let auth = AuthContext::new(vec!["manage_workflows"]);
166        assert!(auth.check::<BackwardRouting>().is_none());
167    }
168
169    #[test]
170    fn require_returns_error_when_not_capable() {
171        let auth = AuthContext::new(Vec::<String>::new());
172        let err = auth.require::<Admin>().unwrap_err();
173        assert!(err.message.contains("admin"));
174    }
175
176    #[test]
177    fn has_checks_by_string_name() {
178        let auth = AuthContext::new(vec!["manage_workflows"]);
179        assert!(auth.has("manage_workflows"));
180        assert!(!auth.has("admin"));
181    }
182
183    #[test]
184    fn empty_context_has_no_capabilities() {
185        let auth = AuthContext::new(Vec::<String>::new());
186        assert!(!auth.has("anything"));
187        assert!(auth.check::<Admin>().is_none());
188    }
189
190    #[test]
191    fn proof_is_copyable() {
192        let auth = AuthContext::new(vec!["manage_workflows"]);
193        let proof = auth.check::<ManageWorkflows>().unwrap();
194        let _copy = proof;
195        let _another = proof; // still valid — Copy
196    }
197}