mcp_authorization/
registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use schemars::JsonSchema;
5use serde_json::Value;
6
7use crate::capability::AuthContext;
8use crate::metadata::AuthSchemaMetadata;
9
10pub struct AuthToolDef {
12 pub base_tool: rmcp::model::Tool,
14 pub authorization: Option<&'static str>,
16 pub input_requirements: &'static [(&'static str, &'static str)],
18 pub output_requirements: &'static [(&'static str, &'static str)],
20}
21
22pub struct AuthToolRegistry {
28 tools: HashMap<String, AuthToolDef>,
29 order: Vec<String>,
31}
32
33impl AuthToolRegistry {
34 pub fn new() -> Self {
35 Self {
36 tools: HashMap::new(),
37 order: Vec::new(),
38 }
39 }
40
41 pub fn register(&mut self, def: AuthToolDef) {
43 let name = def.base_tool.name.to_string();
44 if !self.tools.contains_key(&name) {
45 self.order.push(name.clone());
46 }
47 self.tools.insert(name, def);
48 }
49
50 pub fn register_typed<I, O>(
60 &mut self,
61 name: impl Into<String>,
62 description: impl Into<String>,
63 ) where
64 I: JsonSchema + AuthSchemaMetadata + serde::de::DeserializeOwned + 'static,
65 O: JsonSchema + AuthSchemaMetadata + serde::Serialize + 'static,
66 {
67 let name = name.into();
68 let full_input = rmcp::handler::server::tool::schema_for_type::<I>();
69 let full_output = rmcp::handler::server::tool::schema_for_output::<O>().ok();
70
71 let mut tool = rmcp::model::Tool::new(name.clone(), description.into(), full_input);
72 if let Some(output) = full_output {
73 tool = tool.with_raw_output_schema(output);
74 }
75
76 self.register(AuthToolDef {
77 base_tool: tool,
78 authorization: None,
79 input_requirements: I::requirements(),
80 output_requirements: O::requirements(),
81 });
82 }
83
84 pub fn set_authorization(&mut self, tool_name: &str, capability: &'static str) {
86 if let Some(def) = self.tools.get_mut(tool_name) {
87 def.authorization = Some(capability);
88 }
89 }
90
91 pub fn materialize(&self, auth: &AuthContext) -> Vec<rmcp::model::Tool> {
96 self.order
97 .iter()
98 .filter_map(|name| {
99 let def = self.tools.get(name)?;
100
101 if let Some(required) = def.authorization {
103 if !auth.has(required) {
104 return None;
105 }
106 }
107
108 let mut tool = def.base_tool.clone();
109
110 if !def.input_requirements.is_empty() {
112 let fields_to_remove: Vec<&str> = def
113 .input_requirements
114 .iter()
115 .filter(|(_, cap)| !auth.has(cap))
116 .map(|(field, _)| *field)
117 .collect();
118
119 if !fields_to_remove.is_empty() {
120 let mut schema = (*tool.input_schema).clone();
121 remove_properties(&mut schema, &fields_to_remove);
122 tool.input_schema = Arc::new(schema);
123 }
124 }
125
126 if !def.output_requirements.is_empty() {
128 if let Some(ref output) = tool.output_schema {
129 let variants_to_remove: Vec<&str> = def
130 .output_requirements
131 .iter()
132 .filter(|(_, cap)| !auth.has(cap))
133 .map(|(variant, _)| *variant)
134 .collect();
135
136 if !variants_to_remove.is_empty() {
137 let mut schema = (**output).clone();
138 remove_variants(&mut schema, &variants_to_remove);
139 tool.output_schema = Some(Arc::new(schema));
140 }
141 }
142 }
143
144 Some(tool)
145 })
146 .collect()
147 }
148
149 pub fn is_visible(&self, tool_name: &str, auth: &AuthContext) -> bool {
151 self.tools.get(tool_name).map_or(false, |def| {
152 def.authorization.map_or(true, |cap| auth.has(cap))
153 })
154 }
155
156 pub fn get(&self, tool_name: &str) -> Option<&AuthToolDef> {
158 self.tools.get(tool_name)
159 }
160}
161
162impl Default for AuthToolRegistry {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168fn remove_properties(schema: &mut serde_json::Map<String, Value>, fields: &[&str]) {
172 if let Some(Value::Object(props)) = schema.get_mut("properties") {
173 for field in fields {
174 props.remove(*field);
175 }
176 }
177 if let Some(Value::Array(required)) = schema.get_mut("required") {
178 required.retain(|v| v.as_str().map_or(true, |name| !fields.contains(&name)));
179 }
180}
181
182fn remove_variants(schema: &mut serde_json::Map<String, Value>, variants: &[&str]) {
183 for key in &["oneOf", "anyOf"] {
184 if let Some(Value::Array(items)) = schema.get_mut(*key) {
185 items.retain(|item| {
186 let name = variant_name(item);
187 match name {
188 Some(n) => !variants.contains(&n.as_str()),
189 None => true,
190 }
191 });
192 }
193 }
194}
195
196fn variant_name(item: &Value) -> Option<String> {
197 let obj = item.as_object()?;
198 if let Some(Value::String(title)) = obj.get("title") {
199 return Some(title.clone());
200 }
201 if let Some(Value::String(ref_str)) = obj.get("$ref") {
202 return ref_str.rsplit('/').next().map(String::from);
203 }
204 if let Some(Value::Object(props)) = obj.get("properties") {
205 if let Some(Value::Object(type_prop)) = props.get("type") {
206 if let Some(Value::String(const_val)) = type_prop.get("const") {
207 return Some(const_val.clone());
208 }
209 }
210 }
211 None
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[derive(serde::Deserialize, JsonSchema)]
219 #[allow(dead_code)]
220 struct Input {
221 pub name: String,
222 pub secret: Option<String>,
223 }
224
225 impl AuthSchemaMetadata for Input {
226 fn requirements() -> &'static [(&'static str, &'static str)] {
227 &[("secret", "admin")]
228 }
229 }
230
231 #[derive(serde::Serialize, JsonSchema)]
232 #[serde(tag = "type")]
233 #[allow(dead_code)]
234 enum Output {
235 Ok { id: String },
236 AdminOk { id: String, detail: String },
237 }
238
239 impl AuthSchemaMetadata for Output {
240 fn requirements() -> &'static [(&'static str, &'static str)] {
241 &[("AdminOk", "admin")]
242 }
243 }
244
245 #[test]
246 fn materialize_hides_unauthorized_tools() {
247 let mut reg = AuthToolRegistry::new();
248 reg.register_typed::<Input, Output>("my_tool", "A tool");
249 reg.set_authorization("my_tool", "admin");
250
251 let no_auth = AuthContext::new(Vec::<String>::new());
252 assert!(reg.materialize(&no_auth).is_empty());
253
254 let admin = AuthContext::new(vec!["admin"]);
255 assert_eq!(reg.materialize(&admin).len(), 1);
256 }
257
258 #[test]
259 fn materialize_shapes_input_schema() {
260 let mut reg = AuthToolRegistry::new();
261 reg.register_typed::<Input, Output>("my_tool", "A tool");
262
263 let no_auth = AuthContext::new(Vec::<String>::new());
264 let tools = reg.materialize(&no_auth);
265 let schema = &tools[0].input_schema;
266 let props = schema.get("properties").unwrap().as_object().unwrap();
267 assert!(props.contains_key("name"));
268 assert!(!props.contains_key("secret"));
269 }
270
271 #[test]
272 fn is_visible_checks_tool_authorization() {
273 let mut reg = AuthToolRegistry::new();
274 reg.register_typed::<Input, Output>("my_tool", "A tool");
275 reg.set_authorization("my_tool", "admin");
276
277 let no_auth = AuthContext::new(Vec::<String>::new());
278 assert!(!reg.is_visible("my_tool", &no_auth));
279
280 let admin = AuthContext::new(vec!["admin"]);
281 assert!(reg.is_visible("my_tool", &admin));
282 }
283}