1use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10use mlua_isle::AsyncIsle;
11use rmcp::{
12 handler::client::ClientHandler,
13 model::{
14 CreateElicitationRequestParams, CreateElicitationResult, CreateMessageRequestParams,
15 CreateMessageResult, ElicitationAction, ElicitationCreateRequestMethod, LoggingLevel,
16 LoggingMessageNotificationParam, ProgressNotificationParam,
17 ResourceUpdatedNotificationParam, Role, SamplingMessage, SamplingMessageContent,
18 },
19 service::{NotificationContext, RequestContext, RoleClient},
20 ErrorData as McpError,
21};
22use tokio::sync::mpsc;
23
24pub(crate) const MCP_SAMPLING_HANDLERS: &str = "__mcp_sampling_handlers";
27
28const MCP_DISPATCH_SAMPLING: &str = "__mcp_dispatch_sampling";
30
31pub(crate) const MCP_ROOTS_HANDLERS: &str = "__mcp_roots_handlers";
34
35const MCP_DISPATCH_ROOTS: &str = "__mcp_dispatch_roots";
37
38pub(crate) const MCP_ELICITATION_HANDLERS: &str = "__mcp_elicitation_handlers";
41
42const MCP_DISPATCH_ELICITATION: &str = "__mcp_dispatch_elicitation";
44
45pub const MCP_USER_PROGRESS_CBS: &str = "__mcp_user_progress_cbs";
52
53pub const MCP_USER_LOG_CBS: &str = "__mcp_user_log_cbs";
58
59pub const MCP_USER_RESOURCE_UPDATE_CBS: &str = "__mcp_user_resource_update_cbs";
64
65pub const MCP_USER_RESOURCES_LIST_CHANGED_CBS: &str = "__mcp_user_resources_list_changed_cbs";
70
71pub const MCP_USER_TOOLS_LIST_CHANGED_CBS: &str = "__mcp_user_tools_list_changed_cbs";
76
77pub const MCP_USER_PROMPTS_LIST_CHANGED_CBS: &str = "__mcp_user_prompts_list_changed_cbs";
82
83const NOTIFY_CHANNEL_CAPACITY: usize = 128;
89
90type BuildEvFn = Box<dyn FnOnce(&mlua::Lua, &str) -> mlua::Result<mlua::Table> + Send + 'static>;
92
93pub(crate) struct NotificationItem {
99 pub(crate) isle: Arc<AsyncIsle>,
100 pub(crate) server_name: String,
101 pub(crate) cbs_table: &'static str,
102 pub(crate) build_ev: BuildEvFn,
103 pub(crate) caller: &'static str,
104}
105
106pub(crate) struct ServerHandlerRegistry {
112 pub(crate) on_progress: bool,
114 pub(crate) on_log: bool,
116 pub(crate) on_resource_updated: bool,
118 pub(crate) on_resource_list_changed: bool,
120 pub(crate) on_tool_list_changed: bool,
122 pub(crate) on_prompt_list_changed: bool,
124 pub(crate) sampling: bool,
126 pub(crate) roots: bool,
128 pub(crate) elicitation: bool,
130 pub(crate) trace_context: bool,
134}
135
136impl ServerHandlerRegistry {
137 fn new() -> Self {
138 Self {
139 on_progress: false,
140 on_log: false,
141 on_resource_updated: false,
142 on_resource_list_changed: false,
143 on_tool_list_changed: false,
144 on_prompt_list_changed: false,
145 sampling: false,
146 roots: false,
147 elicitation: false,
148 trace_context: false,
149 }
150 }
151}
152
153#[derive(Clone)]
175pub struct AgentBlockClientHandler {
176 pub(crate) registry: Arc<Mutex<HashMap<String, ServerHandlerRegistry>>>,
179 pub(crate) handler_isle: Option<Arc<AsyncIsle>>,
182 pub(crate) main_isle: Option<Arc<AsyncIsle>>,
187 pub(crate) server_name: Option<String>,
190 pub(crate) notify_tx: Option<mpsc::Sender<NotificationItem>>,
202}
203
204impl AgentBlockClientHandler {
205 pub fn new() -> Self {
211 Self {
212 registry: Arc::new(Mutex::new(HashMap::new())),
213 handler_isle: None,
214 main_isle: None,
215 server_name: None,
216 notify_tx: None,
217 }
218 }
219
220 pub(crate) fn start_dispatch_task(&mut self) {
228 let (tx, mut rx) = mpsc::channel::<NotificationItem>(NOTIFY_CHANNEL_CAPACITY);
229 self.notify_tx = Some(tx);
230 tokio::spawn(async move {
232 while let Some(item) = rx.recv().await {
233 let sn = item.server_name.clone();
234 let result = item
235 .isle
236 .exec(move |lua| {
237 use mlua::prelude::*;
238 let cbs: LuaTable = match lua.globals().get(item.cbs_table) {
239 Ok(t) => t,
240 Err(_) => return Ok(String::new()),
241 };
242 let cb: LuaFunction = match cbs.get(item.server_name.as_str()) {
243 Ok(f) => f,
244 Err(_) => return Ok(String::new()),
245 };
246 let ev = (item.build_ev)(lua, item.server_name.as_str()).map_err(|e| {
247 mlua_isle::IsleError::Lua(format!("{}: build_ev: {e}", item.caller))
248 })?;
249 if let Err(e) = cb.call::<()>(ev) {
250 tracing::warn!(
251 target: "mcp_client",
252 server = %item.server_name,
253 caller = %item.caller,
254 error = %e,
255 "user callback returned error"
256 );
257 }
258 Ok(String::new())
259 })
260 .await;
261 if let Err(e) = result {
262 tracing::warn!(
263 target: "mcp_client",
264 server = %sn,
265 error = %e,
266 "notification dispatch: main isle exec failed"
267 );
268 }
269 }
270 });
271 }
272
273 pub(crate) fn ensure_server(&self, server_name: &str) {
279 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
280 guard
281 .entry(server_name.to_string())
282 .or_insert_with(ServerHandlerRegistry::new);
283 }
284
285 pub fn mark_on_progress(&self, server_name: &str) {
288 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
289 let entry = guard
290 .entry(server_name.to_string())
291 .or_insert_with(ServerHandlerRegistry::new);
292 entry.on_progress = true;
293 }
294
295 pub fn mark_on_log(&self, server_name: &str) {
297 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
298 let entry = guard
299 .entry(server_name.to_string())
300 .or_insert_with(ServerHandlerRegistry::new);
301 entry.on_log = true;
302 }
303
304 pub fn mark_on_resource_updated(&self, server_name: &str) {
306 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
307 let entry = guard
308 .entry(server_name.to_string())
309 .or_insert_with(ServerHandlerRegistry::new);
310 entry.on_resource_updated = true;
311 }
312
313 pub fn mark_on_resource_list_changed(&self, server_name: &str) {
315 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
316 let entry = guard
317 .entry(server_name.to_string())
318 .or_insert_with(ServerHandlerRegistry::new);
319 entry.on_resource_list_changed = true;
320 }
321
322 pub fn mark_on_tool_list_changed(&self, server_name: &str) {
324 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
325 let entry = guard
326 .entry(server_name.to_string())
327 .or_insert_with(ServerHandlerRegistry::new);
328 entry.on_tool_list_changed = true;
329 }
330
331 pub fn mark_on_prompt_list_changed(&self, server_name: &str) {
333 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
334 let entry = guard
335 .entry(server_name.to_string())
336 .or_insert_with(ServerHandlerRegistry::new);
337 entry.on_prompt_list_changed = true;
338 }
339
340 pub(crate) fn set_trace_context(&self, server_name: &str, enabled: bool) {
343 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
344 let entry = guard
345 .entry(server_name.to_string())
346 .or_insert_with(ServerHandlerRegistry::new);
347 entry.trace_context = enabled;
348 }
349
350 pub fn trace_context_enabled(&self, server_name: &str) -> bool {
352 let guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
353 guard.get(server_name).is_some_and(|r| r.trace_context)
354 }
355
356 pub fn mark_sampling(&self, server_name: &str) {
358 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
359 let entry = guard
360 .entry(server_name.to_string())
361 .or_insert_with(ServerHandlerRegistry::new);
362 entry.sampling = true;
363 }
364
365 pub fn mark_roots(&self, server_name: &str) {
375 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
376 let entry = guard
377 .entry(server_name.to_string())
378 .or_insert_with(ServerHandlerRegistry::new);
379 entry.roots = true;
380 }
381
382 pub fn mark_elicitation(&self, server_name: &str) {
392 let mut guard = self.registry.lock().unwrap_or_else(|e| e.into_inner());
393 let entry = guard
394 .entry(server_name.to_string())
395 .or_insert_with(ServerHandlerRegistry::new);
396 entry.elicitation = true;
397 }
398}
399
400impl Default for AgentBlockClientHandler {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406pub fn install_mcp_dispatcher_on_handler_isle(lua: &mlua::Lua) -> mlua::Result<()> {
419 use mlua::prelude::*;
420
421 lua.globals()
423 .set(MCP_SAMPLING_HANDLERS, lua.create_table()?)?;
424
425 let sampling_src = r#"
426 local HANDLERS = "__mcp_sampling_handlers"
427 return function(server_name, params_json)
428 local handlers = _G[HANDLERS]
429 local h = handlers and handlers[server_name]
430 if type(h) ~= "function" then
431 return nil -- signal: no handler registered
432 end
433 return h(server_name, params_json)
434 end
435 "#;
436 let dispatch_sampling: LuaFunction = lua
437 .load(sampling_src)
438 .set_name("@agent_block:__mcp_dispatch_sampling")
439 .eval()?;
440 lua.globals()
441 .set(MCP_DISPATCH_SAMPLING, dispatch_sampling)?;
442
443 lua.globals().set(MCP_ROOTS_HANDLERS, lua.create_table()?)?;
445
446 let roots_src = r#"
447 local HANDLERS = "__mcp_roots_handlers"
448 return function(server_name)
449 local handlers = _G[HANDLERS]
450 local h = handlers and handlers[server_name]
451 if type(h) ~= "function" then
452 return nil -- signal: no handler registered
453 end
454 return h(server_name)
455 end
456 "#;
457 let dispatch_roots: LuaFunction = lua
458 .load(roots_src)
459 .set_name("@agent_block:__mcp_dispatch_roots")
460 .eval()?;
461 lua.globals().set(MCP_DISPATCH_ROOTS, dispatch_roots)?;
462
463 lua.globals()
465 .set(MCP_ELICITATION_HANDLERS, lua.create_table()?)?;
466
467 let elicitation_src = r#"
468 local HANDLERS = "__mcp_elicitation_handlers"
469 return function(server_name, message, schema_json)
470 local handlers = _G[HANDLERS]
471 local h = handlers and handlers[server_name]
472 if type(h) ~= "function" then
473 return nil -- signal: no handler registered → Decline
474 end
475 return h(server_name, message, schema_json)
476 end
477 "#;
478 let dispatch_elicitation: LuaFunction = lua
479 .load(elicitation_src)
480 .set_name("@agent_block:__mcp_dispatch_elicitation")
481 .eval()?;
482 lua.globals()
483 .set(MCP_DISPATCH_ELICITATION, dispatch_elicitation)?;
484
485 Ok(())
486}
487
488fn isle_dispatch<F>(
504 isle: Arc<AsyncIsle>,
505 server_name: String,
506 cbs_table: &'static str,
507 build_ev: F,
508 caller: &'static str,
509) where
510 F: FnOnce(&mlua::Lua, &str) -> mlua::Result<mlua::Table> + Send + 'static,
511{
512 tokio::spawn(async move {
513 let sn = server_name.clone();
514 let result = isle
515 .exec(move |lua| {
516 use mlua::prelude::*;
517 let cbs: LuaTable = match lua.globals().get(cbs_table) {
519 Ok(t) => t,
520 Err(_) => return Ok(String::new()), };
522 let cb: LuaFunction = match cbs.get(server_name.as_str()) {
523 Ok(f) => f,
524 Err(_) => return Ok(String::new()), };
526 let ev = build_ev(lua, server_name.as_str())
530 .map_err(|e| mlua_isle::IsleError::Lua(format!("{caller}: build_ev: {e}")))?;
531 if let Err(e) = cb.call::<()>(ev) {
532 tracing::warn!(
533 target: "mcp_client",
534 server = %server_name,
535 caller = %caller,
536 error = %e,
537 "user callback returned error"
538 );
539 }
540 Ok(String::new())
541 })
542 .await;
543 if let Err(e) = result {
544 tracing::warn!(
545 target: "mcp_client",
546 server = %sn,
547 error = %e,
548 "{}: main isle exec failed",
549 caller
550 );
551 }
552 });
553}
554
555impl ClientHandler for AgentBlockClientHandler {
556 fn on_progress(
557 &self,
558 params: ProgressNotificationParam,
559 _context: NotificationContext<RoleClient>,
560 ) -> impl std::future::Future<Output = ()> + Send + '_ {
561 let main_isle = self.main_isle.clone();
564 let registry = Arc::clone(&self.registry);
565 let server_name_opt = self.server_name.clone();
568 let notify_tx = self.notify_tx.clone();
570
571 async move {
572 let main_isle = match main_isle {
573 Some(i) => i,
574 None => return, };
576
577 let server_name = match server_name_opt {
583 Some(s) => s,
584 None => return, };
586 let has_cb = {
587 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
588 guard.get(&server_name).is_some_and(|r| r.on_progress)
589 };
590 if !has_cb {
592 return;
593 }
594
595 let token_str = match ¶ms.progress_token.0 {
596 rmcp::model::NumberOrString::Number(n) => n.to_string(),
597 rmcp::model::NumberOrString::String(s) => s.to_string(),
598 };
599 let progress_f64: f64 = params.progress;
600 let total_opt: Option<f64> = params.total;
601 let message_opt: Option<String> = params.message;
602
603 if let Some(tx) = notify_tx {
606 let item = NotificationItem {
607 isle: main_isle,
608 server_name,
609 cbs_table: MCP_USER_PROGRESS_CBS,
610 build_ev: Box::new(move |lua, server_for_task| {
611 let ev = lua.create_table()?;
612 ev.set("type", "progress")?;
613 ev.set("server", server_for_task)?;
614 ev.set("token", token_str.as_str())?;
615 ev.set("progress", progress_f64)?;
616 if let Some(t) = total_opt {
617 ev.set("total", t)?;
618 }
619 if let Some(ref m) = message_opt {
620 ev.set("message", m.as_str())?;
621 }
622 Ok(ev)
623 }),
624 caller: "on_progress",
625 };
626 if let Err(e) = tx.try_send(item) {
627 tracing::warn!(
629 target: "mcp_client",
630 error = %e,
631 "on_progress: notification channel full, dropping notification \
632 (server is emitting faster than Lua can consume)"
633 );
634 }
635 } else {
636 isle_dispatch(
638 main_isle,
639 server_name,
640 MCP_USER_PROGRESS_CBS,
641 move |lua, server_for_task| {
642 let ev = lua.create_table()?;
643 ev.set("type", "progress")?;
644 ev.set("server", server_for_task)?;
645 ev.set("token", token_str.as_str())?;
646 ev.set("progress", progress_f64)?;
647 if let Some(t) = total_opt {
648 ev.set("total", t)?;
649 }
650 if let Some(ref m) = message_opt {
651 ev.set("message", m.as_str())?;
652 }
653 Ok(ev)
654 },
655 "on_progress",
656 );
657 }
658 }
659 }
660
661 fn on_logging_message(
662 &self,
663 params: LoggingMessageNotificationParam,
664 _context: NotificationContext<RoleClient>,
665 ) -> impl std::future::Future<Output = ()> + Send + '_ {
666 let main_isle = self.main_isle.clone();
667 let registry = Arc::clone(&self.registry);
668 let server_name = self.server_name.clone();
669 let notify_tx = self.notify_tx.clone();
670
671 async move {
672 let level = ¶ms.level;
673 let logger = params.logger.as_deref().unwrap_or("").to_string();
674 let data_str = match serde_json::to_string(¶ms.data) {
676 Ok(s) => s,
677 Err(e) => {
678 tracing::warn!(
679 target: "mcp_client",
680 error = %e,
681 "on_logging_message: failed to serialize data"
682 );
683 return;
684 }
685 };
686
687 let level_str = match level {
688 LoggingLevel::Debug => "debug",
689 LoggingLevel::Info | LoggingLevel::Notice => "info",
690 LoggingLevel::Warning => "warning",
691 LoggingLevel::Error
692 | LoggingLevel::Critical
693 | LoggingLevel::Alert
694 | LoggingLevel::Emergency => "error",
695 }
696 .to_string();
697
698 let sn_str = server_name.as_deref().unwrap_or("unknown").to_string();
700
701 let has_lua_handler = server_name.as_deref().is_some_and(|sn| {
703 registry
704 .lock()
705 .unwrap_or_else(|e| e.into_inner())
706 .get(sn)
707 .is_some_and(|r| r.on_log)
708 });
709
710 if has_lua_handler {
711 if let (Some(isle), Some(sn)) = (main_isle, server_name) {
712 let level_task = level_str.clone();
713 let logger_task = logger.clone();
714 let data_task = data_str.clone();
715
716 if let Some(tx) = notify_tx {
717 let item = NotificationItem {
718 isle,
719 server_name: sn,
720 cbs_table: MCP_USER_LOG_CBS,
721 build_ev: Box::new(move |lua, server_for_task| {
722 let ev = lua.create_table()?;
723 ev.set("type", "log")?;
724 ev.set("server", server_for_task)?;
725 ev.set("level", level_task.as_str())?;
726 ev.set("logger", logger_task.as_str())?;
727 ev.set("data", data_task.as_str())?;
728 Ok(ev)
729 }),
730 caller: "on_logging_message",
731 };
732 if let Err(e) = tx.try_send(item) {
733 tracing::warn!(
734 target: "mcp_client",
735 error = %e,
736 "on_logging_message: notification channel full, dropping notification"
737 );
738 }
739 } else {
740 isle_dispatch(
742 isle,
743 sn,
744 MCP_USER_LOG_CBS,
745 move |lua, server_for_task| {
746 let ev = lua.create_table()?;
747 ev.set("type", "log")?;
748 ev.set("server", server_for_task)?;
749 ev.set("level", level_task.as_str())?;
750 ev.set("logger", logger_task.as_str())?;
751 ev.set("data", data_task.as_str())?;
752 Ok(ev)
753 },
754 "on_logging_message",
755 );
756 }
757 return;
758 }
759 }
760
761 match level {
764 LoggingLevel::Debug => {
765 tracing::debug!(
766 target: "lua",
767 script = "mcp_server",
768 server = %sn_str,
769 logger = %logger,
770 "{}",
771 data_str
772 );
773 }
774 LoggingLevel::Info | LoggingLevel::Notice => {
775 tracing::info!(
776 target: "lua",
777 script = "mcp_server",
778 server = %sn_str,
779 logger = %logger,
780 "{}",
781 data_str
782 );
783 }
784 LoggingLevel::Warning => {
785 tracing::warn!(
786 target: "lua",
787 script = "mcp_server",
788 server = %sn_str,
789 logger = %logger,
790 "{}",
791 data_str
792 );
793 }
794 LoggingLevel::Error
795 | LoggingLevel::Critical
796 | LoggingLevel::Alert
797 | LoggingLevel::Emergency => {
798 tracing::error!(
799 target: "lua",
800 script = "mcp_server",
801 server = %sn_str,
802 logger = %logger,
803 "{}",
804 data_str
805 );
806 }
807 }
808 }
809 }
810
811 fn on_resource_updated(
812 &self,
813 params: ResourceUpdatedNotificationParam,
814 _context: NotificationContext<RoleClient>,
815 ) -> impl std::future::Future<Output = ()> + Send + '_ {
816 let main_isle = self.main_isle.clone();
817 let registry = Arc::clone(&self.registry);
818 let server_name_opt = self.server_name.clone();
819 let notify_tx = self.notify_tx.clone();
820
821 async move {
822 let main_isle = match main_isle {
823 Some(i) => i,
824 None => return,
825 };
826 let server_name = match server_name_opt {
827 Some(s) => s,
828 None => return,
829 };
830 let has_cb = {
831 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
832 guard
833 .get(&server_name)
834 .is_some_and(|r| r.on_resource_updated)
835 };
837 if !has_cb {
838 return;
839 }
840
841 let uri = params.uri.clone();
842
843 if let Some(tx) = notify_tx {
844 let item = NotificationItem {
845 isle: main_isle,
846 server_name,
847 cbs_table: MCP_USER_RESOURCE_UPDATE_CBS,
848 build_ev: Box::new(move |lua, server_for_task| {
849 let ev = lua.create_table()?;
850 ev.set("type", "resource_update")?;
851 ev.set("server", server_for_task)?;
852 ev.set("uri", uri.as_str())?;
853 Ok(ev)
854 }),
855 caller: "on_resource_updated",
856 };
857 if let Err(e) = tx.try_send(item) {
858 tracing::warn!(
859 target: "mcp_client",
860 error = %e,
861 "on_resource_updated: notification channel full, dropping notification \
862 (server is emitting faster than Lua can consume)"
863 );
864 }
865 } else {
866 isle_dispatch(
867 main_isle,
868 server_name,
869 MCP_USER_RESOURCE_UPDATE_CBS,
870 move |lua, server_for_task| {
871 let ev = lua.create_table()?;
872 ev.set("type", "resource_update")?;
873 ev.set("server", server_for_task)?;
874 ev.set("uri", uri.as_str())?;
875 Ok(ev)
876 },
877 "on_resource_updated",
878 );
879 }
880 }
881 }
882
883 fn on_resource_list_changed(
884 &self,
885 _context: NotificationContext<RoleClient>,
886 ) -> impl std::future::Future<Output = ()> + Send + '_ {
887 let main_isle = self.main_isle.clone();
888 let registry = Arc::clone(&self.registry);
889 let server_name_opt = self.server_name.clone();
890 let notify_tx = self.notify_tx.clone();
891
892 async move {
893 let main_isle = match main_isle {
894 Some(i) => i,
895 None => return,
896 };
897 let server_name = match server_name_opt {
898 Some(s) => s,
899 None => return,
900 };
901 let has_cb = {
902 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
903 guard
904 .get(&server_name)
905 .is_some_and(|r| r.on_resource_list_changed)
906 };
908 if !has_cb {
909 return;
910 }
911
912 if let Some(tx) = notify_tx {
913 let item = NotificationItem {
914 isle: main_isle,
915 server_name,
916 cbs_table: MCP_USER_RESOURCES_LIST_CHANGED_CBS,
917 build_ev: Box::new(move |lua, server_for_task| {
918 let ev = lua.create_table()?;
919 ev.set("type", "resources_list_changed")?;
920 ev.set("server", server_for_task)?;
921 Ok(ev)
922 }),
923 caller: "on_resource_list_changed",
924 };
925 if let Err(e) = tx.try_send(item) {
926 tracing::warn!(
927 target: "mcp_client",
928 error = %e,
929 "on_resource_list_changed: notification channel full, dropping notification"
930 );
931 }
932 } else {
933 isle_dispatch(
934 main_isle,
935 server_name,
936 MCP_USER_RESOURCES_LIST_CHANGED_CBS,
937 move |lua, server_for_task| {
938 let ev = lua.create_table()?;
939 ev.set("type", "resources_list_changed")?;
940 ev.set("server", server_for_task)?;
941 Ok(ev)
942 },
943 "on_resource_list_changed",
944 );
945 }
946 }
947 }
948
949 fn on_tool_list_changed(
950 &self,
951 _context: NotificationContext<RoleClient>,
952 ) -> impl std::future::Future<Output = ()> + Send + '_ {
953 let main_isle = self.main_isle.clone();
954 let registry = Arc::clone(&self.registry);
955 let server_name_opt = self.server_name.clone();
956 let notify_tx = self.notify_tx.clone();
957
958 async move {
959 let main_isle = match main_isle {
960 Some(i) => i,
961 None => return,
962 };
963 let server_name = match server_name_opt {
964 Some(s) => s,
965 None => return,
966 };
967 let has_cb = {
968 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
969 guard
970 .get(&server_name)
971 .is_some_and(|r| r.on_tool_list_changed)
972 };
974 if !has_cb {
975 return;
976 }
977
978 if let Some(tx) = notify_tx {
979 let item = NotificationItem {
980 isle: main_isle,
981 server_name,
982 cbs_table: MCP_USER_TOOLS_LIST_CHANGED_CBS,
983 build_ev: Box::new(move |lua, server_for_task| {
984 let ev = lua.create_table()?;
985 ev.set("type", "tools_list_changed")?;
986 ev.set("server", server_for_task)?;
987 Ok(ev)
988 }),
989 caller: "on_tool_list_changed",
990 };
991 if let Err(e) = tx.try_send(item) {
992 tracing::warn!(
993 target: "mcp_client",
994 error = %e,
995 "on_tool_list_changed: notification channel full, dropping notification"
996 );
997 }
998 } else {
999 isle_dispatch(
1000 main_isle,
1001 server_name,
1002 MCP_USER_TOOLS_LIST_CHANGED_CBS,
1003 move |lua, server_for_task| {
1004 let ev = lua.create_table()?;
1005 ev.set("type", "tools_list_changed")?;
1006 ev.set("server", server_for_task)?;
1007 Ok(ev)
1008 },
1009 "on_tool_list_changed",
1010 );
1011 }
1012 }
1013 }
1014
1015 fn on_prompt_list_changed(
1016 &self,
1017 _context: NotificationContext<RoleClient>,
1018 ) -> impl std::future::Future<Output = ()> + Send + '_ {
1019 let main_isle = self.main_isle.clone();
1020 let registry = Arc::clone(&self.registry);
1021 let server_name_opt = self.server_name.clone();
1022 let notify_tx = self.notify_tx.clone();
1023
1024 async move {
1025 let main_isle = match main_isle {
1026 Some(i) => i,
1027 None => return,
1028 };
1029 let server_name = match server_name_opt {
1030 Some(s) => s,
1031 None => return,
1032 };
1033 let has_cb = {
1034 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
1035 guard
1036 .get(&server_name)
1037 .is_some_and(|r| r.on_prompt_list_changed)
1038 };
1040 if !has_cb {
1041 return;
1042 }
1043
1044 if let Some(tx) = notify_tx {
1045 let item = NotificationItem {
1046 isle: main_isle,
1047 server_name,
1048 cbs_table: MCP_USER_PROMPTS_LIST_CHANGED_CBS,
1049 build_ev: Box::new(move |lua, server_for_task| {
1050 let ev = lua.create_table()?;
1051 ev.set("type", "prompts_list_changed")?;
1052 ev.set("server", server_for_task)?;
1053 Ok(ev)
1054 }),
1055 caller: "on_prompt_list_changed",
1056 };
1057 if let Err(e) = tx.try_send(item) {
1058 tracing::warn!(
1059 target: "mcp_client",
1060 error = %e,
1061 "on_prompt_list_changed: notification channel full, dropping notification"
1062 );
1063 }
1064 } else {
1065 isle_dispatch(
1066 main_isle,
1067 server_name,
1068 MCP_USER_PROMPTS_LIST_CHANGED_CBS,
1069 move |lua, server_for_task| {
1070 let ev = lua.create_table()?;
1071 ev.set("type", "prompts_list_changed")?;
1072 ev.set("server", server_for_task)?;
1073 Ok(ev)
1074 },
1075 "on_prompt_list_changed",
1076 );
1077 }
1078 }
1079 }
1080
1081 fn create_message(
1082 &self,
1083 params: CreateMessageRequestParams,
1084 _context: RequestContext<RoleClient>,
1085 ) -> impl std::future::Future<Output = Result<CreateMessageResult, McpError>> + Send + '_ {
1086 let isle = self.handler_isle.clone();
1087 let registry = Arc::clone(&self.registry);
1088 let server_name = self.server_name.clone();
1089
1090 async move {
1091 let sn = match server_name.as_deref() {
1093 Some(s) => s.to_string(),
1094 None => {
1095 return Err(McpError::method_not_found::<
1096 rmcp::model::CreateMessageRequestMethod,
1097 >());
1098 }
1099 };
1100
1101 let has_sampling = {
1103 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
1104 guard.get(&sn).is_some_and(|r| r.sampling)
1105 };
1106
1107 if !has_sampling {
1108 return Err(McpError::method_not_found::<
1109 rmcp::model::CreateMessageRequestMethod,
1110 >());
1111 }
1112
1113 let isle = match isle {
1114 Some(i) => i,
1115 None => {
1116 return Err(McpError::method_not_found::<
1117 rmcp::model::CreateMessageRequestMethod,
1118 >());
1119 }
1120 };
1121
1122 let params_json = match serde_json::to_string(¶ms) {
1124 Ok(s) => s,
1125 Err(e) => {
1126 tracing::warn!(
1127 target: "mcp_client",
1128 server = %sn,
1129 error = %e,
1130 "create_message: failed to serialize params"
1131 );
1132 return Err(McpError::internal_error(
1133 format!("create_message serialize: {e}"),
1134 None,
1135 ));
1136 }
1137 };
1138
1139 let sn_task = sn.clone();
1141 let params_task = params_json.clone();
1142 let result_json = isle
1143 .exec(move |lua| {
1144 use mlua::prelude::*;
1145 let dispatch: LuaFunction =
1146 lua.globals().get(MCP_DISPATCH_SAMPLING).map_err(|e| {
1147 mlua_isle::IsleError::Lua(format!(
1148 "create_message: get dispatcher: {e}"
1149 ))
1150 })?;
1151 let result: LuaValue = dispatch
1152 .call((sn_task.as_str(), params_task.as_str()))
1153 .map_err(|e| {
1154 mlua_isle::IsleError::Lua(format!("create_message: dispatch: {e}"))
1155 })?;
1156
1157 match result {
1159 LuaValue::Nil => Ok(String::new()),
1160 LuaValue::Table(tbl) => {
1161 let json_val = crate::lua_json::lua_to_json(lua, LuaValue::Table(tbl))
1163 .map_err(|e| {
1164 mlua_isle::IsleError::Lua(format!(
1165 "create_message: lua_to_json: {e}"
1166 ))
1167 })?;
1168 serde_json::to_string(&json_val).map_err(|e| {
1169 mlua_isle::IsleError::Lua(format!("create_message: to_string: {e}"))
1170 })
1171 }
1172 other => Err(mlua_isle::IsleError::Lua(format!(
1173 "create_message: handler must return table or nil, got: {:?}",
1174 other.type_name()
1175 ))),
1176 }
1177 })
1178 .await;
1179
1180 match result_json {
1181 Err(e) => {
1182 tracing::warn!(
1183 target: "mcp_client",
1184 server = %sn,
1185 error = %e,
1186 "create_message: handler isle error"
1187 );
1188 Err(McpError::internal_error(
1189 format!("sampling handler: {e}"),
1190 None,
1191 ))
1192 }
1193 Ok(json_str) if json_str.is_empty() => {
1194 Err(McpError::method_not_found::<
1196 rmcp::model::CreateMessageRequestMethod,
1197 >())
1198 }
1199 Ok(json_str) => {
1200 let v: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
1202 McpError::internal_error(
1203 format!("sampling handler result parse: {e}"),
1204 None,
1205 )
1206 })?;
1207
1208 let model = v
1209 .get("model")
1210 .and_then(|v| v.as_str())
1211 .unwrap_or("unknown")
1212 .to_string();
1213 let stop_reason = v
1214 .get("stop_reason")
1215 .and_then(|v| v.as_str())
1216 .map(ToString::to_string);
1217 let role_str = v
1218 .get("role")
1219 .and_then(|v| v.as_str())
1220 .unwrap_or("assistant");
1221 let role = match role_str {
1222 "user" => Role::User,
1223 _ => Role::Assistant,
1224 };
1225 let content_str = v
1226 .get("content")
1227 .and_then(|v| v.as_str())
1228 .unwrap_or("")
1229 .to_string();
1230
1231 let message =
1232 SamplingMessage::new(role, SamplingMessageContent::text(content_str));
1233 let mut result = CreateMessageResult::new(message, model);
1234 if let Some(sr) = stop_reason {
1235 result = result.with_stop_reason(sr);
1236 }
1237 Ok(result)
1238 }
1239 }
1240 }
1241 }
1242
1243 fn list_roots(
1257 &self,
1258 _context: RequestContext<RoleClient>,
1259 ) -> impl std::future::Future<Output = Result<rmcp::model::ListRootsResult, McpError>> + Send + '_
1260 {
1261 let isle = self.handler_isle.clone();
1262 let registry = Arc::clone(&self.registry);
1263 let server_name = self.server_name.clone();
1264
1265 async move {
1266 let sn = match server_name.as_deref() {
1268 Some(s) => s.to_string(),
1269 None => {
1270 return Err(McpError::method_not_found::<
1271 rmcp::model::ListRootsRequestMethod,
1272 >());
1273 }
1274 };
1275
1276 let has_roots = {
1278 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
1279 guard.get(&sn).is_some_and(|r| r.roots)
1280 };
1281
1282 if !has_roots {
1283 return Err(McpError::method_not_found::<
1284 rmcp::model::ListRootsRequestMethod,
1285 >());
1286 }
1287
1288 let isle = match isle {
1289 Some(i) => i,
1290 None => {
1291 return Err(McpError::method_not_found::<
1292 rmcp::model::ListRootsRequestMethod,
1293 >());
1294 }
1295 };
1296
1297 let sn_task = sn.clone();
1299 let result_val = isle
1300 .exec(move |lua| {
1301 use mlua::prelude::*;
1302 let dispatch: LuaFunction =
1303 lua.globals().get(MCP_DISPATCH_ROOTS).map_err(|e| {
1304 mlua_isle::IsleError::Lua(format!("list_roots: get dispatcher: {e}"))
1305 })?;
1306 let result: LuaValue = dispatch.call(sn_task.as_str()).map_err(|e| {
1307 mlua_isle::IsleError::Lua(format!("list_roots: dispatch: {e}"))
1308 })?;
1309
1310 match result {
1312 LuaValue::Nil => Ok(String::new()),
1313 LuaValue::Table(tbl) => {
1314 let json_val = crate::lua_json::lua_to_json(lua, LuaValue::Table(tbl))
1316 .map_err(|e| {
1317 mlua_isle::IsleError::Lua(format!(
1318 "list_roots: lua_to_json: {e}"
1319 ))
1320 })?;
1321 serde_json::to_string(&json_val).map_err(|e| {
1322 mlua_isle::IsleError::Lua(format!("list_roots: to_string: {e}"))
1323 })
1324 }
1325 other => Err(mlua_isle::IsleError::Lua(format!(
1326 "list_roots: handler must return table or nil, got: {:?}",
1327 other.type_name()
1328 ))),
1329 }
1330 })
1331 .await;
1332
1333 match result_val {
1334 Err(e) => {
1335 tracing::warn!(
1336 target: "mcp_client",
1337 server = %sn,
1338 error = %e,
1339 "list_roots: handler isle error"
1340 );
1341 Err(McpError::internal_error(
1342 format!("roots handler: {e}"),
1343 None,
1344 ))
1345 }
1346 Ok(json_str) if json_str.is_empty() => {
1347 Err(McpError::method_not_found::<
1349 rmcp::model::ListRootsRequestMethod,
1350 >())
1351 }
1352 Ok(json_str) => {
1353 let v: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
1355 McpError::internal_error(format!("roots handler result parse: {e}"), None)
1356 })?;
1357
1358 let entries = v.as_array().ok_or_else(|| {
1360 McpError::internal_error(
1361 "roots handler result parse: expected array".to_string(),
1362 None,
1363 )
1364 })?;
1365
1366 let mut roots = Vec::with_capacity(entries.len());
1367 for entry in entries {
1368 let uri = entry
1369 .get("uri")
1370 .and_then(|v| v.as_str())
1371 .unwrap_or("")
1372 .to_string();
1373 let name = entry
1374 .get("name")
1375 .and_then(|v| v.as_str())
1376 .map(ToString::to_string);
1377 let root = if let Some(n) = name {
1378 rmcp::model::Root::new(uri).with_name(n)
1379 } else {
1380 rmcp::model::Root::new(uri)
1381 };
1382 roots.push(root);
1383 }
1384 Ok(rmcp::model::ListRootsResult::new(roots))
1385 }
1386 }
1387 }
1388 }
1389
1390 fn create_elicitation(
1407 &self,
1408 request: CreateElicitationRequestParams,
1409 _context: RequestContext<RoleClient>,
1410 ) -> impl std::future::Future<Output = Result<CreateElicitationResult, McpError>> + Send + '_
1411 {
1412 let isle = self.handler_isle.clone();
1413 let registry = Arc::clone(&self.registry);
1414 let server_name = self.server_name.clone();
1415
1416 async move {
1417 let (message, requested_schema) = match request {
1419 CreateElicitationRequestParams::UrlElicitationParams { .. } => {
1420 return Ok(CreateElicitationResult {
1421 action: ElicitationAction::Decline,
1422 content: None,
1423 meta: None,
1424 });
1425 }
1426 CreateElicitationRequestParams::FormElicitationParams {
1427 message,
1428 requested_schema,
1429 ..
1430 } => (message, requested_schema),
1431 };
1432
1433 let sn = match server_name.as_deref() {
1435 Some(s) => s.to_string(),
1436 None => {
1437 return Err(McpError::method_not_found::<ElicitationCreateRequestMethod>());
1438 }
1439 };
1440
1441 let has_elicitation = {
1443 let guard = registry.lock().unwrap_or_else(|e| e.into_inner());
1444 guard.get(&sn).is_some_and(|r| r.elicitation)
1445 };
1446
1447 if !has_elicitation {
1448 return Ok(CreateElicitationResult {
1450 action: ElicitationAction::Decline,
1451 content: None,
1452 meta: None,
1453 });
1454 }
1455
1456 let isle = match isle {
1457 Some(i) => i,
1458 None => {
1459 return Err(McpError::method_not_found::<ElicitationCreateRequestMethod>());
1460 }
1461 };
1462
1463 let schema_json = serde_json::to_string(&requested_schema).map_err(|e| {
1465 McpError::internal_error(format!("create_elicitation: schema serialize: {e}"), None)
1466 })?;
1467
1468 let sn_task = sn.clone();
1470 let message_task = message.clone();
1471 let result_val = isle
1472 .exec(move |lua| {
1473 use mlua::prelude::*;
1474 let dispatch: LuaFunction =
1475 lua.globals().get(MCP_DISPATCH_ELICITATION).map_err(|e| {
1476 mlua_isle::IsleError::Lua(format!(
1477 "create_elicitation: get dispatcher: {e}"
1478 ))
1479 })?;
1480 let result: LuaValue = dispatch
1481 .call((
1482 sn_task.as_str(),
1483 message_task.as_str(),
1484 schema_json.as_str(),
1485 ))
1486 .map_err(|e| {
1487 mlua_isle::IsleError::Lua(format!("create_elicitation: dispatch: {e}"))
1488 })?;
1489
1490 match result {
1492 LuaValue::Nil => Ok(String::new()),
1493 LuaValue::Table(tbl) => {
1494 let json_val = crate::lua_json::lua_to_json(lua, LuaValue::Table(tbl))
1496 .map_err(|e| {
1497 mlua_isle::IsleError::Lua(format!(
1498 "create_elicitation: lua_to_json: {e}"
1499 ))
1500 })?;
1501 serde_json::to_string(&json_val).map_err(|e| {
1502 mlua_isle::IsleError::Lua(format!(
1503 "create_elicitation: to_string: {e}"
1504 ))
1505 })
1506 }
1507 other => Err(mlua_isle::IsleError::Lua(format!(
1508 "create_elicitation: handler must return table or nil, got: {:?}",
1509 other.type_name()
1510 ))),
1511 }
1512 })
1513 .await;
1514
1515 match result_val {
1516 Err(e) => {
1517 tracing::warn!(
1518 target: "mcp_client",
1519 server = %sn,
1520 error = %e,
1521 "create_elicitation: handler isle error"
1522 );
1523 Err(McpError::internal_error(
1524 format!("elicitation handler: {e}"),
1525 None,
1526 ))
1527 }
1528 Ok(json_str) if json_str.is_empty() => {
1529 Ok(CreateElicitationResult {
1531 action: ElicitationAction::Decline,
1532 content: None,
1533 meta: None,
1534 })
1535 }
1536 Ok(json_str) => {
1537 let v: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
1539 McpError::internal_error(
1540 format!("elicitation handler result parse: {e}"),
1541 None,
1542 )
1543 })?;
1544
1545 let action_str = v
1546 .get("action")
1547 .and_then(serde_json::Value::as_str)
1548 .ok_or_else(|| {
1549 McpError::internal_error(
1550 "elicitation handler result: missing or non-string 'action' field"
1551 .to_string(),
1552 None,
1553 )
1554 })?;
1555
1556 let content = v.get("content").cloned();
1557
1558 match action_str {
1559 "accept" => {
1560 if content.is_none() {
1561 tracing::warn!(
1562 target: "mcp_client",
1563 server = %sn,
1564 "create_elicitation: action=accept but content is nil"
1565 );
1566 return Err(McpError::internal_error(
1567 "elicitation handler: action=accept but content is nil"
1568 .to_string(),
1569 None,
1570 ));
1571 }
1572 Ok(CreateElicitationResult {
1573 action: ElicitationAction::Accept,
1574 content,
1575 meta: None,
1576 })
1577 }
1578 "decline" => {
1579 if content.is_some() {
1580 tracing::warn!(
1581 target: "mcp_client",
1582 server = %sn,
1583 "create_elicitation: action=decline but content is non-nil"
1584 );
1585 return Err(McpError::internal_error(
1586 "elicitation handler: action=decline but content is non-nil"
1587 .to_string(),
1588 None,
1589 ));
1590 }
1591 Ok(CreateElicitationResult {
1592 action: ElicitationAction::Decline,
1593 content: None,
1594 meta: None,
1595 })
1596 }
1597 "cancel" => {
1598 if content.is_some() {
1599 tracing::warn!(
1600 target: "mcp_client",
1601 server = %sn,
1602 "create_elicitation: action=cancel but content is non-nil"
1603 );
1604 return Err(McpError::internal_error(
1605 "elicitation handler: action=cancel but content is non-nil"
1606 .to_string(),
1607 None,
1608 ));
1609 }
1610 Ok(CreateElicitationResult {
1611 action: ElicitationAction::Cancel,
1612 content: None,
1613 meta: None,
1614 })
1615 }
1616 other => {
1617 tracing::warn!(
1618 target: "mcp_client",
1619 server = %sn,
1620 action = %other,
1621 "create_elicitation: unknown action"
1622 );
1623 Err(McpError::internal_error(
1624 format!("elicitation handler: unknown action: {other}"),
1625 None,
1626 ))
1627 }
1628 }
1629 }
1630 }
1631 }
1632 }
1633}
1634
1635#[cfg(test)]
1636mod tests {
1637 use super::*;
1638
1639 #[test]
1640 fn new_handler_has_empty_registry() {
1641 let handler = AgentBlockClientHandler::new();
1642 let guard = handler.registry.lock().unwrap();
1643 assert!(guard.is_empty());
1644 }
1645
1646 #[test]
1647 fn new_handler_has_no_server_name() {
1648 let handler = AgentBlockClientHandler::new();
1649 assert!(handler.server_name.is_none());
1650 }
1651
1652 #[test]
1653 fn server_name_is_preserved_through_clone() {
1654 let mut handler = AgentBlockClientHandler::new();
1655 handler.server_name = Some("srv-a".to_string());
1656 let cloned = handler.clone();
1657 assert_eq!(cloned.server_name.as_deref(), Some("srv-a"));
1658 }
1659
1660 #[test]
1661 fn ensure_server_creates_entry() {
1662 let handler = AgentBlockClientHandler::new();
1663 handler.ensure_server("my-server");
1664 let guard = handler.registry.lock().unwrap();
1665 assert!(guard.contains_key("my-server"));
1666 }
1667
1668 #[test]
1669 fn ensure_server_idempotent() {
1670 let handler = AgentBlockClientHandler::new();
1671 handler.ensure_server("srv");
1672 handler.ensure_server("srv");
1673 let guard = handler.registry.lock().unwrap();
1674 assert_eq!(guard.len(), 1);
1675 }
1676
1677 #[test]
1678 fn clone_shares_registry() {
1679 let h1 = AgentBlockClientHandler::new();
1680 let h2 = h1.clone();
1681 h1.ensure_server("alpha");
1682 let guard = h2.registry.lock().unwrap();
1683 assert!(guard.contains_key("alpha"), "clone must share registry Arc");
1684 }
1685
1686 #[test]
1687 fn mark_on_progress_sets_flag() {
1688 let h = AgentBlockClientHandler::new();
1689 h.ensure_server("srv");
1690 h.mark_on_progress("srv");
1691 let guard = h.registry.lock().unwrap();
1692 assert!(guard.get("srv").unwrap().on_progress);
1693 }
1694
1695 #[test]
1696 fn mark_on_log_sets_flag() {
1697 let h = AgentBlockClientHandler::new();
1698 h.ensure_server("srv");
1699 h.mark_on_log("srv");
1700 let guard = h.registry.lock().unwrap();
1701 assert!(guard.get("srv").unwrap().on_log);
1702 }
1703
1704 #[test]
1705 fn mark_sampling_sets_flag() {
1706 let h = AgentBlockClientHandler::new();
1707 h.ensure_server("srv");
1708 h.mark_sampling("srv");
1709 let guard = h.registry.lock().unwrap();
1710 assert!(guard.get("srv").unwrap().sampling);
1711 }
1712
1713 #[test]
1714 fn mark_on_resource_updated_sets_flag() {
1715 let h = AgentBlockClientHandler::new();
1716 h.ensure_server("srv");
1717 h.mark_on_resource_updated("srv");
1718 let guard = h.registry.lock().unwrap();
1719 assert!(guard.get("srv").unwrap().on_resource_updated);
1720 }
1721
1722 #[test]
1723 fn mark_on_resource_list_changed_sets_flag() {
1724 let h = AgentBlockClientHandler::new();
1725 h.ensure_server("srv");
1726 h.mark_on_resource_list_changed("srv");
1727 let guard = h.registry.lock().unwrap();
1728 assert!(guard.get("srv").unwrap().on_resource_list_changed);
1729 }
1730
1731 #[test]
1732 fn mark_on_tool_list_changed_sets_flag() {
1733 let h = AgentBlockClientHandler::new();
1734 h.ensure_server("srv");
1735 h.mark_on_tool_list_changed("srv");
1736 let guard = h.registry.lock().unwrap();
1737 assert!(guard.get("srv").unwrap().on_tool_list_changed);
1738 }
1739
1740 #[test]
1741 fn mark_on_prompt_list_changed_sets_flag() {
1742 let h = AgentBlockClientHandler::new();
1743 h.ensure_server("srv");
1744 h.mark_on_prompt_list_changed("srv");
1745 let guard = h.registry.lock().unwrap();
1746 assert!(guard.get("srv").unwrap().on_prompt_list_changed);
1747 }
1748
1749 #[test]
1753 fn install_dispatcher_creates_sampling_globals() {
1754 let lua = mlua::Lua::new();
1755 install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
1756
1757 let _: mlua::Table = lua.globals().get(MCP_SAMPLING_HANDLERS).unwrap();
1758 let _: mlua::Function = lua.globals().get(MCP_DISPATCH_SAMPLING).unwrap();
1759
1760 let progress_handlers: mlua::Value = lua.globals().get("__mcp_progress_handlers").unwrap();
1764 assert!(
1765 matches!(progress_handlers, mlua::Value::Nil),
1766 "__mcp_progress_handlers must not be installed on handler Isle"
1767 );
1768 let log_handlers: mlua::Value = lua.globals().get("__mcp_log_handlers").unwrap();
1769 assert!(
1770 matches!(log_handlers, mlua::Value::Nil),
1771 "__mcp_log_handlers must not be installed on handler Isle"
1772 );
1773 }
1774
1775 #[test]
1778 fn handler_isle_has_no_user_callback_tables() {
1779 let lua = mlua::Lua::new();
1780 install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
1781
1782 let progress_cbs: mlua::Value = lua.globals().get(MCP_USER_PROGRESS_CBS).unwrap();
1783 assert!(
1784 matches!(progress_cbs, mlua::Value::Nil),
1785 "__mcp_user_progress_cbs must not be on handler Isle"
1786 );
1787 let log_cbs: mlua::Value = lua.globals().get(MCP_USER_LOG_CBS).unwrap();
1788 assert!(
1789 matches!(log_cbs, mlua::Value::Nil),
1790 "__mcp_user_log_cbs must not be on handler Isle"
1791 );
1792 }
1793
1794 #[tokio::test]
1797 async fn main_isle_progress_cb_preserves_upvalue() {
1798 use mlua_isle::AsyncIsle;
1799
1800 let (isle, driver) = AsyncIsle::spawn(|_lua: &mlua::Lua| Ok(()))
1801 .await
1802 .expect("AsyncIsle::spawn should succeed");
1803
1804 isle.exec(|lua| {
1807 lua.load(
1808 r#"
1809 __mcp_user_progress_cbs = {}
1810 local hits = 0
1811 __mcp_user_progress_cbs["test-srv"] = function(ev)
1812 hits = hits + 1
1813 end
1814 _G.get_hits = function() return hits end
1815 "#,
1816 )
1817 .exec()
1818 .map_err(|e| mlua_isle::IsleError::Lua(format!("setup: {e}")))?;
1819 Ok(String::new())
1820 })
1821 .await
1822 .expect("setup exec");
1823
1824 for _ in 0..3 {
1826 isle.exec(|lua| {
1827 use mlua::prelude::*;
1828 let cbs: LuaTable = lua
1829 .globals()
1830 .get(MCP_USER_PROGRESS_CBS)
1831 .map_err(|e| mlua_isle::IsleError::Lua(format!("get cbs: {e}")))?;
1832 let cb: LuaFunction = cbs
1833 .get("test-srv")
1834 .map_err(|e| mlua_isle::IsleError::Lua(format!("get cb: {e}")))?;
1835 let ev = lua
1836 .create_table()
1837 .map_err(|e| mlua_isle::IsleError::Lua(format!("create ev: {e}")))?;
1838 let _ = cb.call::<()>(ev);
1839 Ok(String::new())
1840 })
1841 .await
1842 .expect("dispatch exec");
1843 }
1844
1845 let hits_str = isle
1847 .exec(|lua| {
1848 use mlua::prelude::*;
1849 let get_hits: LuaFunction = lua
1850 .globals()
1851 .get("get_hits")
1852 .map_err(|e| mlua_isle::IsleError::Lua(format!("get_hits: {e}")))?;
1853 let n: i64 = get_hits
1854 .call(())
1855 .map_err(|e| mlua_isle::IsleError::Lua(format!("call get_hits: {e}")))?;
1856 Ok(n.to_string())
1857 })
1858 .await
1859 .expect("read hits exec");
1860 let hits: i64 = hits_str.parse().expect("hits must be integer");
1861 assert_eq!(hits, 3, "upvalue counter must reach 3");
1862
1863 driver.shutdown().await.expect("shutdown");
1864 }
1865
1866 #[test]
1867 fn sampling_dispatcher_returns_nil_when_no_handler() {
1868 let lua = mlua::Lua::new();
1869 install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
1870 let dispatch: mlua::Function = lua.globals().get(MCP_DISPATCH_SAMPLING).unwrap();
1871 let result: mlua::Value = dispatch.call(("no-srv", "{}")).unwrap();
1872 assert!(
1873 matches!(result, mlua::Value::Nil),
1874 "expected nil when no handler"
1875 );
1876 }
1877
1878 #[test]
1879 fn sampling_dispatcher_calls_registered_handler() {
1880 let lua = mlua::Lua::new();
1881 install_mcp_dispatcher_on_handler_isle(&lua).unwrap();
1882
1883 lua.load(
1884 r#"
1885 __mcp_sampling_handlers["srv"] = function(sn, params_json)
1886 return { model = "test-model", stop_reason = "endTurn",
1887 role = "assistant", content = "hello" }
1888 end
1889 local result = __mcp_dispatch_sampling("srv", "{}")
1890 assert(type(result) == "table")
1891 assert(result.model == "test-model")
1892 assert(result.content == "hello")
1893 "#,
1894 )
1895 .exec()
1896 .unwrap();
1897 }
1898}