Skip to main content

mcp_authorization/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use schemars::JsonSchema;
5use serde_json::Value;
6
7use crate::capability::AuthContext;
8use crate::metadata::AuthSchemaMetadata;
9
10/// A tool definition with authorization metadata attached.
11pub struct AuthToolDef {
12    /// The base rmcp tool (name, description, full schemas)
13    pub base_tool: rmcp::model::Tool,
14    /// Tool-level gate: entire tool hidden if user lacks this capability
15    pub authorization: Option<&'static str>,
16    /// Field-level requirements for input schema shaping
17    pub input_requirements: &'static [(&'static str, &'static str)],
18    /// Variant-level requirements for output schema shaping
19    pub output_requirements: &'static [(&'static str, &'static str)],
20}
21
22/// Registry of tools with authorization metadata.
23///
24/// On each request, `materialize` produces per-user tool lists with
25/// shaped schemas — the same concept as `ToolRegistry.tool_classes_for`
26/// in the Ruby gem.
27pub struct AuthToolRegistry {
28    tools: HashMap<String, AuthToolDef>,
29    /// Insertion order for deterministic tool listing
30    order: Vec<String>,
31}
32
33impl AuthToolRegistry {
34    pub fn new() -> Self {
35        Self {
36            tools: HashMap::new(),
37            order: Vec::new(),
38        }
39    }
40
41    /// Register a tool with its authorization metadata.
42    pub fn register(&mut self, def: AuthToolDef) {
43        let name = def.base_tool.name.to_string();
44        if !self.tools.contains_key(&name) {
45            self.order.push(name.clone());
46        }
47        self.tools.insert(name, def);
48    }
49
50    /// Register a tool using type information for schema and auth metadata.
51    ///
52    /// This is the ergonomic builder method:
53    /// ```ignore
54    /// registry.register_typed::<AdvanceStepInput, AdvanceStepOutput>(
55    ///     "advance_step",
56    ///     "Advance an applicant in their workflow",
57    /// );
58    /// ```
59    pub fn register_typed<I, O>(
60        &mut self,
61        name: impl Into<String>,
62        description: impl Into<String>,
63    ) where
64        I: JsonSchema + AuthSchemaMetadata + serde::de::DeserializeOwned + 'static,
65        O: JsonSchema + AuthSchemaMetadata + serde::Serialize + 'static,
66    {
67        let name = name.into();
68        let full_input = rmcp::handler::server::tool::schema_for_type::<I>();
69        let full_output = rmcp::handler::server::tool::schema_for_output::<O>().ok();
70
71        let mut tool = rmcp::model::Tool::new(name.clone(), description.into(), full_input);
72        if let Some(output) = full_output {
73            tool = tool.with_raw_output_schema(output);
74        }
75
76        self.register(AuthToolDef {
77            base_tool: tool,
78            authorization: None,
79            input_requirements: I::requirements(),
80            output_requirements: O::requirements(),
81        });
82    }
83
84    /// Set tool-level authorization for a named tool.
85    pub fn set_authorization(&mut self, tool_name: &str, capability: &'static str) {
86        if let Some(def) = self.tools.get_mut(tool_name) {
87            def.authorization = Some(capability);
88        }
89    }
90
91    /// Materialize tools for a specific user: filter hidden tools, shape schemas.
92    ///
93    /// This is the per-request method, equivalent to `ToolRegistry.tool_classes_for`
94    /// in the Ruby gem.
95    pub fn materialize(&self, auth: &AuthContext) -> Vec<rmcp::model::Tool> {
96        self.order
97            .iter()
98            .filter_map(|name| {
99                let def = self.tools.get(name)?;
100
101                // Tool-level gate
102                if let Some(required) = def.authorization {
103                    if !auth.has(required) {
104                        return None;
105                    }
106                }
107
108                let mut tool = def.base_tool.clone();
109
110                // Shape input schema
111                if !def.input_requirements.is_empty() {
112                    let fields_to_remove: Vec<&str> = def
113                        .input_requirements
114                        .iter()
115                        .filter(|(_, cap)| !auth.has(cap))
116                        .map(|(field, _)| *field)
117                        .collect();
118
119                    if !fields_to_remove.is_empty() {
120                        let mut schema = (*tool.input_schema).clone();
121                        remove_properties(&mut schema, &fields_to_remove);
122                        tool.input_schema = Arc::new(schema);
123                    }
124                }
125
126                // Shape output schema
127                if !def.output_requirements.is_empty() {
128                    if let Some(ref output) = tool.output_schema {
129                        let variants_to_remove: Vec<&str> = def
130                            .output_requirements
131                            .iter()
132                            .filter(|(_, cap)| !auth.has(cap))
133                            .map(|(variant, _)| *variant)
134                            .collect();
135
136                        if !variants_to_remove.is_empty() {
137                            let mut schema = (**output).clone();
138                            remove_variants(&mut schema, &variants_to_remove);
139                            tool.output_schema = Some(Arc::new(schema));
140                        }
141                    }
142                }
143
144                Some(tool)
145            })
146            .collect()
147    }
148
149    /// Check if a tool is visible to the given auth context.
150    pub fn is_visible(&self, tool_name: &str, auth: &AuthContext) -> bool {
151        self.tools.get(tool_name).map_or(false, |def| {
152            def.authorization.map_or(true, |cap| auth.has(cap))
153        })
154    }
155
156    /// Get a tool definition by name (unshaped).
157    pub fn get(&self, tool_name: &str) -> Option<&AuthToolDef> {
158        self.tools.get(tool_name)
159    }
160}
161
162impl Default for AuthToolRegistry {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168// Schema manipulation helpers (duplicated from schema.rs to avoid
169// coupling the registry to the generic SchemaShaper)
170
171fn remove_properties(schema: &mut serde_json::Map<String, Value>, fields: &[&str]) {
172    if let Some(Value::Object(props)) = schema.get_mut("properties") {
173        for field in fields {
174            props.remove(*field);
175        }
176    }
177    if let Some(Value::Array(required)) = schema.get_mut("required") {
178        required.retain(|v| v.as_str().map_or(true, |name| !fields.contains(&name)));
179    }
180}
181
182fn remove_variants(schema: &mut serde_json::Map<String, Value>, variants: &[&str]) {
183    for key in &["oneOf", "anyOf"] {
184        if let Some(Value::Array(items)) = schema.get_mut(*key) {
185            items.retain(|item| {
186                let name = variant_name(item);
187                match name {
188                    Some(n) => !variants.contains(&n.as_str()),
189                    None => true,
190                }
191            });
192        }
193    }
194}
195
196fn variant_name(item: &Value) -> Option<String> {
197    let obj = item.as_object()?;
198    if let Some(Value::String(title)) = obj.get("title") {
199        return Some(title.clone());
200    }
201    if let Some(Value::String(ref_str)) = obj.get("$ref") {
202        return ref_str.rsplit('/').next().map(String::from);
203    }
204    if let Some(Value::Object(props)) = obj.get("properties") {
205        if let Some(Value::Object(type_prop)) = props.get("type") {
206            if let Some(Value::String(const_val)) = type_prop.get("const") {
207                return Some(const_val.clone());
208            }
209        }
210    }
211    None
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[derive(serde::Deserialize, JsonSchema)]
219    #[allow(dead_code)]
220    struct Input {
221        pub name: String,
222        pub secret: Option<String>,
223    }
224
225    impl AuthSchemaMetadata for Input {
226        fn requirements() -> &'static [(&'static str, &'static str)] {
227            &[("secret", "admin")]
228        }
229    }
230
231    #[derive(serde::Serialize, JsonSchema)]
232    #[serde(tag = "type")]
233    #[allow(dead_code)]
234    enum Output {
235        Ok { id: String },
236        AdminOk { id: String, detail: String },
237    }
238
239    impl AuthSchemaMetadata for Output {
240        fn requirements() -> &'static [(&'static str, &'static str)] {
241            &[("AdminOk", "admin")]
242        }
243    }
244
245    #[test]
246    fn materialize_hides_unauthorized_tools() {
247        let mut reg = AuthToolRegistry::new();
248        reg.register_typed::<Input, Output>("my_tool", "A tool");
249        reg.set_authorization("my_tool", "admin");
250
251        let no_auth = AuthContext::new(Vec::<String>::new());
252        assert!(reg.materialize(&no_auth).is_empty());
253
254        let admin = AuthContext::new(vec!["admin"]);
255        assert_eq!(reg.materialize(&admin).len(), 1);
256    }
257
258    #[test]
259    fn materialize_shapes_input_schema() {
260        let mut reg = AuthToolRegistry::new();
261        reg.register_typed::<Input, Output>("my_tool", "A tool");
262
263        let no_auth = AuthContext::new(Vec::<String>::new());
264        let tools = reg.materialize(&no_auth);
265        let schema = &tools[0].input_schema;
266        let props = schema.get("properties").unwrap().as_object().unwrap();
267        assert!(props.contains_key("name"));
268        assert!(!props.contains_key("secret"));
269    }
270
271    #[test]
272    fn is_visible_checks_tool_authorization() {
273        let mut reg = AuthToolRegistry::new();
274        reg.register_typed::<Input, Output>("my_tool", "A tool");
275        reg.set_authorization("my_tool", "admin");
276
277        let no_auth = AuthContext::new(Vec::<String>::new());
278        assert!(!reg.is_visible("my_tool", &no_auth));
279
280        let admin = AuthContext::new(vec!["admin"]);
281        assert!(reg.is_visible("my_tool", &admin));
282    }
283}