use agent_client_protocol_schema::v1::{
JsonRpcMessage as VersionedJsonRpcMessage, Notification as RpcNotification,
Request as RpcRequest, RequestId, Response as RpcResponse, SessionId,
};
use serde::{Deserialize, Serialize};
use std::any::TypeId;
#[cfg(feature = "unstable_cancel_request")]
use std::collections::HashMap;
use std::fmt::Debug;
use std::panic::Location;
use std::pin::pin;
use std::sync::Arc;
#[cfg(feature = "unstable_cancel_request")]
use std::sync::{
Mutex,
atomic::{AtomicBool, Ordering},
};
use uuid::Uuid;
#[cfg(feature = "unstable_cancel_request")]
use futures::FutureExt;
use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Either};
use futures::{AsyncRead, AsyncWrite, StreamExt};
mod dynamic_handler;
pub(crate) mod handlers;
mod incoming_actor;
mod outgoing_actor;
mod protocol_compat;
pub(crate) mod run;
mod task_actor;
mod transport_actor;
use crate::jsonrpc::dynamic_handler::DynamicHandlerMessage;
pub use crate::jsonrpc::handlers::NullHandler;
use crate::jsonrpc::handlers::{ChainedHandler, NamedHandler};
use crate::jsonrpc::handlers::{MessageHandler, NotificationHandler, RequestHandler};
use crate::jsonrpc::outgoing_actor::{OutgoingMessageTx, send_raw_message};
use crate::jsonrpc::protocol_compat::{ProtocolCompat, ProtocolMode};
use crate::jsonrpc::run::SpawnedRun;
use crate::jsonrpc::run::{ChainRun, NullRun, RunWithConnectionTo};
use crate::jsonrpc::task_actor::{Task, TaskTx};
use crate::mcp_server::McpServer;
use crate::role::HasPeer;
use crate::role::Role;
use crate::{Agent, Client, ConnectTo, RoleId};
#[derive(Debug, Clone)]
pub enum RawJsonRpcMessage {
Request(RpcRequest<RawJsonRpcParams>),
Notification(RpcNotification<RawJsonRpcParams>),
Response(RpcResponse<serde_json::Value>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum RawJsonRpcParams {
Array(Vec<serde_json::Value>),
Object(serde_json::Map<String, serde_json::Value>),
}
impl RawJsonRpcParams {
pub fn from_value(value: serde_json::Value) -> Result<Option<Self>, crate::Error> {
match value {
serde_json::Value::Null => Ok(None),
serde_json::Value::Array(array) => Ok(Some(Self::Array(array))),
serde_json::Value::Object(object) => Ok(Some(Self::Object(object))),
_ => {
Err(crate::Error::invalid_params()
.data("JSON-RPC params must be an object or array"))
}
}
}
#[must_use]
pub fn into_value(self) -> serde_json::Value {
match self {
Self::Array(array) => serde_json::Value::Array(array),
Self::Object(object) => serde_json::Value::Object(object),
}
}
}
impl Serialize for RawJsonRpcParams {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Self::Array(array) => array.serialize(serializer),
Self::Object(object) => object.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for RawJsonRpcParams {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::Array(array) => Ok(Self::Array(array)),
serde_json::Value::Object(object) => Ok(Self::Object(object)),
_ => Err(serde::de::Error::custom(
"JSON-RPC params must be an object or array",
)),
}
}
}
impl RawJsonRpcMessage {
pub fn request(
method: String,
params: serde_json::Value,
id: RequestId,
) -> Result<Self, crate::Error> {
Ok(Self::Request(RpcRequest {
id,
method: Arc::from(method),
params: RawJsonRpcParams::from_value(params)?,
}))
}
pub fn notification(method: String, params: serde_json::Value) -> Result<Self, crate::Error> {
Ok(Self::Notification(RpcNotification {
method: Arc::from(method),
params: RawJsonRpcParams::from_value(params)?,
}))
}
#[must_use]
pub fn response(id: RequestId, response: Result<serde_json::Value, crate::Error>) -> Self {
Self::Response(RpcResponse::new(id, response))
}
#[must_use]
pub fn response_id(&self) -> Option<&RequestId> {
match self {
Self::Response(RpcResponse::Result { id, .. } | RpcResponse::Error { id, .. }) => {
Some(id)
}
Self::Request(_) | Self::Notification(_) => None,
}
}
}
impl Serialize for RawJsonRpcMessage {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Self::Request(request) => {
VersionedJsonRpcMessage::wrap(request.clone()).serialize(serializer)
}
Self::Notification(notification) => {
VersionedJsonRpcMessage::wrap(notification.clone()).serialize(serializer)
}
Self::Response(response) => {
VersionedJsonRpcMessage::wrap(response.clone()).serialize(serializer)
}
}
}
}
impl<'de> Deserialize<'de> for RawJsonRpcMessage {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
if value.get("method").is_some() {
if value.get("id").is_some() {
let request = serde_json::from_value::<
VersionedJsonRpcMessage<RpcRequest<RawJsonRpcParams>>,
>(value)
.map_err(serde::de::Error::custom)?
.into_inner();
Ok(Self::Request(request))
} else {
let notification = serde_json::from_value::<
VersionedJsonRpcMessage<RpcNotification<RawJsonRpcParams>>,
>(value)
.map_err(serde::de::Error::custom)?
.into_inner();
Ok(Self::Notification(notification))
}
} else if value.get("result").is_some() || value.get("error").is_some() {
let response = serde_json::from_value::<
VersionedJsonRpcMessage<RpcResponse<serde_json::Value>>,
>(value)
.map_err(serde::de::Error::custom)?
.into_inner();
Ok(Self::Response(response))
} else {
Err(serde::de::Error::custom("invalid JSON-RPC message"))
}
}
}
fn params_from_transport(params: Option<RawJsonRpcParams>) -> serde_json::Value {
params.map_or(serde_json::Value::Null, RawJsonRpcParams::into_value)
}
#[allow(async_fn_in_trait)]
pub trait HandleDispatchFrom<Counterpart: Role>: Send {
fn handle_dispatch_from(
&mut self,
message: Dispatch,
connection: ConnectionTo<Counterpart>,
) -> impl Future<Output = Result<Handled<Dispatch>, crate::Error>> + Send;
fn describe_chain(&self) -> impl std::fmt::Debug;
}
impl<Counterpart: Role, H> HandleDispatchFrom<Counterpart> for &mut H
where
H: HandleDispatchFrom<Counterpart>,
{
fn handle_dispatch_from(
&mut self,
message: Dispatch,
cx: ConnectionTo<Counterpart>,
) -> impl Future<Output = Result<Handled<Dispatch>, crate::Error>> + Send {
H::handle_dispatch_from(self, message, cx)
}
fn describe_chain(&self) -> impl std::fmt::Debug {
H::describe_chain(self)
}
}
#[must_use]
#[derive(Debug)]
pub struct Builder<Host: Role, Handler = NullHandler, Runner = NullRun>
where
Handler: HandleDispatchFrom<Host::Counterpart>,
Runner: RunWithConnectionTo<Host::Counterpart>,
{
host: Host,
name: Option<String>,
handler: Handler,
responder: Runner,
protocol_mode: ProtocolMode,
}
fn default_protocol_mode<Host: Role>() -> ProtocolMode {
let role = TypeId::of::<Host>();
if role == TypeId::of::<Agent>() {
ProtocolMode::v1_agent()
} else if role == TypeId::of::<Client>() {
ProtocolMode::v1_client()
} else {
ProtocolMode::disabled()
}
}
impl<Host: Role> Builder<Host, NullHandler, NullRun> {
pub fn new(role: Host) -> Self {
Self {
host: role,
name: None,
handler: NullHandler,
responder: NullRun,
protocol_mode: default_protocol_mode::<Host>(),
}
}
}
impl<Host: Role, Handler> Builder<Host, Handler, NullRun>
where
Handler: HandleDispatchFrom<Host::Counterpart>,
{
pub fn new_with(role: Host, handler: Handler) -> Self {
Self {
host: role,
name: None,
handler,
responder: NullRun,
protocol_mode: default_protocol_mode::<Host>(),
}
}
}
impl<
Host: Role,
Handler: HandleDispatchFrom<Host::Counterpart>,
Runner: RunWithConnectionTo<Host::Counterpart>,
> Builder<Host, Handler, Runner>
{
pub fn name(mut self, name: impl ToString) -> Self {
self.name = Some(name.to_string());
self
}
pub(crate) fn v1_agent(mut self) -> Self {
self.protocol_mode = ProtocolMode::v1_agent();
self
}
pub(crate) fn v1_client(mut self) -> Self {
self.protocol_mode = ProtocolMode::v1_client();
self
}
#[cfg(feature = "unstable_protocol_v2")]
pub(crate) fn v2_agent(mut self) -> Self {
self.protocol_mode = ProtocolMode::v2_agent();
self
}
#[cfg(feature = "unstable_protocol_v2")]
pub(crate) fn v2_client(mut self) -> Self {
self.protocol_mode = ProtocolMode::v2_client();
self
}
pub fn with_connection_builder(
self,
other: Builder<
Host,
impl HandleDispatchFrom<Host::Counterpart>,
impl RunWithConnectionTo<Host::Counterpart>,
>,
) -> Builder<
Host,
impl HandleDispatchFrom<Host::Counterpart>,
impl RunWithConnectionTo<Host::Counterpart>,
> {
let Builder {
name: other_name,
handler: other_handler,
responder: other_responder,
protocol_mode: other_protocol_mode,
host: _,
} = other;
Builder {
host: self.host,
name: self.name,
handler: ChainedHandler::new(
self.handler,
NamedHandler::new(other_name, other_handler),
),
responder: ChainRun::new(self.responder, other_responder),
protocol_mode: self.protocol_mode.merge(other_protocol_mode),
}
}
pub fn with_handler(
self,
handler: impl HandleDispatchFrom<Host::Counterpart>,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner> {
Builder {
host: self.host,
name: self.name,
handler: ChainedHandler::new(self.handler, handler),
responder: self.responder,
protocol_mode: self.protocol_mode,
}
}
pub fn with_responder<Run1>(
self,
responder: Run1,
) -> Builder<Host, Handler, impl RunWithConnectionTo<Host::Counterpart>>
where
Run1: RunWithConnectionTo<Host::Counterpart>,
{
Builder {
host: self.host,
name: self.name,
handler: self.handler,
responder: ChainRun::new(self.responder, responder),
protocol_mode: self.protocol_mode,
}
}
#[track_caller]
pub fn with_spawned<F, Fut>(
self,
task: F,
) -> Builder<Host, Handler, impl RunWithConnectionTo<Host::Counterpart>>
where
F: FnOnce(ConnectionTo<Host::Counterpart>) -> Fut + Send,
Fut: Future<Output = Result<(), crate::Error>> + Send,
{
let location = Location::caller();
self.with_responder(SpawnedRun::new(location, task))
}
pub fn on_receive_dispatch<Req, Notif, F, T, ToFut>(
self,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Host::Counterpart>,
Req: JsonRpcRequest,
Notif: JsonRpcNotification,
F: AsyncFnMut(
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<Dispatch<Req, Notif>>,
ToFut: Fn(
&mut F,
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = MessageHandler::new(
self.host.counterpart(),
self.host.counterpart(),
op,
to_future_hack,
);
self.with_handler(handler)
}
pub fn on_receive_request<Req: JsonRpcRequest, F, T, ToFut>(
self,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Host::Counterpart>,
F: AsyncFnMut(
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<(Req, Responder<Req::Response>)>,
ToFut: Fn(
&mut F,
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = RequestHandler::new(
self.host.counterpart(),
self.host.counterpart(),
op,
to_future_hack,
);
self.with_handler(handler)
}
pub fn on_receive_notification<Notif, F, T, ToFut>(
self,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Host::Counterpart>,
Notif: JsonRpcNotification,
F: AsyncFnMut(Notif, ConnectionTo<Host::Counterpart>) -> Result<T, crate::Error> + Send,
T: IntoHandled<(Notif, ConnectionTo<Host::Counterpart>)>,
ToFut: Fn(
&mut F,
Notif,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = NotificationHandler::new(
self.host.counterpart(),
self.host.counterpart(),
op,
to_future_hack,
);
self.with_handler(handler)
}
pub fn on_receive_dispatch_from<
Req: JsonRpcRequest,
Notif: JsonRpcNotification,
Peer: Role,
F,
T,
ToFut,
>(
self,
peer: Peer,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Peer>,
F: AsyncFnMut(
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<Dispatch<Req, Notif>>,
ToFut: Fn(
&mut F,
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = MessageHandler::new(self.host.counterpart(), peer, op, to_future_hack);
self.with_handler(handler)
}
pub fn on_receive_request_from<Req: JsonRpcRequest, Peer: Role, F, T, ToFut>(
self,
peer: Peer,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Peer>,
F: AsyncFnMut(
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<(Req, Responder<Req::Response>)>,
ToFut: Fn(
&mut F,
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = RequestHandler::new(self.host.counterpart(), peer, op, to_future_hack);
self.with_handler(handler)
}
pub fn on_receive_notification_from<Notif: JsonRpcNotification, Peer: Role, F, T, ToFut>(
self,
peer: Peer,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Peer>,
F: AsyncFnMut(Notif, ConnectionTo<Host::Counterpart>) -> Result<T, crate::Error> + Send,
T: IntoHandled<(Notif, ConnectionTo<Host::Counterpart>)>,
ToFut: Fn(
&mut F,
Notif,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = NotificationHandler::new(self.host.counterpart(), peer, op, to_future_hack);
self.with_handler(handler)
}
pub fn with_mcp_server(
self,
mcp_server: McpServer<Host::Counterpart, impl RunWithConnectionTo<Host::Counterpart>>,
) -> Builder<
Host,
impl HandleDispatchFrom<Host::Counterpart>,
impl RunWithConnectionTo<Host::Counterpart>,
>
where
Host::Counterpart: HasPeer<Agent> + HasPeer<Client>,
{
let (handler, responder) = mcp_server.into_handler_and_responder();
self.with_handler(handler).with_responder(responder)
}
pub async fn connect_to(
self,
transport: impl ConnectTo<Host> + 'static,
) -> Result<(), crate::Error> {
self.connect_with(transport, async move |_cx| future::pending().await)
.await
}
pub async fn connect_with<R>(
self,
transport: impl ConnectTo<Host> + 'static,
main_fn: impl AsyncFnOnce(ConnectionTo<Host::Counterpart>) -> Result<R, crate::Error>,
) -> Result<R, crate::Error> {
let (_, future) = self.into_connection_and_future(transport, main_fn);
future.await
}
fn into_connection_and_future<R>(
self,
transport: impl ConnectTo<Host> + 'static,
main_fn: impl AsyncFnOnce(ConnectionTo<Host::Counterpart>) -> Result<R, crate::Error>,
) -> (
ConnectionTo<Host::Counterpart>,
impl Future<Output = Result<R, crate::Error>>,
) {
let Self {
name,
handler,
responder,
host: me,
protocol_mode,
} = self;
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
let (new_task_tx, new_task_rx) = mpsc::unbounded();
let (dynamic_handler_tx, dynamic_handler_rx) = mpsc::unbounded();
let connection = ConnectionTo::new(
me.counterpart(),
outgoing_tx,
new_task_tx,
dynamic_handler_tx,
);
let transport_component = crate::DynConnectTo::new(transport);
let (transport_channel, transport_future) = transport_component.into_channel_and_future();
let spawn_result = connection.spawn(transport_future);
let Channel {
rx: transport_incoming_rx,
tx: transport_outgoing_tx,
} = transport_channel;
let (reply_tx, reply_rx) = mpsc::unbounded();
let protocol_compat = ProtocolCompat::new(protocol_mode);
let future = crate::util::instrument_with_connection_name(name, {
let connection = connection.clone();
async move {
let () = spawn_result?;
let background = async {
futures::try_join!(
outgoing_actor::outgoing_protocol_actor(
outgoing_rx,
reply_tx.clone(),
transport_outgoing_tx,
protocol_compat.clone(),
),
incoming_actor::incoming_protocol_actor(
me.counterpart(),
&connection,
transport_incoming_rx,
dynamic_handler_rx,
reply_rx,
handler,
protocol_compat,
),
task_actor::task_actor(new_task_rx, &connection),
responder.run_with_connection_to(connection.clone()),
)?;
Ok(())
};
crate::util::run_until(Box::pin(background), Box::pin(main_fn(connection.clone())))
.await
}
});
(connection, future)
}
}
impl<R, H, Run> ConnectTo<R::Counterpart> for Builder<R, H, Run>
where
R: Role,
H: HandleDispatchFrom<R::Counterpart> + 'static,
Run: RunWithConnectionTo<R::Counterpart> + 'static,
{
async fn connect_to(self, client: impl ConnectTo<R>) -> Result<(), crate::Error> {
Builder::connect_to(self, client).await
}
}
pub(crate) struct ResponsePayload {
pub(crate) result: Result<serde_json::Value, crate::Error>,
pub(crate) ack_tx: Option<oneshot::Sender<()>>,
}
impl std::fmt::Debug for ResponsePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponsePayload")
.field("result", &self.result)
.field("ack_tx", &self.ack_tx.as_ref().map(|_| "..."))
.finish()
}
}
enum ReplyMessage {
Subscribe {
id: RequestId,
role_id: RoleId,
method: String,
sender: oneshot::Sender<ResponsePayload>,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm: SentRequestCancellationDisarm,
},
}
impl std::fmt::Debug for ReplyMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReplyMessage::Subscribe { id, method, .. } => f
.debug_struct("Subscribe")
.field("id", id)
.field("method", method)
.finish(),
}
}
}
#[cfg(feature = "unstable_cancel_request")]
#[derive(Clone)]
pub struct RequestCancellation {
state: Arc<RequestCancellationState>,
}
#[cfg(feature = "unstable_cancel_request")]
struct RequestCancellationState {
cancelled: AtomicBool,
signal_tx: Mutex<Option<oneshot::Sender<()>>>,
signal_rx: future::Shared<BoxFuture<'static, ()>>,
}
#[cfg(feature = "unstable_cancel_request")]
impl RequestCancellation {
fn new() -> Self {
let (signal_tx, signal_rx) = oneshot::channel();
let signal_rx = signal_rx.map(|_| ()).boxed().shared();
Self {
state: Arc::new(RequestCancellationState {
cancelled: AtomicBool::new(false),
signal_tx: Mutex::new(Some(signal_tx)),
signal_rx,
}),
}
}
pub async fn cancelled(&self) {
self.state.signal_rx.clone().await;
}
pub async fn run_until_cancelled<T>(
&self,
future: impl std::future::Future<Output = Result<T, crate::Error>>,
) -> Result<T, crate::Error> {
if self.is_cancelled() {
return Err(crate::Error::request_cancelled());
}
match future::select(pin!(future), pin!(self.cancelled())).await {
Either::Left((result, _)) => result,
Either::Right(((), _)) => Err(crate::Error::request_cancelled()),
}
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.state.cancelled.load(Ordering::Acquire)
}
fn cancel(&self) {
if self.state.cancelled.swap(true, Ordering::AcqRel) {
return;
}
let signal_tx = self
.state
.signal_tx
.lock()
.expect("request cancellation signal mutex poisoned")
.take();
if let Some(signal_tx) = signal_tx {
let _ = signal_tx.send(());
}
}
}
#[cfg(feature = "unstable_cancel_request")]
impl Debug for RequestCancellation {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("RequestCancellation")
.field("is_cancelled", &self.is_cancelled())
.finish_non_exhaustive()
}
}
#[cfg(feature = "unstable_cancel_request")]
#[derive(Debug)]
enum RequestCancellationEntry {
Armed,
Cancelled,
Marker(RequestCancellation),
}
#[cfg(feature = "unstable_cancel_request")]
#[derive(Debug)]
struct RequestCancellationSlot {
generation: u64,
entry: RequestCancellationEntry,
}
#[cfg(feature = "unstable_cancel_request")]
#[derive(Debug, Default)]
struct RequestCancellationRegistryInner {
slots: HashMap<RequestId, RequestCancellationSlot>,
next_generation: u64,
}
#[cfg(feature = "unstable_cancel_request")]
#[derive(Clone, Debug, Default)]
struct RequestCancellationRegistry {
inner: Arc<Mutex<RequestCancellationRegistryInner>>,
}
#[cfg(not(feature = "unstable_cancel_request"))]
#[derive(Clone, Debug, Default)]
struct RequestCancellationRegistry;
#[cfg(feature = "unstable_cancel_request")]
#[derive(Debug)]
struct ResponderCancellation {
id: RequestId,
generation: u64,
registry: RequestCancellationRegistry,
}
#[cfg(not(feature = "unstable_cancel_request"))]
#[derive(Debug)]
struct ResponderCancellation;
#[cfg(feature = "unstable_cancel_request")]
impl RequestCancellationRegistry {
fn new() -> Self {
Self::default()
}
fn register(&self, id: &RequestId) -> ResponderCancellation {
let generation = {
let mut inner = self
.inner
.lock()
.expect("request cancellation registry mutex poisoned");
let generation = inner.next_generation;
inner.next_generation += 1;
if inner
.slots
.insert(
id.clone(),
RequestCancellationSlot {
generation,
entry: RequestCancellationEntry::Armed,
},
)
.is_some()
{
tracing::debug!(
?id,
"peer reused the ID of a request that is still in flight"
);
}
generation
};
ResponderCancellation {
id: id.clone(),
generation,
registry: self.clone(),
}
}
fn marker(&self, id: &RequestId, generation: u64) -> RequestCancellation {
let mut inner = self
.inner
.lock()
.expect("request cancellation registry mutex poisoned");
let Some(slot) = inner.slots.get_mut(id) else {
return RequestCancellation::new();
};
if slot.generation != generation {
return RequestCancellation::new();
}
let entry = &mut slot.entry;
match entry {
RequestCancellationEntry::Marker(marker) => marker.clone(),
RequestCancellationEntry::Armed => {
let marker = RequestCancellation::new();
*entry = RequestCancellationEntry::Marker(marker.clone());
marker
}
RequestCancellationEntry::Cancelled => {
let marker = RequestCancellation::new();
marker.cancel();
*entry = RequestCancellationEntry::Marker(marker.clone());
marker
}
}
}
fn cancel_if_requested(&self, dispatch: &Dispatch) -> Result<bool, crate::Error> {
let Some(request_id) = cancellation_request_id(dispatch)? else {
return Ok(false);
};
Ok(self.cancel(&request_id))
}
fn cancel(&self, request_id: &RequestId) -> bool {
let marker = {
let mut inner = self
.inner
.lock()
.expect("request cancellation registry mutex poisoned");
let Some(slot) = inner.slots.get_mut(request_id) else {
return false;
};
let entry = &mut slot.entry;
match entry {
RequestCancellationEntry::Marker(marker) => marker.clone(),
RequestCancellationEntry::Cancelled => return true,
RequestCancellationEntry::Armed => {
*entry = RequestCancellationEntry::Cancelled;
return true;
}
}
};
marker.cancel();
true
}
fn remove(&self, request_id: &RequestId, generation: u64) {
let mut inner = self
.inner
.lock()
.expect("request cancellation registry mutex poisoned");
if inner
.slots
.get(request_id)
.is_some_and(|slot| slot.generation == generation)
{
inner.slots.remove(request_id);
}
}
}
#[cfg(not(feature = "unstable_cancel_request"))]
impl RequestCancellationRegistry {
fn new() -> Self {
Self
}
#[expect(
clippy::unused_self,
reason = "feature-disabled stub mirrors the real registry API"
)]
fn register(&self, _id: &RequestId) -> ResponderCancellation {
ResponderCancellation
}
#[expect(
clippy::unused_self,
clippy::unnecessary_wraps,
reason = "feature-disabled stub mirrors the real registry API"
)]
fn cancel_if_requested(&self, _dispatch: &Dispatch) -> Result<bool, crate::Error> {
Ok(false)
}
}
#[cfg(feature = "unstable_cancel_request")]
impl ResponderCancellation {
fn cancellation(&self) -> RequestCancellation {
self.registry.marker(&self.id, self.generation)
}
}
#[cfg(feature = "unstable_cancel_request")]
impl Drop for ResponderCancellation {
fn drop(&mut self) {
self.registry.remove(&self.id, self.generation);
}
}
#[cfg(feature = "unstable_cancel_request")]
fn cancellation_request_id(dispatch: &Dispatch) -> Result<Option<RequestId>, crate::Error> {
let Dispatch::Notification(message) = dispatch else {
return Ok(None);
};
cancellation_request_id_from_message(message)
}
#[cfg(feature = "unstable_cancel_request")]
fn cancellation_request_id_from_message(
message: &UntypedMessage,
) -> Result<Option<RequestId>, crate::Error> {
let (method, params) = peel_successor_envelopes(&message.method, &message.params);
if !crate::schema::v1::CancelRequestNotification::matches_method(method) {
return Ok(None);
}
let notification = crate::schema::v1::CancelRequestNotification::parse_message(method, params)?;
Ok(Some(notification.request_id))
}
fn peel_successor_envelopes<'message>(
mut method: &'message str,
mut params: &'message serde_json::Value,
) -> (&'message str, &'message serde_json::Value) {
while crate::schema::SuccessorMessage::<UntypedMessage>::matches_method(method) {
let Some(inner_method) = params.get("method").and_then(serde_json::Value::as_str) else {
break;
};
method = inner_method;
params = params.get("params").unwrap_or(&serde_json::Value::Null);
}
(method, params)
}
#[cfg(feature = "unstable_cancel_request")]
#[must_use]
pub fn is_cancel_request_notification<N: JsonRpcNotification>(notification: &N) -> bool {
let method = notification.method();
if crate::schema::v1::CancelRequestNotification::matches_method(method) {
return true;
}
if !crate::schema::SuccessorMessage::<UntypedMessage>::matches_method(method) {
return false;
}
match notification.to_untyped_message() {
Ok(untyped) => {
let (method, _params) = peel_successor_envelopes(&untyped.method, &untyped.params);
crate::schema::v1::CancelRequestNotification::matches_method(method)
}
Err(error) => {
tracing::debug!(
?error,
"failed to inspect successor-wrapped notification for cancellation"
);
false
}
}
}
fn is_protocol_level_notification(dispatch: &Dispatch) -> bool {
let Dispatch::Notification(message) = dispatch else {
return false;
};
let (method, _params) = peel_successor_envelopes(&message.method, &message.params);
method.starts_with("$/")
}
#[derive(Debug)]
enum OutgoingMessage {
Request {
id: RequestId,
method: String,
role_id: RoleId,
untyped: UntypedMessage,
response_tx: oneshot::Sender<ResponsePayload>,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm: SentRequestCancellationDisarm,
},
Notification {
untyped: UntypedMessage,
},
Response {
id: RequestId,
method: String,
response: Result<serde_json::Value, crate::Error>,
},
Error { error: crate::Error },
}
#[must_use]
#[derive(Debug)]
pub enum Handled<T> {
Yes,
No {
message: T,
retry: bool,
},
}
pub trait IntoHandled<T> {
fn into_handled(self) -> Handled<T>;
}
impl<T> IntoHandled<T> for () {
fn into_handled(self) -> Handled<T> {
Handled::Yes
}
}
impl<T> IntoHandled<T> for Handled<T> {
fn into_handled(self) -> Handled<T> {
self
}
}
#[derive(Clone, Debug)]
pub struct ConnectionTo<Counterpart: Role> {
counterpart: Counterpart,
message_tx: OutgoingMessageTx,
task_tx: TaskTx,
dynamic_handler_tx: mpsc::UnboundedSender<DynamicHandlerMessage<Counterpart>>,
}
impl<Counterpart: Role> ConnectionTo<Counterpart> {
fn new(
counterpart: Counterpart,
message_tx: mpsc::UnboundedSender<OutgoingMessage>,
task_tx: mpsc::UnboundedSender<Task>,
dynamic_handler_tx: mpsc::UnboundedSender<DynamicHandlerMessage<Counterpart>>,
) -> Self {
Self {
counterpart,
message_tx,
task_tx,
dynamic_handler_tx,
}
}
pub fn counterpart(&self) -> Counterpart {
self.counterpart.clone()
}
#[track_caller]
pub fn spawn(
&self,
task: impl IntoFuture<Output = Result<(), crate::Error>, IntoFuture: Send + 'static>,
) -> Result<(), crate::Error> {
let location = std::panic::Location::caller();
let task = task.into_future();
Task::new(location, task).spawn(&self.task_tx)
}
#[track_caller]
pub fn spawn_connection<R: Role>(
&self,
builder: Builder<
R,
impl HandleDispatchFrom<R::Counterpart> + 'static,
impl RunWithConnectionTo<R::Counterpart> + 'static,
>,
transport: impl ConnectTo<R> + 'static,
) -> Result<ConnectionTo<R::Counterpart>, crate::Error> {
let (connection, future) =
builder.into_connection_and_future(transport, |_| std::future::pending());
Task::new(std::panic::Location::caller(), future).spawn(&self.task_tx)?;
Ok(connection)
}
pub fn send_proxied_message<Req: JsonRpcRequest<Response: Send>, Notif: JsonRpcNotification>(
&self,
message: Dispatch<Req, Notif>,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Counterpart>,
{
self.send_proxied_message_to(self.counterpart(), message)
}
pub fn send_proxied_message_to<
Peer: Role,
Req: JsonRpcRequest<Response: Send>,
Notif: JsonRpcNotification,
>(
&self,
peer: Peer,
message: Dispatch<Req, Notif>,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Peer>,
{
match message {
Dispatch::Request(request, responder) => self
.send_request_to(peer, request)
.forward_response_to(responder),
Dispatch::Notification(notification) => {
#[cfg(feature = "unstable_cancel_request")]
if is_cancel_request_notification(¬ification) {
tracing::debug!(
"not forwarding hop-scoped `$/cancel_request` notification across proxy hop"
);
return Ok(());
}
self.send_notification_to(peer, notification)
}
Dispatch::Response(result, router) => {
router.respond_with_result(result)
}
}
}
pub fn send_request<Req: JsonRpcRequest>(&self, request: Req) -> SentRequest<Req::Response>
where
Counterpart: HasPeer<Counterpart>,
{
self.send_request_to(self.counterpart.clone(), request)
}
pub fn send_request_to<Peer: Role, Req: JsonRpcRequest>(
&self,
peer: Peer,
request: Req,
) -> SentRequest<Req::Response>
where
Counterpart: HasPeer<Peer>,
{
let method = request.method().to_string();
let id = RequestId::Str(uuid::Uuid::new_v4().to_string());
let (response_tx, response_rx) = oneshot::channel();
let role_id = peer.role_id();
let remote_style = self.counterpart.remote_style(peer);
#[cfg(feature = "unstable_cancel_request")]
let cancellation =
SentRequestCancellation::new(self.message_tx.clone(), remote_style, id.clone());
match remote_style.transform_outgoing_message(request) {
Ok(untyped) => {
let message = OutgoingMessage::Request {
id: id.clone(),
method: method.clone(),
role_id,
untyped,
response_tx,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm: cancellation.disarm_handle(),
};
match self.message_tx.unbounded_send(message) {
Ok(()) => (),
Err(error) => {
#[cfg(feature = "unstable_cancel_request")]
cancellation.disarm();
let OutgoingMessage::Request {
method,
response_tx,
..
} = error.into_inner()
else {
unreachable!();
};
response_tx
.send(ResponsePayload {
result: Err(crate::util::internal_error(format!(
"failed to send outgoing request `{method}"
))),
ack_tx: None,
})
.unwrap();
}
}
}
Err(err) => {
#[cfg(feature = "unstable_cancel_request")]
cancellation.disarm();
response_tx
.send(ResponsePayload {
result: Err(crate::util::internal_error(format!(
"failed to create untyped request for `{method}`: {err}"
))),
ack_tx: None,
})
.unwrap();
}
}
SentRequest::new(
id,
method.clone(),
self.task_tx.clone(),
response_rx,
#[cfg(feature = "unstable_cancel_request")]
cancellation,
)
.map(move |json| <Req::Response>::from_value(&method, json))
}
pub fn send_notification<N: JsonRpcNotification>(
&self,
notification: N,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Counterpart>,
{
self.send_notification_to(self.counterpart.clone(), notification)
}
pub fn send_notification_to<Peer: Role, N: JsonRpcNotification>(
&self,
peer: Peer,
notification: N,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Peer>,
{
let remote_style = self.counterpart.remote_style(peer);
tracing::debug!(
role = std::any::type_name::<Counterpart>(),
peer = std::any::type_name::<Peer>(),
notification_type = std::any::type_name::<N>(),
?remote_style,
original_method = notification.method(),
"send_notification_to"
);
let transformed = remote_style.transform_outgoing_message(notification)?;
tracing::debug!(
transformed_method = %transformed.method,
"send_notification_to transformed"
);
send_raw_message(
&self.message_tx,
OutgoingMessage::Notification {
untyped: transformed,
},
)
}
#[cfg(feature = "unstable_cancel_request")]
pub fn send_cancel_request(
&self,
request_id: impl Into<crate::schema::v1::RequestId>,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Counterpart>,
{
self.send_cancel_request_to(self.counterpart.clone(), request_id)
}
#[cfg(feature = "unstable_cancel_request")]
pub fn send_cancel_request_to<Peer: Role>(
&self,
peer: Peer,
request_id: impl Into<crate::schema::v1::RequestId>,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Peer>,
{
self.send_notification_to(
peer,
crate::schema::v1::CancelRequestNotification::new(request_id),
)
}
pub fn send_error_notification(&self, error: crate::Error) -> Result<(), crate::Error> {
send_raw_message(&self.message_tx, OutgoingMessage::Error { error })
}
pub fn add_dynamic_handler(
&self,
handler: impl HandleDispatchFrom<Counterpart> + 'static,
) -> Result<DynamicHandlerRegistration<Counterpart>, crate::Error> {
let uuid = Uuid::new_v4();
self.dynamic_handler_tx
.unbounded_send(DynamicHandlerMessage::AddDynamicHandler(
uuid,
Box::new(handler),
))
.map_err(crate::util::internal_error)?;
Ok(DynamicHandlerRegistration::new(uuid, self.clone()))
}
fn remove_dynamic_handler(&self, uuid: Uuid) {
drop(
self.dynamic_handler_tx
.unbounded_send(DynamicHandlerMessage::RemoveDynamicHandler(uuid)),
);
}
}
#[derive(Clone, Debug)]
pub struct DynamicHandlerRegistration<R: Role> {
uuid: Uuid,
cx: ConnectionTo<R>,
}
impl<R: Role> DynamicHandlerRegistration<R> {
fn new(uuid: Uuid, cx: ConnectionTo<R>) -> Self {
Self { uuid, cx }
}
pub fn run_indefinitely(self) {
std::mem::forget(self);
}
}
impl<R: Role> Drop for DynamicHandlerRegistration<R> {
fn drop(&mut self) {
self.cx.remove_dynamic_handler(self.uuid);
}
}
#[must_use]
pub struct Responder<T: JsonRpcResponse = serde_json::Value> {
method: String,
id: RequestId,
cancellation: ResponderCancellation,
send_fn: Box<dyn FnOnce(Result<T, crate::Error>) -> Result<(), crate::Error> + Send>,
}
impl<T: JsonRpcResponse> std::fmt::Debug for Responder<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Responder")
.field("method", &self.method)
.field("id", &self.id)
.field("response_type", &std::any::type_name::<T>())
.finish_non_exhaustive()
}
}
impl Responder<serde_json::Value> {
fn new(
message_tx: OutgoingMessageTx,
method: String,
id: RequestId,
cancellation_registry: &RequestCancellationRegistry,
) -> Self {
let id_clone = id.clone();
let method_clone = method.clone();
let cancellation = cancellation_registry.register(&id);
Self {
method,
id,
cancellation,
send_fn: Box::new(move |response: Result<serde_json::Value, crate::Error>| {
send_raw_message(
&message_tx,
OutgoingMessage::Response {
id: id_clone,
method: method_clone,
response,
},
)
}),
}
}
pub fn cast<T: JsonRpcResponse>(self) -> Responder<T> {
self.wrap_params(move |method, value| match value {
Ok(value) => T::into_json(value, method),
Err(e) => Err(e),
})
}
}
impl<T: JsonRpcResponse> Responder<T> {
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
#[must_use]
pub fn id(&self) -> serde_json::Value {
crate::util::id_to_json(&self.id)
}
#[cfg(feature = "unstable_cancel_request")]
#[must_use]
pub fn cancellation(&self) -> RequestCancellation {
self.cancellation.cancellation()
}
pub fn erase_to_json(self) -> Responder<serde_json::Value> {
self.wrap_params(|method, value| T::from_value(method, value?))
}
pub fn wrap_method(self, method: String) -> Responder<T> {
Responder {
method,
id: self.id,
cancellation: self.cancellation,
send_fn: self.send_fn,
}
}
pub fn wrap_params<U: JsonRpcResponse>(
self,
wrap_fn: impl FnOnce(&str, Result<U, crate::Error>) -> Result<T, crate::Error> + Send + 'static,
) -> Responder<U> {
let method = self.method.clone();
Responder {
method: self.method,
id: self.id,
cancellation: self.cancellation,
send_fn: Box::new(move |input: Result<U, crate::Error>| {
let t_value = wrap_fn(&method, input);
(self.send_fn)(t_value)
}),
}
}
pub fn respond_with_result(
self,
response: Result<T, crate::Error>,
) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, "respond called");
(self.send_fn)(response)
}
pub fn respond(self, response: T) -> Result<(), crate::Error> {
self.respond_with_result(Ok(response))
}
pub fn respond_with_internal_error(self, message: impl ToString) -> Result<(), crate::Error> {
self.respond_with_error(crate::util::internal_error(message))
}
pub fn respond_with_error(self, error: crate::Error) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, ?error, "respond_with_error called");
self.respond_with_result(Err(error))
}
}
#[must_use]
pub struct ResponseRouter<T: JsonRpcResponse = serde_json::Value> {
method: String,
id: RequestId,
role_id: RoleId,
send_fn: Box<dyn FnOnce(Result<T, crate::Error>) -> Result<(), crate::Error> + Send>,
}
impl<T: JsonRpcResponse> std::fmt::Debug for ResponseRouter<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseRouter")
.field("method", &self.method)
.field("id", &self.id)
.field("response_type", &std::any::type_name::<T>())
.finish_non_exhaustive()
}
}
impl ResponseRouter<serde_json::Value> {
pub(crate) fn new(
method: String,
id: RequestId,
role_id: RoleId,
sender: oneshot::Sender<ResponsePayload>,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm: SentRequestCancellationDisarm,
) -> Self {
let response_method = method.clone();
let response_id = id.clone();
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm.disarm();
Self {
method,
id,
role_id,
send_fn: Box::new(move |response: Result<serde_json::Value, crate::Error>| {
if sender
.send(ResponsePayload {
result: response,
ack_tx: None,
})
.is_err()
{
tracing::debug!(
method = %response_method,
id = ?response_id,
"dropped response because local receiver was gone"
);
}
Ok(())
}),
}
}
pub fn cast<T: JsonRpcResponse>(self) -> ResponseRouter<T> {
self.wrap_params(move |method, value| match value {
Ok(value) => T::into_json(value, method),
Err(e) => Err(e),
})
}
}
impl<T: JsonRpcResponse> ResponseRouter<T> {
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
#[must_use]
pub fn id(&self) -> serde_json::Value {
crate::util::id_to_json(&self.id)
}
#[must_use]
pub fn role_id(&self) -> RoleId {
self.role_id.clone()
}
pub fn erase_to_json(self) -> ResponseRouter<serde_json::Value> {
self.wrap_params(|method, value| T::from_value(method, value?))
}
fn wrap_params<U: JsonRpcResponse>(
self,
wrap_fn: impl FnOnce(&str, Result<U, crate::Error>) -> Result<T, crate::Error> + Send + 'static,
) -> ResponseRouter<U> {
let method = self.method.clone();
ResponseRouter {
method: self.method,
id: self.id,
role_id: self.role_id,
send_fn: Box::new(move |input: Result<U, crate::Error>| {
let t_value = wrap_fn(&method, input);
(self.send_fn)(t_value)
}),
}
}
pub fn respond_with_result(
self,
response: Result<T, crate::Error>,
) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, "response routed to awaiter");
(self.send_fn)(response)
}
pub fn respond(self, response: T) -> Result<(), crate::Error> {
self.respond_with_result(Ok(response))
}
pub fn respond_with_internal_error(self, message: impl ToString) -> Result<(), crate::Error> {
self.respond_with_error(crate::util::internal_error(message))
}
pub fn respond_with_error(self, error: crate::Error) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, ?error, "error routed to awaiter");
self.respond_with_result(Err(error))
}
}
pub trait JsonRpcMessage: 'static + Debug + Sized + Send + Clone {
fn matches_method(method: &str) -> bool;
fn method(&self) -> &str;
fn to_untyped_message(&self) -> Result<UntypedMessage, crate::Error>;
fn parse_message(method: &str, params: &impl Serialize) -> Result<Self, crate::Error>;
}
pub trait JsonRpcResponse: 'static + Debug + Sized + Send + Clone {
fn into_json(self, method: &str) -> Result<serde_json::Value, crate::Error>;
fn from_value(method: &str, value: serde_json::Value) -> Result<Self, crate::Error>;
}
impl JsonRpcResponse for serde_json::Value {
fn from_value(_method: &str, value: serde_json::Value) -> Result<Self, crate::Error> {
Ok(value)
}
fn into_json(self, _method: &str) -> Result<serde_json::Value, crate::Error> {
Ok(self)
}
}
pub trait JsonRpcNotification: JsonRpcMessage {}
pub trait JsonRpcRequest: JsonRpcMessage {
type Response: JsonRpcResponse;
}
#[derive(Debug)]
pub enum Dispatch<Req: JsonRpcRequest = UntypedMessage, Notif: JsonRpcMessage = UntypedMessage> {
Request(Req, Responder<Req::Response>),
Notification(Notif),
Response(
Result<Req::Response, crate::Error>,
ResponseRouter<Req::Response>,
),
}
impl<Req: JsonRpcRequest, Notif: JsonRpcMessage> Dispatch<Req, Notif> {
pub fn map<Req1, Notif1>(
self,
map_request: impl FnOnce(Req, Responder<Req::Response>) -> (Req1, Responder<Req1::Response>),
map_notification: impl FnOnce(Notif) -> Notif1,
) -> Dispatch<Req1, Notif1>
where
Req1: JsonRpcRequest<Response = Req::Response>,
Notif1: JsonRpcMessage,
{
match self {
Dispatch::Request(request, responder) => {
let (new_request, new_responder) = map_request(request, responder);
Dispatch::Request(new_request, new_responder)
}
Dispatch::Notification(notification) => {
let new_notification = map_notification(notification);
Dispatch::Notification(new_notification)
}
Dispatch::Response(result, router) => Dispatch::Response(result, router),
}
}
pub fn respond_with_error<R: Role>(
self,
error: crate::Error,
cx: ConnectionTo<R>,
) -> Result<(), crate::Error> {
match self {
Dispatch::Request(_, responder) => responder.respond_with_error(error),
Dispatch::Notification(_) => cx.send_error_notification(error),
Dispatch::Response(_, responder) => responder.respond_with_error(error),
}
}
pub fn erase_to_json(self) -> Result<Dispatch, crate::Error> {
match self {
Dispatch::Request(response, responder) => Ok(Dispatch::Request(
response.to_untyped_message()?,
responder.erase_to_json(),
)),
Dispatch::Notification(notification) => {
Ok(Dispatch::Notification(notification.to_untyped_message()?))
}
Dispatch::Response(_, _) => Err(crate::util::internal_error(
"cannot erase Response variant to JSON",
)),
}
}
pub fn to_untyped_message(&self) -> Result<UntypedMessage, crate::Error> {
match self {
Dispatch::Request(request, _) => request.to_untyped_message(),
Dispatch::Notification(notification) => notification.to_untyped_message(),
Dispatch::Response(_, _) => Err(crate::util::internal_error(
"Response variant has no untyped message representation",
)),
}
}
pub fn into_untyped_dispatch(self) -> Result<Dispatch, crate::Error> {
match self {
Dispatch::Request(request, responder) => Ok(Dispatch::Request(
request.to_untyped_message()?,
responder.erase_to_json(),
)),
Dispatch::Notification(notification) => {
Ok(Dispatch::Notification(notification.to_untyped_message()?))
}
Dispatch::Response(_, _) => Err(crate::util::internal_error(
"cannot convert Response variant to untyped message context",
)),
}
}
pub fn id(&self) -> Option<serde_json::Value> {
match self {
Dispatch::Request(_, cx) => Some(cx.id()),
Dispatch::Notification(_) => None,
Dispatch::Response(_, cx) => Some(cx.id()),
}
}
pub fn method(&self) -> &str {
match self {
Dispatch::Request(msg, _) => msg.method(),
Dispatch::Notification(msg) => msg.method(),
Dispatch::Response(_, cx) => cx.method(),
}
}
}
impl Dispatch {
#[tracing::instrument(skip(self), fields(Request = ?std::any::type_name::<Req>(), Notif = ?std::any::type_name::<Notif>()), level = "trace", ret)]
pub(crate) fn into_typed_dispatch<Req: JsonRpcRequest, Notif: JsonRpcNotification>(
self,
) -> Result<Result<Dispatch<Req, Notif>, Dispatch>, crate::Error> {
tracing::debug!(
message = ?self,
"into_typed_dispatch"
);
match self {
Dispatch::Request(message, responder) => {
if Req::matches_method(&message.method) {
match Req::parse_message(&message.method, &message.params) {
Ok(req) => {
tracing::trace!(?req, "parsed ok");
Ok(Ok(Dispatch::Request(req, responder.cast())))
}
Err(err) => {
tracing::trace!(?err, "parse error");
Err(err)
}
}
} else {
tracing::trace!("method doesn't match");
Ok(Err(Dispatch::Request(message, responder)))
}
}
Dispatch::Notification(message) => {
if Notif::matches_method(&message.method) {
match Notif::parse_message(&message.method, &message.params) {
Ok(notif) => {
tracing::trace!(?notif, "parse ok");
Ok(Ok(Dispatch::Notification(notif)))
}
Err(err) => {
tracing::trace!(?err, "parse error");
Err(err)
}
}
} else {
tracing::trace!("method doesn't match");
Ok(Err(Dispatch::Notification(message)))
}
}
Dispatch::Response(result, cx) => {
let method = cx.method();
if Req::matches_method(method) {
let typed_result = match result {
Ok(value) => {
match <Req::Response as JsonRpcResponse>::from_value(method, value) {
Ok(parsed) => {
tracing::trace!(?parsed, "parse ok");
Ok(parsed)
}
Err(err) => {
tracing::trace!(?err, "parse error");
return Err(err);
}
}
}
Err(err) => {
tracing::trace!("error, passthrough");
Err(err)
}
};
Ok(Ok(Dispatch::Response(typed_result, cx.cast())))
} else {
tracing::trace!("method doesn't match");
Ok(Err(Dispatch::Response(result, cx)))
}
}
}
}
#[must_use]
pub fn has_field(&self, field_name: &str) -> bool {
self.message()
.and_then(|m| m.params().get(field_name))
.is_some()
}
pub(crate) fn has_session_id(&self) -> bool {
self.has_field("sessionId")
}
pub(crate) fn get_session_id(&self) -> Result<Option<SessionId>, crate::Error> {
let Some(message) = self.message() else {
return Ok(None);
};
let Some(value) = message.params().get("sessionId") else {
return Ok(None);
};
let session_id = serde_json::from_value(value.clone())?;
Ok(Some(session_id))
}
pub fn into_notification<N: JsonRpcNotification>(
self,
) -> Result<Result<N, Dispatch>, crate::Error> {
match self {
Dispatch::Notification(msg) => {
if !N::matches_method(&msg.method) {
return Ok(Err(Dispatch::Notification(msg)));
}
match N::parse_message(&msg.method, &msg.params) {
Ok(n) => Ok(Ok(n)),
Err(err) => Err(err),
}
}
Dispatch::Request(..) | Dispatch::Response(..) => Ok(Err(self)),
}
}
pub fn into_request<Req: JsonRpcRequest>(
self,
) -> Result<Result<(Req, Responder<Req::Response>), Dispatch>, crate::Error> {
match self {
Dispatch::Request(msg, responder) => {
if !Req::matches_method(&msg.method) {
return Ok(Err(Dispatch::Request(msg, responder)));
}
match Req::parse_message(&msg.method, &msg.params) {
Ok(req) => Ok(Ok((req, responder.cast()))),
Err(err) => Err(err),
}
}
Dispatch::Notification(..) | Dispatch::Response(..) => Ok(Err(self)),
}
}
}
impl<M: JsonRpcRequest + JsonRpcNotification> Dispatch<M, M> {
pub fn message(&self) -> Option<&M> {
match self {
Dispatch::Request(msg, _) | Dispatch::Notification(msg) => Some(msg),
Dispatch::Response(_, _) => None,
}
}
pub(crate) fn try_map_message(
self,
map_message: impl FnOnce(M) -> Result<M, crate::Error>,
) -> Result<Dispatch<M, M>, crate::Error> {
match self {
Dispatch::Request(request, cx) => Ok(Dispatch::Request(map_message(request)?, cx)),
Dispatch::Notification(notification) => {
Ok(Dispatch::<M, M>::Notification(map_message(notification)?))
}
Dispatch::Response(result, cx) => Ok(Dispatch::Response(result, cx)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct UntypedMessage {
pub method: String,
pub params: serde_json::Value,
}
impl UntypedMessage {
pub fn new(method: &str, params: impl Serialize) -> Result<Self, crate::Error> {
let params = serde_json::to_value(params)?;
Ok(Self {
method: method.to_string(),
params,
})
}
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
#[must_use]
pub fn params(&self) -> &serde_json::Value {
&self.params
}
#[must_use]
pub fn into_parts(self) -> (String, serde_json::Value) {
(self.method, self.params)
}
pub(crate) fn into_raw_jsonrpc_message(
self,
id: Option<RequestId>,
) -> Result<RawJsonRpcMessage, crate::Error> {
let Self { method, params } = self;
match id {
Some(id) => RawJsonRpcMessage::request(method, params, id),
None => RawJsonRpcMessage::notification(method, params),
}
}
}
impl JsonRpcMessage for UntypedMessage {
fn matches_method(_method: &str) -> bool {
true
}
fn method(&self) -> &str {
&self.method
}
fn to_untyped_message(&self) -> Result<UntypedMessage, crate::Error> {
Ok(self.clone())
}
fn parse_message(method: &str, params: &impl Serialize) -> Result<Self, crate::Error> {
UntypedMessage::new(method, params)
}
}
impl JsonRpcRequest for UntypedMessage {
type Response = serde_json::Value;
}
impl JsonRpcNotification for UntypedMessage {}
#[must_use = "dropping a SentRequest discards the response (and, with the \
`unstable_cancel_request` feature, asks the peer to cancel the \
request); consume it with `block_task`, `on_receiving_result`, \
`forward_response_to`, or `detach`"]
pub struct SentRequest<T> {
id: RequestId,
method: String,
task_tx: TaskTx,
response_rx: oneshot::Receiver<ResponsePayload>,
to_result: Box<dyn Fn(serde_json::Value) -> Result<T, crate::Error> + Send>,
#[cfg(feature = "unstable_cancel_request")]
cancellation: SentRequestCancellation,
#[cfg(feature = "unstable_cancel_request")]
cancellation_sources: Vec<RequestCancellation>,
}
#[cfg(feature = "unstable_cancel_request")]
#[derive(Clone, Debug)]
pub(crate) struct SentRequestCancellationDisarm {
armed: Arc<AtomicBool>,
}
#[cfg(feature = "unstable_cancel_request")]
impl SentRequestCancellationDisarm {
fn new() -> Self {
Self {
armed: Arc::new(AtomicBool::new(true)),
}
}
fn disarm(&self) {
self.armed.store(false, Ordering::Release);
}
}
#[cfg(feature = "unstable_cancel_request")]
struct SentRequestCancellation {
message_tx: OutgoingMessageTx,
remote_style: crate::role::RemoteStyle,
request_id: RequestId,
disarm: SentRequestCancellationDisarm,
}
#[cfg(feature = "unstable_cancel_request")]
impl SentRequestCancellation {
fn new(
message_tx: OutgoingMessageTx,
remote_style: crate::role::RemoteStyle,
request_id: RequestId,
) -> Self {
Self {
message_tx,
remote_style,
request_id,
disarm: SentRequestCancellationDisarm::new(),
}
}
fn disarm(&self) {
self.disarm.disarm();
}
fn disarm_handle(&self) -> SentRequestCancellationDisarm {
self.disarm.clone()
}
fn send(&self) -> Result<(), crate::Error> {
if !self.disarm.armed.swap(false, Ordering::AcqRel) {
return Ok(());
}
let untyped = self.remote_style.transform_outgoing_message(
crate::schema::v1::CancelRequestNotification::new(self.request_id.clone()),
)?;
send_raw_message(&self.message_tx, OutgoingMessage::Notification { untyped })
}
}
#[cfg(feature = "unstable_cancel_request")]
impl Drop for SentRequestCancellation {
fn drop(&mut self) {
if let Err(error) = self.send() {
tracing::debug!(?error, "failed to auto-cancel dropped request");
}
}
}
#[cfg(feature = "unstable_cancel_request")]
impl Debug for SentRequestCancellation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SentRequestCancellation")
.field("request_id", &self.request_id)
.field("remote_style", &self.remote_style)
.field("armed", &self.disarm.armed.load(Ordering::Acquire))
.finish_non_exhaustive()
}
}
#[cfg(feature = "unstable_cancel_request")]
async fn await_response_forwarding_cancellation(
response_rx: oneshot::Receiver<ResponsePayload>,
cancellation: &SentRequestCancellation,
sources: &[RequestCancellation],
) -> Result<ResponsePayload, oneshot::Canceled> {
let forward_cancellation = || {
if let Err(error) = cancellation.send() {
tracing::debug!(
?error,
"failed to forward cancellation to downstream request"
);
}
};
let response = if sources.is_empty() {
response_rx.await
} else if sources.iter().any(RequestCancellation::is_cancelled) {
forward_cancellation();
response_rx.await
} else {
let cancelled = sources.iter().map(|source| source.state.signal_rx.clone());
match future::select(future::select_all(cancelled), response_rx).await {
Either::Left((_, response_rx)) => {
forward_cancellation();
response_rx.await
}
Either::Right((response, _)) => response,
}
};
cancellation.disarm();
response
}
impl<T: Debug> Debug for SentRequest<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut debug = f.debug_struct("SentRequest");
debug
.field("id", &self.id)
.field("method", &self.method)
.field("task_tx", &self.task_tx)
.field("response_rx", &self.response_rx);
#[cfg(feature = "unstable_cancel_request")]
debug
.field("cancellation", &self.cancellation)
.field("cancellation_sources", &self.cancellation_sources);
debug.finish_non_exhaustive()
}
}
impl SentRequest<serde_json::Value> {
fn new(
id: RequestId,
method: String,
task_tx: mpsc::UnboundedSender<Task>,
response_rx: oneshot::Receiver<ResponsePayload>,
#[cfg(feature = "unstable_cancel_request")] cancellation: SentRequestCancellation,
) -> Self {
Self {
id,
method,
response_rx,
task_tx,
to_result: Box::new(Ok),
#[cfg(feature = "unstable_cancel_request")]
cancellation,
#[cfg(feature = "unstable_cancel_request")]
cancellation_sources: Vec::new(),
}
}
}
impl<T> SentRequest<T> {
pub fn detach(self) {
#[cfg(feature = "unstable_cancel_request")]
self.cancellation.disarm();
}
#[cfg(feature = "unstable_cancel_request")]
pub fn cancel(&self) -> Result<(), crate::Error> {
self.cancellation.send()
}
#[cfg(feature = "unstable_cancel_request")]
pub fn forward_cancellation_from(mut self, source: RequestCancellation) -> Self {
self.cancellation_sources.push(source);
self
}
}
impl<T: JsonRpcResponse> SentRequest<T> {
#[must_use]
pub fn id(&self) -> serde_json::Value {
crate::util::id_to_json(&self.id)
}
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
pub fn map<U>(
self,
map_fn: impl Fn(T) -> Result<U, crate::Error> + 'static + Send,
) -> SentRequest<U> {
SentRequest {
id: self.id,
method: self.method,
response_rx: self.response_rx,
task_tx: self.task_tx,
to_result: Box::new(move |value| map_fn((self.to_result)(value)?)),
#[cfg(feature = "unstable_cancel_request")]
cancellation: self.cancellation,
#[cfg(feature = "unstable_cancel_request")]
cancellation_sources: self.cancellation_sources,
}
}
#[track_caller]
pub fn forward_response_to(self, responder: Responder<T>) -> Result<(), crate::Error>
where
T: Send,
{
#[cfg(feature = "unstable_cancel_request")]
let this = self.forward_cancellation_from(responder.cancellation());
#[cfg(not(feature = "unstable_cancel_request"))]
let this = self;
this.consume_with(async move |response| {
responder.respond_with_result(response.unwrap_or_else(Err))
})
}
#[track_caller]
fn consume_with<F>(
self,
handle: impl FnOnce(Result<Result<T, crate::Error>, crate::Error>) -> F + 'static + Send,
) -> Result<(), crate::Error>
where
F: Future<Output = Result<(), crate::Error>> + 'static + Send,
T: Send,
{
let task_tx = self.task_tx.clone();
let method = self.method;
let response_rx = self.response_rx;
let to_result = self.to_result;
#[cfg(feature = "unstable_cancel_request")]
let cancellation = self.cancellation;
#[cfg(feature = "unstable_cancel_request")]
let cancellation_sources = self.cancellation_sources;
let location = Location::caller();
Task::new(location, async move {
#[cfg(feature = "unstable_cancel_request")]
let response = await_response_forwarding_cancellation(
response_rx,
&cancellation,
&cancellation_sources,
)
.await;
#[cfg(not(feature = "unstable_cancel_request"))]
let response = response_rx.await;
match response {
Ok(ResponsePayload { result, ack_tx }) => {
let typed_result = match result {
Ok(json_value) => to_result(json_value),
Err(err) => Err(err),
};
let outcome = handle(Ok(typed_result)).await;
if let Some(tx) = ack_tx {
let _ = tx.send(());
}
outcome
}
Err(err) => {
handle(Err(crate::util::internal_error(format!(
"response to `{method}` never received: {err}"
))))
.await
}
}
})
.spawn(&task_tx)
}
pub async fn block_task(self) -> Result<T, crate::Error>
where
T: Send,
{
#[cfg(feature = "unstable_cancel_request")]
let response = await_response_forwarding_cancellation(
self.response_rx,
&self.cancellation,
&self.cancellation_sources,
)
.await;
#[cfg(not(feature = "unstable_cancel_request"))]
let response = self.response_rx.await;
match response {
Ok(ResponsePayload {
result: Ok(json_value),
ack_tx,
}) => {
if let Some(tx) = ack_tx {
let _ = tx.send(());
}
match (self.to_result)(json_value) {
Ok(value) => Ok(value),
Err(err) => Err(err),
}
}
Ok(ResponsePayload {
result: Err(err),
ack_tx,
}) => {
if let Some(tx) = ack_tx {
let _ = tx.send(());
}
Err(err)
}
Err(err) => Err(crate::util::internal_error(format!(
"response to `{}` never received: {}",
self.method, err
))),
}
}
#[track_caller]
pub fn on_receiving_ok_result<F>(
self,
responder: Responder<T>,
task: impl FnOnce(T, Responder<T>) -> F + 'static + Send,
) -> Result<(), crate::Error>
where
F: Future<Output = Result<(), crate::Error>> + 'static + Send,
T: Send,
{
self.on_receiving_result(async move |result| match result {
Ok(value) => task(value, responder).await,
Err(err) => responder.respond_with_error(err),
})
}
#[track_caller]
pub fn on_receiving_result<F>(
self,
task: impl FnOnce(Result<T, crate::Error>) -> F + 'static + Send,
) -> Result<(), crate::Error>
where
F: Future<Output = Result<(), crate::Error>> + 'static + Send,
T: Send,
{
self.consume_with(async move |response| {
match response {
Ok(result) => task(result).await,
Err(err) => Err(err),
}
})
}
}
#[derive(Debug)]
pub struct Lines<OutgoingSink, IncomingStream> {
pub outgoing: OutgoingSink,
pub incoming: IncomingStream,
}
impl<OutgoingSink, IncomingStream> Lines<OutgoingSink, IncomingStream>
where
OutgoingSink: futures::Sink<String, Error = std::io::Error> + Send + 'static,
IncomingStream: futures::Stream<Item = std::io::Result<String>> + Send + 'static,
{
pub fn new(outgoing: OutgoingSink, incoming: IncomingStream) -> Self {
Self { outgoing, incoming }
}
}
impl<OutgoingSink, IncomingStream, R: Role> ConnectTo<R> for Lines<OutgoingSink, IncomingStream>
where
OutgoingSink: futures::Sink<String, Error = std::io::Error> + Send + 'static,
IncomingStream: futures::Stream<Item = std::io::Result<String>> + Send + 'static,
{
async fn connect_to(self, client: impl ConnectTo<R::Counterpart>) -> Result<(), crate::Error> {
let (channel, serve_self) = ConnectTo::<R>::into_channel_and_future(self);
match futures::future::select(Box::pin(client.connect_to(channel)), serve_self).await {
Either::Left((result, _)) | Either::Right((result, _)) => result,
}
}
fn into_channel_and_future(self) -> (Channel, BoxFuture<'static, Result<(), crate::Error>>) {
let Self { outgoing, incoming } = self;
let (channel_for_caller, channel_for_lines) = Channel::duplex();
let server_future = Box::pin(async move {
let Channel { rx, tx } = channel_for_lines;
let outgoing_future = transport_actor::transport_outgoing_lines_actor(rx, outgoing);
let incoming_future = transport_actor::transport_incoming_lines_actor(incoming, tx);
futures::try_join!(outgoing_future, incoming_future)?;
Ok(())
});
(channel_for_caller, server_future)
}
}
#[derive(Debug)]
pub struct ByteStreams<OB, IB> {
pub outgoing: OB,
pub incoming: IB,
}
impl<OB, IB> ByteStreams<OB, IB>
where
OB: AsyncWrite + Send + 'static,
IB: AsyncRead + Send + 'static,
{
pub fn new(outgoing: OB, incoming: IB) -> Self {
Self { outgoing, incoming }
}
}
impl<OB, IB, R: Role> ConnectTo<R> for ByteStreams<OB, IB>
where
OB: AsyncWrite + Send + 'static,
IB: AsyncRead + Send + 'static,
{
async fn connect_to(self, client: impl ConnectTo<R::Counterpart>) -> Result<(), crate::Error> {
let (channel, serve_self) = ConnectTo::<R>::into_channel_and_future(self);
match futures::future::select(pin!(client.connect_to(channel)), serve_self).await {
Either::Left((result, _)) | Either::Right((result, _)) => result,
}
}
fn into_channel_and_future(self) -> (Channel, BoxFuture<'static, Result<(), crate::Error>>) {
use futures::AsyncBufReadExt;
use futures::AsyncWriteExt;
use futures::io::BufReader;
let Self { outgoing, incoming } = self;
let incoming_lines = Box::pin(BufReader::new(incoming).lines());
let outgoing_sink =
futures::sink::unfold(Box::pin(outgoing), async move |mut writer, line: String| {
let mut bytes = line.into_bytes();
bytes.push(b'\n');
writer.write_all(&bytes).await?;
Ok::<_, std::io::Error>(writer)
});
ConnectTo::<R>::into_channel_and_future(Lines::new(outgoing_sink, incoming_lines))
}
}
#[derive(Debug)]
pub struct Channel {
pub rx: mpsc::UnboundedReceiver<Result<RawJsonRpcMessage, crate::Error>>,
pub tx: mpsc::UnboundedSender<Result<RawJsonRpcMessage, crate::Error>>,
}
impl Channel {
#[must_use]
pub fn duplex() -> (Self, Self) {
let (a_tx, b_rx) = mpsc::unbounded();
let (b_tx, a_rx) = mpsc::unbounded();
let channel_a = Self { rx: a_rx, tx: a_tx };
let channel_b = Self { rx: b_rx, tx: b_tx };
(channel_a, channel_b)
}
pub async fn copy(mut self) -> Result<(), crate::Error> {
while let Some(msg) = self.rx.next().await {
self.tx
.unbounded_send(msg)
.map_err(crate::util::internal_error)?;
}
Ok(())
}
}
impl<R: Role> ConnectTo<R> for Channel {
async fn connect_to(self, client: impl ConnectTo<R::Counterpart>) -> Result<(), crate::Error> {
let (client_channel, client_serve) = client.into_channel_and_future();
match futures::try_join!(
Channel {
rx: client_channel.rx,
tx: self.tx
}
.copy(),
Channel {
rx: self.rx,
tx: client_channel.tx
}
.copy(),
client_serve
) {
Ok(((), (), ())) => Ok(()),
Err(err) => Err(err),
}
}
fn into_channel_and_future(self) -> (Channel, BoxFuture<'static, Result<(), crate::Error>>) {
(self, Box::pin(future::ready(Ok(()))))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn peel_successor_envelopes_returns_plain_messages_unchanged() {
let params = serde_json::json!({ "key": "value" });
let (method, peeled) = peel_successor_envelopes("session/update", ¶ms);
assert_eq!(method, "session/update");
assert_eq!(peeled, ¶ms);
}
#[test]
fn peel_successor_envelopes_unwraps_nested_envelopes() {
let params = serde_json::json!({
"method": "_proxy/successor",
"params": {
"method": "$/cancel_request",
"params": { "requestId": "req-1" }
}
});
let (method, peeled) = peel_successor_envelopes("_proxy/successor", ¶ms);
assert_eq!(method, "$/cancel_request");
assert_eq!(peeled, &serde_json::json!({ "requestId": "req-1" }));
}
#[test]
fn peel_successor_envelopes_leaves_malformed_envelopes_intact() {
let params = serde_json::json!({ "unexpected": true });
let (method, peeled) = peel_successor_envelopes("_proxy/successor", ¶ms);
assert_eq!(method, "_proxy/successor");
assert_eq!(peeled, ¶ms);
}
#[cfg(feature = "unstable_cancel_request")]
mod cancel_request {
use super::super::*;
fn notification(method: &str, params: serde_json::Value) -> UntypedMessage {
UntypedMessage::new(method, params).expect("well-formed JSON")
}
#[test]
fn cancellation_request_id_is_extracted_from_wrapped_notifications() {
let message = notification(
"_proxy/successor",
serde_json::json!({
"method": "$/cancel_request",
"params": { "requestId": "req-1" }
}),
);
let request_id = cancellation_request_id_from_message(&message)
.expect("wrapped cancel should parse");
assert_eq!(request_id, Some(RequestId::Str("req-1".into())));
}
#[test]
fn malformed_successor_envelope_is_not_treated_as_cancellation() {
let message = notification("_proxy/successor", serde_json::json!({ "bogus": true }));
let request_id = cancellation_request_id_from_message(&message)
.expect("malformed envelope should be left to the handler chain");
assert_eq!(request_id, None);
}
#[test]
fn cancel_request_notifications_are_detected_even_when_wrapped() {
let plain = notification("$/cancel_request", serde_json::json!({ "requestId": 1 }));
assert!(is_cancel_request_notification(&plain));
let wrapped = notification(
"_proxy/successor",
serde_json::json!({
"method": "$/cancel_request",
"params": { "requestId": 1 }
}),
);
assert!(is_cancel_request_notification(&wrapped));
let other_wrapped = notification(
"_proxy/successor",
serde_json::json!({
"method": "session/update",
"params": {}
}),
);
assert!(!is_cancel_request_notification(&other_wrapped));
let malformed_envelope =
notification("_proxy/successor", serde_json::json!({ "bogus": true }));
assert!(!is_cancel_request_notification(&malformed_envelope));
}
#[test]
fn malformed_cancel_request_params_error() {
let message = notification(
"$/cancel_request",
serde_json::json!({ "requestId": { "not": "an id" } }),
);
cancellation_request_id_from_message(&message)
.expect_err("malformed cancel params should error");
}
#[test]
fn registry_marks_and_removes_requests() {
let registry = RequestCancellationRegistry::new();
let id = RequestId::Str("req-1".into());
let responder_cancellation = registry.register(&id);
let marker = responder_cancellation.cancellation();
assert!(!marker.is_cancelled());
assert!(registry.cancel(&id));
assert!(marker.is_cancelled());
assert!(responder_cancellation.cancellation().is_cancelled());
drop(responder_cancellation);
assert!(!registry.cancel(&id), "slot should be removed on drop");
}
#[test]
fn reused_request_id_does_not_cross_wire_cancellation_state() {
let registry = RequestCancellationRegistry::new();
let id = RequestId::Str("dup".into());
let first = registry.register(&id);
let first_marker = first.cancellation();
let second = registry.register(&id);
let second_marker = second.cancellation();
assert!(registry.cancel(&id));
assert!(second_marker.is_cancelled());
assert!(
!first_marker.is_cancelled(),
"the stale request must not observe the newer request's cancellation"
);
assert!(!first.cancellation().is_cancelled());
drop(first);
assert!(registry.cancel(&id), "newer slot should still be present");
drop(second);
assert!(!registry.cancel(&id), "slot should be removed on drop");
}
}
}