Skip to main content

mcp_authorization/
capability.rs

1use std::collections::HashSet;
2use std::marker::PhantomData;
3use std::sync::Arc;
4
5/// Marker trait for authorization capabilities.
6///
7/// Implement this on zero-sized types (ZSTs) to define capabilities:
8///
9/// ```
10/// use mcp_authorization::Capability;
11///
12/// struct ManageWorkflows;
13/// impl Capability for ManageWorkflows {
14///     const NAME: &'static str = "manage_workflows";
15/// }
16/// ```
17pub trait Capability: Send + Sync + 'static {
18    /// The wire name for this capability (matches JWT claims, DB flags, etc.)
19    const NAME: &'static str;
20}
21
22/// A zero-sized, compile-time proof that a capability was verified.
23///
24/// `Proof<C>` cannot be constructed outside this crate — the only way to
25/// obtain one is through [`AuthContext::check`] or [`AuthContext::require`].
26/// This means any function that demands `Proof<C>` in its signature is
27/// *statically guaranteed* to have been preceded by an authorization check.
28///
29/// ```compile_fail
30/// use mcp_authorization::{Capability, Proof};
31///
32/// struct Admin;
33/// impl Capability for Admin { const NAME: &'static str = "admin"; }
34///
35/// // This will not compile — Proof's fields are private:
36/// let fake = Proof::<Admin>::new();
37/// ```
38///
39/// At runtime, `Proof<C>` compiles away entirely:
40/// ```
41/// use mcp_authorization::{Capability, Proof};
42/// assert_eq!(std::mem::size_of::<Proof<Admin>>(), 0);
43///
44/// struct Admin;
45/// impl Capability for Admin { const NAME: &'static str = "admin"; }
46/// ```
47pub struct Proof<C: Capability> {
48    _marker: PhantomData<C>,
49}
50
51// Proof is Copy — it's zero-sized, so this is free.
52impl<C: Capability> Clone for Proof<C> {
53    fn clone(&self) -> Self {
54        *self
55    }
56}
57
58impl<C: Capability> Copy for Proof<C> {}
59
60impl<C: Capability> std::fmt::Debug for Proof<C> {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "Proof<{}>", C::NAME)
63    }
64}
65
66/// Per-request authorization context.
67///
68/// Built by middleware (e.g. from JWT claims) and stored in rmcp's
69/// `RequestContext::extensions`. Use rmcp's `Extension<AuthContext>`
70/// extractor to access it in tool handlers.
71///
72/// ```
73/// use mcp_authorization::{AuthContext, Capability, Proof};
74///
75/// struct BackwardRouting;
76/// impl Capability for BackwardRouting {
77///     const NAME: &'static str = "backward_routing";
78/// }
79///
80/// let auth = AuthContext::new(vec!["backward_routing"]);
81/// let proof: Proof<BackwardRouting> = auth.require::<BackwardRouting>().unwrap();
82/// ```
83#[derive(Clone, Debug)]
84pub struct AuthContext {
85    capabilities: Arc<HashSet<String>>,
86}
87
88impl AuthContext {
89    /// Create a new `AuthContext` from an iterable of capability names.
90    pub fn new(caps: impl IntoIterator<Item = impl Into<String>>) -> Self {
91        Self {
92            capabilities: Arc::new(caps.into_iter().map(Into::into).collect()),
93        }
94    }
95
96    /// An `AuthContext` with no capabilities — the deny-by-default identity.
97    ///
98    /// A request resolved to an empty context sees only ungated tools, never
99    /// any tool guarded by [`authorize`](crate::AuthorizedServer::authorize) or
100    /// any `#[requires(...)]` field/variant.
101    pub fn empty() -> Self {
102        Self::new(Vec::<String>::new())
103    }
104
105    /// Try to obtain a `Proof<C>`. Returns `Some(Proof)` if the user has
106    /// the capability, `None` otherwise.
107    ///
108    /// This is the **only way** to construct a `Proof<C>`.
109    pub fn check<C: Capability>(&self) -> Option<Proof<C>> {
110        if self.capabilities.contains(C::NAME) {
111            Some(Proof {
112                _marker: PhantomData,
113            })
114        } else {
115            None
116        }
117    }
118
119    /// Like [`check`](Self::check), but returns an `McpError` on failure.
120    pub fn require<C: Capability>(&self) -> Result<Proof<C>, rmcp::ErrorData> {
121        self.check::<C>().ok_or_else(|| {
122            rmcp::ErrorData::invalid_params(
123                format!("missing required capability: {}", C::NAME),
124                None,
125            )
126        })
127    }
128
129    /// String-based capability query for runtime schema shaping.
130    pub fn has(&self, name: &str) -> bool {
131        self.capabilities.contains(name)
132    }
133
134    /// Returns the set of capability names this context holds.
135    pub fn capability_names(&self) -> &HashSet<String> {
136        &self.capabilities
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    struct ManageWorkflows;
145    impl Capability for ManageWorkflows {
146        const NAME: &'static str = "manage_workflows";
147    }
148
149    struct BackwardRouting;
150    impl Capability for BackwardRouting {
151        const NAME: &'static str = "backward_routing";
152    }
153
154    struct Admin;
155    impl Capability for Admin {
156        const NAME: &'static str = "admin";
157    }
158
159    #[test]
160    fn proof_is_zero_sized() {
161        assert_eq!(std::mem::size_of::<Proof<ManageWorkflows>>(), 0);
162        assert_eq!(std::mem::size_of::<Proof<BackwardRouting>>(), 0);
163    }
164
165    #[test]
166    fn check_returns_proof_when_capable() {
167        let auth = AuthContext::new(vec!["manage_workflows", "backward_routing"]);
168        assert!(auth.check::<ManageWorkflows>().is_some());
169        assert!(auth.check::<BackwardRouting>().is_some());
170    }
171
172    #[test]
173    fn check_returns_none_when_not_capable() {
174        let auth = AuthContext::new(vec!["manage_workflows"]);
175        assert!(auth.check::<BackwardRouting>().is_none());
176    }
177
178    #[test]
179    fn require_returns_error_when_not_capable() {
180        let auth = AuthContext::new(Vec::<String>::new());
181        let err = auth.require::<Admin>().unwrap_err();
182        assert!(err.message.contains("admin"));
183    }
184
185    #[test]
186    fn has_checks_by_string_name() {
187        let auth = AuthContext::new(vec!["manage_workflows"]);
188        assert!(auth.has("manage_workflows"));
189        assert!(!auth.has("admin"));
190    }
191
192    #[test]
193    fn empty_context_has_no_capabilities() {
194        let auth = AuthContext::new(Vec::<String>::new());
195        assert!(!auth.has("anything"));
196        assert!(auth.check::<Admin>().is_none());
197    }
198
199    #[test]
200    fn proof_is_copyable() {
201        let auth = AuthContext::new(vec!["manage_workflows"]);
202        let proof = auth.check::<ManageWorkflows>().unwrap();
203        let _copy = proof;
204        let _another = proof; // still valid — Copy
205    }
206}