mcp_authorization/
schema.rs1use std::sync::Arc;
2
3use schemars::JsonSchema;
4use serde_json::Value;
5
6use crate::capability::AuthContext;
7use crate::metadata::AuthSchemaMetadata;
8
9pub struct SchemaShaper;
16
17impl SchemaShaper {
18 pub fn shape_input<T: JsonSchema + AuthSchemaMetadata + 'static>(
24 auth: &AuthContext,
25 ) -> Arc<serde_json::Map<String, Value>> {
26 let full_schema = rmcp::handler::server::tool::schema_for_type::<T>();
27 let requirements = T::requirements();
28
29 if requirements.is_empty() || requirements.iter().all(|(_, cap)| auth.has(cap)) {
31 return full_schema;
32 }
33
34 let fields_to_remove: Vec<&str> = requirements
35 .iter()
36 .filter(|(_, cap)| !auth.has(cap))
37 .map(|(field, _)| *field)
38 .collect();
39
40 let mut schema = (*full_schema).clone();
41 remove_properties(&mut schema, &fields_to_remove);
42 Arc::new(schema)
43 }
44
45 pub fn shape_output<T: JsonSchema + AuthSchemaMetadata + 'static>(
50 auth: &AuthContext,
51 ) -> Option<Arc<serde_json::Map<String, Value>>> {
52 let full_schema = rmcp::handler::server::tool::schema_for_output::<T>().ok()?;
53 let requirements = T::requirements();
54
55 if requirements.is_empty() || requirements.iter().all(|(_, cap)| auth.has(cap)) {
56 return Some(full_schema);
57 }
58
59 let variants_to_remove: Vec<&str> = requirements
60 .iter()
61 .filter(|(_, cap)| !auth.has(cap))
62 .map(|(variant, _)| *variant)
63 .collect();
64
65 let mut schema = (*full_schema).clone();
66 remove_variants(&mut schema, &variants_to_remove);
67 Some(Arc::new(schema))
68 }
69}
70
71fn remove_properties(schema: &mut serde_json::Map<String, Value>, fields: &[&str]) {
73 if let Some(Value::Object(props)) = schema.get_mut("properties") {
75 for field in fields {
76 props.remove(*field);
77 }
78 }
79
80 if let Some(Value::Array(required)) = schema.get_mut("required") {
82 required.retain(|v| {
83 v.as_str()
84 .map_or(true, |name| !fields.contains(&name))
85 });
86 }
87
88 for key in &["allOf", "anyOf", "oneOf"] {
91 if let Some(Value::Array(variants)) = schema.get_mut(*key) {
92 for variant in variants.iter_mut() {
93 if let Value::Object(obj) = variant {
94 remove_properties(obj, fields);
95 }
96 }
97 }
98 }
99}
100
101fn remove_variants(schema: &mut serde_json::Map<String, Value>, variants: &[&str]) {
108 for key in &["oneOf", "anyOf"] {
109 if let Some(Value::Array(items)) = schema.get_mut(*key) {
110 items.retain(|item| {
111 let name = variant_name(item);
112 match name {
113 Some(n) => !variants.contains(&n.as_str()),
114 None => true, }
116 });
117 }
118 }
119
120 if let Some(Value::Object(defs)) = schema.get("$defs") {
122 let def_names: Vec<String> = defs.keys().cloned().collect();
123 let schema_str = serde_json::to_string(&schema).unwrap_or_default();
124 let unused: Vec<String> = def_names
125 .into_iter()
126 .filter(|name| {
127 let ref_str = format!("#/$defs/{}", name);
128 !schema_str.contains(&ref_str) || variants.contains(&name.as_str())
129 })
130 .collect();
131
132 if !unused.is_empty() {
133 if let Some(Value::Object(defs)) = schema.get_mut("$defs") {
134 for name in &unused {
135 if variants.contains(&name.as_str()) {
137 defs.remove(name);
138 }
139 }
140 }
141 }
142 }
143}
144
145fn variant_name(item: &Value) -> Option<String> {
147 let obj = item.as_object()?;
148
149 if let Some(Value::String(title)) = obj.get("title") {
151 return Some(title.clone());
152 }
153
154 if let Some(Value::String(ref_str)) = obj.get("$ref") {
156 return ref_str.rsplit('/').next().map(String::from);
157 }
158
159 if let Some(Value::Object(props)) = obj.get("properties") {
161 if let Some(Value::Object(type_prop)) = props.get("type") {
162 if let Some(Value::String(const_val)) = type_prop.get("const") {
163 return Some(const_val.clone());
164 }
165 }
166 }
167
168 None
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::AuthSchemaMetadata;
175 use schemars::JsonSchema;
176 use serde::{Deserialize, Serialize};
177
178 #[derive(Deserialize, JsonSchema)]
179 #[allow(dead_code)]
180 struct TestInput {
181 pub name: String,
182 pub public_field: String,
183 pub secret_field: Option<String>,
184 pub admin_field: Option<i32>,
185 }
186
187 impl AuthSchemaMetadata for TestInput {
188 fn requirements() -> &'static [(&'static str, &'static str)] {
189 &[
190 ("secret_field", "view_secrets"),
191 ("admin_field", "admin"),
192 ]
193 }
194 }
195
196 #[test]
197 fn shape_input_removes_unauthorized_fields() {
198 let auth = AuthContext::new(Vec::<String>::new());
199 let schema = SchemaShaper::shape_input::<TestInput>(&auth);
200
201 let props = schema.get("properties").unwrap().as_object().unwrap();
202 assert!(props.contains_key("name"));
203 assert!(props.contains_key("public_field"));
204 assert!(!props.contains_key("secret_field"));
205 assert!(!props.contains_key("admin_field"));
206 }
207
208 #[test]
209 fn shape_input_keeps_authorized_fields() {
210 let auth = AuthContext::new(vec!["view_secrets", "admin"]);
211 let schema = SchemaShaper::shape_input::<TestInput>(&auth);
212
213 let props = schema.get("properties").unwrap().as_object().unwrap();
214 assert!(props.contains_key("name"));
215 assert!(props.contains_key("secret_field"));
216 assert!(props.contains_key("admin_field"));
217 }
218
219 #[test]
220 fn shape_input_partial_authorization() {
221 let auth = AuthContext::new(vec!["view_secrets"]);
222 let schema = SchemaShaper::shape_input::<TestInput>(&auth);
223
224 let props = schema.get("properties").unwrap().as_object().unwrap();
225 assert!(props.contains_key("secret_field"));
226 assert!(!props.contains_key("admin_field"));
227 }
228
229 #[derive(Deserialize, JsonSchema)]
230 #[allow(dead_code)]
231 struct NoAuthInput {
232 pub name: String,
233 }
234
235 impl AuthSchemaMetadata for NoAuthInput {
236 fn requirements() -> &'static [(&'static str, &'static str)] {
237 &[]
238 }
239 }
240
241 #[test]
242 fn shape_input_no_requirements_returns_full_schema() {
243 let auth = AuthContext::new(Vec::<String>::new());
244 let shaped = SchemaShaper::shape_input::<NoAuthInput>(&auth);
245 let full = rmcp::handler::server::tool::schema_for_type::<NoAuthInput>();
246 assert!(Arc::ptr_eq(&shaped, &full));
248 }
249
250 #[test]
251 fn shape_input_removes_from_required_array() {
252 let auth = AuthContext::new(Vec::<String>::new());
253 let schema = SchemaShaper::shape_input::<TestInput>(&auth);
254
255 if let Some(Value::Array(required)) = schema.get("required") {
256 let names: Vec<&str> = required
257 .iter()
258 .filter_map(|v| v.as_str())
259 .collect();
260 assert!(!names.contains(&"secret_field"));
261 assert!(!names.contains(&"admin_field"));
262 }
263 }
264
265 #[derive(Serialize, JsonSchema)]
268 #[serde(tag = "type")]
269 #[allow(dead_code)]
270 enum TestOutput {
271 Success { id: String },
272 AdminDetail { id: String, secret: String },
273 Error { message: String },
274 }
275
276 impl AuthSchemaMetadata for TestOutput {
277 fn requirements() -> &'static [(&'static str, &'static str)] {
278 &[("AdminDetail", "admin")]
279 }
280 }
281
282 #[test]
283 fn shape_output_removes_unauthorized_variants() {
284 let auth = AuthContext::new(Vec::<String>::new());
285 let schema = SchemaShaper::shape_output::<TestOutput>(&auth);
286
287 if let Some(schema) = schema {
288 let schema_str = serde_json::to_string(&*schema).unwrap();
289 assert!(!schema_str.contains("AdminDetail"));
290 assert!(schema_str.contains("Success"));
291 assert!(schema_str.contains("Error"));
292 }
293 }
294
295 #[test]
296 fn shape_output_keeps_all_when_authorized() {
297 let auth = AuthContext::new(vec!["admin"]);
298 let schema = SchemaShaper::shape_output::<TestOutput>(&auth);
299
300 if let Some(schema) = schema {
301 let schema_str = serde_json::to_string(&*schema).unwrap();
302 assert!(schema_str.contains("AdminDetail"));
303 assert!(schema_str.contains("Success"));
304 }
305 }
306}