use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use mlua_isle::AsyncIsle;
use rmcp::{
handler::client::ClientHandler,
model::{
CreateElicitationRequestParams, CreateElicitationResult, CreateMessageRequestParams,
CreateMessageResult, ElicitationAction, ElicitationCreateRequestMethod, LoggingLevel,
LoggingMessageNotificationParam, ProgressNotificationParam,
ResourceUpdatedNotificationParam, Role, SamplingMessage, SamplingMessageContent,
},
service::{NotificationContext, RequestContext, RoleClient},
ErrorData as McpError,
};
use tokio::sync::mpsc;
pub(crate) const MCP_SAMPLING_HANDLERS: &str = "__mcp_sampling_handlers";
const MCP_DISPATCH_SAMPLING: &str = "__mcp_dispatch_sampling";
pub(crate) const MCP_ROOTS_HANDLERS: &str = "__mcp_roots_handlers";
const MCP_DISPATCH_ROOTS: &str = "__mcp_dispatch_roots";
pub(crate) const MCP_ELICITATION_HANDLERS: &str = "__mcp_elicitation_handlers";
const MCP_DISPATCH_ELICITATION: &str = "__mcp_dispatch_elicitation";
pub(crate) const MCP_USER_PROGRESS_CBS: &str = "__mcp_user_progress_cbs";
pub(crate) const MCP_USER_LOG_CBS: &str = "__mcp_user_log_cbs";
pub(crate) const MCP_USER_RESOURCE_UPDATE_CBS: &str = "__mcp_user_resource_update_cbs";
pub(crate) const MCP_USER_RESOURCES_LIST_CHANGED_CBS: &str =
"__mcp_user_resources_list_changed_cbs";
pub(crate) const MCP_USER_TOOLS_LIST_CHANGED_CBS: &str = "__mcp_user_tools_list_changed_cbs";
pub(crate) const MCP_USER_PROMPTS_LIST_CHANGED_CBS: &str = "__mcp_user_prompts_list_changed_cbs";
const NOTIFY_CHANNEL_CAPACITY: usize = 128;
type BuildEvFn = Box<dyn FnOnce(&mlua::Lua, &str) -> mlua::Result<mlua::Table> + Send + 'static>;
pub(crate) struct NotificationItem {
pub(crate) isle: Arc<AsyncIsle>,
pub(crate) server_name: String,
pub(crate) cbs_table: &'static str,
pub(crate) build_ev: BuildEvFn,
pub(crate) caller: &'static str,
}
pub(crate) struct ServerHandlerRegistry {
pub(crate) on_progress: bool,
pub(crate) on_log: bool,
pub(crate) on_resource_updated: bool,
pub(crate) on_resource_list_changed: bool,
pub(crate) on_tool_list_changed: bool,
pub(crate) on_prompt_list_changed: bool,
pub(crate) sampling: bool,
pub(crate) roots: bool,
pub(crate) elicitation: bool,
pub(crate) trace_context: bool,
}
impl ServerHandlerRegistry {
fn new() -> Self {
Self {
on_progress: false,
on_log: false,
on_resource_updated: false,
on_resource_list_changed: false,
on_tool_list_changed: false,
on_prompt_list_changed: false,
sampling: false,
roots: false,
elicitation: false,
trace_context: false,
}
}
}
#[derive(Clone)]
pub struct AgentBlockClientHandler {
pub(crate) registry: Arc<Mutex<HashMap<String, ServerHandlerRegistry>>>,
pub(crate) handler_isle: Option<Arc<AsyncIsle>>,
pub(crate) main_isle: Option<Arc<AsyncIsle>>,
pub(crate) server_name: Option<String>,
pub(crate) notify_tx: Option<mpsc::Sender<NotificationItem>>,
}
impl AgentBlockClientHandler {
pub fn new() -> Self {
Self {
registry: Arc::new(Mutex::new(HashMap::new())),
handler_isle: None,
main_isle: None,
server_name: None,
notify_tx: None,
}
}
pub(crate) fn start_dispatch_task(&mut self) {
let (tx, mut rx) = mpsc::channel::<NotificationItem>(NOTIFY_CHANNEL_CAPACITY);
self.notify_tx = Some(tx);
tokio::spawn(async move {
while let Some(item) = rx.recv().await {
let sn = item.server_name.clone();
let result = item
.isle
.exec(move |lua| {
use mlua::prelude::*;
let cbs: LuaTable = match lua.globals().get(item.cbs_table) {
Ok(t) => t,
Err(_) => return Ok(String::new()),
};
let cb: LuaFunction = match cbs.get(item.server_name.as_str()) {
Ok(f) => f,
Err(_) => return Ok(String::new()),
};
let ev = (item.build_ev)(lua, item.server_name.as_str()).map_err(|e| {
mlua_isle::IsleError::Lua(format!("{}: build_ev: {e}", item.caller))
})?;
if let Err(e) = cb.call::<()>(ev) {
tracing::warn!(
target: "mcp_client",
server = %item.server_name,
caller = %item.caller,
error = %e,
"user callback returned error"
);
}
Ok(String::new())
})
.await;
if let Err(e) = result {
tracing::warn!(
target: "mcp_client",
server = %sn,
error = %e,
"notification dispatch: main isle exec failed"
);
}
}
});
}
pub(crate) fn ensure_server(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
}
pub(crate) fn mark_on_progress(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.on_progress = true;
}
pub(crate) fn mark_on_log(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.on_log = true;
}
pub(crate) fn mark_on_resource_updated(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.on_resource_updated = true;
}
pub(crate) fn mark_on_resource_list_changed(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.on_resource_list_changed = true;
}
pub(crate) fn mark_on_tool_list_changed(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.on_tool_list_changed = true;
}
pub(crate) fn mark_on_prompt_list_changed(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.on_prompt_list_changed = true;
}
pub(crate) fn set_trace_context(&self, server_name: &str, enabled: bool) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.trace_context = enabled;
}
pub(crate) fn trace_context_enabled(&self, server_name: &str) -> bool {
let guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
guard.get(server_name).is_some_and(|r| r.trace_context)
}
pub(crate) fn mark_sampling(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.sampling = true;
}
pub(crate) fn mark_roots(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.roots = true;
}
pub(crate) fn mark_elicitation(&self, server_name: &str) {
let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
let entry = guard
.entry(server_name.to_string())
.or_insert_with(ServerHandlerRegistry::new);
entry.elicitation = true;
}
}
impl Default for AgentBlockClientHandler {
fn default() -> Self {
Self::new()
}
}
pub fn install_mcp_dispatcher_on_handler_isle(lua: &mlua::Lua) -> mlua::Result<()> {
use mlua::prelude::*;
lua.globals()
.set(MCP_SAMPLING_HANDLERS, lua.create_table()?)?;
let sampling_src = r#"
local HANDLERS = "__mcp_sampling_handlers"
return function(server_name, params_json)
local handlers = _G[HANDLERS]
local h = handlers and handlers[server_name]
if type(h) ~= "function" then
return nil -- signal: no handler registered
end
return h(server_name, params_json)
end
"#;
let dispatch_sampling: LuaFunction = lua
.load(sampling_src)
.set_name("@agent_block:__mcp_dispatch_sampling")
.eval()?;
lua.globals()
.set(MCP_DISPATCH_SAMPLING, dispatch_sampling)?;
lua.globals().set(MCP_ROOTS_HANDLERS, lua.create_table()?)?;
let roots_src = r#"
local HANDLERS = "__mcp_roots_handlers"
return function(server_name)
local handlers = _G[HANDLERS]
local h = handlers and handlers[server_name]
if type(h) ~= "function" then
return nil -- signal: no handler registered
end
return h(server_name)
end
"#;
let dispatch_roots: LuaFunction = lua
.load(roots_src)
.set_name("@agent_block:__mcp_dispatch_roots")
.eval()?;
lua.globals().set(MCP_DISPATCH_ROOTS, dispatch_roots)?;
lua.globals()
.set(MCP_ELICITATION_HANDLERS, lua.create_table()?)?;
let elicitation_src = r#"
local HANDLERS = "__mcp_elicitation_handlers"
return function(server_name, message, schema_json)
local handlers = _G[HANDLERS]
local h = handlers and handlers[server_name]
if type(h) ~= "function" then
return nil -- signal: no handler registered → Decline
end
return h(server_name, message, schema_json)
end
"#;
let dispatch_elicitation: LuaFunction = lua
.load(elicitation_src)
.set_name("@agent_block:__mcp_dispatch_elicitation")
.eval()?;
lua.globals()
.set(MCP_DISPATCH_ELICITATION, dispatch_elicitation)?;
Ok(())
}
fn isle_dispatch<F>(
isle: Arc<AsyncIsle>,
server_name: String,
cbs_table: &'static str,
build_ev: F,
caller: &'static str,
) where
F: FnOnce(&mlua::Lua, &str) -> mlua::Result<mlua::Table> + Send + 'static,
{
tokio::spawn(async move {
let sn = server_name.clone();
let result = isle
.exec(move |lua| {
use mlua::prelude::*;
let cbs: LuaTable = match lua.globals().get(cbs_table) {
Ok(t) => t,
Err(_) => return Ok(String::new()), };
let cb: LuaFunction = match cbs.get(server_name.as_str()) {
Ok(f) => f,
Err(_) => return Ok(String::new()), };
let ev = build_ev(lua, server_name.as_str())
.map_err(|e| mlua_isle::IsleError::Lua(format!("{caller}: build_ev: {e}")))?;
if let Err(e) = cb.call::<()>(ev) {
tracing::warn!(
target: "mcp_client",
server = %server_name,
caller = %caller,
error = %e,
"user callback returned error"
);
}
Ok(String::new())
})
.await;
if let Err(e) = result {
tracing::warn!(
target: "mcp_client",
server = %sn,
error = %e,
"{}: main isle exec failed",
caller
);
}
});
}
impl ClientHandler for AgentBlockClientHandler {
fn on_progress(
&self,
params: ProgressNotificationParam,
_context: NotificationContext<RoleClient>,
) -> impl std::future::Future<Output = ()> + Send + '_ {
let main_isle = self.main_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name_opt = self.server_name.clone();
let notify_tx = self.notify_tx.clone();
async move {
let main_isle = match main_isle {
Some(i) => i,
None => return, };
let server_name = match server_name_opt {
Some(s) => s,
None => return, };
let has_cb = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard.get(&server_name).is_some_and(|r| r.on_progress)
};
if !has_cb {
return;
}
let token_str = match ¶ms.progress_token.0 {
rmcp::model::NumberOrString::Number(n) => n.to_string(),
rmcp::model::NumberOrString::String(s) => s.to_string(),
};
let progress_f64: f64 = params.progress;
let total_opt: Option<f64> = params.total;
let message_opt: Option<String> = params.message;
if let Some(tx) = notify_tx {
let item = NotificationItem {
isle: main_isle,
server_name,
cbs_table: MCP_USER_PROGRESS_CBS,
build_ev: Box::new(move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "progress")?;
ev.set("server", server_for_task)?;
ev.set("token", token_str.as_str())?;
ev.set("progress", progress_f64)?;
if let Some(t) = total_opt {
ev.set("total", t)?;
}
if let Some(ref m) = message_opt {
ev.set("message", m.as_str())?;
}
Ok(ev)
}),
caller: "on_progress",
};
if let Err(e) = tx.try_send(item) {
tracing::warn!(
target: "mcp_client",
error = %e,
"on_progress: notification channel full, dropping notification \
(server is emitting faster than Lua can consume)"
);
}
} else {
isle_dispatch(
main_isle,
server_name,
MCP_USER_PROGRESS_CBS,
move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "progress")?;
ev.set("server", server_for_task)?;
ev.set("token", token_str.as_str())?;
ev.set("progress", progress_f64)?;
if let Some(t) = total_opt {
ev.set("total", t)?;
}
if let Some(ref m) = message_opt {
ev.set("message", m.as_str())?;
}
Ok(ev)
},
"on_progress",
);
}
}
}
fn on_logging_message(
&self,
params: LoggingMessageNotificationParam,
_context: NotificationContext<RoleClient>,
) -> impl std::future::Future<Output = ()> + Send + '_ {
let main_isle = self.main_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name = self.server_name.clone();
let notify_tx = self.notify_tx.clone();
async move {
let level = ¶ms.level;
let logger = params.logger.as_deref().unwrap_or("").to_string();
let data_str = match serde_json::to_string(¶ms.data) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
target: "mcp_client",
error = %e,
"on_logging_message: failed to serialize data"
);
return;
}
};
let level_str = match level {
LoggingLevel::Debug => "debug",
LoggingLevel::Info | LoggingLevel::Notice => "info",
LoggingLevel::Warning => "warning",
LoggingLevel::Error
| LoggingLevel::Critical
| LoggingLevel::Alert
| LoggingLevel::Emergency => "error",
}
.to_string();
let sn_str = server_name.as_deref().unwrap_or("unknown").to_string();
let has_lua_handler = server_name.as_deref().is_some_and(|sn| {
registry
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(sn)
.is_some_and(|r| r.on_log)
});
if has_lua_handler {
if let (Some(isle), Some(sn)) = (main_isle, server_name) {
let level_task = level_str.clone();
let logger_task = logger.clone();
let data_task = data_str.clone();
if let Some(tx) = notify_tx {
let item = NotificationItem {
isle,
server_name: sn,
cbs_table: MCP_USER_LOG_CBS,
build_ev: Box::new(move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "log")?;
ev.set("server", server_for_task)?;
ev.set("level", level_task.as_str())?;
ev.set("logger", logger_task.as_str())?;
ev.set("data", data_task.as_str())?;
Ok(ev)
}),
caller: "on_logging_message",
};
if let Err(e) = tx.try_send(item) {
tracing::warn!(
target: "mcp_client",
error = %e,
"on_logging_message: notification channel full, dropping notification"
);
}
} else {
isle_dispatch(
isle,
sn,
MCP_USER_LOG_CBS,
move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "log")?;
ev.set("server", server_for_task)?;
ev.set("level", level_task.as_str())?;
ev.set("logger", logger_task.as_str())?;
ev.set("data", data_task.as_str())?;
Ok(ev)
},
"on_logging_message",
);
}
return;
}
}
match level {
LoggingLevel::Debug => {
tracing::debug!(
target: "lua",
script = "mcp_server",
server = %sn_str,
logger = %logger,
"{}",
data_str
);
}
LoggingLevel::Info | LoggingLevel::Notice => {
tracing::info!(
target: "lua",
script = "mcp_server",
server = %sn_str,
logger = %logger,
"{}",
data_str
);
}
LoggingLevel::Warning => {
tracing::warn!(
target: "lua",
script = "mcp_server",
server = %sn_str,
logger = %logger,
"{}",
data_str
);
}
LoggingLevel::Error
| LoggingLevel::Critical
| LoggingLevel::Alert
| LoggingLevel::Emergency => {
tracing::error!(
target: "lua",
script = "mcp_server",
server = %sn_str,
logger = %logger,
"{}",
data_str
);
}
}
}
}
fn on_resource_updated(
&self,
params: ResourceUpdatedNotificationParam,
_context: NotificationContext<RoleClient>,
) -> impl std::future::Future<Output = ()> + Send + '_ {
let main_isle = self.main_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name_opt = self.server_name.clone();
let notify_tx = self.notify_tx.clone();
async move {
let main_isle = match main_isle {
Some(i) => i,
None => return,
};
let server_name = match server_name_opt {
Some(s) => s,
None => return,
};
let has_cb = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard
.get(&server_name)
.is_some_and(|r| r.on_resource_updated)
};
if !has_cb {
return;
}
let uri = params.uri.clone();
if let Some(tx) = notify_tx {
let item = NotificationItem {
isle: main_isle,
server_name,
cbs_table: MCP_USER_RESOURCE_UPDATE_CBS,
build_ev: Box::new(move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "resource_update")?;
ev.set("server", server_for_task)?;
ev.set("uri", uri.as_str())?;
Ok(ev)
}),
caller: "on_resource_updated",
};
if let Err(e) = tx.try_send(item) {
tracing::warn!(
target: "mcp_client",
error = %e,
"on_resource_updated: notification channel full, dropping notification \
(server is emitting faster than Lua can consume)"
);
}
} else {
isle_dispatch(
main_isle,
server_name,
MCP_USER_RESOURCE_UPDATE_CBS,
move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "resource_update")?;
ev.set("server", server_for_task)?;
ev.set("uri", uri.as_str())?;
Ok(ev)
},
"on_resource_updated",
);
}
}
}
fn on_resource_list_changed(
&self,
_context: NotificationContext<RoleClient>,
) -> impl std::future::Future<Output = ()> + Send + '_ {
let main_isle = self.main_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name_opt = self.server_name.clone();
let notify_tx = self.notify_tx.clone();
async move {
let main_isle = match main_isle {
Some(i) => i,
None => return,
};
let server_name = match server_name_opt {
Some(s) => s,
None => return,
};
let has_cb = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard
.get(&server_name)
.is_some_and(|r| r.on_resource_list_changed)
};
if !has_cb {
return;
}
if let Some(tx) = notify_tx {
let item = NotificationItem {
isle: main_isle,
server_name,
cbs_table: MCP_USER_RESOURCES_LIST_CHANGED_CBS,
build_ev: Box::new(move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "resources_list_changed")?;
ev.set("server", server_for_task)?;
Ok(ev)
}),
caller: "on_resource_list_changed",
};
if let Err(e) = tx.try_send(item) {
tracing::warn!(
target: "mcp_client",
error = %e,
"on_resource_list_changed: notification channel full, dropping notification"
);
}
} else {
isle_dispatch(
main_isle,
server_name,
MCP_USER_RESOURCES_LIST_CHANGED_CBS,
move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "resources_list_changed")?;
ev.set("server", server_for_task)?;
Ok(ev)
},
"on_resource_list_changed",
);
}
}
}
fn on_tool_list_changed(
&self,
_context: NotificationContext<RoleClient>,
) -> impl std::future::Future<Output = ()> + Send + '_ {
let main_isle = self.main_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name_opt = self.server_name.clone();
let notify_tx = self.notify_tx.clone();
async move {
let main_isle = match main_isle {
Some(i) => i,
None => return,
};
let server_name = match server_name_opt {
Some(s) => s,
None => return,
};
let has_cb = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard
.get(&server_name)
.is_some_and(|r| r.on_tool_list_changed)
};
if !has_cb {
return;
}
if let Some(tx) = notify_tx {
let item = NotificationItem {
isle: main_isle,
server_name,
cbs_table: MCP_USER_TOOLS_LIST_CHANGED_CBS,
build_ev: Box::new(move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "tools_list_changed")?;
ev.set("server", server_for_task)?;
Ok(ev)
}),
caller: "on_tool_list_changed",
};
if let Err(e) = tx.try_send(item) {
tracing::warn!(
target: "mcp_client",
error = %e,
"on_tool_list_changed: notification channel full, dropping notification"
);
}
} else {
isle_dispatch(
main_isle,
server_name,
MCP_USER_TOOLS_LIST_CHANGED_CBS,
move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "tools_list_changed")?;
ev.set("server", server_for_task)?;
Ok(ev)
},
"on_tool_list_changed",
);
}
}
}
fn on_prompt_list_changed(
&self,
_context: NotificationContext<RoleClient>,
) -> impl std::future::Future<Output = ()> + Send + '_ {
let main_isle = self.main_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name_opt = self.server_name.clone();
let notify_tx = self.notify_tx.clone();
async move {
let main_isle = match main_isle {
Some(i) => i,
None => return,
};
let server_name = match server_name_opt {
Some(s) => s,
None => return,
};
let has_cb = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard
.get(&server_name)
.is_some_and(|r| r.on_prompt_list_changed)
};
if !has_cb {
return;
}
if let Some(tx) = notify_tx {
let item = NotificationItem {
isle: main_isle,
server_name,
cbs_table: MCP_USER_PROMPTS_LIST_CHANGED_CBS,
build_ev: Box::new(move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "prompts_list_changed")?;
ev.set("server", server_for_task)?;
Ok(ev)
}),
caller: "on_prompt_list_changed",
};
if let Err(e) = tx.try_send(item) {
tracing::warn!(
target: "mcp_client",
error = %e,
"on_prompt_list_changed: notification channel full, dropping notification"
);
}
} else {
isle_dispatch(
main_isle,
server_name,
MCP_USER_PROMPTS_LIST_CHANGED_CBS,
move |lua, server_for_task| {
let ev = lua.create_table()?;
ev.set("type", "prompts_list_changed")?;
ev.set("server", server_for_task)?;
Ok(ev)
},
"on_prompt_list_changed",
);
}
}
}
fn create_message(
&self,
params: CreateMessageRequestParams,
_context: RequestContext<RoleClient>,
) -> impl std::future::Future<Output = Result<CreateMessageResult, McpError>> + Send + '_ {
let isle = self.handler_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name = self.server_name.clone();
async move {
let sn = match server_name.as_deref() {
Some(s) => s.to_string(),
None => {
return Err(McpError::method_not_found::<
rmcp::model::CreateMessageRequestMethod,
>());
}
};
let has_sampling = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard.get(&sn).is_some_and(|r| r.sampling)
};
if !has_sampling {
return Err(McpError::method_not_found::<
rmcp::model::CreateMessageRequestMethod,
>());
}
let isle = match isle {
Some(i) => i,
None => {
return Err(McpError::method_not_found::<
rmcp::model::CreateMessageRequestMethod,
>());
}
};
let params_json = match serde_json::to_string(¶ms) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
target: "mcp_client",
server = %sn,
error = %e,
"create_message: failed to serialize params"
);
return Err(McpError::internal_error(
format!("create_message serialize: {e}"),
None,
));
}
};
let sn_task = sn.clone();
let params_task = params_json.clone();
let result_json = isle
.exec(move |lua| {
use mlua::prelude::*;
let dispatch: LuaFunction =
lua.globals().get(MCP_DISPATCH_SAMPLING).map_err(|e| {
mlua_isle::IsleError::Lua(format!(
"create_message: get dispatcher: {e}"
))
})?;
let result: LuaValue = dispatch
.call((sn_task.as_str(), params_task.as_str()))
.map_err(|e| {
mlua_isle::IsleError::Lua(format!("create_message: dispatch: {e}"))
})?;
match result {
LuaValue::Nil => Ok(String::new()),
LuaValue::Table(tbl) => {
let json_val = crate::bridge::lua_to_json(lua, LuaValue::Table(tbl))
.map_err(|e| {
mlua_isle::IsleError::Lua(format!(
"create_message: lua_to_json: {e}"
))
})?;
serde_json::to_string(&json_val).map_err(|e| {
mlua_isle::IsleError::Lua(format!("create_message: to_string: {e}"))
})
}
other => Err(mlua_isle::IsleError::Lua(format!(
"create_message: handler must return table or nil, got: {:?}",
other.type_name()
))),
}
})
.await;
match result_json {
Err(e) => {
tracing::warn!(
target: "mcp_client",
server = %sn,
error = %e,
"create_message: handler isle error"
);
Err(McpError::internal_error(
format!("sampling handler: {e}"),
None,
))
}
Ok(json_str) if json_str.is_empty() => {
Err(McpError::method_not_found::<
rmcp::model::CreateMessageRequestMethod,
>())
}
Ok(json_str) => {
let v: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
McpError::internal_error(
format!("sampling handler result parse: {e}"),
None,
)
})?;
let model = v
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let stop_reason = v
.get("stop_reason")
.and_then(|v| v.as_str())
.map(ToString::to_string);
let role_str = v
.get("role")
.and_then(|v| v.as_str())
.unwrap_or("assistant");
let role = match role_str {
"user" => Role::User,
_ => Role::Assistant,
};
let content_str = v
.get("content")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let message =
SamplingMessage::new(role, SamplingMessageContent::text(content_str));
let mut result = CreateMessageResult::new(message, model);
if let Some(sr) = stop_reason {
result = result.with_stop_reason(sr);
}
Ok(result)
}
}
}
}
fn list_roots(
&self,
_context: RequestContext<RoleClient>,
) -> impl std::future::Future<Output = Result<rmcp::model::ListRootsResult, McpError>> + Send + '_
{
let isle = self.handler_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name = self.server_name.clone();
async move {
let sn = match server_name.as_deref() {
Some(s) => s.to_string(),
None => {
return Err(McpError::method_not_found::<
rmcp::model::ListRootsRequestMethod,
>());
}
};
let has_roots = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard.get(&sn).is_some_and(|r| r.roots)
};
if !has_roots {
return Err(McpError::method_not_found::<
rmcp::model::ListRootsRequestMethod,
>());
}
let isle = match isle {
Some(i) => i,
None => {
return Err(McpError::method_not_found::<
rmcp::model::ListRootsRequestMethod,
>());
}
};
let sn_task = sn.clone();
let result_val = isle
.exec(move |lua| {
use mlua::prelude::*;
let dispatch: LuaFunction =
lua.globals().get(MCP_DISPATCH_ROOTS).map_err(|e| {
mlua_isle::IsleError::Lua(format!("list_roots: get dispatcher: {e}"))
})?;
let result: LuaValue = dispatch.call(sn_task.as_str()).map_err(|e| {
mlua_isle::IsleError::Lua(format!("list_roots: dispatch: {e}"))
})?;
match result {
LuaValue::Nil => Ok(String::new()),
LuaValue::Table(tbl) => {
let json_val = crate::bridge::lua_to_json(lua, LuaValue::Table(tbl))
.map_err(|e| {
mlua_isle::IsleError::Lua(format!(
"list_roots: lua_to_json: {e}"
))
})?;
serde_json::to_string(&json_val).map_err(|e| {
mlua_isle::IsleError::Lua(format!("list_roots: to_string: {e}"))
})
}
other => Err(mlua_isle::IsleError::Lua(format!(
"list_roots: handler must return table or nil, got: {:?}",
other.type_name()
))),
}
})
.await;
match result_val {
Err(e) => {
tracing::warn!(
target: "mcp_client",
server = %sn,
error = %e,
"list_roots: handler isle error"
);
Err(McpError::internal_error(
format!("roots handler: {e}"),
None,
))
}
Ok(json_str) if json_str.is_empty() => {
Err(McpError::method_not_found::<
rmcp::model::ListRootsRequestMethod,
>())
}
Ok(json_str) => {
let v: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
McpError::internal_error(format!("roots handler result parse: {e}"), None)
})?;
let entries = v.as_array().ok_or_else(|| {
McpError::internal_error(
"roots handler result parse: expected array".to_string(),
None,
)
})?;
let mut roots = Vec::with_capacity(entries.len());
for entry in entries {
let uri = entry
.get("uri")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let name = entry
.get("name")
.and_then(|v| v.as_str())
.map(ToString::to_string);
let root = if let Some(n) = name {
rmcp::model::Root::new(uri).with_name(n)
} else {
rmcp::model::Root::new(uri)
};
roots.push(root);
}
Ok(rmcp::model::ListRootsResult::new(roots))
}
}
}
}
fn create_elicitation(
&self,
request: CreateElicitationRequestParams,
_context: RequestContext<RoleClient>,
) -> impl std::future::Future<Output = Result<CreateElicitationResult, McpError>> + Send + '_
{
let isle = self.handler_isle.clone();
let registry = Arc::clone(&self.registry);
let server_name = self.server_name.clone();
async move {
let (message, requested_schema) = match request {
CreateElicitationRequestParams::UrlElicitationParams { .. } => {
return Ok(CreateElicitationResult {
action: ElicitationAction::Decline,
content: None,
meta: None,
});
}
CreateElicitationRequestParams::FormElicitationParams {
message,
requested_schema,
..
} => (message, requested_schema),
};
let sn = match server_name.as_deref() {
Some(s) => s.to_string(),
None => {
return Err(McpError::method_not_found::<ElicitationCreateRequestMethod>());
}
};
let has_elicitation = {
let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
guard.get(&sn).is_some_and(|r| r.elicitation)
};
if !has_elicitation {
return Ok(CreateElicitationResult {
action: ElicitationAction::Decline,
content: None,
meta: None,
});
}
let isle = match isle {
Some(i) => i,
None => {
return Err(McpError::method_not_found::<ElicitationCreateRequestMethod>());
}
};
let schema_json = serde_json::to_string(&requested_schema).map_err(|e| {
McpError::internal_error(format!("create_elicitation: schema serialize: {e}"), None)
})?;
let sn_task = sn.clone();
let message_task = message.clone();
let result_val = isle
.exec(move |lua| {
use mlua::prelude::*;
let dispatch: LuaFunction =
lua.globals().get(MCP_DISPATCH_ELICITATION).map_err(|e| {
mlua_isle::IsleError::Lua(format!(
"create_elicitation: get dispatcher: {e}"
))
})?;
let result: LuaValue = dispatch
.call((
sn_task.as_str(),
message_task.as_str(),
schema_json.as_str(),
))
.map_err(|e| {
mlua_isle::IsleError::Lua(format!("create_elicitation: dispatch: {e}"))
})?;
match result {
LuaValue::Nil => Ok(String::new()),
LuaValue::Table(tbl) => {
let json_val = crate::bridge::lua_to_json(lua, LuaValue::Table(tbl))
.map_err(|e| {
mlua_isle::IsleError::Lua(format!(
"create_elicitation: lua_to_json: {e}"
))
})?;
serde_json::to_string(&json_val).map_err(|e| {
mlua_isle::IsleError::Lua(format!(
"create_elicitation: to_string: {e}"
))
})
}
other => Err(mlua_isle::IsleError::Lua(format!(
"create_elicitation: handler must return table or nil, got: {:?}",
other.type_name()
))),
}
})
.await;
match result_val {
Err(e) => {
tracing::warn!(
target: "mcp_client",
server = %sn,
error = %e,
"create_elicitation: handler isle error"
);
Err(McpError::internal_error(
format!("elicitation handler: {e}"),
None,
))
}
Ok(json_str) if json_str.is_empty() => {
Ok(CreateElicitationResult {
action: ElicitationAction::Decline,
content: None,
meta: None,
})
}
Ok(json_str) => {
let v: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
McpError::internal_error(
format!("elicitation handler result parse: {e}"),
None,
)
})?;
let action_str = v
.get("action")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| {
McpError::internal_error(
"elicitation handler result: missing or non-string 'action' field"
.to_string(),
None,
)
})?;
let content = v.get("content").cloned();
match action_str {
"accept" => {
if content.is_none() {
tracing::warn!(
target: "mcp_client",
server = %sn,
"create_elicitation: action=accept but content is nil"
);
return Err(McpError::internal_error(
"elicitation handler: action=accept but content is nil"
.to_string(),
None,
));
}
Ok(CreateElicitationResult {
action: ElicitationAction::Accept,
content,
meta: None,
})
}
"decline" => {
if content.is_some() {
tracing::warn!(
target: "mcp_client",
server = %sn,
"create_elicitation: action=decline but content is non-nil"
);
return Err(McpError::internal_error(
"elicitation handler: action=decline but content is non-nil"
.to_string(),
None,
));
}
Ok(CreateElicitationResult {
action: ElicitationAction::Decline,
content: None,
meta: None,
})
}
"cancel" => {
if content.is_some() {
tracing::warn!(
target: "mcp_client",
server = %sn,
"create_elicitation: action=cancel but content is non-nil"
);
return Err(McpError::internal_error(
"elicitation handler: action=cancel but content is non-nil"
.to_string(),
None,
));
}
Ok(CreateElicitationResult {
action: ElicitationAction::Cancel,
content: None,
meta: None,
})
}
other => {
tracing::warn!(
target: "mcp_client",
server = %sn,
action = %other,
"create_elicitation: unknown action"
);
Err(McpError::internal_error(
format!("elicitation handler: unknown action: {other}"),
None,
))
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_handler_has_empty_registry() {
let handler = AgentBlockClientHandler::new();
let guard = handler.registry.lock().unwrap();
assert!(guard.is_empty());
}
#[test]
fn new_handler_has_no_server_name() {
let handler = AgentBlockClientHandler::new();
assert!(handler.server_name.is_none());
}
#[test]
fn server_name_is_preserved_through_clone() {
let mut handler = AgentBlockClientHandler::new();
handler.server_name = Some("srv-a".to_string());
let cloned = handler.clone();
assert_eq!(cloned.server_name.as_deref(), Some("srv-a"));
}
#[test]
fn ensure_server_creates_entry() {
let handler = AgentBlockClientHandler::new();
handler.ensure_server("my-server");
let guard = handler.registry.lock().unwrap();
assert!(guard.contains_key("my-server"));
}
#[test]
fn ensure_server_idempotent() {
let handler = AgentBlockClientHandler::new();
handler.ensure_server("srv");
handler.ensure_server("srv");
let guard = handler.registry.lock().unwrap();
assert_eq!(guard.len(), 1);
}
#[test]
fn clone_shares_registry() {
let h1 = AgentBlockClientHandler::new();
let h2 = h1.clone();
h1.ensure_server("alpha");
let guard = h2.registry.lock().unwrap();
assert!(guard.contains_key("alpha"), "clone must share registry Arc");
}
#[test]
fn mark_on_progress_sets_flag() {
let h = AgentBlockClientHandler::new();
h.ensure_server("srv");
h.mark_on_progress("srv");
let guard = h.registry.lock().unwrap();
assert!(guard.get("srv").unwrap().on_progress);
}
#[test]
fn mark_on_log_sets_flag() {
let h = AgentBlockClientHandler::new();
h.ensure_server("srv");
h.mark_on_log("srv");
let guard = h.registry.lock().unwrap();
assert!(guard.get("srv").unwrap().on_log);
}
#[test]
fn mark_sampling_sets_flag() {
let h = AgentBlockClientHandler::new();
h.ensure_server("srv");
h.mark_sampling("srv");
let guard = h.registry.lock().unwrap();
assert!(guard.get("srv").unwrap().sampling);
}
#[test]
fn mark_on_resource_updated_sets_flag() {
let h = AgentBlockClientHandler::new();
h.ensure_server("srv");
h.mark_on_resource_updated("srv");
let guard = h.registry.lock().unwrap();
assert!(guard.get("srv").unwrap().on_resource_updated);
}
#[test]
fn mark_on_resource_list_changed_sets_flag() {
let h = AgentBlockClientHandler::new();
h.ensure_server("srv");
h.mark_on_resource_list_changed("srv");
let guard = h.registry.lock().unwrap();
assert!(guard.get("srv").unwrap().on_resource_list_changed);
}
#[test]
fn mark_on_tool_list_changed_sets_flag() {
let h = AgentBlockClientHandler::new();
h.ensure_server("srv");
h.mark_on_tool_list_changed("srv");
let guard = h.registry.lock().unwrap();
assert!(guard.get("srv").unwrap().on_tool_list_changed);
}
#[test]
fn mark_on_prompt_list_changed_sets_flag() {
let h = AgentBlockClientHandler::new();
h.ensure_server("srv");
h.mark_on_prompt_list_changed("srv");
let guard = h.registry.lock().unwrap();
assert!(guard.get("srv").unwrap().on_prompt_list_changed);
}
#[test]
fn install_dispatcher_creates_sampling_globals() {
let lua = mlua::Lua::new();
install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
let _: mlua::Table = lua.globals().get(MCP_SAMPLING_HANDLERS).unwrap();
let _: mlua::Function = lua.globals().get(MCP_DISPATCH_SAMPLING).unwrap();
let progress_handlers: mlua::Value = lua.globals().get("__mcp_progress_handlers").unwrap();
assert!(
matches!(progress_handlers, mlua::Value::Nil),
"__mcp_progress_handlers must not be installed on handler Isle"
);
let log_handlers: mlua::Value = lua.globals().get("__mcp_log_handlers").unwrap();
assert!(
matches!(log_handlers, mlua::Value::Nil),
"__mcp_log_handlers must not be installed on handler Isle"
);
}
#[test]
fn handler_isle_has_no_user_callback_tables() {
let lua = mlua::Lua::new();
install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
let progress_cbs: mlua::Value = lua.globals().get(MCP_USER_PROGRESS_CBS).unwrap();
assert!(
matches!(progress_cbs, mlua::Value::Nil),
"__mcp_user_progress_cbs must not be on handler Isle"
);
let log_cbs: mlua::Value = lua.globals().get(MCP_USER_LOG_CBS).unwrap();
assert!(
matches!(log_cbs, mlua::Value::Nil),
"__mcp_user_log_cbs must not be on handler Isle"
);
}
#[tokio::test]
async fn main_isle_progress_cb_preserves_upvalue() {
use mlua_isle::AsyncIsle;
let (isle, driver) = AsyncIsle::spawn(|_lua: &mlua::Lua| Ok(()))
.await
.expect("AsyncIsle::spawn should succeed");
isle.exec(|lua| {
lua.load(
r#"
__mcp_user_progress_cbs = {}
local hits = 0
__mcp_user_progress_cbs["test-srv"] = function(ev)
hits = hits + 1
end
_G.get_hits = function() return hits end
"#,
)
.exec()
.map_err(|e| mlua_isle::IsleError::Lua(format!("setup: {e}")))?;
Ok(String::new())
})
.await
.expect("setup exec");
for _ in 0..3 {
isle.exec(|lua| {
use mlua::prelude::*;
let cbs: LuaTable = lua
.globals()
.get(MCP_USER_PROGRESS_CBS)
.map_err(|e| mlua_isle::IsleError::Lua(format!("get cbs: {e}")))?;
let cb: LuaFunction = cbs
.get("test-srv")
.map_err(|e| mlua_isle::IsleError::Lua(format!("get cb: {e}")))?;
let ev = lua
.create_table()
.map_err(|e| mlua_isle::IsleError::Lua(format!("create ev: {e}")))?;
let _ = cb.call::<()>(ev);
Ok(String::new())
})
.await
.expect("dispatch exec");
}
let hits_str = isle
.exec(|lua| {
use mlua::prelude::*;
let get_hits: LuaFunction = lua
.globals()
.get("get_hits")
.map_err(|e| mlua_isle::IsleError::Lua(format!("get_hits: {e}")))?;
let n: i64 = get_hits
.call(())
.map_err(|e| mlua_isle::IsleError::Lua(format!("call get_hits: {e}")))?;
Ok(n.to_string())
})
.await
.expect("read hits exec");
let hits: i64 = hits_str.parse().expect("hits must be integer");
assert_eq!(hits, 3, "upvalue counter must reach 3");
driver.shutdown().await.expect("shutdown");
}
#[test]
fn sampling_dispatcher_returns_nil_when_no_handler() {
let lua = mlua::Lua::new();
install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
let dispatch: mlua::Function = lua.globals().get(MCP_DISPATCH_SAMPLING).unwrap();
let result: mlua::Value = dispatch.call(("no-srv", "{}")).unwrap();
assert!(
matches!(result, mlua::Value::Nil),
"expected nil when no handler"
);
}
#[test]
fn sampling_dispatcher_calls_registered_handler() {
let lua = mlua::Lua::new();
install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
lua.load(
r#"
__mcp_sampling_handlers["srv"] = function(sn, params_json)
return { model = "test-model", stop_reason = "endTurn",
role = "assistant", content = "hello" }
end
local result = __mcp_dispatch_sampling("srv", "{}")
assert(type(result) == "table")
assert(result.model == "test-model")
assert(result.content == "hello")
"#,
)
.exec()
.unwrap();
}
}