Skip to main content

rmcp_soddygo/handler/server/
common.rs

1//! Common utilities shared between tool and prompt handlers
2
3use std::{any::TypeId, collections::HashMap, sync::Arc};
4
5use schemars::JsonSchema;
6
7use crate::{
8    RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext,
9};
10
11/// Generates a JSON schema for a type
12pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
13    thread_local! {
14        static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
15    };
16    CACHE_FOR_TYPE.with(|cache| {
17        if let Some(x) = cache
18            .read()
19            .expect("schema cache lock poisoned")
20            .get(&TypeId::of::<T>())
21        {
22            x.clone()
23        } else {
24            // explicitly to align json schema version to official specifications.
25            // refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details.
26            let settings = SchemaSettings::draft2020_12();
27            // Note: AddNullable is intentionally NOT used here because the `nullable` keyword
28            // is an OpenAPI 3.0 extension, not part of JSON Schema 2020-12. Using it would
29            // cause validation failures with strict JSON Schema validators.
30            let generator = settings.into_generator();
31            let schema = generator.into_root_schema_for::<T>();
32            let object = serde_json::to_value(schema).expect("failed to serialize schema");
33            let object = match object {
34                serde_json::Value::Object(object) => object,
35                _ => panic!(
36                    "Schema serialization produced non-object value: expected JSON object but got {:?}",
37                    object
38                ),
39            };
40            let schema = Arc::new(object);
41            cache
42                .write()
43                .expect("schema cache lock poisoned")
44                .insert(TypeId::of::<T>(), schema.clone());
45
46            schema
47        }
48    })
49}
50
51// TODO: should be updated according to the new specifications
52/// Schema used when input is empty.
53pub fn schema_for_empty_input() -> Arc<JsonObject> {
54    std::sync::Arc::new(
55        serde_json::json!({
56            "type": "object",
57            "properties": {}
58        })
59        .as_object()
60        .unwrap()
61        .clone(),
62    )
63}
64
65/// Generate and validate a JSON schema for outputSchema (must have root type "object").
66pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
67    thread_local! {
68        static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
69    };
70
71    CACHE_FOR_OUTPUT.with(|cache| {
72        // Try to get from cache first
73        if let Some(result) = cache
74            .read()
75            .expect("output schema cache lock poisoned")
76            .get(&TypeId::of::<T>())
77        {
78            return result.clone();
79        }
80
81        // Generate and validate schema
82        let schema = schema_for_type::<T>();
83        let result = match schema.get("type") {
84            Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
85            Some(serde_json::Value::String(t)) => Err(format!(
86                "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
87                t
88            )),
89            None => Err(
90                "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
91            ),
92            Some(other) => Err(format!(
93                "Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
94                other
95            )),
96        };
97
98        // Cache the result (both success and error cases)
99        cache
100            .write()
101            .expect("output schema cache lock poisoned")
102            .insert(TypeId::of::<T>(), result.clone());
103
104        result
105    })
106}
107
108/// Trait for extracting parts from a context, unifying tool and prompt extraction
109pub trait FromContextPart<C>: Sized {
110    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
111}
112
113/// Common extractors that can be used by both tool and prompt handlers
114impl<C> FromContextPart<C> for RequestContext<RoleServer>
115where
116    C: AsRequestContext,
117{
118    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
119        Ok(context.as_request_context().clone())
120    }
121}
122
123impl<C> FromContextPart<C> for tokio_util::sync::CancellationToken
124where
125    C: AsRequestContext,
126{
127    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
128        Ok(context.as_request_context().ct.clone())
129    }
130}
131
132impl<C> FromContextPart<C> for crate::model::Extensions
133where
134    C: AsRequestContext,
135{
136    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
137        Ok(context.as_request_context().extensions.clone())
138    }
139}
140
141#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
142pub struct Extension<T>(pub T);
143
144impl<C, T> FromContextPart<C> for Extension<T>
145where
146    C: AsRequestContext,
147    T: Send + Sync + 'static + Clone,
148{
149    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
150        let extension = context
151            .as_request_context()
152            .extensions
153            .get::<T>()
154            .cloned()
155            .ok_or_else(|| {
156                crate::ErrorData::invalid_params(
157                    format!("missing extension {}", std::any::type_name::<T>()),
158                    None,
159                )
160            })?;
161        Ok(Extension(extension))
162    }
163}
164
165impl<C> FromContextPart<C> for crate::Peer<RoleServer>
166where
167    C: AsRequestContext,
168{
169    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
170        Ok(context.as_request_context().peer.clone())
171    }
172}
173
174impl<C> FromContextPart<C> for crate::model::Meta
175where
176    C: AsRequestContext,
177{
178    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
179        let request_context = context.as_request_context_mut();
180        let mut meta = crate::model::Meta::default();
181        std::mem::swap(&mut meta, &mut request_context.meta);
182        Ok(meta)
183    }
184}
185
186#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
187pub struct RequestId(pub crate::model::RequestId);
188
189impl<C> FromContextPart<C> for RequestId
190where
191    C: AsRequestContext,
192{
193    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
194        Ok(RequestId(context.as_request_context().id.clone()))
195    }
196}
197
198/// Trait for types that can provide access to RequestContext
199pub trait AsRequestContext {
200    fn as_request_context(&self) -> &RequestContext<RoleServer>;
201    fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
209    struct TestObject {
210        value: i32,
211    }
212
213    #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
214    struct AnotherTestObject {
215        value: i32,
216    }
217
218    #[test]
219    fn test_schema_for_type_handles_primitive() {
220        let schema = schema_for_type::<i32>();
221
222        assert_eq!(schema.get("type"), Some(&serde_json::json!("integer")));
223    }
224
225    #[test]
226    fn test_schema_for_type_handles_array() {
227        let schema = schema_for_type::<Vec<i32>>();
228
229        assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
230        let items = schema.get("items").and_then(|v| v.as_object());
231        assert_eq!(
232            items.unwrap().get("type"),
233            Some(&serde_json::json!("integer"))
234        );
235    }
236
237    #[test]
238    fn test_schema_for_type_handles_struct() {
239        let schema = schema_for_type::<TestObject>();
240
241        assert_eq!(schema.get("type"), Some(&serde_json::json!("object")));
242        let properties = schema.get("properties").and_then(|v| v.as_object());
243        assert!(properties.unwrap().contains_key("value"));
244    }
245
246    #[test]
247    fn test_schema_for_type_caches_primitive_types() {
248        let schema1 = schema_for_type::<i32>();
249        let schema2 = schema_for_type::<i32>();
250
251        assert!(Arc::ptr_eq(&schema1, &schema2));
252    }
253
254    #[test]
255    fn test_schema_for_type_caches_struct_types() {
256        let schema1 = schema_for_type::<TestObject>();
257        let schema2 = schema_for_type::<TestObject>();
258
259        assert!(Arc::ptr_eq(&schema1, &schema2));
260    }
261
262    #[test]
263    fn test_schema_for_type_different_types_different_schemas() {
264        let schema1 = schema_for_type::<TestObject>();
265        let schema2 = schema_for_type::<AnotherTestObject>();
266
267        assert!(!Arc::ptr_eq(&schema1, &schema2));
268    }
269
270    #[test]
271    fn test_schema_for_type_arc_can_be_shared() {
272        let schema = schema_for_type::<TestObject>();
273        let cloned = schema.clone();
274
275        assert!(Arc::ptr_eq(&schema, &cloned));
276    }
277
278    #[test]
279    fn test_schema_for_output_rejects_primitive() {
280        let result = schema_for_output::<i32>();
281        assert!(result.is_err(),);
282    }
283
284    #[test]
285    fn test_schema_for_output_accepts_object() {
286        let result = schema_for_output::<TestObject>();
287        assert!(result.is_ok(),);
288    }
289}