Skip to main content

mcp_authorization/
schema.rs

1use std::sync::Arc;
2
3use schemars::JsonSchema;
4use serde_json::Value;
5
6use crate::capability::AuthContext;
7use crate::metadata::AuthSchemaMetadata;
8
9/// Shapes JSON Schema at runtime by removing fields or variants that the
10/// current user lacks capabilities for.
11///
12/// This is the Rust equivalent of `RbsSchemaCompiler` in the Ruby gem:
13/// the full schema is generated once at compile time (via `schemars`),
14/// then filtered per-request based on the user's `AuthContext`.
15pub struct SchemaShaper;
16
17impl SchemaShaper {
18    /// Generate a shaped input schema for type `T`.
19    ///
20    /// Starts from the full schemars-generated schema, then removes
21    /// properties whose `AuthSchemaMetadata` requirements are not
22    /// satisfied by the given `AuthContext`.
23    pub fn shape_input<T: JsonSchema + AuthSchemaMetadata + 'static>(
24        auth: &AuthContext,
25    ) -> Arc<serde_json::Map<String, Value>> {
26        let full_schema = rmcp::handler::server::tool::schema_for_type::<T>();
27        let requirements = T::requirements();
28
29        // Fast path: no auth-gated fields
30        if requirements.is_empty() || requirements.iter().all(|(_, cap)| auth.has(cap)) {
31            return full_schema;
32        }
33
34        let fields_to_remove: Vec<&str> = requirements
35            .iter()
36            .filter(|(_, cap)| !auth.has(cap))
37            .map(|(field, _)| *field)
38            .collect();
39
40        let mut schema = (*full_schema).clone();
41        remove_properties(&mut schema, &fields_to_remove);
42        Arc::new(schema)
43    }
44
45    /// Generate a shaped output schema for type `T`.
46    ///
47    /// Removes `oneOf`/`anyOf` variants whose `AuthSchemaMetadata`
48    /// requirements are not satisfied.
49    pub fn shape_output<T: JsonSchema + AuthSchemaMetadata + 'static>(
50        auth: &AuthContext,
51    ) -> Option<Arc<serde_json::Map<String, Value>>> {
52        let full_schema = rmcp::handler::server::tool::schema_for_output::<T>().ok()?;
53        let requirements = T::requirements();
54
55        if requirements.is_empty() || requirements.iter().all(|(_, cap)| auth.has(cap)) {
56            return Some(full_schema);
57        }
58
59        let variants_to_remove: Vec<&str> = requirements
60            .iter()
61            .filter(|(_, cap)| !auth.has(cap))
62            .map(|(variant, _)| *variant)
63            .collect();
64
65        let mut schema = (*full_schema).clone();
66        remove_variants(&mut schema, &variants_to_remove);
67        Some(Arc::new(schema))
68    }
69}
70
71/// Remove properties from a JSON Schema object and its `required` array.
72fn remove_properties(schema: &mut serde_json::Map<String, Value>, fields: &[&str]) {
73    // Remove from top-level "properties"
74    if let Some(Value::Object(props)) = schema.get_mut("properties") {
75        for field in fields {
76            props.remove(*field);
77        }
78    }
79
80    // Remove from "required" array
81    if let Some(Value::Array(required)) = schema.get_mut("required") {
82        required.retain(|v| {
83            v.as_str()
84                .map_or(true, |name| !fields.contains(&name))
85        });
86    }
87
88    // Handle $defs-based schemas: if properties are referenced via $ref,
89    // also check nested allOf/anyOf/oneOf at the top level
90    for key in &["allOf", "anyOf", "oneOf"] {
91        if let Some(Value::Array(variants)) = schema.get_mut(*key) {
92            for variant in variants.iter_mut() {
93                if let Value::Object(obj) = variant {
94                    remove_properties(obj, fields);
95                }
96            }
97        }
98    }
99}
100
101/// Remove variants from oneOf/anyOf in a JSON Schema.
102///
103/// Matches variants by checking:
104/// 1. `title` field
105/// 2. `$ref` ending (e.g. `#/$defs/ReroutedSuccess`)
106/// 3. Internally tagged enum discriminator value
107fn remove_variants(schema: &mut serde_json::Map<String, Value>, variants: &[&str]) {
108    for key in &["oneOf", "anyOf"] {
109        if let Some(Value::Array(items)) = schema.get_mut(*key) {
110            items.retain(|item| {
111                let name = variant_name(item);
112                match name {
113                    Some(n) => !variants.contains(&n.as_str()),
114                    None => true, // keep unrecognized variants
115                }
116            });
117        }
118    }
119
120    // Also clean up $defs — remove definitions that are no longer referenced
121    if let Some(Value::Object(defs)) = schema.get("$defs") {
122        let def_names: Vec<String> = defs.keys().cloned().collect();
123        let schema_str = serde_json::to_string(&schema).unwrap_or_default();
124        let unused: Vec<String> = def_names
125            .into_iter()
126            .filter(|name| {
127                let ref_str = format!("#/$defs/{}", name);
128                !schema_str.contains(&ref_str) || variants.contains(&name.as_str())
129            })
130            .collect();
131
132        if !unused.is_empty() {
133            if let Some(Value::Object(defs)) = schema.get_mut("$defs") {
134                for name in &unused {
135                    // Only remove if it's a variant we're filtering
136                    if variants.contains(&name.as_str()) {
137                        defs.remove(name);
138                    }
139                }
140            }
141        }
142    }
143}
144
145/// Extract a variant's name from its JSON Schema representation.
146fn variant_name(item: &Value) -> Option<String> {
147    let obj = item.as_object()?;
148
149    // Check "title" field first
150    if let Some(Value::String(title)) = obj.get("title") {
151        return Some(title.clone());
152    }
153
154    // Check $ref (e.g. "#/$defs/ReroutedSuccess")
155    if let Some(Value::String(ref_str)) = obj.get("$ref") {
156        return ref_str.rsplit('/').next().map(String::from);
157    }
158
159    // Check internally tagged enum: {"properties": {"type": {"const": "VariantName"}}}
160    if let Some(Value::Object(props)) = obj.get("properties") {
161        if let Some(Value::Object(type_prop)) = props.get("type") {
162            if let Some(Value::String(const_val)) = type_prop.get("const") {
163                return Some(const_val.clone());
164            }
165        }
166    }
167
168    None
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::AuthSchemaMetadata;
175    use schemars::JsonSchema;
176    use serde::{Deserialize, Serialize};
177
178    #[derive(Deserialize, JsonSchema)]
179    #[allow(dead_code)]
180    struct TestInput {
181        pub name: String,
182        pub public_field: String,
183        pub secret_field: Option<String>,
184        pub admin_field: Option<i32>,
185    }
186
187    impl AuthSchemaMetadata for TestInput {
188        fn requirements() -> &'static [(&'static str, &'static str)] {
189            &[
190                ("secret_field", "view_secrets"),
191                ("admin_field", "admin"),
192            ]
193        }
194    }
195
196    #[test]
197    fn shape_input_removes_unauthorized_fields() {
198        let auth = AuthContext::new(Vec::<String>::new());
199        let schema = SchemaShaper::shape_input::<TestInput>(&auth);
200
201        let props = schema.get("properties").unwrap().as_object().unwrap();
202        assert!(props.contains_key("name"));
203        assert!(props.contains_key("public_field"));
204        assert!(!props.contains_key("secret_field"));
205        assert!(!props.contains_key("admin_field"));
206    }
207
208    #[test]
209    fn shape_input_keeps_authorized_fields() {
210        let auth = AuthContext::new(vec!["view_secrets", "admin"]);
211        let schema = SchemaShaper::shape_input::<TestInput>(&auth);
212
213        let props = schema.get("properties").unwrap().as_object().unwrap();
214        assert!(props.contains_key("name"));
215        assert!(props.contains_key("secret_field"));
216        assert!(props.contains_key("admin_field"));
217    }
218
219    #[test]
220    fn shape_input_partial_authorization() {
221        let auth = AuthContext::new(vec!["view_secrets"]);
222        let schema = SchemaShaper::shape_input::<TestInput>(&auth);
223
224        let props = schema.get("properties").unwrap().as_object().unwrap();
225        assert!(props.contains_key("secret_field"));
226        assert!(!props.contains_key("admin_field"));
227    }
228
229    #[derive(Deserialize, JsonSchema)]
230    #[allow(dead_code)]
231    struct NoAuthInput {
232        pub name: String,
233    }
234
235    impl AuthSchemaMetadata for NoAuthInput {
236        fn requirements() -> &'static [(&'static str, &'static str)] {
237            &[]
238        }
239    }
240
241    #[test]
242    fn shape_input_no_requirements_returns_full_schema() {
243        let auth = AuthContext::new(Vec::<String>::new());
244        let shaped = SchemaShaper::shape_input::<NoAuthInput>(&auth);
245        let full = rmcp::handler::server::tool::schema_for_type::<NoAuthInput>();
246        // Same Arc — no clone happened
247        assert!(Arc::ptr_eq(&shaped, &full));
248    }
249
250    #[test]
251    fn shape_input_removes_from_required_array() {
252        let auth = AuthContext::new(Vec::<String>::new());
253        let schema = SchemaShaper::shape_input::<TestInput>(&auth);
254
255        if let Some(Value::Array(required)) = schema.get("required") {
256            let names: Vec<&str> = required
257                .iter()
258                .filter_map(|v| v.as_str())
259                .collect();
260            assert!(!names.contains(&"secret_field"));
261            assert!(!names.contains(&"admin_field"));
262        }
263    }
264
265    // Output variant filtering tests
266
267    #[derive(Serialize, JsonSchema)]
268    #[serde(tag = "type")]
269    #[allow(dead_code)]
270    enum TestOutput {
271        Success { id: String },
272        AdminDetail { id: String, secret: String },
273        Error { message: String },
274    }
275
276    impl AuthSchemaMetadata for TestOutput {
277        fn requirements() -> &'static [(&'static str, &'static str)] {
278            &[("AdminDetail", "admin")]
279        }
280    }
281
282    #[test]
283    fn shape_output_removes_unauthorized_variants() {
284        let auth = AuthContext::new(Vec::<String>::new());
285        let schema = SchemaShaper::shape_output::<TestOutput>(&auth);
286
287        if let Some(schema) = schema {
288            let schema_str = serde_json::to_string(&*schema).unwrap();
289            assert!(!schema_str.contains("AdminDetail"));
290            assert!(schema_str.contains("Success"));
291            assert!(schema_str.contains("Error"));
292        }
293    }
294
295    #[test]
296    fn shape_output_keeps_all_when_authorized() {
297        let auth = AuthContext::new(vec!["admin"]);
298        let schema = SchemaShaper::shape_output::<TestOutput>(&auth);
299
300        if let Some(schema) = schema {
301            let schema_str = serde_json::to_string(&*schema).unwrap();
302            assert!(schema_str.contains("AdminDetail"));
303            assert!(schema_str.contains("Success"));
304        }
305    }
306}