rmcp_soddygo/handler/server/
common.rs1use std::{any::TypeId, collections::HashMap, sync::Arc};
4
5use schemars::JsonSchema;
6
7use crate::{
8 RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext,
9};
10
11pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
13 thread_local! {
14 static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
15 };
16 CACHE_FOR_TYPE.with(|cache| {
17 if let Some(x) = cache
18 .read()
19 .expect("schema cache lock poisoned")
20 .get(&TypeId::of::<T>())
21 {
22 x.clone()
23 } else {
24 let settings = SchemaSettings::draft2020_12();
27 let generator = settings.into_generator();
31 let schema = generator.into_root_schema_for::<T>();
32 let object = serde_json::to_value(schema).expect("failed to serialize schema");
33 let object = match object {
34 serde_json::Value::Object(object) => object,
35 _ => panic!(
36 "Schema serialization produced non-object value: expected JSON object but got {:?}",
37 object
38 ),
39 };
40 let schema = Arc::new(object);
41 cache
42 .write()
43 .expect("schema cache lock poisoned")
44 .insert(TypeId::of::<T>(), schema.clone());
45
46 schema
47 }
48 })
49}
50
51pub fn schema_for_empty_input() -> Arc<JsonObject> {
54 std::sync::Arc::new(
55 serde_json::json!({
56 "type": "object",
57 "properties": {}
58 })
59 .as_object()
60 .unwrap()
61 .clone(),
62 )
63}
64
65pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
67 thread_local! {
68 static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
69 };
70
71 CACHE_FOR_OUTPUT.with(|cache| {
72 if let Some(result) = cache
74 .read()
75 .expect("output schema cache lock poisoned")
76 .get(&TypeId::of::<T>())
77 {
78 return result.clone();
79 }
80
81 let schema = schema_for_type::<T>();
83 let result = match schema.get("type") {
84 Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
85 Some(serde_json::Value::String(t)) => Err(format!(
86 "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
87 t
88 )),
89 None => Err(
90 "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
91 ),
92 Some(other) => Err(format!(
93 "Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
94 other
95 )),
96 };
97
98 cache
100 .write()
101 .expect("output schema cache lock poisoned")
102 .insert(TypeId::of::<T>(), result.clone());
103
104 result
105 })
106}
107
108pub trait FromContextPart<C>: Sized {
110 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
111}
112
113impl<C> FromContextPart<C> for RequestContext<RoleServer>
115where
116 C: AsRequestContext,
117{
118 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
119 Ok(context.as_request_context().clone())
120 }
121}
122
123impl<C> FromContextPart<C> for tokio_util::sync::CancellationToken
124where
125 C: AsRequestContext,
126{
127 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
128 Ok(context.as_request_context().ct.clone())
129 }
130}
131
132impl<C> FromContextPart<C> for crate::model::Extensions
133where
134 C: AsRequestContext,
135{
136 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
137 Ok(context.as_request_context().extensions.clone())
138 }
139}
140
141#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
142pub struct Extension<T>(pub T);
143
144impl<C, T> FromContextPart<C> for Extension<T>
145where
146 C: AsRequestContext,
147 T: Send + Sync + 'static + Clone,
148{
149 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
150 let extension = context
151 .as_request_context()
152 .extensions
153 .get::<T>()
154 .cloned()
155 .ok_or_else(|| {
156 crate::ErrorData::invalid_params(
157 format!("missing extension {}", std::any::type_name::<T>()),
158 None,
159 )
160 })?;
161 Ok(Extension(extension))
162 }
163}
164
165impl<C> FromContextPart<C> for crate::Peer<RoleServer>
166where
167 C: AsRequestContext,
168{
169 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
170 Ok(context.as_request_context().peer.clone())
171 }
172}
173
174impl<C> FromContextPart<C> for crate::model::Meta
175where
176 C: AsRequestContext,
177{
178 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
179 let request_context = context.as_request_context_mut();
180 let mut meta = crate::model::Meta::default();
181 std::mem::swap(&mut meta, &mut request_context.meta);
182 Ok(meta)
183 }
184}
185
186#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
187pub struct RequestId(pub crate::model::RequestId);
188
189impl<C> FromContextPart<C> for RequestId
190where
191 C: AsRequestContext,
192{
193 fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
194 Ok(RequestId(context.as_request_context().id.clone()))
195 }
196}
197
198pub trait AsRequestContext {
200 fn as_request_context(&self) -> &RequestContext<RoleServer>;
201 fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
209 struct TestObject {
210 value: i32,
211 }
212
213 #[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
214 struct AnotherTestObject {
215 value: i32,
216 }
217
218 #[test]
219 fn test_schema_for_type_handles_primitive() {
220 let schema = schema_for_type::<i32>();
221
222 assert_eq!(schema.get("type"), Some(&serde_json::json!("integer")));
223 }
224
225 #[test]
226 fn test_schema_for_type_handles_array() {
227 let schema = schema_for_type::<Vec<i32>>();
228
229 assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
230 let items = schema.get("items").and_then(|v| v.as_object());
231 assert_eq!(
232 items.unwrap().get("type"),
233 Some(&serde_json::json!("integer"))
234 );
235 }
236
237 #[test]
238 fn test_schema_for_type_handles_struct() {
239 let schema = schema_for_type::<TestObject>();
240
241 assert_eq!(schema.get("type"), Some(&serde_json::json!("object")));
242 let properties = schema.get("properties").and_then(|v| v.as_object());
243 assert!(properties.unwrap().contains_key("value"));
244 }
245
246 #[test]
247 fn test_schema_for_type_caches_primitive_types() {
248 let schema1 = schema_for_type::<i32>();
249 let schema2 = schema_for_type::<i32>();
250
251 assert!(Arc::ptr_eq(&schema1, &schema2));
252 }
253
254 #[test]
255 fn test_schema_for_type_caches_struct_types() {
256 let schema1 = schema_for_type::<TestObject>();
257 let schema2 = schema_for_type::<TestObject>();
258
259 assert!(Arc::ptr_eq(&schema1, &schema2));
260 }
261
262 #[test]
263 fn test_schema_for_type_different_types_different_schemas() {
264 let schema1 = schema_for_type::<TestObject>();
265 let schema2 = schema_for_type::<AnotherTestObject>();
266
267 assert!(!Arc::ptr_eq(&schema1, &schema2));
268 }
269
270 #[test]
271 fn test_schema_for_type_arc_can_be_shared() {
272 let schema = schema_for_type::<TestObject>();
273 let cloned = schema.clone();
274
275 assert!(Arc::ptr_eq(&schema, &cloned));
276 }
277
278 #[test]
279 fn test_schema_for_output_rejects_primitive() {
280 let result = schema_for_output::<i32>();
281 assert!(result.is_err(),);
282 }
283
284 #[test]
285 fn test_schema_for_output_accepts_object() {
286 let result = schema_for_output::<TestObject>();
287 assert!(result.is_ok(),);
288 }
289}