1use std::collections::{HashMap, HashSet};
13use std::future::Future;
14use std::sync::{Arc, RwLock};
15use std::time::Duration;
16
17use async_trait::async_trait;
18use bamboo_agent_core::tools::{
19 FunctionCall, ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
20};
21use bamboo_subagent::{AgentRef, InboxKind, InboxMessage, MsgId};
22use chrono::Utc;
23use serde::{Deserialize, Serialize};
24use tokio::sync::Mutex;
25use tokio_util::sync::CancellationToken;
26
27use crate::client::BrokerClient;
28use crate::error::{BrokerError, BrokerResult};
29use crate::mux::MultiplexedClient;
30
31const PROXY_RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_millis(500);
36const PROXY_RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(30);
38const PROXY_BACKOFF_RESET_AFTER: Duration = Duration::from_secs(10);
42
43const WORKER_RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
45const WORKER_RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(2);
47const WORKER_RECONNECT_MAX_ATTEMPTS: u32 = 5;
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "op", rename_all = "snake_case")]
54pub enum McpRequest {
55 Manifest,
57 Call { tool: String, arguments: String },
59}
60
61#[derive(Debug, Clone, Default, Serialize, Deserialize)]
63pub struct McpReply {
64 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub manifest: Option<Vec<ToolSchema>>,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
69 pub result: Option<ProxiedResult>,
70 #[serde(default, skip_serializing_if = "Option::is_none")]
72 pub error: Option<String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ProxiedResult {
78 pub success: bool,
79 pub result: String,
80}
81
82#[derive(Debug, Clone, Default)]
109pub struct RoleToolAllowlist {
110 by_role: HashMap<String, HashSet<String>>,
113}
114
115impl RoleToolAllowlist {
116 pub fn unrestricted() -> Self {
118 Self::default()
119 }
120
121 pub fn from_entries<R, T, I>(entries: I) -> Self
125 where
126 R: Into<String>,
127 T: Into<String>,
128 I: IntoIterator<Item = (R, Vec<T>)>,
129 {
130 let by_role = entries
131 .into_iter()
132 .map(|(role, tools)| (role.into(), tools.into_iter().map(Into::into).collect()))
133 .collect();
134 Self { by_role }
135 }
136
137 pub fn with_role(
139 mut self,
140 role: impl Into<String>,
141 tools: impl IntoIterator<Item = impl Into<String>>,
142 ) -> Self {
143 self.by_role
144 .insert(role.into(), tools.into_iter().map(Into::into).collect());
145 self
146 }
147
148 fn is_restricted(&self, role: Option<&str>) -> bool {
151 role.is_some_and(|r| self.by_role.contains_key(r))
152 }
153
154 fn allows(&self, role: Option<&str>, tool: &str) -> bool {
157 match role.and_then(|r| self.by_role.get(r)) {
158 Some(allowed) => allowed.contains(tool),
159 None => true, }
161 }
162
163 fn filter_manifest(&self, role: Option<&str>, mut tools: Vec<ToolSchema>) -> Vec<ToolSchema> {
166 if let Some(allowed) = role.and_then(|r| self.by_role.get(r)) {
167 tools.retain(|t| allowed.contains(&t.function.name));
168 }
169 tools
170 }
171}
172
173pub async fn serve_mcp_proxy(
179 endpoint: &str,
180 me: AgentRef,
181 token: &str,
182 backend: Arc<dyn ToolExecutor>,
183 allowlist: Arc<RoleToolAllowlist>,
184) -> BrokerResult<()> {
185 let mut client = BrokerClient::connect(endpoint, me.clone(), token).await?;
186 client.subscribe().await?;
187
188 let (reply_tx, mut reply_rx) =
203 tokio::sync::mpsc::unbounded_channel::<(MsgId, String, McpReply)>();
204 loop {
205 tokio::select! {
206 Some((corr, reply_to, reply_body)) = reply_rx.recv() => {
209 let reply = InboxMessage {
210 id: MsgId::new(),
211 from: me.clone(),
212 kind: InboxKind::McpReply,
213 body: serde_json::to_value(reply_body).unwrap_or_default(),
214 created_at: Utc::now(),
215 correlation_id: Some(corr.clone()),
216 };
217 client.deliver(&reply_to, reply).await?;
218 client.ack(corr).await?;
219 }
220 msg = client.next_message() => {
221 let Some(msg) = msg else { break }; if msg.kind != InboxKind::McpRequest {
223 let _ = client.ack(msg.id).await;
224 continue;
225 }
226 let backend = Arc::clone(&backend);
227 let allowlist = Arc::clone(&allowlist);
228 let reply_tx = reply_tx.clone();
229 let corr = msg.id.clone();
230 let reply_to = msg.from.session_id.clone();
231 tokio::spawn(async move {
232 let reply_body = handle_mcp_request(backend.as_ref(), &allowlist, msg).await;
233 let _ = reply_tx.send((corr, reply_to, reply_body));
235 });
236 }
237 }
238 }
239 Ok(())
240}
241
242pub async fn serve_mcp_proxy_supervised(
258 endpoint: &str,
259 me: AgentRef,
260 token: &str,
261 backend: Arc<dyn ToolExecutor>,
262 allowlist: Arc<RoleToolAllowlist>,
263 shutdown: CancellationToken,
264) {
265 supervise_reconnect(
266 || {
267 serve_mcp_proxy(
268 endpoint,
269 me.clone(),
270 token,
271 backend.clone(),
272 allowlist.clone(),
273 )
274 },
275 shutdown,
276 PROXY_RECONNECT_INITIAL_BACKOFF,
277 PROXY_RECONNECT_MAX_BACKOFF,
278 PROXY_BACKOFF_RESET_AFTER,
279 )
280 .await
281}
282
283async fn supervise_reconnect<F, Fut>(
288 mut serve_once: F,
289 shutdown: CancellationToken,
290 initial_backoff: Duration,
291 max_backoff: Duration,
292 reset_after: Duration,
293) where
294 F: FnMut() -> Fut + Send,
295 Fut: Future<Output = BrokerResult<()>> + Send,
296{
297 let mut backoff = initial_backoff;
298 loop {
299 let started = std::time::Instant::now();
302 let outcome = tokio::select! {
303 biased;
304 _ = shutdown.cancelled() => {
305 tracing::info!("MCP proxy supervisor: shutdown requested, stopping");
306 return;
307 }
308 r = serve_once() => r,
309 };
310
311 if started.elapsed() >= reset_after {
313 backoff = initial_backoff;
314 }
315
316 match outcome {
317 Ok(()) => tracing::warn!(
318 "MCP proxy connection ended; restarting (backoff {:?})",
319 backoff
320 ),
321 Err(e) => tracing::warn!(
322 "MCP proxy service errored: {e}; restarting (backoff {:?})",
323 backoff
324 ),
325 }
326
327 let slept = tokio::select! {
329 biased;
330 _ = shutdown.cancelled() => false,
331 _ = tokio::time::sleep(backoff) => true,
332 };
333 if !slept {
334 tracing::info!("MCP proxy supervisor: shutdown during backoff, stopping");
335 return;
336 }
337 backoff = std::cmp::min(backoff * 2, max_backoff);
338 }
339}
340
341async fn handle_mcp_request(
342 backend: &dyn ToolExecutor,
343 allowlist: &RoleToolAllowlist,
344 msg: InboxMessage,
345) -> McpReply {
346 let role = msg.from.role.as_deref();
349 match serde_json::from_value::<McpRequest>(msg.body) {
350 Ok(McpRequest::Manifest) => {
351 let tools = allowlist.filter_manifest(role, backend.list_tools());
352 if allowlist.is_restricted(role) {
353 tracing::debug!(
354 role = role.unwrap_or("<none>"),
355 tools = tools.len(),
356 "mcp proxy: serving role-scoped manifest"
357 );
358 }
359 McpReply {
360 manifest: Some(tools),
361 ..Default::default()
362 }
363 }
364 Ok(McpRequest::Call { tool, arguments }) => {
365 if !allowlist.allows(role, &tool) {
368 tracing::warn!(
369 role = role.unwrap_or("<none>"),
370 tool = %tool,
371 "mcp proxy: rejecting tool call not on role allowlist"
372 );
373 return McpReply {
374 error: Some(format!(
375 "tool '{tool}' is not allowed for role '{}'",
376 role.unwrap_or("<none>")
377 )),
378 ..Default::default()
379 };
380 }
381 let call = ToolCall {
382 id: format!("mcp-{}", MsgId::new().as_str()),
383 tool_type: "function".to_string(),
384 function: FunctionCall {
385 name: tool,
386 arguments,
387 },
388 };
389 match backend.execute(&call).await {
390 Ok(r) => McpReply {
391 result: Some(ProxiedResult {
392 success: r.success,
393 result: r.result,
394 }),
395 ..Default::default()
396 },
397 Err(e) => McpReply {
398 error: Some(e.to_string()),
399 ..Default::default()
400 },
401 }
402 }
403 Err(e) => McpReply {
404 error: Some(format!("bad mcp request: {e}")),
405 ..Default::default()
406 },
407 }
408}
409
410pub struct McpProxyExecutor {
415 client: tokio::sync::RwLock<Arc<MultiplexedClient>>,
421 reconnect_lock: Mutex<()>,
426 me: AgentRef,
427 endpoint: String,
428 token: String,
429 orchestrator: String,
430 manifest: RwLock<Vec<ToolSchema>>,
434 timeout: Duration,
435}
436
437impl McpProxyExecutor {
438 pub async fn connect(
442 endpoint: &str,
443 proxy_id: impl Into<String>,
444 token: &str,
445 orchestrator: impl Into<String>,
446 timeout: Duration,
447 ) -> BrokerResult<Self> {
448 let me = AgentRef {
449 session_id: proxy_id.into(),
450 role: None,
451 };
452 let orchestrator = orchestrator.into();
453 let mut client = BrokerClient::connect(endpoint, me.clone(), token).await?;
454 client.subscribe().await?;
455 let mux = client.into_multiplexed(me.clone());
456
457 let reply = mux
458 .request(
459 &orchestrator,
460 InboxKind::McpRequest,
461 serde_json::to_value(McpRequest::Manifest).expect("McpRequest serializes"),
462 timeout,
463 )
464 .await?;
465 let reply: McpReply = serde_json::from_value(reply)
466 .map_err(|e| BrokerError::Protocol(format!("bad manifest reply: {e}")))?;
467 let manifest = reply.manifest.unwrap_or_default();
468
469 Ok(Self {
470 client: tokio::sync::RwLock::new(Arc::new(mux)),
471 reconnect_lock: Mutex::new(()),
472 me,
473 endpoint: endpoint.to_string(),
474 token: token.to_string(),
475 orchestrator,
476 manifest: RwLock::new(manifest),
477 timeout,
478 })
479 }
480
481 pub fn tool_count(&self) -> usize {
483 self.manifest.read().map(|m| m.len()).unwrap_or(0)
484 }
485
486 async fn request_once(&self, body: serde_json::Value) -> BrokerResult<serde_json::Value> {
489 let mux = self.client.read().await.clone();
492 mux.request(
493 &self.orchestrator,
494 InboxKind::McpRequest,
495 body,
496 self.timeout,
497 )
498 .await
499 }
500
501 async fn request_with_reconnect(
507 &self,
508 body: serde_json::Value,
509 ) -> Result<serde_json::Value, ToolError> {
510 match self.request_once(body.clone()).await {
511 Ok(v) => Ok(v),
512 Err(first) => {
513 if !self.connection_broken().await {
516 return Err(ToolError::Execution(format!("mcp proxy: {first}")));
517 }
518 tracing::warn!("mcp proxy connection dropped; reconnecting: {first}");
519 self.reconnect_if_needed()
520 .await
521 .map_err(|re| ToolError::Execution(format!("mcp proxy (reconnect): {re}")))?;
522 self.request_once(body)
524 .await
525 .map_err(|re| ToolError::Execution(format!("mcp proxy: {re}")))
526 }
527 }
528 }
529
530 async fn connection_broken(&self) -> bool {
534 !self.client.read().await.reader_alive()
535 }
536
537 async fn reconnect_if_needed(&self) -> BrokerResult<()> {
545 let _guard = self.reconnect_lock.lock().await;
546 if !self.connection_broken().await {
549 return Ok(());
550 }
551 let mut backoff = WORKER_RECONNECT_INITIAL_BACKOFF;
552 for _ in 0..WORKER_RECONNECT_MAX_ATTEMPTS {
553 match self.reconnect_once().await {
554 Ok(()) => return Ok(()),
555 Err(e) => {
556 tracing::warn!("mcp proxy reconnect failed (backoff {:?}): {e}", backoff);
557 }
558 }
559 tokio::time::sleep(backoff).await;
563 backoff = std::cmp::min(backoff * 2, WORKER_RECONNECT_MAX_BACKOFF);
564 }
565 Err(BrokerError::Transport(
566 "mcp proxy reconnect attempts exhausted".into(),
567 ))
568 }
569
570 async fn reconnect_once(&self) -> BrokerResult<()> {
574 let mut client =
575 BrokerClient::connect(&self.endpoint, self.me.clone(), &self.token).await?;
576 client.subscribe().await?;
577 let mux = client.into_multiplexed(self.me.clone());
578 let reply = mux
581 .request(
582 &self.orchestrator,
583 InboxKind::McpRequest,
584 serde_json::to_value(McpRequest::Manifest).expect("McpRequest serializes"),
585 self.timeout,
586 )
587 .await?;
588 let reply: McpReply = serde_json::from_value(reply)
589 .map_err(|e| BrokerError::Protocol(format!("bad manifest reply: {e}")))?;
590 let manifest = reply.manifest.unwrap_or_default();
591 {
592 let mut slot = self.client.write().await;
596 *slot = Arc::new(mux);
597 }
598 if let Ok(mut m) = self.manifest.write() {
599 *m = manifest;
600 }
601 Ok(())
602 }
603}
604
605#[async_trait]
606impl ToolExecutor for McpProxyExecutor {
607 async fn execute(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
608 {
609 let manifest = self
610 .manifest
611 .read()
612 .map_err(|_| ToolError::Execution("mcp proxy manifest lock poisoned".into()))?;
613 if !manifest
614 .iter()
615 .any(|s| s.function.name == call.function.name)
616 {
617 return Err(ToolError::NotFound(call.function.name.clone()));
618 }
619 }
620 let body = serde_json::to_value(McpRequest::Call {
621 tool: call.function.name.clone(),
622 arguments: call.function.arguments.clone(),
623 })
624 .expect("McpRequest serializes");
625
626 let reply = self.request_with_reconnect(body).await?;
630
631 let reply: McpReply = serde_json::from_value(reply)
632 .map_err(|e| ToolError::Execution(format!("bad mcp reply: {e}")))?;
633 if let Some(err) = reply.error {
634 return Err(ToolError::Execution(err));
635 }
636 let r = reply
637 .result
638 .ok_or_else(|| ToolError::Execution("mcp reply missing result".to_string()))?;
639 Ok(ToolResult {
640 success: r.success,
641 result: r.result,
642 display_preference: None,
643 images: Vec::new(),
644 })
645 }
646
647 async fn execute_with_context(
648 &self,
649 call: &ToolCall,
650 _ctx: ToolExecutionContext<'_>,
651 ) -> Result<ToolResult, ToolError> {
652 self.execute(call).await
653 }
654
655 fn list_tools(&self) -> Vec<ToolSchema> {
656 self.manifest.read().map(|m| m.clone()).unwrap_or_default()
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663 use crate::core::BrokerCore;
664 use crate::server::BrokerServer;
665 use bamboo_agent_core::tools::FunctionSchema;
666 use serde_json::json;
667 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
668 use tokio::net::TcpListener;
669
670 const TOKEN: &str = "t";
671
672 struct StubMcp;
674
675 #[async_trait]
676 impl ToolExecutor for StubMcp {
677 async fn execute(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
678 Ok(ToolResult {
679 success: true,
680 result: format!(
681 "ran {} args={}",
682 call.function.name, call.function.arguments
683 ),
684 display_preference: None,
685 images: Vec::new(),
686 })
687 }
688 async fn execute_with_context(
689 &self,
690 call: &ToolCall,
691 _ctx: ToolExecutionContext<'_>,
692 ) -> Result<ToolResult, ToolError> {
693 self.execute(call).await
694 }
695 fn list_tools(&self) -> Vec<ToolSchema> {
696 vec![ToolSchema {
697 schema_type: "function".into(),
698 function: FunctionSchema {
699 name: "nova_click".into(),
700 description: "click a mark".into(),
701 parameters: json!({ "type": "object" }),
702 },
703 }]
704 }
705 }
706
707 struct MultiToolMcp;
709
710 #[async_trait]
711 impl ToolExecutor for MultiToolMcp {
712 async fn execute(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
713 Ok(ToolResult {
714 success: true,
715 result: format!("ran {}", call.function.name),
716 display_preference: None,
717 images: Vec::new(),
718 })
719 }
720 async fn execute_with_context(
721 &self,
722 call: &ToolCall,
723 _ctx: ToolExecutionContext<'_>,
724 ) -> Result<ToolResult, ToolError> {
725 self.execute(call).await
726 }
727 fn list_tools(&self) -> Vec<ToolSchema> {
728 ["nova_screenshot", "nova_click", "fetch_url"]
729 .into_iter()
730 .map(|name| ToolSchema {
731 schema_type: "function".into(),
732 function: FunctionSchema {
733 name: name.into(),
734 description: "t".into(),
735 parameters: json!({ "type": "object" }),
736 },
737 })
738 .collect()
739 }
740 }
741
742 fn worker_request(role: Option<&str>, req: McpRequest) -> InboxMessage {
745 InboxMessage {
746 id: MsgId::new(),
747 from: AgentRef {
748 session_id: "worker#mcp".into(),
749 role: role.map(Into::into),
750 },
751 kind: InboxKind::McpRequest,
752 body: serde_json::to_value(req).unwrap(),
753 created_at: Utc::now(),
754 correlation_id: None,
755 }
756 }
757
758 fn manifest_names(reply: &McpReply) -> Vec<String> {
759 reply
760 .manifest
761 .as_ref()
762 .expect("manifest reply")
763 .iter()
764 .map(|t| t.function.name.clone())
765 .collect()
766 }
767
768 #[tokio::test]
771 async fn manifest_is_filtered_by_role_allowlist() {
772 let backend = MultiToolMcp;
773 let allowlist = RoleToolAllowlist::unrestricted().with_role("researcher", ["fetch_url"]);
775
776 let reply = handle_mcp_request(
778 &backend,
779 &allowlist,
780 worker_request(Some("researcher"), McpRequest::Manifest),
781 )
782 .await;
783 assert_eq!(manifest_names(&reply), vec!["fetch_url".to_string()]);
784
785 let reply = handle_mcp_request(
787 &backend,
788 &allowlist,
789 worker_request(Some("operator"), McpRequest::Manifest),
790 )
791 .await;
792 assert_eq!(manifest_names(&reply).len(), 3);
793
794 let reply = handle_mcp_request(
796 &backend,
797 &allowlist,
798 worker_request(None, McpRequest::Manifest),
799 )
800 .await;
801 assert_eq!(manifest_names(&reply).len(), 3);
802 }
803
804 #[tokio::test]
808 async fn call_is_rejected_when_tool_not_on_role_allowlist() {
809 let backend = MultiToolMcp;
810 let allowlist = RoleToolAllowlist::unrestricted().with_role("researcher", ["fetch_url"]);
811
812 let reply = handle_mcp_request(
814 &backend,
815 &allowlist,
816 worker_request(
817 Some("researcher"),
818 McpRequest::Call {
819 tool: "nova_screenshot".into(),
820 arguments: "{}".into(),
821 },
822 ),
823 )
824 .await;
825 assert!(reply.result.is_none());
826 let err = reply.error.expect("a rejection error");
827 assert!(
828 err.contains("nova_screenshot") && err.contains("not allowed"),
829 "{err}"
830 );
831
832 let reply = handle_mcp_request(
834 &backend,
835 &allowlist,
836 worker_request(
837 Some("researcher"),
838 McpRequest::Call {
839 tool: "fetch_url".into(),
840 arguments: "{}".into(),
841 },
842 ),
843 )
844 .await;
845 assert!(reply.error.is_none());
846 assert_eq!(reply.result.expect("result").result, "ran fetch_url");
847
848 let reply = handle_mcp_request(
850 &backend,
851 &allowlist,
852 worker_request(
853 None,
854 McpRequest::Call {
855 tool: "nova_screenshot".into(),
856 arguments: "{}".into(),
857 },
858 ),
859 )
860 .await;
861 assert!(reply.error.is_none());
862 assert_eq!(reply.result.expect("result").result, "ran nova_screenshot");
863 }
864
865 #[tokio::test]
868 async fn empty_allowlist_entry_is_explicit_lockout() {
869 let backend = MultiToolMcp;
870 let allowlist = RoleToolAllowlist::from_entries(vec![("sandbox", Vec::<String>::new())]);
871 let reply = handle_mcp_request(
872 &backend,
873 &allowlist,
874 worker_request(Some("sandbox"), McpRequest::Manifest),
875 )
876 .await;
877 assert!(manifest_names(&reply).is_empty());
878 }
879
880 #[tokio::test]
881 async fn proxy_lists_and_forwards_calls_over_the_broker() {
882 let dir = tempfile::tempdir().unwrap();
883 let core = Arc::new(BrokerCore::new(dir.path()));
884 let server = Arc::new(BrokerServer::new(core, TOKEN));
885 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
886 let addr = listener.local_addr().unwrap();
887 tokio::spawn(async move {
888 let _ = server.serve(listener).await;
889 });
890 let endpoint = format!("ws://{addr}");
891
892 let ep = endpoint.clone();
894 tokio::spawn(async move {
895 let _ = serve_mcp_proxy(
896 &ep,
897 AgentRef {
898 session_id: "orchestrator".into(),
899 role: None,
900 },
901 TOKEN,
902 Arc::new(StubMcp),
903 Arc::new(RoleToolAllowlist::unrestricted()),
904 )
905 .await;
906 });
907
908 let proxy = McpProxyExecutor::connect(
910 &endpoint,
911 "worker#mcp",
912 TOKEN,
913 "orchestrator",
914 Duration::from_secs(5),
915 )
916 .await
917 .expect("proxy connects + fetches manifest");
918 let tools = proxy.list_tools();
919 assert_eq!(tools.len(), 1);
920 assert_eq!(tools[0].function.name, "nova_click");
921
922 let call = ToolCall {
924 id: "c1".into(),
925 tool_type: "function".into(),
926 function: FunctionCall {
927 name: "nova_click".into(),
928 arguments: "{\"mark\":7}".into(),
929 },
930 };
931 let result = proxy.execute(&call).await.expect("proxied call returns");
932 assert!(result.success);
933 assert_eq!(result.result, "ran nova_click args={\"mark\":7}");
934
935 let miss = ToolCall {
937 id: "c2".into(),
938 tool_type: "function".into(),
939 function: FunctionCall {
940 name: "not_proxied".into(),
941 arguments: "{}".into(),
942 },
943 };
944 assert!(matches!(
945 proxy.execute(&miss).await,
946 Err(ToolError::NotFound(_))
947 ));
948 }
949
950 #[tokio::test]
951 async fn proxy_handles_concurrent_calls_correctly() {
952 let dir = tempfile::tempdir().unwrap();
953 let core = Arc::new(BrokerCore::new(dir.path()));
954 let server = Arc::new(BrokerServer::new(core, TOKEN));
955 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
956 let addr = listener.local_addr().unwrap();
957 tokio::spawn(async move {
958 let _ = server.serve(listener).await;
959 });
960 let endpoint = format!("ws://{addr}");
961
962 let ep = endpoint.clone();
963 tokio::spawn(async move {
964 let _ = serve_mcp_proxy(
965 &ep,
966 AgentRef {
967 session_id: "orchestrator".into(),
968 role: None,
969 },
970 TOKEN,
971 Arc::new(StubMcp),
972 Arc::new(RoleToolAllowlist::unrestricted()),
973 )
974 .await;
975 });
976
977 let proxy = Arc::new(
978 McpProxyExecutor::connect(
979 &endpoint,
980 "worker#mcp",
981 TOKEN,
982 "orchestrator",
983 Duration::from_secs(5),
984 )
985 .await
986 .expect("proxy connects"),
987 );
988
989 let mut handles = Vec::new();
995 for i in 0..8u32 {
996 let p = proxy.clone();
997 handles.push(tokio::spawn(async move {
998 let call = ToolCall {
999 id: format!("c{i}"),
1000 tool_type: "function".into(),
1001 function: FunctionCall {
1002 name: "nova_click".into(),
1003 arguments: format!("{{\"mark\":{i}}}"),
1004 },
1005 };
1006 let r = p.execute(&call).await.expect("proxied call returns");
1007 (i, r.result)
1008 }));
1009 }
1010 for h in handles {
1011 let (i, result) = h.await.unwrap();
1012 assert_eq!(result, format!("ran nova_click args={{\"mark\":{i}}}"));
1013 }
1014 }
1015
1016 #[tokio::test]
1017 async fn concurrent_proxy_calls_overlap_at_the_orchestrator() {
1018 use std::time::Instant;
1019
1020 struct SlowMcp;
1023 #[async_trait]
1024 impl ToolExecutor for SlowMcp {
1025 async fn execute(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
1026 tokio::time::sleep(Duration::from_millis(200)).await;
1027 Ok(ToolResult {
1028 success: true,
1029 result: format!("done {}", call.function.arguments),
1030 display_preference: None,
1031 images: Vec::new(),
1032 })
1033 }
1034 async fn execute_with_context(
1035 &self,
1036 call: &ToolCall,
1037 _ctx: ToolExecutionContext<'_>,
1038 ) -> Result<ToolResult, ToolError> {
1039 self.execute(call).await
1040 }
1041 fn list_tools(&self) -> Vec<ToolSchema> {
1042 vec![ToolSchema {
1043 schema_type: "function".into(),
1044 function: FunctionSchema {
1045 name: "nova_click".into(),
1046 description: "click a mark".into(),
1047 parameters: json!({ "type": "object" }),
1048 },
1049 }]
1050 }
1051 }
1052
1053 let dir = tempfile::tempdir().unwrap();
1054 let core = Arc::new(BrokerCore::new(dir.path()));
1055 let server = Arc::new(BrokerServer::new(core, TOKEN));
1056 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1057 let addr = listener.local_addr().unwrap();
1058 tokio::spawn(async move {
1059 let _ = server.serve(listener).await;
1060 });
1061 let endpoint = format!("ws://{addr}");
1062
1063 let ep = endpoint.clone();
1064 tokio::spawn(async move {
1065 let _ = serve_mcp_proxy(
1066 &ep,
1067 AgentRef {
1068 session_id: "orchestrator".into(),
1069 role: None,
1070 },
1071 TOKEN,
1072 Arc::new(SlowMcp),
1073 Arc::new(RoleToolAllowlist::unrestricted()),
1074 )
1075 .await;
1076 });
1077
1078 let proxy = Arc::new(
1079 McpProxyExecutor::connect(
1080 &endpoint,
1081 "worker#mcp",
1082 TOKEN,
1083 "orchestrator",
1084 Duration::from_secs(5),
1085 )
1086 .await
1087 .expect("proxy connects"),
1088 );
1089
1090 let start = Instant::now();
1092 let mut handles = Vec::new();
1093 for i in 0..4u32 {
1094 let p = proxy.clone();
1095 handles.push(tokio::spawn(async move {
1096 let call = ToolCall {
1097 id: format!("c{i}"),
1098 tool_type: "function".into(),
1099 function: FunctionCall {
1100 name: "nova_click".into(),
1101 arguments: format!("{{\"i\":{i}}}"),
1102 },
1103 };
1104 p.execute(&call).await.expect("returns")
1105 }));
1106 }
1107 for h in handles {
1108 h.await.unwrap();
1109 }
1110 let elapsed = start.elapsed();
1111 assert!(
1112 elapsed < Duration::from_millis(500),
1113 "4 concurrent 200ms proxy calls must OVERLAP at the orchestrator \
1114 (serial would be ~800ms); took {elapsed:?}"
1115 );
1116 }
1117
1118 #[tokio::test]
1125 async fn supervisor_restarts_on_drop_and_stops_on_shutdown() {
1126 let shutdown = CancellationToken::new();
1127 let calls = Arc::new(AtomicU32::new(0));
1128 let calls_for_serve = calls.clone();
1129 let serve = move || {
1132 let c = calls_for_serve.clone();
1133 async move {
1134 let n = c.fetch_add(1, Ordering::SeqCst);
1135 if n < 3 {
1136 Ok(())
1137 } else {
1138 std::future::pending::<BrokerResult<()>>().await
1139 }
1140 }
1141 };
1142
1143 let started = std::time::Instant::now();
1144 let task = tokio::spawn(supervise_reconnect(
1145 serve,
1146 shutdown.clone(),
1147 Duration::from_millis(2),
1148 Duration::from_millis(8),
1149 Duration::from_secs(60), ));
1151
1152 tokio::time::timeout(Duration::from_secs(3), async {
1155 while calls.load(Ordering::SeqCst) < 4 {
1156 tokio::task::yield_now().await;
1157 }
1158 })
1159 .await
1160 .expect("supervisor restarted within bounded backoff");
1161 assert!(calls.load(Ordering::SeqCst) >= 4);
1162 assert!(
1163 started.elapsed() < Duration::from_secs(3),
1164 "restarts were bounded-fast, took {:?}",
1165 started.elapsed()
1166 );
1167
1168 shutdown.cancel();
1170 tokio::time::timeout(Duration::from_secs(3), task)
1171 .await
1172 .expect("supervisor stops promptly on shutdown")
1173 .expect("supervisor task did not panic");
1174 }
1175
1176 fn answer_mcp_request(req_msg: InboxMessage, orch: &AgentRef) -> InboxMessage {
1179 let reply_body = match serde_json::from_value::<McpRequest>(req_msg.body) {
1180 Ok(McpRequest::Manifest) => McpReply {
1181 manifest: Some(vec![ToolSchema {
1182 schema_type: "function".into(),
1183 function: FunctionSchema {
1184 name: "nova_click".into(),
1185 description: "click a mark".into(),
1186 parameters: json!({ "type": "object" }),
1187 },
1188 }]),
1189 ..Default::default()
1190 },
1191 Ok(McpRequest::Call { tool, arguments }) => McpReply {
1192 result: Some(ProxiedResult {
1193 success: true,
1194 result: format!("ran {tool} args={arguments}"),
1195 }),
1196 ..Default::default()
1197 },
1198 Err(_) => McpReply {
1199 error: Some("bad mcp request".into()),
1200 ..Default::default()
1201 },
1202 };
1203 InboxMessage {
1204 id: MsgId::new(),
1205 from: orch.clone(),
1206 kind: InboxKind::McpReply,
1207 body: serde_json::to_value(reply_body).unwrap_or_default(),
1208 created_at: Utc::now(),
1209 correlation_id: Some(req_msg.id),
1210 }
1211 }
1212
1213 async fn flaky_mcp_broker() -> (String, Arc<AtomicU32>) {
1220 use futures_util::{SinkExt, StreamExt};
1221 use tokio_tungstenite::accept_async;
1222 use tokio_tungstenite::tungstenite::Message;
1223
1224 use crate::proto::{BrokerFrame, ClientFrame};
1225
1226 let orch = AgentRef {
1227 session_id: "orchestrator".into(),
1228 role: None,
1229 };
1230 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1231 let addr = listener.local_addr().unwrap();
1232 let first_taken = Arc::new(AtomicBool::new(false));
1233 let conns = Arc::new(AtomicU32::new(0));
1234 let conns_for_loop = conns.clone();
1235 tokio::spawn(async move {
1236 loop {
1237 let (stream, _) = match listener.accept().await {
1238 Ok(s) => s,
1239 Err(_) => break,
1240 };
1241 let is_first = !first_taken.swap(true, Ordering::SeqCst);
1242 conns_for_loop.fetch_add(1, Ordering::SeqCst);
1243 let orch = orch.clone();
1244 tokio::spawn(async move {
1245 let ws = match accept_async(stream).await {
1246 Ok(ws) => ws,
1247 Err(_) => return,
1248 };
1249 let (mut sink, mut source) = ws.split();
1250
1251 while let Some(Ok(msg)) = source.next().await {
1253 if let Message::Text(t) = msg {
1254 if ClientFrame::from_text(&t).is_ok() {
1255 let _ = sink
1256 .send(Message::Text(BrokerFrame::Welcome.to_text()))
1257 .await;
1258 break;
1259 }
1260 }
1261 }
1262
1263 let mut delivered = 0u32;
1267 while let Some(Ok(msg)) = source.next().await {
1268 let message = match msg {
1269 Message::Text(t) => match ClientFrame::from_text(&t) {
1270 Ok(ClientFrame::Deliver { message, .. }) => message,
1271 _ => continue,
1272 },
1273 _ => continue,
1274 };
1275 let _ = sink
1276 .send(Message::Text(
1277 BrokerFrame::Delivered {
1278 id: message.id.clone(),
1279 }
1280 .to_text(),
1281 ))
1282 .await;
1283 let reply = answer_mcp_request(message, &orch);
1284 let _ = sink
1285 .send(Message::Text(
1286 BrokerFrame::Message { message: reply }.to_text(),
1287 ))
1288 .await;
1289 if is_first {
1290 delivered += 1;
1291 if delivered == 2 {
1292 let _ = sink.send(Message::Close(None)).await;
1294 break;
1295 }
1296 }
1297 }
1298 });
1299 }
1300 });
1301 (format!("ws://{addr}"), conns)
1302 }
1303
1304 #[tokio::test]
1308 async fn proxy_executor_reconnects_after_transient_drop() {
1309 let (endpoint, conns) = flaky_mcp_broker().await;
1310
1311 let proxy = McpProxyExecutor::connect(
1312 &endpoint,
1313 "worker#mcp",
1314 TOKEN,
1315 "orchestrator",
1316 Duration::from_secs(5),
1317 )
1318 .await
1319 .expect("proxy connects + fetches manifest");
1320 assert_eq!(proxy.list_tools().len(), 1);
1321
1322 let call = |n: usize| ToolCall {
1323 id: format!("c{n}"),
1324 tool_type: "function".into(),
1325 function: FunctionCall {
1326 name: "nova_click".into(),
1327 arguments: "{}".into(),
1328 },
1329 };
1330
1331 let r1 = proxy.execute(&call(1)).await.expect("call1 on conn1");
1333 assert!(r1.success);
1334
1335 let r2 = tokio::time::timeout(Duration::from_secs(15), proxy.execute(&call(2)))
1338 .await
1339 .expect("call2 did not hang")
1340 .expect("call2 succeeds after reconnect (not a permanent error)");
1341 assert!(r2.success);
1342 assert_eq!(r2.result, "ran nova_click args={}");
1343
1344 assert!(
1346 conns.load(Ordering::SeqCst) >= 2,
1347 "worker reconnected (>=2 connections accepted), got {}",
1348 conns.load(Ordering::SeqCst)
1349 );
1350
1351 let r3 = proxy.execute(&call(3)).await.expect("call3 on conn2");
1354 assert!(r3.success);
1355 }
1356}