Skip to main content

mcpkit_rs/handler/server/
common.rs

1//! Common utilities shared between tool and prompt handlers
2
3#[cfg(feature = "schemars")]
4use std::{any::TypeId, collections::HashMap, sync::Arc};
5
6#[cfg(feature = "schemars")]
7use schemars::JsonSchema;
8
9#[cfg(feature = "schemars")]
10use crate::model::JsonObject;
11#[cfg(feature = "schemars")]
12use crate::schemars::generate::SchemaSettings;
13use crate::{RoleServer, service::RequestContext};
14
15/// Generates a JSON schema for a type
16#[cfg(feature = "schemars")]
17pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
18    thread_local! {
19        static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
20    };
21    CACHE_FOR_TYPE.with(|cache| {
22        if let Some(x) = cache
23            .read()
24            .expect("schema cache lock poisoned")
25            .get(&TypeId::of::<T>())
26        {
27            x.clone()
28        } else {
29            // explicitly to align json schema version to official specifications.
30            // refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details.
31            let mut settings = SchemaSettings::draft2020_12();
32            settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())];
33            let generator = settings.into_generator();
34            let schema = generator.into_root_schema_for::<T>();
35            let object = serde_json::to_value(schema).expect("failed to serialize schema");
36            let object = match object {
37                serde_json::Value::Object(object) => object,
38                _ => panic!(
39                    "Schema serialization produced non-object value: expected JSON object but got {:?}",
40                    object
41                ),
42            };
43            let schema = Arc::new(object);
44            cache
45                .write()
46                .expect("schema cache lock poisoned")
47                .insert(TypeId::of::<T>(), schema.clone());
48
49            schema
50        }
51    })
52}
53
54/// Generate and validate a JSON schema for outputSchema (must have root type "object").
55#[cfg(feature = "schemars")]
56pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
57    thread_local! {
58        static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
59    };
60
61    CACHE_FOR_OUTPUT.with(|cache| {
62        // Try to get from cache first
63        if let Some(result) = cache
64            .read()
65            .expect("output schema cache lock poisoned")
66            .get(&TypeId::of::<T>())
67        {
68            return result.clone();
69        }
70
71        // Generate and validate schema
72        let schema = schema_for_type::<T>();
73        let result = match schema.get("type") {
74            Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
75            Some(serde_json::Value::String(t)) => Err(format!(
76                "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
77                t
78            )),
79            None => Err(
80                "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
81            ),
82            Some(other) => Err(format!(
83                "Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
84                other
85            )),
86        };
87
88        // Cache the result (both success and error cases)
89        cache
90            .write()
91            .expect("output schema cache lock poisoned")
92            .insert(TypeId::of::<T>(), result.clone());
93
94        result
95    })
96}
97
98/// Trait for extracting parts from a context, unifying tool and prompt extraction
99pub trait FromContextPart<C>: Sized {
100    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
101}
102
103/// Common extractors that can be used by both tool and prompt handlers
104impl<C> FromContextPart<C> for RequestContext<RoleServer>
105where
106    C: AsRequestContext,
107{
108    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
109        Ok(context.as_request_context().clone())
110    }
111}
112
113impl<C> FromContextPart<C> for tokio_util::sync::CancellationToken
114where
115    C: AsRequestContext,
116{
117    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
118        Ok(context.as_request_context().ct.clone())
119    }
120}
121
122impl<C> FromContextPart<C> for crate::model::Extensions
123where
124    C: AsRequestContext,
125{
126    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
127        Ok(context.as_request_context().extensions.clone())
128    }
129}
130
131pub struct Extension<T>(pub T);
132
133impl<C, T> FromContextPart<C> for Extension<T>
134where
135    C: AsRequestContext,
136    T: Send + Sync + 'static + Clone,
137{
138    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
139        let extension = context
140            .as_request_context()
141            .extensions
142            .get::<T>()
143            .cloned()
144            .ok_or_else(|| {
145                crate::ErrorData::invalid_params(
146                    format!("missing extension {}", std::any::type_name::<T>()),
147                    None,
148                )
149            })?;
150        Ok(Extension(extension))
151    }
152}
153
154impl<C> FromContextPart<C> for crate::Peer<RoleServer>
155where
156    C: AsRequestContext,
157{
158    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
159        Ok(context.as_request_context().peer.clone())
160    }
161}
162
163impl<C> FromContextPart<C> for crate::model::Meta
164where
165    C: AsRequestContext,
166{
167    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
168        let request_context = context.as_request_context_mut();
169        let mut meta = crate::model::Meta::default();
170        std::mem::swap(&mut meta, &mut request_context.meta);
171        Ok(meta)
172    }
173}
174
175pub struct RequestId(pub crate::model::RequestId);
176
177impl<C> FromContextPart<C> for RequestId
178where
179    C: AsRequestContext,
180{
181    fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
182        Ok(RequestId(context.as_request_context().id.clone()))
183    }
184}
185
186/// Trait for types that can provide access to RequestContext
187pub trait AsRequestContext {
188    fn as_request_context(&self) -> &RequestContext<RoleServer>;
189    fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
190}
191
192#[cfg(all(test, feature = "schemars"))]
193mod tests {
194    use schemars::JsonSchema;
195
196    use super::*;
197
198    #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
199    struct TestObject {
200        value: i32,
201    }
202
203    #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
204    struct AnotherTestObject {
205        value: i32,
206    }
207
208    #[test]
209    fn test_schema_for_type_handles_primitive() {
210        let schema = schema_for_type::<i32>();
211
212        assert_eq!(schema.get("type"), Some(&serde_json::json!("integer")));
213    }
214
215    #[test]
216    fn test_schema_for_type_handles_array() {
217        let schema = schema_for_type::<Vec<i32>>();
218
219        assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
220        let items = schema.get("items").and_then(|v| v.as_object());
221        assert_eq!(
222            items.unwrap().get("type"),
223            Some(&serde_json::json!("integer"))
224        );
225    }
226
227    #[test]
228    fn test_schema_for_type_handles_struct() {
229        let schema = schema_for_type::<TestObject>();
230
231        assert_eq!(schema.get("type"), Some(&serde_json::json!("object")));
232        let properties = schema.get("properties").and_then(|v| v.as_object());
233        assert!(properties.unwrap().contains_key("value"));
234    }
235
236    #[test]
237    fn test_schema_for_type_caches_primitive_types() {
238        let schema1 = schema_for_type::<i32>();
239        let schema2 = schema_for_type::<i32>();
240
241        assert!(Arc::ptr_eq(&schema1, &schema2));
242    }
243
244    #[test]
245    fn test_schema_for_type_caches_struct_types() {
246        let schema1 = schema_for_type::<TestObject>();
247        let schema2 = schema_for_type::<TestObject>();
248
249        assert!(Arc::ptr_eq(&schema1, &schema2));
250    }
251
252    #[test]
253    fn test_schema_for_type_different_types_different_schemas() {
254        let schema1 = schema_for_type::<TestObject>();
255        let schema2 = schema_for_type::<AnotherTestObject>();
256
257        assert!(!Arc::ptr_eq(&schema1, &schema2));
258    }
259
260    #[test]
261    fn test_schema_for_type_arc_can_be_shared() {
262        let schema = schema_for_type::<TestObject>();
263        let cloned = schema.clone();
264
265        assert!(Arc::ptr_eq(&schema, &cloned));
266    }
267
268    #[test]
269    fn test_schema_for_output_rejects_primitive() {
270        let result = schema_for_output::<i32>();
271        assert!(result.is_err(),);
272    }
273
274    #[test]
275    fn test_schema_for_output_accepts_object() {
276        let result = schema_for_output::<TestObject>();
277        assert!(result.is_ok(),);
278    }
279}