use agentic_tools_core::fmt::{TextOptions, fallback_text_from_json};
use agentic_tools_core::{ToolContext, ToolRegistry};
use rmcp::model as m;
use rmcp::service::RequestContext;
use rmcp::{RoleServer, ServerHandler};
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Clone, Copy, Debug, Default)]
pub enum OutputMode {
#[default]
Text,
Structured,
}
pub struct RegistryServer {
registry: Arc<ToolRegistry>,
allowlist: Option<HashSet<String>>,
output_mode: OutputMode,
name: String,
version: String,
}
impl RegistryServer {
pub fn new(registry: Arc<ToolRegistry>) -> Self {
Self {
registry,
allowlist: None,
output_mode: OutputMode::default(),
name: "agentic-tools".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
pub fn with_allowlist(mut self, allowlist: impl IntoIterator<Item = String>) -> Self {
self.allowlist = Some(allowlist.into_iter().collect());
self
}
pub fn with_output_mode(mut self, mode: OutputMode) -> Self {
self.output_mode = mode;
self
}
pub fn with_info(mut self, name: &str, version: &str) -> Self {
self.name = name.to_string();
self.version = version.to_string();
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn version(&self) -> &str {
&self.version
}
pub fn output_mode(&self) -> OutputMode {
self.output_mode
}
pub fn effective_tool_names(&self) -> Vec<String> {
self.registry
.list_names()
.into_iter()
.filter(|n| self.is_allowed(n))
.collect()
}
fn is_allowed(&self, name: &str) -> bool {
self.allowlist.as_ref().is_none_or(|set| set.contains(name))
}
}
#[allow(clippy::manual_async_fn)]
impl ServerHandler for RegistryServer {
fn initialize(
&self,
_params: m::InitializeRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::InitializeResult, m::ErrorData>> + Send + '_
{
async move {
let server_info =
m::Implementation::new(&self.name, &self.version).with_title(&self.name);
Ok(
m::InitializeResult::new(m::ServerCapabilities::builder().enable_tools().build())
.with_server_info(server_info),
)
}
}
fn list_tools(
&self,
_req: Option<m::PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::ListToolsResult, m::ErrorData>> + Send + '_
{
async move {
let mut tools = vec![];
for name in self.registry.list_names() {
if !self.is_allowed(&name) {
continue;
}
if let Some(erased) = self.registry.get(&name) {
let input_schema = erased.input_schema();
let schema_json = serde_json::to_value(&input_schema)
.unwrap_or(serde_json::json!({"type": "object"}));
let output_schema = if matches!(self.output_mode, OutputMode::Structured) {
erased.output_schema().and_then(|s| {
serde_json::to_value(&s)
.ok()
.and_then(|v| v.as_object().cloned())
.map(Arc::new)
})
} else {
None
};
let input_schema =
Arc::new(schema_json.as_object().cloned().unwrap_or_default());
let mut tool = m::Tool::new(name.clone(), erased.description(), input_schema)
.with_title(name);
if let Some(schema) = output_schema {
tool = tool.with_raw_output_schema(schema);
}
tools.push(tool);
}
}
Ok(m::ListToolsResult::with_all_items(tools))
}
}
fn call_tool(
&self,
req: m::CallToolRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::CallToolResult, m::ErrorData>> + Send + '_
{
async move {
if !self.is_allowed(&req.name) {
return Ok(m::CallToolResult::error(vec![m::Content::text(format!(
"Tool '{}' not enabled on this server",
req.name
))]));
}
let args = serde_json::Value::Object(req.arguments.unwrap_or_default());
let ctx = ToolContext::default();
let text_opts = TextOptions::default();
match self
.registry
.dispatch_json_formatted(&req.name, args, &ctx, &text_opts)
.await
{
Ok(res) => {
let text = res
.text
.unwrap_or_else(|| fallback_text_from_json(&res.data));
let contents = vec![m::Content::text(text)];
let structured_content = if matches!(self.output_mode, OutputMode::Structured) {
let has_schema = self
.registry
.get(&req.name)
.and_then(|t| t.output_schema())
.is_some();
if has_schema { Some(res.data) } else { None }
} else {
None
};
let mut result = m::CallToolResult::success(contents);
result.structured_content = structured_content;
Ok(result)
}
Err(e) => Ok(m::CallToolResult::error(vec![m::Content::text(
e.to_string(),
)])),
}
}
}
fn ping(
&self,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
async { Ok(()) }
}
fn complete(
&self,
_req: m::CompleteRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::CompleteResult, m::ErrorData>> + Send + '_
{
async {
Err(m::ErrorData::invalid_request(
"Method not implemented",
None,
))
}
}
fn set_level(
&self,
_req: m::SetLevelRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
async { Ok(()) }
}
fn get_prompt(
&self,
_req: m::GetPromptRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::GetPromptResult, m::ErrorData>> + Send + '_
{
async {
Err(m::ErrorData::invalid_request(
"Method not implemented",
None,
))
}
}
fn list_prompts(
&self,
_req: Option<m::PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::ListPromptsResult, m::ErrorData>> + Send + '_
{
async { Ok(m::ListPromptsResult::with_all_items(vec![])) }
}
fn list_resources(
&self,
_req: Option<m::PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::ListResourcesResult, m::ErrorData>> + Send + '_
{
async { Ok(m::ListResourcesResult::with_all_items(vec![])) }
}
fn list_resource_templates(
&self,
_req: Option<m::PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::ListResourceTemplatesResult, m::ErrorData>>
+ Send
+ '_ {
async { Ok(m::ListResourceTemplatesResult::with_all_items(vec![])) }
}
fn read_resource(
&self,
_req: m::ReadResourceRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<m::ReadResourceResult, m::ErrorData>> + Send + '_
{
async {
Err(m::ErrorData::invalid_request(
"Method not implemented",
None,
))
}
}
fn subscribe(
&self,
_req: m::SubscribeRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
async {
Err(m::ErrorData::invalid_request(
"Method not implemented",
None,
))
}
}
fn unsubscribe(
&self,
_req: m::UnsubscribeRequestParams,
_ctx: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<(), m::ErrorData>> + Send + '_ {
async {
Err(m::ErrorData::invalid_request(
"Method not implemented",
None,
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use agentic_tools_core::fmt::TextFormat;
use agentic_tools_core::{Tool, ToolError};
use futures::future::BoxFuture;
#[test]
fn test_registry_server_allowlist() {
let registry = Arc::new(ToolRegistry::builder().finish());
let server = RegistryServer::new(registry.clone())
.with_allowlist(["tool_a".to_string(), "tool_b".to_string()]);
assert!(server.is_allowed("tool_a"));
assert!(server.is_allowed("tool_b"));
assert!(!server.is_allowed("tool_c"));
}
#[test]
fn test_registry_server_no_allowlist() {
let registry = Arc::new(ToolRegistry::builder().finish());
let server = RegistryServer::new(registry.clone());
assert!(server.is_allowed("any_tool"));
}
#[test]
fn test_registry_server_info() {
let registry = Arc::new(ToolRegistry::builder().finish());
let server = RegistryServer::new(registry.clone()).with_info("my-server", "1.0.0");
assert_eq!(server.name(), "my-server");
assert_eq!(server.version(), "1.0.0");
}
#[derive(Clone)]
struct TestObjTool;
#[derive(
serde::Serialize, serde::Deserialize, schemars::JsonSchema, Clone, Debug, PartialEq,
)]
struct TestObjOut {
message: String,
}
impl TextFormat for TestObjOut {
fn fmt_text(&self, _opts: &TextOptions) -> String {
format!("Message: {}", self.message)
}
}
impl Tool for TestObjTool {
type Input = ();
type Output = TestObjOut;
const NAME: &'static str = "test_obj_tool";
const DESCRIPTION: &'static str = "outputs an object";
fn call(
&self,
_input: (),
_ctx: &ToolContext,
) -> BoxFuture<'static, Result<TestObjOut, ToolError>> {
Box::pin(async move {
Ok(TestObjOut {
message: "hello".into(),
})
})
}
}
#[test]
fn test_structured_mode_output_schema_gating() {
let registry = Arc::new(
ToolRegistry::builder()
.register::<TestObjTool, ()>(TestObjTool)
.finish(),
);
let structured_server =
RegistryServer::new(registry.clone()).with_output_mode(OutputMode::Structured);
assert!(matches!(
structured_server.output_mode(),
OutputMode::Structured
));
let text_server = RegistryServer::new(registry.clone()).with_output_mode(OutputMode::Text);
assert!(matches!(text_server.output_mode(), OutputMode::Text));
let tool = registry.get("test_obj_tool").unwrap();
assert!(
tool.output_schema().is_some(),
"TestObjTool should have an output schema"
);
}
#[tokio::test]
async fn test_structured_mode_structured_content_via_dispatch() {
let registry = Arc::new(
ToolRegistry::builder()
.register::<TestObjTool, ()>(TestObjTool)
.finish(),
);
let ctx = ToolContext::default();
let text_opts = TextOptions::default();
let result = registry
.dispatch_json_formatted("test_obj_tool", serde_json::json!(null), &ctx, &text_opts)
.await
.unwrap();
assert_eq!(result.data, serde_json::json!({"message": "hello"}));
assert!(result.text.is_some());
let tool = registry.get("test_obj_tool").unwrap();
let has_schema = tool.output_schema().is_some();
assert!(
has_schema,
"Tool should have output schema for structured content"
);
}
#[test]
fn test_output_mode_default_is_text() {
let registry = Arc::new(ToolRegistry::builder().finish());
let server = RegistryServer::new(registry);
assert!(matches!(server.output_mode(), OutputMode::Text));
}
}