mcp_authorization/
capability.rs1use std::collections::HashSet;
2use std::marker::PhantomData;
3use std::sync::Arc;
4
5pub trait Capability: Send + Sync + 'static {
18 const NAME: &'static str;
20}
21
22pub struct Proof<C: Capability> {
48 _marker: PhantomData<C>,
49}
50
51impl<C: Capability> Clone for Proof<C> {
53 fn clone(&self) -> Self {
54 *self
55 }
56}
57
58impl<C: Capability> Copy for Proof<C> {}
59
60impl<C: Capability> std::fmt::Debug for Proof<C> {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "Proof<{}>", C::NAME)
63 }
64}
65
66#[derive(Clone, Debug)]
84pub struct AuthContext {
85 capabilities: Arc<HashSet<String>>,
86}
87
88impl AuthContext {
89 pub fn new(caps: impl IntoIterator<Item = impl Into<String>>) -> Self {
91 Self {
92 capabilities: Arc::new(caps.into_iter().map(Into::into).collect()),
93 }
94 }
95
96 pub fn empty() -> Self {
102 Self::new(Vec::<String>::new())
103 }
104
105 pub fn check<C: Capability>(&self) -> Option<Proof<C>> {
110 if self.capabilities.contains(C::NAME) {
111 Some(Proof {
112 _marker: PhantomData,
113 })
114 } else {
115 None
116 }
117 }
118
119 pub fn require<C: Capability>(&self) -> Result<Proof<C>, rmcp::ErrorData> {
121 self.check::<C>().ok_or_else(|| {
122 rmcp::ErrorData::invalid_params(
123 format!("missing required capability: {}", C::NAME),
124 None,
125 )
126 })
127 }
128
129 pub fn has(&self, name: &str) -> bool {
131 self.capabilities.contains(name)
132 }
133
134 pub fn capability_names(&self) -> &HashSet<String> {
136 &self.capabilities
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 struct ManageWorkflows;
145 impl Capability for ManageWorkflows {
146 const NAME: &'static str = "manage_workflows";
147 }
148
149 struct BackwardRouting;
150 impl Capability for BackwardRouting {
151 const NAME: &'static str = "backward_routing";
152 }
153
154 struct Admin;
155 impl Capability for Admin {
156 const NAME: &'static str = "admin";
157 }
158
159 #[test]
160 fn proof_is_zero_sized() {
161 assert_eq!(std::mem::size_of::<Proof<ManageWorkflows>>(), 0);
162 assert_eq!(std::mem::size_of::<Proof<BackwardRouting>>(), 0);
163 }
164
165 #[test]
166 fn check_returns_proof_when_capable() {
167 let auth = AuthContext::new(vec!["manage_workflows", "backward_routing"]);
168 assert!(auth.check::<ManageWorkflows>().is_some());
169 assert!(auth.check::<BackwardRouting>().is_some());
170 }
171
172 #[test]
173 fn check_returns_none_when_not_capable() {
174 let auth = AuthContext::new(vec!["manage_workflows"]);
175 assert!(auth.check::<BackwardRouting>().is_none());
176 }
177
178 #[test]
179 fn require_returns_error_when_not_capable() {
180 let auth = AuthContext::new(Vec::<String>::new());
181 let err = auth.require::<Admin>().unwrap_err();
182 assert!(err.message.contains("admin"));
183 }
184
185 #[test]
186 fn has_checks_by_string_name() {
187 let auth = AuthContext::new(vec!["manage_workflows"]);
188 assert!(auth.has("manage_workflows"));
189 assert!(!auth.has("admin"));
190 }
191
192 #[test]
193 fn empty_context_has_no_capabilities() {
194 let auth = AuthContext::new(Vec::<String>::new());
195 assert!(!auth.has("anything"));
196 assert!(auth.check::<Admin>().is_none());
197 }
198
199 #[test]
200 fn proof_is_copyable() {
201 let auth = AuthContext::new(vec!["manage_workflows"]);
202 let proof = auth.check::<ManageWorkflows>().unwrap();
203 let _copy = proof;
204 let _another = proof; }
206}