1use agentic_tools_core::ToolContext;
4use agentic_tools_core::ToolError;
5use agentic_tools_core::ToolRegistry;
6use agentic_tools_core::fmt::TextOptions;
7use agentic_tools_core::fmt::fallback_text_from_json;
8use rmcp::RoleServer;
9use rmcp::ServerHandler;
10use rmcp::model as m;
11use rmcp::service::RequestContext;
12use std::collections::HashSet;
13use std::sync::Arc;
14
15#[derive(Clone, Copy, Debug, Default)]
17pub enum OutputMode {
18 #[default]
20 Text,
21 Structured,
25}
26
27pub struct RegistryServer {
61 registry: Arc<ToolRegistry>,
62 allowlist: Option<HashSet<String>>,
63 output_mode: OutputMode,
64 text_options: TextOptions,
65 name: String,
66 version: String,
67}
68
69impl RegistryServer {
70 pub fn new(registry: Arc<ToolRegistry>) -> Self {
72 Self {
73 registry,
74 allowlist: None,
75 output_mode: OutputMode::default(),
76 text_options: TextOptions::default(),
77 name: "agentic-tools".to_string(),
78 version: env!("CARGO_PKG_VERSION").to_string(),
79 }
80 }
81
82 pub fn with_allowlist(mut self, allowlist: impl IntoIterator<Item = String>) -> Self {
86 self.allowlist = Some(allowlist.into_iter().collect());
87 self
88 }
89
90 pub fn with_output_mode(mut self, mode: OutputMode) -> Self {
92 self.output_mode = mode;
93 self
94 }
95
96 pub fn with_text_options(mut self, text_options: TextOptions) -> Self {
98 self.text_options = text_options;
99 self
100 }
101
102 pub fn with_info(mut self, name: &str, version: &str) -> Self {
104 self.name = name.to_string();
105 self.version = version.to_string();
106 self
107 }
108
109 pub fn name(&self) -> &str {
111 &self.name
112 }
113
114 pub fn version(&self) -> &str {
116 &self.version
117 }
118
119 pub fn output_mode(&self) -> OutputMode {
121 self.output_mode
122 }
123
124 pub fn effective_tool_names(&self) -> Vec<String> {
126 self.registry
127 .list_names()
128 .into_iter()
129 .filter(|n| self.is_allowed(n))
130 .collect()
131 }
132
133 fn is_allowed(&self, name: &str) -> bool {
134 self.allowlist.as_ref().is_none_or(|set| set.contains(name))
135 }
136}
137
138#[allow(clippy::manual_async_fn)]
140impl ServerHandler for RegistryServer {
141 fn initialize(
142 &self,
143 _params: m::InitializeRequestParams,
144 _ctx: RequestContext<RoleServer>,
145 ) -> impl std::future::Future<Output = Result<m::InitializeResult, m::ErrorData>> + Send + '_
146 {
147 async move {
148 let server_info =
149 m::Implementation::new(&self.name, &self.version).with_title(&self.name);
150 Ok(
151 m::InitializeResult::new(m::ServerCapabilities::builder().enable_tools().build())
152 .with_server_info(server_info),
153 )
154 }
155 }
156
157 fn list_tools(
158 &self,
159 _req: Option<m::PaginatedRequestParams>,
160 _ctx: RequestContext<RoleServer>,
161 ) -> impl std::future::Future<Output = Result<m::ListToolsResult, m::ErrorData>> + Send + '_
162 {
163 async move {
164 let mut tools = vec![];
165 for name in self.registry.list_names() {
166 if !self.is_allowed(&name) {
167 continue;
168 }
169 if let Some(erased) = self.registry.get(&name) {
170 let input_schema = erased.input_schema();
171 let schema_json = serde_json::to_value(&input_schema)
172 .unwrap_or(serde_json::json!({"type": "object"}));
173
174 let output_schema = if matches!(self.output_mode, OutputMode::Structured) {
176 erased.output_schema().and_then(|s| {
177 serde_json::to_value(&s)
178 .ok()
179 .and_then(|v| v.as_object().cloned())
180 .map(Arc::new)
181 })
182 } else {
183 None
184 };
185
186 let input_schema =
187 Arc::new(schema_json.as_object().cloned().unwrap_or_default());
188 let mut tool = m::Tool::new(name.clone(), erased.description(), input_schema)
189 .with_title(name);
190
191 if let Some(schema) = output_schema {
193 tool = tool.with_raw_output_schema(schema);
194 }
195
196 tools.push(tool);
197 }
198 }
199 Ok(m::ListToolsResult::with_all_items(tools))
200 }
201 }
202
203 fn call_tool(
204 &self,
205 req: m::CallToolRequestParams,
206 request_context: RequestContext<RoleServer>,
207 ) -> impl std::future::Future<Output = Result<m::CallToolResult, m::ErrorData>> + Send + '_
208 {
209 async move {
210 if !self.is_allowed(&req.name) {
211 return Ok(m::CallToolResult::error(vec![m::Content::text(format!(
212 "Tool '{}' not enabled on this server",
213 req.name
214 ))]));
215 }
216
217 let args = serde_json::Value::Object(req.arguments.unwrap_or_default());
218 let ctx = ToolContext::with_cancel(request_context.ct.child_token());
219 let text_opts = self.text_options.clone();
220
221 tracing::info!(tool = %req.name, "tool dispatch started");
222
223 let dispatch_result = self
224 .registry
225 .dispatch_json_formatted(&req.name, args, &ctx, &text_opts)
226 .await;
227
228 if matches!(&dispatch_result, Err(ToolError::Cancelled { .. })) || ctx.is_cancelled() {
229 tracing::info!(tool = %req.name, "tool dispatch exiting after cancellation");
230 }
231
232 match dispatch_result {
233 Ok(res) => {
234 let text = res
235 .text
236 .unwrap_or_else(|| fallback_text_from_json(&res.data));
237
238 let contents = vec![m::Content::text(text)];
240
241 let structured_content = if matches!(self.output_mode, OutputMode::Structured) {
243 let has_schema = self
245 .registry
246 .get(&req.name)
247 .and_then(|t| t.output_schema())
248 .is_some();
249
250 if has_schema { Some(res.data) } else { None }
251 } else {
252 None
253 };
254
255 let mut result = m::CallToolResult::success(contents);
257 result.structured_content = structured_content;
258 Ok(result)
259 }
260 Err(e) => Ok(m::CallToolResult::error(vec![m::Content::text(
261 e.to_string(),
262 )])),
263 }
264 }
265 }
266
267 fn ping(
268 &self,
269 _ctx: RequestContext<RoleServer>,
270 ) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
271 async { Ok(()) }
272 }
273
274 fn complete(
275 &self,
276 _req: m::CompleteRequestParams,
277 _ctx: RequestContext<RoleServer>,
278 ) -> impl std::future::Future<Output = Result<m::CompleteResult, m::ErrorData>> + Send + '_
279 {
280 async {
281 Err(m::ErrorData::invalid_request(
282 "Method not implemented",
283 None,
284 ))
285 }
286 }
287
288 fn set_level(
289 &self,
290 _req: m::SetLevelRequestParams,
291 _ctx: RequestContext<RoleServer>,
292 ) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
293 async { Ok(()) }
294 }
295
296 fn get_prompt(
297 &self,
298 _req: m::GetPromptRequestParams,
299 _ctx: RequestContext<RoleServer>,
300 ) -> impl std::future::Future<Output = Result<m::GetPromptResult, m::ErrorData>> + Send + '_
301 {
302 async {
303 Err(m::ErrorData::invalid_request(
304 "Method not implemented",
305 None,
306 ))
307 }
308 }
309
310 fn list_prompts(
311 &self,
312 _req: Option<m::PaginatedRequestParams>,
313 _ctx: RequestContext<RoleServer>,
314 ) -> impl std::future::Future<Output = Result<m::ListPromptsResult, m::ErrorData>> + Send + '_
315 {
316 async { Ok(m::ListPromptsResult::with_all_items(vec![])) }
317 }
318
319 fn list_resources(
320 &self,
321 _req: Option<m::PaginatedRequestParams>,
322 _ctx: RequestContext<RoleServer>,
323 ) -> impl std::future::Future<Output = Result<m::ListResourcesResult, m::ErrorData>> + Send + '_
324 {
325 async { Ok(m::ListResourcesResult::with_all_items(vec![])) }
326 }
327
328 fn list_resource_templates(
329 &self,
330 _req: Option<m::PaginatedRequestParams>,
331 _ctx: RequestContext<RoleServer>,
332 ) -> impl std::future::Future<Output = Result<m::ListResourceTemplatesResult, m::ErrorData>>
333 + Send
334 + '_ {
335 async { Ok(m::ListResourceTemplatesResult::with_all_items(vec![])) }
336 }
337
338 fn read_resource(
339 &self,
340 _req: m::ReadResourceRequestParams,
341 _ctx: RequestContext<RoleServer>,
342 ) -> impl std::future::Future<Output = Result<m::ReadResourceResult, m::ErrorData>> + Send + '_
343 {
344 async {
345 Err(m::ErrorData::invalid_request(
346 "Method not implemented",
347 None,
348 ))
349 }
350 }
351
352 fn subscribe(
353 &self,
354 _req: m::SubscribeRequestParams,
355 _ctx: RequestContext<RoleServer>,
356 ) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
357 async {
358 Err(m::ErrorData::invalid_request(
359 "Method not implemented",
360 None,
361 ))
362 }
363 }
364
365 fn unsubscribe(
366 &self,
367 _req: m::UnsubscribeRequestParams,
368 _ctx: RequestContext<RoleServer>,
369 ) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
370 async {
371 Err(m::ErrorData::invalid_request(
372 "Method not implemented",
373 None,
374 ))
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use agentic_tools_core::Tool;
383 use agentic_tools_core::ToolError;
384 use agentic_tools_core::fmt::TextFormat;
385 use futures::future::BoxFuture;
386
387 async fn dispatch_text_for_test(server: &RegistryServer, tool_name: &str) -> String {
388 let ctx = ToolContext::default();
389 let result = server
390 .registry
391 .dispatch_json_formatted(
392 tool_name,
393 serde_json::json!(null),
394 &ctx,
395 &server.text_options,
396 )
397 .await
398 .unwrap();
399 result.text.unwrap()
400 }
401
402 #[test]
403 fn test_registry_server_allowlist() {
404 let registry = Arc::new(ToolRegistry::builder().finish());
405 let server = RegistryServer::new(registry.clone())
406 .with_allowlist(["tool_a".to_string(), "tool_b".to_string()]);
407
408 assert!(server.is_allowed("tool_a"));
409 assert!(server.is_allowed("tool_b"));
410 assert!(!server.is_allowed("tool_c"));
411 }
412
413 #[test]
414 fn test_registry_server_no_allowlist() {
415 let registry = Arc::new(ToolRegistry::builder().finish());
416 let server = RegistryServer::new(registry.clone());
417
418 assert!(server.is_allowed("any_tool"));
420 }
421
422 #[test]
423 fn test_registry_server_info() {
424 let registry = Arc::new(ToolRegistry::builder().finish());
425 let server = RegistryServer::new(registry.clone()).with_info("my-server", "1.0.0");
426
427 assert_eq!(server.name(), "my-server");
428 assert_eq!(server.version(), "1.0.0");
429 }
430
431 #[derive(Clone)]
433 struct TestObjTool;
434
435 #[derive(
436 serde::Serialize, serde::Deserialize, schemars::JsonSchema, Clone, Debug, PartialEq,
437 )]
438 struct TestObjOut {
439 message: String,
440 }
441
442 impl TextFormat for TestObjOut {
443 fn fmt_text(&self, _opts: &TextOptions) -> String {
444 format!("Message: {}", self.message)
445 }
446 }
447
448 impl Tool for TestObjTool {
449 type Input = ();
450 type Output = TestObjOut;
451 const NAME: &'static str = "test_obj_tool";
452 const DESCRIPTION: &'static str = "outputs an object";
453
454 fn call(
455 &self,
456 _input: (),
457 _ctx: &ToolContext,
458 ) -> BoxFuture<'static, Result<TestObjOut, ToolError>> {
459 Box::pin(async move {
460 Ok(TestObjOut {
461 message: "hello".into(),
462 })
463 })
464 }
465 }
466
467 #[derive(Clone)]
468 struct TestTextOptionsTool;
469
470 #[derive(
471 serde::Serialize, serde::Deserialize, schemars::JsonSchema, Clone, Debug, PartialEq,
472 )]
473 struct TestTextOptionsOut;
474
475 impl TextFormat for TestTextOptionsOut {
476 fn fmt_text(&self, opts: &TextOptions) -> String {
477 if opts.suppress_search_reminder {
478 "suppressed".to_string()
479 } else {
480 "default".to_string()
481 }
482 }
483 }
484
485 impl Tool for TestTextOptionsTool {
486 type Input = ();
487 type Output = TestTextOptionsOut;
488 const NAME: &'static str = "test_text_options_tool";
489 const DESCRIPTION: &'static str = "outputs text that depends on text options";
490
491 fn call(
492 &self,
493 _input: (),
494 _ctx: &ToolContext,
495 ) -> BoxFuture<'static, Result<TestTextOptionsOut, ToolError>> {
496 Box::pin(async move { Ok(TestTextOptionsOut) })
497 }
498 }
499
500 #[test]
501 fn test_structured_mode_output_schema_gating() {
502 let registry = Arc::new(
504 ToolRegistry::builder()
505 .register::<TestObjTool, ()>(TestObjTool)
506 .finish(),
507 );
508
509 let structured_server =
511 RegistryServer::new(registry.clone()).with_output_mode(OutputMode::Structured);
512 assert!(matches!(
513 structured_server.output_mode(),
514 OutputMode::Structured
515 ));
516
517 let text_server = RegistryServer::new(registry.clone()).with_output_mode(OutputMode::Text);
519 assert!(matches!(text_server.output_mode(), OutputMode::Text));
520
521 let tool = registry.get("test_obj_tool").unwrap();
523 assert!(
524 tool.output_schema().is_some(),
525 "TestObjTool should have an output schema"
526 );
527 }
528
529 #[tokio::test]
530 async fn test_structured_mode_structured_content_via_dispatch() {
531 let registry = Arc::new(
533 ToolRegistry::builder()
534 .register::<TestObjTool, ()>(TestObjTool)
535 .finish(),
536 );
537
538 let ctx = ToolContext::default();
540 let text_opts = TextOptions::default();
541 let result = registry
542 .dispatch_json_formatted("test_obj_tool", serde_json::json!(null), &ctx, &text_opts)
543 .await
544 .unwrap();
545
546 assert_eq!(result.data, serde_json::json!({"message": "hello"}));
548 assert!(result.text.is_some());
549
550 let tool = registry.get("test_obj_tool").unwrap();
553 let has_schema = tool.output_schema().is_some();
554 assert!(
555 has_schema,
556 "Tool should have output schema for structured content"
557 );
558
559 }
562
563 #[tokio::test]
564 async fn test_registry_server_uses_stored_text_options() {
565 let registry = Arc::new(
566 ToolRegistry::builder()
567 .register::<TestTextOptionsTool, ()>(TestTextOptionsTool)
568 .finish(),
569 );
570
571 let default_server = RegistryServer::new(registry.clone());
572 let suppressed_server = RegistryServer::new(registry)
573 .with_text_options(TextOptions::default().with_suppress_search_reminder(true));
574
575 assert_eq!(
576 dispatch_text_for_test(&default_server, "test_text_options_tool").await,
577 "default"
578 );
579 assert_eq!(
580 dispatch_text_for_test(&suppressed_server, "test_text_options_tool").await,
581 "suppressed"
582 );
583 }
584
585 #[test]
586 fn test_output_mode_default_is_text() {
587 let registry = Arc::new(ToolRegistry::builder().finish());
588 let server = RegistryServer::new(registry);
589
590 assert!(matches!(server.output_mode(), OutputMode::Text));
592 }
593}