use std::collections::HashMap;
use std::sync::Arc;
use axum::{
Router,
extract::{
State, WebSocketUpgrade,
ws::{Message, WebSocket},
},
response::Response,
routing::get,
};
use futures::{SinkExt, StreamExt};
use tokio::sync::{Mutex, RwLock, watch};
use crate::context::{
ChannelClientRequester, ClientRequesterHandle, OutgoingRequest, OutgoingRequestReceiver,
OutgoingRequestSender, outgoing_request_channel,
};
use crate::error::{Error, JsonRpcError, Result};
use crate::jsonrpc::JsonRpcService;
use crate::protocol::{
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, McpNotification,
RequestId,
};
use crate::router::{McpRouter, RouterRequest, RouterResponse};
use crate::transport::service::{
CatchError, InjectAnnotations, McpBoxService, ServiceFactory, identity_factory,
};
struct Session {
id: String,
router: McpRouter,
service_factory: ServiceFactory,
cancel_tx: Mutex<watch::Sender<bool>>,
}
impl Session {
fn new(router: McpRouter, service_factory: ServiceFactory) -> Self {
let (cancel_tx, _) = watch::channel(false);
Self {
id: uuid::Uuid::new_v4().to_string(),
router,
service_factory,
cancel_tx: Mutex::new(cancel_tx),
}
}
fn make_service(&self) -> McpBoxService {
(self.service_factory)(self.router.clone())
}
async fn cancel_receiver(&self) -> watch::Receiver<bool> {
self.cancel_tx.lock().await.subscribe()
}
async fn replace_connection(&self) -> watch::Receiver<bool> {
let mut tx = self.cancel_tx.lock().await;
let _ = tx.send(true);
let (new_tx, new_rx) = watch::channel(false);
*tx = new_tx;
new_rx
}
}
impl std::fmt::Debug for Session {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Session")
.field("id", &self.id)
.field("router", &self.router)
.finish_non_exhaustive()
}
}
#[derive(Debug, Default)]
struct SessionStore {
sessions: RwLock<HashMap<String, Arc<Session>>>,
}
impl SessionStore {
fn new() -> Self {
Self::default()
}
async fn create(
&self,
router: McpRouter,
service_factory: ServiceFactory,
) -> (Arc<Session>, watch::Receiver<bool>) {
let session = Arc::new(Session::new(router, service_factory));
let cancel_rx = session.cancel_receiver().await;
let mut sessions = self.sessions.write().await;
sessions.insert(session.id.clone(), session.clone());
tracing::debug!(session_id = %session.id, "Created WebSocket session");
(session, cancel_rx)
}
#[cfg_attr(not(test), allow(dead_code))]
async fn reconnect(&self, id: &str) -> Option<(Arc<Session>, watch::Receiver<bool>)> {
let sessions = self.sessions.read().await;
let session = sessions.get(id)?;
let cancel_rx = session.replace_connection().await;
tracing::info!(session_id = %id, "Replaced active WebSocket connection (zombie prevention)");
Some((session.clone(), cancel_rx))
}
async fn remove(&self, id: &str) -> bool {
let mut sessions = self.sessions.write().await;
let removed = sessions.remove(id).is_some();
if removed {
tracing::debug!(session_id = %id, "Removed WebSocket session");
}
removed
}
}
struct PendingRequest {
response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
}
struct AppState {
router_template: McpRouter,
service_factory: ServiceFactory,
sessions: SessionStore,
sampling_enabled: bool,
}
pub struct WebSocketTransport {
router: McpRouter,
sampling_enabled: bool,
service_factory: ServiceFactory,
#[cfg(feature = "oauth")]
oauth_config: Option<crate::oauth::ProtectedResourceMetadata>,
}
impl WebSocketTransport {
pub fn new(router: McpRouter) -> Self {
Self {
router,
sampling_enabled: false,
service_factory: identity_factory(),
#[cfg(feature = "oauth")]
oauth_config: None,
}
}
pub fn with_sampling(mut self) -> Self {
self.sampling_enabled = true;
self
}
#[cfg(feature = "oauth")]
pub fn oauth(mut self, metadata: crate::oauth::ProtectedResourceMetadata) -> Self {
self.oauth_config = Some(metadata);
self
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<McpRouter> + Send + Sync + 'static,
L::Service:
tower::Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
<L::Service as tower::Service<RouterRequest>>::Error: std::fmt::Display + Send,
<L::Service as tower::Service<RouterRequest>>::Future: Send,
{
self.service_factory = Arc::new(move |router: McpRouter| {
let annotations = router.tool_annotations_map();
let wrapped = layer.layer(router);
tower::util::BoxCloneService::new(InjectAnnotations::new(
CatchError::new(wrapped),
annotations,
))
});
self
}
pub fn into_router(self) -> Router {
#[cfg(feature = "oauth")]
let oauth_config = self.oauth_config;
let state = Arc::new(AppState {
router_template: self.router,
service_factory: self.service_factory,
sessions: SessionStore::new(),
sampling_enabled: self.sampling_enabled,
});
let router = Router::new()
.route("/", get(handle_websocket))
.with_state(state);
#[cfg(feature = "oauth")]
let router = add_oauth_route(router, "", oauth_config.as_ref());
router
}
pub fn into_router_at(self, path: &str) -> Router {
#[cfg(feature = "oauth")]
let oauth_config = self.oauth_config;
let state = Arc::new(AppState {
router_template: self.router,
service_factory: self.service_factory,
sessions: SessionStore::new(),
sampling_enabled: self.sampling_enabled,
});
let ws_router = Router::new()
.route("/", get(handle_websocket))
.with_state(state);
let router = Router::new().nest(path, ws_router);
#[cfg(feature = "oauth")]
let router = add_oauth_route(router, path, oauth_config.as_ref());
router
}
pub async fn serve(self, addr: &str) -> Result<()> {
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| Error::Transport(format!("Failed to bind to {}: {}", addr, e)))?;
tracing::info!("MCP WebSocket transport listening on {}", addr);
let router = self.into_router();
axum::serve(listener, router)
.await
.map_err(|e| Error::Transport(format!("Server error: {}", e)))?;
Ok(())
}
}
#[cfg(feature = "oauth")]
fn add_oauth_route(
router: Router,
base_path: &str,
metadata: Option<&crate::oauth::ProtectedResourceMetadata>,
) -> Router {
if let Some(metadata) = metadata {
let metadata = metadata.clone();
let well_known_path = if base_path.is_empty() {
crate::oauth::ProtectedResourceMetadata::well_known_path().to_string()
} else {
format!(
"{}{}",
base_path.trim_end_matches('/'),
crate::oauth::ProtectedResourceMetadata::well_known_path()
)
};
router.route(
&well_known_path,
get(move || {
let m = metadata.clone();
async move { axum::Json(m) }
}),
)
} else {
router
}
}
#[derive(Debug, Default)]
struct McpSubprotocols {
auth_token: Option<String>,
protocol_version: Option<String>,
selected: Vec<String>,
}
fn parse_mcp_subprotocols(headers: &axum::http::HeaderMap) -> McpSubprotocols {
use crate::protocol::SUPPORTED_PROTOCOL_VERSIONS;
let mut result = McpSubprotocols::default();
let Some(header) = headers.get("sec-websocket-protocol") else {
return result;
};
let Ok(header_str) = header.to_str() else {
return result;
};
for protocol in header_str.split(',').map(|s| s.trim()) {
if let Some(token) = protocol.strip_prefix("mcp.auth.") {
if !token.is_empty() {
result.auth_token = Some(token.to_string());
result.selected.push(protocol.to_string());
}
} else if let Some(version) = protocol.strip_prefix("mcp.version.") {
if SUPPORTED_PROTOCOL_VERSIONS.contains(&version) {
result.protocol_version = Some(version.to_string());
result.selected.push(protocol.to_string());
} else {
tracing::warn!(version = %version, "Unsupported MCP protocol version in subprotocol");
}
}
}
result
}
async fn handle_websocket(
State(state): State<Arc<AppState>>,
request: axum::extract::Request,
) -> Response {
use axum::extract::FromRequestParts;
use axum::response::IntoResponse;
let (mut parts, _body) = request.into_parts();
let subprotocols = parse_mcp_subprotocols(&parts.headers);
if let Some(ref version) = subprotocols.protocol_version {
tracing::debug!(version = %version, "Client requested MCP protocol version via subprotocol");
}
#[allow(unused_mut)]
let mut mcp_extensions = crate::router::Extensions::new();
#[cfg(feature = "oauth")]
{
if let Some(claims) = parts.extensions.get::<crate::oauth::token::TokenClaims>() {
mcp_extensions.insert(claims.clone());
}
}
if let Some(ref token) = subprotocols.auth_token {
mcp_extensions.insert(WebSocketAuthToken(token.clone()));
}
let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await {
Ok(ws) => ws,
Err(e) => return e.into_response(),
};
let ws = if !subprotocols.selected.is_empty() {
ws.protocols(subprotocols.selected)
} else {
ws
};
ws.on_upgrade(move |socket| handle_socket(socket, state, mcp_extensions))
}
#[derive(Debug, Clone)]
pub struct WebSocketAuthToken(pub String);
async fn handle_socket(
socket: WebSocket,
state: Arc<AppState>,
mcp_extensions: crate::router::Extensions,
) {
let (session, cancel_rx) = state
.sessions
.create(
state.router_template.with_fresh_session(),
state.service_factory.clone(),
)
.await;
let session_id = session.id.clone();
tracing::info!(session_id = %session_id, "WebSocket connection established");
if state.sampling_enabled {
handle_socket_bidirectional(socket, session, &session_id, mcp_extensions, cancel_rx).await;
} else {
handle_socket_simple(socket, session, &session_id, mcp_extensions, cancel_rx).await;
}
state.sessions.remove(&session_id).await;
tracing::info!(session_id = %session_id, "WebSocket connection closed");
}
async fn handle_socket_simple(
socket: WebSocket,
session: Arc<Session>,
session_id: &str,
mcp_extensions: crate::router::Extensions,
mut cancel_rx: watch::Receiver<bool>,
) {
let mut service = JsonRpcService::new(session.make_service()).with_extensions(mcp_extensions);
let (mut sender, mut receiver) = socket.split();
loop {
let msg = tokio::select! {
msg = receiver.next() => {
match msg {
Some(msg) => msg,
None => break,
}
}
_ = cancel_rx.changed() => {
if *cancel_rx.borrow() {
tracing::info!(session_id = %session_id, "Connection superseded by new connection, closing");
let _ = sender.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 1000,
reason: "Connection replaced by newer WebSocket connection".into(),
}))).await;
break;
}
continue;
}
};
let msg = match msg {
Ok(m) => m,
Err(e) => {
tracing::error!(error = %e, "WebSocket receive error");
break;
}
};
match msg {
Message::Text(text) => {
match process_message(&mut service, &session.router, &text).await {
Ok(Some(response)) => {
let response_json = match serde_json::to_string(&response) {
Ok(json) => json,
Err(e) => {
tracing::error!(error = %e, "Failed to serialize response");
continue;
}
};
if let Err(e) = sender.send(Message::Text(response_json.into())).await {
tracing::error!(error = %e, "Failed to send response");
break;
}
}
Ok(None) => {
}
Err(e) => {
tracing::error!(error = %e, "Error processing message");
let error_response = JsonRpcResponse::error(
None,
JsonRpcError::internal_error(e.to_string()),
);
if let Ok(json) = serde_json::to_string(&error_response) {
let _ = sender.send(Message::Text(json.into())).await;
}
}
}
}
Message::Binary(_) => {
tracing::warn!(session_id = %session_id, "Received binary frame, closing with 1003");
let _ = sender
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 1003,
reason: "Binary frames are not supported by MCP".into(),
})))
.await;
break;
}
Message::Ping(data) => {
if let Err(e) = sender.send(Message::Pong(data)).await {
tracing::error!(error = %e, "Failed to send pong");
break;
}
}
Message::Pong(_) => {
}
Message::Close(_) => {
tracing::info!(session_id = %session_id, "WebSocket close received");
break;
}
}
}
}
async fn handle_socket_bidirectional(
socket: WebSocket,
session: Arc<Session>,
session_id: &str,
_mcp_extensions: crate::router::Extensions,
mut cancel_rx: watch::Receiver<bool>,
) {
let (request_tx, mut request_rx): (OutgoingRequestSender, OutgoingRequestReceiver) =
outgoing_request_channel(32);
let client_requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
let router = session
.router
.clone()
.with_client_requester(client_requester);
let mut service = JsonRpcService::new((session.service_factory)(router.clone()))
.with_extensions(_mcp_extensions);
let pending_requests: Arc<Mutex<HashMap<RequestId, PendingRequest>>> =
Arc::new(Mutex::new(HashMap::new()));
let (sender, mut receiver) = socket.split();
let sender = Arc::new(Mutex::new(sender));
let session_id_owned = session_id.to_string();
loop {
tokio::select! {
msg = receiver.next() => {
let msg = match msg {
Some(Ok(m)) => m,
Some(Err(e)) => {
tracing::error!(error = %e, "WebSocket receive error");
break;
}
None => break,
};
match msg {
Message::Text(text) => {
let result = handle_incoming_message(
&text,
&mut service,
&router,
pending_requests.clone(),
sender.clone(),
).await;
if let Err(e) = result {
tracing::error!(error = %e, "Error handling incoming message");
}
}
Message::Binary(_) => {
tracing::warn!(session_id = %session_id_owned, "Received binary frame, closing with 1003");
let mut s = sender.lock().await;
let _ = s.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 1003,
reason: "Binary frames are not supported by MCP".into(),
}))).await;
break;
}
Message::Ping(data) => {
let mut sender = sender.lock().await;
if let Err(e) = sender.send(Message::Pong(data)).await {
tracing::error!(error = %e, "Failed to send pong");
break;
}
}
Message::Pong(_) => {}
Message::Close(_) => {
tracing::info!(session_id = %session_id_owned, "WebSocket close received");
break;
}
}
}
Some(outgoing) = request_rx.recv() => {
let result = send_outgoing_request(
outgoing,
pending_requests.clone(),
sender.clone(),
).await;
if let Err(e) = result {
tracing::error!(error = %e, "Error sending outgoing request");
}
}
_ = cancel_rx.changed() => {
if *cancel_rx.borrow() {
tracing::info!(session_id = %session_id_owned, "Connection superseded by new connection, closing");
let mut s = sender.lock().await;
let _ = s.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 1000,
reason: "Connection replaced by newer WebSocket connection".into(),
}))).await;
break;
}
}
}
}
}
async fn handle_incoming_message<S>(
text: &str,
service: &mut JsonRpcService<McpBoxService>,
router: &McpRouter,
pending_requests: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
sender: Arc<Mutex<S>>,
) -> Result<()>
where
S: futures::Sink<Message> + Unpin,
S::Error: std::fmt::Display,
{
let parsed: serde_json::Value = serde_json::from_str(text)?;
if parsed.get("method").is_none()
&& (parsed.get("result").is_some() || parsed.get("error").is_some())
{
return handle_response(&parsed, pending_requests).await;
}
if parsed.get("id").is_none() {
if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(text) {
let mcp_notification = McpNotification::from_jsonrpc(¬ification)?;
router.handle_notification(mcp_notification);
}
return Ok(());
}
let message: JsonRpcMessage = serde_json::from_str(text)?;
match service.call_message(message).await {
Ok(response) => {
let response_json = serde_json::to_string(&response)
.map_err(|e| Error::Transport(format!("Failed to serialize response: {}", e)))?;
let mut sender = sender.lock().await;
sender
.send(Message::Text(response_json.into()))
.await
.map_err(|e| Error::Transport(format!("Failed to send response: {}", e)))?;
}
Err(e) => {
tracing::error!(error = %e, "Error processing message");
let error_response =
JsonRpcResponse::error(None, JsonRpcError::internal_error(e.to_string()));
if let Ok(json) = serde_json::to_string(&error_response) {
let mut sender = sender.lock().await;
let _ = sender.send(Message::Text(json.into())).await;
}
}
}
Ok(())
}
async fn handle_response(
parsed: &serde_json::Value,
pending_requests: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
) -> Result<()> {
let id = match parsed.get("id") {
Some(id) => {
if let Some(n) = id.as_i64() {
RequestId::Number(n)
} else if let Some(s) = id.as_str() {
RequestId::String(s.to_string())
} else {
tracing::warn!("Response has invalid id type");
return Ok(());
}
}
None => {
tracing::warn!("Response missing id field");
return Ok(());
}
};
let pending = {
let mut pending_requests = pending_requests.lock().await;
pending_requests.remove(&id)
};
match pending {
Some(pending) => {
let result = if let Some(error) = parsed.get("error") {
let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
let message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
Err(Error::Internal(format!(
"Client error ({}): {}",
code, message
)))
} else if let Some(result) = parsed.get("result") {
Ok(result.clone())
} else {
Err(Error::Internal(
"Response has neither result nor error".to_string(),
))
};
let _ = pending.response_tx.send(result);
}
None => {
tracing::warn!(id = ?id, "Received response for unknown request");
}
}
Ok(())
}
async fn send_outgoing_request<S>(
outgoing: OutgoingRequest,
pending_requests: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
sender: Arc<Mutex<S>>,
) -> Result<()>
where
S: futures::Sink<Message> + Unpin,
S::Error: std::fmt::Display,
{
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: outgoing.id.clone(),
method: outgoing.method,
params: Some(outgoing.params),
};
let request_json = serde_json::to_string(&request)
.map_err(|e| Error::Transport(format!("Failed to serialize request: {}", e)))?;
tracing::debug!(output = %request_json, "Sending request to client");
{
let mut pending = pending_requests.lock().await;
pending.insert(
outgoing.id,
PendingRequest {
response_tx: outgoing.response_tx,
},
);
}
let mut sender = sender.lock().await;
sender
.send(Message::Text(request_json.into()))
.await
.map_err(|e| Error::Transport(format!("Failed to send request: {}", e)))?;
Ok(())
}
async fn process_message(
service: &mut JsonRpcService<McpBoxService>,
router: &McpRouter,
text: &str,
) -> Result<Option<crate::protocol::JsonRpcResponseMessage>> {
let parsed: serde_json::Value = serde_json::from_str(text)?;
if parsed.get("id").is_none()
&& let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(text)
{
let mcp_notification = McpNotification::from_jsonrpc(¬ification)?;
router.handle_notification(mcp_notification);
return Ok(None);
}
let message: JsonRpcMessage = serde_json::from_str(text)?;
let response = service.call_message(message).await?;
Ok(Some(response))
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_router() -> McpRouter {
McpRouter::new().server_info("test-server", "1.0.0")
}
#[tokio::test]
async fn test_websocket_transport_builds() {
let transport = WebSocketTransport::new(create_test_router());
let _router = transport.into_router();
}
#[tokio::test]
async fn test_websocket_transport_at_path() {
let transport = WebSocketTransport::new(create_test_router());
let _router = transport.into_router_at("/mcp");
}
#[tokio::test]
async fn test_layer_with_identity() {
let transport = WebSocketTransport::new(create_test_router())
.layer(tower::layer::util::Identity::new());
let _router = transport.into_router();
}
#[tokio::test]
async fn test_layer_with_timeout() {
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let transport = WebSocketTransport::new(create_test_router())
.layer(TimeoutLayer::new(Duration::from_secs(30)));
let _router = transport.into_router();
}
#[tokio::test]
async fn test_layer_with_composed_layers() {
use std::time::Duration;
use tower::ServiceBuilder;
use tower::timeout::TimeoutLayer;
let transport = WebSocketTransport::new(create_test_router()).layer(
ServiceBuilder::new()
.layer(TimeoutLayer::new(Duration::from_secs(30)))
.concurrency_limit(100)
.into_inner(),
);
let _router = transport.into_router();
}
#[test]
fn test_parse_mcp_subprotocols_empty() {
let headers = axum::http::HeaderMap::new();
let result = parse_mcp_subprotocols(&headers);
assert!(result.auth_token.is_none());
assert!(result.protocol_version.is_none());
assert!(result.selected.is_empty());
}
#[test]
fn test_parse_mcp_subprotocols_auth_and_version() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"mcp.auth.my-secret-token, mcp.version.2025-11-25"
.parse()
.unwrap(),
);
let result = parse_mcp_subprotocols(&headers);
assert_eq!(result.auth_token.as_deref(), Some("my-secret-token"));
assert_eq!(result.protocol_version.as_deref(), Some("2025-11-25"));
assert_eq!(result.selected.len(), 2);
}
#[test]
fn test_parse_mcp_subprotocols_unsupported_version() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"mcp.version.1999-01-01".parse().unwrap(),
);
let result = parse_mcp_subprotocols(&headers);
assert!(result.protocol_version.is_none());
assert!(result.selected.is_empty());
}
#[test]
fn test_parse_mcp_subprotocols_older_supported_version() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"mcp.version.2025-03-26".parse().unwrap(),
);
let result = parse_mcp_subprotocols(&headers);
assert_eq!(result.protocol_version.as_deref(), Some("2025-03-26"));
assert_eq!(result.selected.len(), 1);
}
#[test]
fn test_parse_mcp_subprotocols_auth_only() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"mcp.auth.bearer-xyz123".parse().unwrap(),
);
let result = parse_mcp_subprotocols(&headers);
assert_eq!(result.auth_token.as_deref(), Some("bearer-xyz123"));
assert!(result.protocol_version.is_none());
}
#[test]
fn test_parse_mcp_subprotocols_ignores_unknown() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
"sec-websocket-protocol",
"graphql-ws, mcp.auth.token, mcp.version.2025-11-25, other-protocol"
.parse()
.unwrap(),
);
let result = parse_mcp_subprotocols(&headers);
assert_eq!(result.auth_token.as_deref(), Some("token"));
assert_eq!(result.protocol_version.as_deref(), Some("2025-11-25"));
assert_eq!(result.selected.len(), 2);
}
#[tokio::test]
async fn test_session_cancel_receiver() {
let router = create_test_router();
let session = Session::new(router, identity_factory());
let mut rx = session.cancel_receiver().await;
assert!(!*rx.borrow());
let _new_rx = session.replace_connection().await;
rx.changed().await.unwrap();
assert!(*rx.borrow());
}
#[tokio::test]
async fn test_session_replace_connection_new_rx_starts_clean() {
let router = create_test_router();
let session = Session::new(router, identity_factory());
let _rx1 = session.cancel_receiver().await;
let rx2 = session.replace_connection().await;
assert!(!*rx2.borrow(), "New receiver should start as not-cancelled");
}
#[tokio::test]
async fn test_session_store_reconnect() {
let router = create_test_router();
let store = SessionStore::new();
let (session, mut rx1) = store
.create(router.with_fresh_session(), identity_factory())
.await;
let session_id = session.id.clone();
let result = store.reconnect(&session_id).await;
assert!(result.is_some());
let (_session2, rx2) = result.unwrap();
rx1.changed().await.unwrap();
assert!(*rx1.borrow());
assert!(!*rx2.borrow());
}
}