mcp_authorization/
capability.rs1use std::collections::HashSet;
2use std::marker::PhantomData;
3use std::sync::Arc;
4
5pub trait Capability: Send + Sync + 'static {
18 const NAME: &'static str;
20}
21
22pub struct Proof<C: Capability> {
48 _marker: PhantomData<C>,
49}
50
51impl<C: Capability> Clone for Proof<C> {
53 fn clone(&self) -> Self {
54 *self
55 }
56}
57
58impl<C: Capability> Copy for Proof<C> {}
59
60impl<C: Capability> std::fmt::Debug for Proof<C> {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "Proof<{}>", C::NAME)
63 }
64}
65
66#[derive(Clone, Debug)]
84pub struct AuthContext {
85 capabilities: Arc<HashSet<String>>,
86}
87
88impl AuthContext {
89 pub fn new(caps: impl IntoIterator<Item = impl Into<String>>) -> Self {
91 Self {
92 capabilities: Arc::new(caps.into_iter().map(Into::into).collect()),
93 }
94 }
95
96 pub fn check<C: Capability>(&self) -> Option<Proof<C>> {
101 if self.capabilities.contains(C::NAME) {
102 Some(Proof {
103 _marker: PhantomData,
104 })
105 } else {
106 None
107 }
108 }
109
110 pub fn require<C: Capability>(&self) -> Result<Proof<C>, rmcp::ErrorData> {
112 self.check::<C>().ok_or_else(|| {
113 rmcp::ErrorData::invalid_params(
114 format!("missing required capability: {}", C::NAME),
115 None,
116 )
117 })
118 }
119
120 pub fn has(&self, name: &str) -> bool {
122 self.capabilities.contains(name)
123 }
124
125 pub fn capability_names(&self) -> &HashSet<String> {
127 &self.capabilities
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 struct ManageWorkflows;
136 impl Capability for ManageWorkflows {
137 const NAME: &'static str = "manage_workflows";
138 }
139
140 struct BackwardRouting;
141 impl Capability for BackwardRouting {
142 const NAME: &'static str = "backward_routing";
143 }
144
145 struct Admin;
146 impl Capability for Admin {
147 const NAME: &'static str = "admin";
148 }
149
150 #[test]
151 fn proof_is_zero_sized() {
152 assert_eq!(std::mem::size_of::<Proof<ManageWorkflows>>(), 0);
153 assert_eq!(std::mem::size_of::<Proof<BackwardRouting>>(), 0);
154 }
155
156 #[test]
157 fn check_returns_proof_when_capable() {
158 let auth = AuthContext::new(vec!["manage_workflows", "backward_routing"]);
159 assert!(auth.check::<ManageWorkflows>().is_some());
160 assert!(auth.check::<BackwardRouting>().is_some());
161 }
162
163 #[test]
164 fn check_returns_none_when_not_capable() {
165 let auth = AuthContext::new(vec!["manage_workflows"]);
166 assert!(auth.check::<BackwardRouting>().is_none());
167 }
168
169 #[test]
170 fn require_returns_error_when_not_capable() {
171 let auth = AuthContext::new(Vec::<String>::new());
172 let err = auth.require::<Admin>().unwrap_err();
173 assert!(err.message.contains("admin"));
174 }
175
176 #[test]
177 fn has_checks_by_string_name() {
178 let auth = AuthContext::new(vec!["manage_workflows"]);
179 assert!(auth.has("manage_workflows"));
180 assert!(!auth.has("admin"));
181 }
182
183 #[test]
184 fn empty_context_has_no_capabilities() {
185 let auth = AuthContext::new(Vec::<String>::new());
186 assert!(!auth.has("anything"));
187 assert!(auth.check::<Admin>().is_none());
188 }
189
190 #[test]
191 fn proof_is_copyable() {
192 let auth = AuthContext::new(vec!["manage_workflows"]);
193 let proof = auth.check::<ManageWorkflows>().unwrap();
194 let _copy = proof;
195 let _another = proof; }
197}