mcpkit_rs/handler/server/
common.rs1#[cfg(feature = "schemars")]
4use std::{any::TypeId, collections::HashMap, sync::Arc};
5
6#[cfg(feature = "schemars")]
7use schemars::JsonSchema;
8
9#[cfg(feature = "schemars")]
10use crate::model::JsonObject;
11#[cfg(feature = "schemars")]
12use crate::schemars::generate::SchemaSettings;
13use crate::{RoleServer, service::RequestContext};
14
15#[cfg(feature = "schemars")]
17pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
18 thread_local! {
19 static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
20 };
21 CACHE_FOR_TYPE.with(|cache| {
22 if let Some(x) = cache
23 .read()
24 .expect("schema cache lock poisoned")
25 .get(&TypeId::of::<T>())
26 {
27 x.clone()
28 } else {
29 let mut settings = SchemaSettings::draft2020_12();
32 settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())];
33 let generator = settings.into_generator();
34 let schema = generator.into_root_schema_for::<T>();
35 let object = serde_json::to_value(schema).expect("failed to serialize schema");
36 let object = match object {
37 serde_json::Value::Object(object) => object,
38 _ => panic!(
39 "Schema serialization produced non-object value: expected JSON object but got {:?}",
40 object
41 ),
42 };
43 let schema = Arc::new(object);
44 cache
45 .write()
46 .expect("schema cache lock poisoned")
47 .insert(TypeId::of::<T>(), schema.clone());
48
49 schema
50 }
51 })
52}
53
54#[cfg(feature = "schemars")]
56pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
57 thread_local! {
58 static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
59 };
60
61 CACHE_FOR_OUTPUT.with(|cache| {
62 if let Some(result) = cache
64 .read()
65 .expect("output schema cache lock poisoned")
66 .get(&TypeId::of::<T>())
67 {
68 return result.clone();
69 }
70
71 let schema = schema_for_type::<T>();
73 let result = match schema.get("type") {
74 Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
75 Some(serde_json::Value::String(t)) => Err(format!(
76 "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
77 t
78 )),
79 None => Err(
80 "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
81 ),
82 Some(other) => Err(format!(
83 "Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
84 other
85 )),
86 };
87
88 cache
90 .write()
91 .expect("output schema cache lock poisoned")
92 .insert(TypeId::of::<T>(), result.clone());
93
94 result
95 })
96}
97
98pub trait FromContextPart<C>: Sized {
100 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
101}
102
103impl<C> FromContextPart<C> for RequestContext<RoleServer>
105where
106 C: AsRequestContext,
107{
108 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
109 Ok(context.as_request_context().clone())
110 }
111}
112
113impl<C> FromContextPart<C> for tokio_util::sync::CancellationToken
114where
115 C: AsRequestContext,
116{
117 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
118 Ok(context.as_request_context().ct.clone())
119 }
120}
121
122impl<C> FromContextPart<C> for crate::model::Extensions
123where
124 C: AsRequestContext,
125{
126 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
127 Ok(context.as_request_context().extensions.clone())
128 }
129}
130
131pub struct Extension<T>(pub T);
132
133impl<C, T> FromContextPart<C> for Extension<T>
134where
135 C: AsRequestContext,
136 T: Send + Sync + 'static + Clone,
137{
138 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
139 let extension = context
140 .as_request_context()
141 .extensions
142 .get::<T>()
143 .cloned()
144 .ok_or_else(|| {
145 crate::ErrorData::invalid_params(
146 format!("missing extension {}", std::any::type_name::<T>()),
147 None,
148 )
149 })?;
150 Ok(Extension(extension))
151 }
152}
153
154impl<C> FromContextPart<C> for crate::Peer<RoleServer>
155where
156 C: AsRequestContext,
157{
158 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
159 Ok(context.as_request_context().peer.clone())
160 }
161}
162
163impl<C> FromContextPart<C> for crate::model::Meta
164where
165 C: AsRequestContext,
166{
167 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
168 let request_context = context.as_request_context_mut();
169 let mut meta = crate::model::Meta::default();
170 std::mem::swap(&mut meta, &mut request_context.meta);
171 Ok(meta)
172 }
173}
174
175pub struct RequestId(pub crate::model::RequestId);
176
177impl<C> FromContextPart<C> for RequestId
178where
179 C: AsRequestContext,
180{
181 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
182 Ok(RequestId(context.as_request_context().id.clone()))
183 }
184}
185
186pub trait AsRequestContext {
188 fn as_request_context(&self) -> &RequestContext<RoleServer>;
189 fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
190}
191
192#[cfg(all(test, feature = "schemars"))]
193mod tests {
194 use schemars::JsonSchema;
195
196 use super::*;
197
198 #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
199 struct TestObject {
200 value: i32,
201 }
202
203 #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
204 struct AnotherTestObject {
205 value: i32,
206 }
207
208 #[test]
209 fn test_schema_for_type_handles_primitive() {
210 let schema = schema_for_type::<i32>();
211
212 assert_eq!(schema.get("type"), Some(&serde_json::json!("integer")));
213 }
214
215 #[test]
216 fn test_schema_for_type_handles_array() {
217 let schema = schema_for_type::<Vec<i32>>();
218
219 assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
220 let items = schema.get("items").and_then(|v| v.as_object());
221 assert_eq!(
222 items.unwrap().get("type"),
223 Some(&serde_json::json!("integer"))
224 );
225 }
226
227 #[test]
228 fn test_schema_for_type_handles_struct() {
229 let schema = schema_for_type::<TestObject>();
230
231 assert_eq!(schema.get("type"), Some(&serde_json::json!("object")));
232 let properties = schema.get("properties").and_then(|v| v.as_object());
233 assert!(properties.unwrap().contains_key("value"));
234 }
235
236 #[test]
237 fn test_schema_for_type_caches_primitive_types() {
238 let schema1 = schema_for_type::<i32>();
239 let schema2 = schema_for_type::<i32>();
240
241 assert!(Arc::ptr_eq(&schema1, &schema2));
242 }
243
244 #[test]
245 fn test_schema_for_type_caches_struct_types() {
246 let schema1 = schema_for_type::<TestObject>();
247 let schema2 = schema_for_type::<TestObject>();
248
249 assert!(Arc::ptr_eq(&schema1, &schema2));
250 }
251
252 #[test]
253 fn test_schema_for_type_different_types_different_schemas() {
254 let schema1 = schema_for_type::<TestObject>();
255 let schema2 = schema_for_type::<AnotherTestObject>();
256
257 assert!(!Arc::ptr_eq(&schema1, &schema2));
258 }
259
260 #[test]
261 fn test_schema_for_type_arc_can_be_shared() {
262 let schema = schema_for_type::<TestObject>();
263 let cloned = schema.clone();
264
265 assert!(Arc::ptr_eq(&schema, &cloned));
266 }
267
268 #[test]
269 fn test_schema_for_output_rejects_primitive() {
270 let result = schema_for_output::<i32>();
271 assert!(result.is_err(),);
272 }
273
274 #[test]
275 fn test_schema_for_output_accepts_object() {
276 let result = schema_for_output::<TestObject>();
277 assert!(result.is_ok(),);
278 }
279}