use agent_client_protocol_schema::SessionId;
pub use jsonrpcmsg;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::panic::Location;
use std::pin::pin;
use uuid::Uuid;
use boxfnonce::SendBoxFnOnce;
use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Either};
use futures::{AsyncRead, AsyncWrite, StreamExt};
mod dynamic_handler;
pub(crate) mod handlers;
mod incoming_actor;
mod outgoing_actor;
pub(crate) mod run;
mod task_actor;
mod transport_actor;
use crate::jsonrpc::dynamic_handler::DynamicHandlerMessage;
pub use crate::jsonrpc::handlers::NullHandler;
use crate::jsonrpc::handlers::{ChainedHandler, NamedHandler};
use crate::jsonrpc::handlers::{MessageHandler, NotificationHandler, RequestHandler};
use crate::jsonrpc::outgoing_actor::{OutgoingMessageTx, send_raw_message};
use crate::jsonrpc::run::SpawnedRun;
use crate::jsonrpc::run::{ChainRun, NullRun, RunWithConnectionTo};
use crate::jsonrpc::task_actor::{Task, TaskTx};
use crate::mcp_server::McpServer;
use crate::role::HasPeer;
use crate::role::Role;
use crate::util::json_cast;
use crate::{Agent, Client, ConnectTo, RoleId};
#[allow(async_fn_in_trait)]
pub trait HandleDispatchFrom<Counterpart: Role>: Send {
fn handle_dispatch_from(
&mut self,
message: Dispatch,
connection: ConnectionTo<Counterpart>,
) -> impl Future<Output = Result<Handled<Dispatch>, crate::Error>> + Send;
fn describe_chain(&self) -> impl std::fmt::Debug;
}
impl<Counterpart: Role, H> HandleDispatchFrom<Counterpart> for &mut H
where
H: HandleDispatchFrom<Counterpart>,
{
fn handle_dispatch_from(
&mut self,
message: Dispatch,
cx: ConnectionTo<Counterpart>,
) -> impl Future<Output = Result<Handled<Dispatch>, crate::Error>> + Send {
H::handle_dispatch_from(self, message, cx)
}
fn describe_chain(&self) -> impl std::fmt::Debug {
H::describe_chain(self)
}
}
#[must_use]
#[derive(Debug)]
pub struct Builder<Host: Role, Handler = NullHandler, Runner = NullRun>
where
Handler: HandleDispatchFrom<Host::Counterpart>,
Runner: RunWithConnectionTo<Host::Counterpart>,
{
host: Host,
name: Option<String>,
handler: Handler,
responder: Runner,
}
impl<Host: Role> Builder<Host, NullHandler, NullRun> {
pub fn new(role: Host) -> Self {
Self {
host: role,
name: None,
handler: NullHandler,
responder: NullRun,
}
}
}
impl<Host: Role, Handler> Builder<Host, Handler, NullRun>
where
Handler: HandleDispatchFrom<Host::Counterpart>,
{
pub fn new_with(role: Host, handler: Handler) -> Self {
Self {
host: role,
name: None,
handler,
responder: NullRun,
}
}
}
impl<
Host: Role,
Handler: HandleDispatchFrom<Host::Counterpart>,
Runner: RunWithConnectionTo<Host::Counterpart>,
> Builder<Host, Handler, Runner>
{
pub fn name(mut self, name: impl ToString) -> Self {
self.name = Some(name.to_string());
self
}
pub fn with_connection_builder(
self,
other: Builder<
Host,
impl HandleDispatchFrom<Host::Counterpart>,
impl RunWithConnectionTo<Host::Counterpart>,
>,
) -> Builder<
Host,
impl HandleDispatchFrom<Host::Counterpart>,
impl RunWithConnectionTo<Host::Counterpart>,
> {
Builder {
host: self.host,
name: self.name,
handler: ChainedHandler::new(
self.handler,
NamedHandler::new(other.name, other.handler),
),
responder: ChainRun::new(self.responder, other.responder),
}
}
pub fn with_handler(
self,
handler: impl HandleDispatchFrom<Host::Counterpart>,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner> {
Builder {
host: self.host,
name: self.name,
handler: ChainedHandler::new(self.handler, handler),
responder: self.responder,
}
}
pub fn with_responder<Run1>(
self,
responder: Run1,
) -> Builder<Host, Handler, impl RunWithConnectionTo<Host::Counterpart>>
where
Run1: RunWithConnectionTo<Host::Counterpart>,
{
Builder {
host: self.host,
name: self.name,
handler: self.handler,
responder: ChainRun::new(self.responder, responder),
}
}
#[track_caller]
pub fn with_spawned<F, Fut>(
self,
task: F,
) -> Builder<Host, Handler, impl RunWithConnectionTo<Host::Counterpart>>
where
F: FnOnce(ConnectionTo<Host::Counterpart>) -> Fut + Send,
Fut: Future<Output = Result<(), crate::Error>> + Send,
{
let location = Location::caller();
self.with_responder(SpawnedRun::new(location, task))
}
pub fn on_receive_dispatch<Req, Notif, F, T, ToFut>(
self,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Host::Counterpart>,
Req: JsonRpcRequest,
Notif: JsonRpcNotification,
F: AsyncFnMut(
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<Dispatch<Req, Notif>>,
ToFut: Fn(
&mut F,
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = MessageHandler::new(
self.host.counterpart(),
self.host.counterpart(),
op,
to_future_hack,
);
self.with_handler(handler)
}
pub fn on_receive_request<Req: JsonRpcRequest, F, T, ToFut>(
self,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Host::Counterpart>,
F: AsyncFnMut(
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<(Req, Responder<Req::Response>)>,
ToFut: Fn(
&mut F,
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = RequestHandler::new(
self.host.counterpart(),
self.host.counterpart(),
op,
to_future_hack,
);
self.with_handler(handler)
}
pub fn on_receive_notification<Notif, F, T, ToFut>(
self,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Host::Counterpart>,
Notif: JsonRpcNotification,
F: AsyncFnMut(Notif, ConnectionTo<Host::Counterpart>) -> Result<T, crate::Error> + Send,
T: IntoHandled<(Notif, ConnectionTo<Host::Counterpart>)>,
ToFut: Fn(
&mut F,
Notif,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = NotificationHandler::new(
self.host.counterpart(),
self.host.counterpart(),
op,
to_future_hack,
);
self.with_handler(handler)
}
pub fn on_receive_dispatch_from<
Req: JsonRpcRequest,
Notif: JsonRpcNotification,
Peer: Role,
F,
T,
ToFut,
>(
self,
peer: Peer,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Peer>,
F: AsyncFnMut(
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<Dispatch<Req, Notif>>,
ToFut: Fn(
&mut F,
Dispatch<Req, Notif>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = MessageHandler::new(self.host.counterpart(), peer, op, to_future_hack);
self.with_handler(handler)
}
pub fn on_receive_request_from<Req: JsonRpcRequest, Peer: Role, F, T, ToFut>(
self,
peer: Peer,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Peer>,
F: AsyncFnMut(
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> Result<T, crate::Error>
+ Send,
T: IntoHandled<(Req, Responder<Req::Response>)>,
ToFut: Fn(
&mut F,
Req,
Responder<Req::Response>,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = RequestHandler::new(self.host.counterpart(), peer, op, to_future_hack);
self.with_handler(handler)
}
pub fn on_receive_notification_from<Notif: JsonRpcNotification, Peer: Role, F, T, ToFut>(
self,
peer: Peer,
op: F,
to_future_hack: ToFut,
) -> Builder<Host, impl HandleDispatchFrom<Host::Counterpart>, Runner>
where
Host::Counterpart: HasPeer<Peer>,
F: AsyncFnMut(Notif, ConnectionTo<Host::Counterpart>) -> Result<T, crate::Error> + Send,
T: IntoHandled<(Notif, ConnectionTo<Host::Counterpart>)>,
ToFut: Fn(
&mut F,
Notif,
ConnectionTo<Host::Counterpart>,
) -> crate::BoxFuture<'_, Result<T, crate::Error>>
+ Send
+ Sync,
{
let handler = NotificationHandler::new(self.host.counterpart(), peer, op, to_future_hack);
self.with_handler(handler)
}
pub fn with_mcp_server(
self,
mcp_server: McpServer<Host::Counterpart, impl RunWithConnectionTo<Host::Counterpart>>,
) -> Builder<
Host,
impl HandleDispatchFrom<Host::Counterpart>,
impl RunWithConnectionTo<Host::Counterpart>,
>
where
Host::Counterpart: HasPeer<Agent> + HasPeer<Client>,
{
let (handler, responder) = mcp_server.into_handler_and_responder();
self.with_handler(handler).with_responder(responder)
}
pub async fn connect_to(
self,
transport: impl ConnectTo<Host> + 'static,
) -> Result<(), crate::Error> {
self.connect_with(transport, async move |_cx| future::pending().await)
.await
}
pub async fn connect_with<R>(
self,
transport: impl ConnectTo<Host> + 'static,
main_fn: impl AsyncFnOnce(ConnectionTo<Host::Counterpart>) -> Result<R, crate::Error>,
) -> Result<R, crate::Error> {
let (_, future) = self.into_connection_and_future(transport, main_fn);
future.await
}
fn into_connection_and_future<R>(
self,
transport: impl ConnectTo<Host> + 'static,
main_fn: impl AsyncFnOnce(ConnectionTo<Host::Counterpart>) -> Result<R, crate::Error>,
) -> (
ConnectionTo<Host::Counterpart>,
impl Future<Output = Result<R, crate::Error>>,
) {
let Self {
name,
handler,
responder,
host: me,
} = self;
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
let (new_task_tx, new_task_rx) = mpsc::unbounded();
let (dynamic_handler_tx, dynamic_handler_rx) = mpsc::unbounded();
let connection = ConnectionTo::new(
me.counterpart(),
outgoing_tx,
new_task_tx,
dynamic_handler_tx,
);
let transport_component = crate::DynConnectTo::new(transport);
let (transport_channel, transport_future) = transport_component.into_channel_and_future();
let spawn_result = connection.spawn(transport_future);
let Channel {
rx: transport_incoming_rx,
tx: transport_outgoing_tx,
} = transport_channel;
let (reply_tx, reply_rx) = mpsc::unbounded();
let future = crate::util::instrument_with_connection_name(name, {
let connection = connection.clone();
async move {
let () = spawn_result?;
let background = async {
futures::try_join!(
outgoing_actor::outgoing_protocol_actor(
outgoing_rx,
reply_tx.clone(),
transport_outgoing_tx,
),
incoming_actor::incoming_protocol_actor(
me.counterpart(),
&connection,
transport_incoming_rx,
dynamic_handler_rx,
reply_rx,
handler,
),
task_actor::task_actor(new_task_rx, &connection),
responder.run_with_connection_to(connection.clone()),
)?;
Ok(())
};
crate::util::run_until(Box::pin(background), Box::pin(main_fn(connection.clone())))
.await
}
});
(connection, future)
}
}
impl<R, H, Run> ConnectTo<R::Counterpart> for Builder<R, H, Run>
where
R: Role,
H: HandleDispatchFrom<R::Counterpart> + 'static,
Run: RunWithConnectionTo<R::Counterpart> + 'static,
{
async fn connect_to(self, client: impl ConnectTo<R>) -> Result<(), crate::Error> {
Builder::connect_to(self, client).await
}
}
pub(crate) struct ResponsePayload {
pub(crate) result: Result<serde_json::Value, crate::Error>,
pub(crate) ack_tx: Option<oneshot::Sender<()>>,
}
impl std::fmt::Debug for ResponsePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponsePayload")
.field("result", &self.result)
.field("ack_tx", &self.ack_tx.as_ref().map(|_| "..."))
.finish()
}
}
enum ReplyMessage {
Subscribe {
id: jsonrpcmsg::Id,
role_id: RoleId,
method: String,
sender: oneshot::Sender<ResponsePayload>,
},
}
impl std::fmt::Debug for ReplyMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReplyMessage::Subscribe { id, method, .. } => f
.debug_struct("Subscribe")
.field("id", id)
.field("method", method)
.finish(),
}
}
}
#[derive(Debug)]
enum OutgoingMessage {
Request {
id: jsonrpcmsg::Id,
method: String,
role_id: RoleId,
untyped: UntypedMessage,
response_tx: oneshot::Sender<ResponsePayload>,
},
Notification {
untyped: UntypedMessage,
},
Response {
id: jsonrpcmsg::Id,
response: Result<serde_json::Value, crate::Error>,
},
Error { error: crate::Error },
}
#[must_use]
#[derive(Debug)]
pub enum Handled<T> {
Yes,
No {
message: T,
retry: bool,
},
}
pub trait IntoHandled<T> {
fn into_handled(self) -> Handled<T>;
}
impl<T> IntoHandled<T> for () {
fn into_handled(self) -> Handled<T> {
Handled::Yes
}
}
impl<T> IntoHandled<T> for Handled<T> {
fn into_handled(self) -> Handled<T> {
self
}
}
#[derive(Clone, Debug)]
pub struct ConnectionTo<Counterpart: Role> {
counterpart: Counterpart,
message_tx: OutgoingMessageTx,
task_tx: TaskTx,
dynamic_handler_tx: mpsc::UnboundedSender<DynamicHandlerMessage<Counterpart>>,
}
impl<Counterpart: Role> ConnectionTo<Counterpart> {
fn new(
counterpart: Counterpart,
message_tx: mpsc::UnboundedSender<OutgoingMessage>,
task_tx: mpsc::UnboundedSender<Task>,
dynamic_handler_tx: mpsc::UnboundedSender<DynamicHandlerMessage<Counterpart>>,
) -> Self {
Self {
counterpart,
message_tx,
task_tx,
dynamic_handler_tx,
}
}
pub fn counterpart(&self) -> Counterpart {
self.counterpart.clone()
}
#[track_caller]
pub fn spawn(
&self,
task: impl IntoFuture<Output = Result<(), crate::Error>, IntoFuture: Send + 'static>,
) -> Result<(), crate::Error> {
let location = std::panic::Location::caller();
let task = task.into_future();
Task::new(location, task).spawn(&self.task_tx)
}
#[track_caller]
pub fn spawn_connection<R: Role>(
&self,
builder: Builder<
R,
impl HandleDispatchFrom<R::Counterpart> + 'static,
impl RunWithConnectionTo<R::Counterpart> + 'static,
>,
transport: impl ConnectTo<R> + 'static,
) -> Result<ConnectionTo<R::Counterpart>, crate::Error> {
let (connection, future) =
builder.into_connection_and_future(transport, |_| std::future::pending());
Task::new(std::panic::Location::caller(), future).spawn(&self.task_tx)?;
Ok(connection)
}
pub fn send_proxied_message<Req: JsonRpcRequest<Response: Send>, Notif: JsonRpcNotification>(
&self,
message: Dispatch<Req, Notif>,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Counterpart>,
{
self.send_proxied_message_to(self.counterpart(), message)
}
pub fn send_proxied_message_to<
Peer: Role,
Req: JsonRpcRequest<Response: Send>,
Notif: JsonRpcNotification,
>(
&self,
peer: Peer,
message: Dispatch<Req, Notif>,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Peer>,
{
match message {
Dispatch::Request(request, responder) => self
.send_request_to(peer, request)
.forward_response_to(responder),
Dispatch::Notification(notification) => self.send_notification_to(peer, notification),
Dispatch::Response(result, router) => {
router.respond_with_result(result)
}
}
}
pub fn send_request<Req: JsonRpcRequest>(&self, request: Req) -> SentRequest<Req::Response>
where
Counterpart: HasPeer<Counterpart>,
{
self.send_request_to(self.counterpart.clone(), request)
}
pub fn send_request_to<Peer: Role, Req: JsonRpcRequest>(
&self,
peer: Peer,
request: Req,
) -> SentRequest<Req::Response>
where
Counterpart: HasPeer<Peer>,
{
let method = request.method().to_string();
let id = jsonrpcmsg::Id::String(uuid::Uuid::new_v4().to_string());
let (response_tx, response_rx) = oneshot::channel();
let role_id = peer.role_id();
let remote_style = self.counterpart.remote_style(peer);
match remote_style.transform_outgoing_message(request) {
Ok(untyped) => {
let message = OutgoingMessage::Request {
id: id.clone(),
method: method.clone(),
role_id,
untyped,
response_tx,
};
match self.message_tx.unbounded_send(message) {
Ok(()) => (),
Err(error) => {
let OutgoingMessage::Request {
method,
response_tx,
..
} = error.into_inner()
else {
unreachable!();
};
response_tx
.send(ResponsePayload {
result: Err(crate::util::internal_error(format!(
"failed to send outgoing request `{method}"
))),
ack_tx: None,
})
.unwrap();
}
}
}
Err(err) => {
response_tx
.send(ResponsePayload {
result: Err(crate::util::internal_error(format!(
"failed to create untyped request for `{method}`: {err}"
))),
ack_tx: None,
})
.unwrap();
}
}
SentRequest::new(id, method.clone(), self.task_tx.clone(), response_rx)
.map(move |json| <Req::Response>::from_value(&method, json))
}
pub fn send_notification<N: JsonRpcNotification>(
&self,
notification: N,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Counterpart>,
{
self.send_notification_to(self.counterpart.clone(), notification)
}
pub fn send_notification_to<Peer: Role, N: JsonRpcNotification>(
&self,
peer: Peer,
notification: N,
) -> Result<(), crate::Error>
where
Counterpart: HasPeer<Peer>,
{
let remote_style = self.counterpart.remote_style(peer);
tracing::debug!(
role = std::any::type_name::<Counterpart>(),
peer = std::any::type_name::<Peer>(),
notification_type = std::any::type_name::<N>(),
?remote_style,
original_method = notification.method(),
"send_notification_to"
);
let transformed = remote_style.transform_outgoing_message(notification)?;
tracing::debug!(
transformed_method = %transformed.method,
"send_notification_to transformed"
);
send_raw_message(
&self.message_tx,
OutgoingMessage::Notification {
untyped: transformed,
},
)
}
pub fn send_error_notification(&self, error: crate::Error) -> Result<(), crate::Error> {
send_raw_message(&self.message_tx, OutgoingMessage::Error { error })
}
pub fn add_dynamic_handler(
&self,
handler: impl HandleDispatchFrom<Counterpart> + 'static,
) -> Result<DynamicHandlerRegistration<Counterpart>, crate::Error> {
let uuid = Uuid::new_v4();
self.dynamic_handler_tx
.unbounded_send(DynamicHandlerMessage::AddDynamicHandler(
uuid,
Box::new(handler),
))
.map_err(crate::util::internal_error)?;
Ok(DynamicHandlerRegistration::new(uuid, self.clone()))
}
fn remove_dynamic_handler(&self, uuid: Uuid) {
drop(
self.dynamic_handler_tx
.unbounded_send(DynamicHandlerMessage::RemoveDynamicHandler(uuid)),
);
}
}
#[derive(Clone, Debug)]
pub struct DynamicHandlerRegistration<R: Role> {
uuid: Uuid,
cx: ConnectionTo<R>,
}
impl<R: Role> DynamicHandlerRegistration<R> {
fn new(uuid: Uuid, cx: ConnectionTo<R>) -> Self {
Self { uuid, cx }
}
pub fn run_indefinitely(self) {
std::mem::forget(self);
}
}
impl<R: Role> Drop for DynamicHandlerRegistration<R> {
fn drop(&mut self) {
self.cx.remove_dynamic_handler(self.uuid);
}
}
#[must_use]
pub struct Responder<T: JsonRpcResponse = serde_json::Value> {
method: String,
id: jsonrpcmsg::Id,
send_fn: SendBoxFnOnce<'static, (Result<T, crate::Error>,), Result<(), crate::Error>>,
}
impl<T: JsonRpcResponse> std::fmt::Debug for Responder<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Responder")
.field("method", &self.method)
.field("id", &self.id)
.field("response_type", &std::any::type_name::<T>())
.finish_non_exhaustive()
}
}
impl Responder<serde_json::Value> {
fn new(message_tx: OutgoingMessageTx, method: String, id: jsonrpcmsg::Id) -> Self {
let id_clone = id.clone();
Self {
method,
id,
send_fn: SendBoxFnOnce::new(
move |response: Result<serde_json::Value, crate::Error>| {
send_raw_message(
&message_tx,
OutgoingMessage::Response {
id: id_clone,
response,
},
)
},
),
}
}
pub fn cast<T: JsonRpcResponse>(self) -> Responder<T> {
self.wrap_params(move |method, value| match value {
Ok(value) => T::into_json(value, method),
Err(e) => Err(e),
})
}
}
impl<T: JsonRpcResponse> Responder<T> {
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
#[must_use]
pub fn id(&self) -> serde_json::Value {
crate::util::id_to_json(&self.id)
}
pub fn erase_to_json(self) -> Responder<serde_json::Value> {
self.wrap_params(|method, value| T::from_value(method, value?))
}
pub fn wrap_method(self, method: String) -> Responder<T> {
Responder {
method,
id: self.id,
send_fn: self.send_fn,
}
}
pub fn wrap_params<U: JsonRpcResponse>(
self,
wrap_fn: impl FnOnce(&str, Result<U, crate::Error>) -> Result<T, crate::Error> + Send + 'static,
) -> Responder<U> {
let method = self.method.clone();
Responder {
method: self.method,
id: self.id,
send_fn: SendBoxFnOnce::new(move |input: Result<U, crate::Error>| {
let t_value = wrap_fn(&method, input);
self.send_fn.call(t_value)
}),
}
}
pub fn respond_with_result(
self,
response: Result<T, crate::Error>,
) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, "respond called");
self.send_fn.call(response)
}
pub fn respond(self, response: T) -> Result<(), crate::Error> {
self.respond_with_result(Ok(response))
}
pub fn respond_with_internal_error(self, message: impl ToString) -> Result<(), crate::Error> {
self.respond_with_error(crate::util::internal_error(message))
}
pub fn respond_with_error(self, error: crate::Error) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, ?error, "respond_with_error called");
self.respond_with_result(Err(error))
}
}
#[must_use]
pub struct ResponseRouter<T: JsonRpcResponse = serde_json::Value> {
method: String,
id: jsonrpcmsg::Id,
role_id: RoleId,
send_fn: SendBoxFnOnce<'static, (Result<T, crate::Error>,), Result<(), crate::Error>>,
}
impl<T: JsonRpcResponse> std::fmt::Debug for ResponseRouter<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseRouter")
.field("method", &self.method)
.field("id", &self.id)
.field("response_type", &std::any::type_name::<T>())
.finish_non_exhaustive()
}
}
impl ResponseRouter<serde_json::Value> {
pub(crate) fn new(
method: String,
id: jsonrpcmsg::Id,
role_id: RoleId,
sender: oneshot::Sender<ResponsePayload>,
) -> Self {
Self {
method,
id,
role_id,
send_fn: SendBoxFnOnce::new(
move |response: Result<serde_json::Value, crate::Error>| {
sender
.send(ResponsePayload {
result: response,
ack_tx: None,
})
.map_err(|_| {
crate::util::internal_error("failed to send response, receiver dropped")
})
},
),
}
}
pub fn cast<T: JsonRpcResponse>(self) -> ResponseRouter<T> {
self.wrap_params(move |method, value| match value {
Ok(value) => T::into_json(value, method),
Err(e) => Err(e),
})
}
}
impl<T: JsonRpcResponse> ResponseRouter<T> {
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
#[must_use]
pub fn id(&self) -> serde_json::Value {
crate::util::id_to_json(&self.id)
}
#[must_use]
pub fn role_id(&self) -> RoleId {
self.role_id.clone()
}
pub fn erase_to_json(self) -> ResponseRouter<serde_json::Value> {
self.wrap_params(|method, value| T::from_value(method, value?))
}
fn wrap_params<U: JsonRpcResponse>(
self,
wrap_fn: impl FnOnce(&str, Result<U, crate::Error>) -> Result<T, crate::Error> + Send + 'static,
) -> ResponseRouter<U> {
let method = self.method.clone();
ResponseRouter {
method: self.method,
id: self.id,
role_id: self.role_id,
send_fn: SendBoxFnOnce::new(move |input: Result<U, crate::Error>| {
let t_value = wrap_fn(&method, input);
self.send_fn.call(t_value)
}),
}
}
pub fn respond_with_result(
self,
response: Result<T, crate::Error>,
) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, "response routed to awaiter");
self.send_fn.call(response)
}
pub fn respond(self, response: T) -> Result<(), crate::Error> {
self.respond_with_result(Ok(response))
}
pub fn respond_with_internal_error(self, message: impl ToString) -> Result<(), crate::Error> {
self.respond_with_error(crate::util::internal_error(message))
}
pub fn respond_with_error(self, error: crate::Error) -> Result<(), crate::Error> {
tracing::debug!(id = ?self.id, ?error, "error routed to awaiter");
self.respond_with_result(Err(error))
}
}
pub trait JsonRpcMessage: 'static + Debug + Sized + Send + Clone {
fn matches_method(method: &str) -> bool;
fn method(&self) -> &str;
fn to_untyped_message(&self) -> Result<UntypedMessage, crate::Error>;
fn parse_message(method: &str, params: &impl Serialize) -> Result<Self, crate::Error>;
}
pub trait JsonRpcResponse: 'static + Debug + Sized + Send + Clone {
fn into_json(self, method: &str) -> Result<serde_json::Value, crate::Error>;
fn from_value(method: &str, value: serde_json::Value) -> Result<Self, crate::Error>;
}
impl JsonRpcResponse for serde_json::Value {
fn from_value(_method: &str, value: serde_json::Value) -> Result<Self, crate::Error> {
Ok(value)
}
fn into_json(self, _method: &str) -> Result<serde_json::Value, crate::Error> {
Ok(self)
}
}
pub trait JsonRpcNotification: JsonRpcMessage {}
pub trait JsonRpcRequest: JsonRpcMessage {
type Response: JsonRpcResponse;
}
#[derive(Debug)]
pub enum Dispatch<Req: JsonRpcRequest = UntypedMessage, Notif: JsonRpcMessage = UntypedMessage> {
Request(Req, Responder<Req::Response>),
Notification(Notif),
Response(
Result<Req::Response, crate::Error>,
ResponseRouter<Req::Response>,
),
}
impl<Req: JsonRpcRequest, Notif: JsonRpcMessage> Dispatch<Req, Notif> {
pub fn map<Req1, Notif1>(
self,
map_request: impl FnOnce(Req, Responder<Req::Response>) -> (Req1, Responder<Req1::Response>),
map_notification: impl FnOnce(Notif) -> Notif1,
) -> Dispatch<Req1, Notif1>
where
Req1: JsonRpcRequest<Response = Req::Response>,
Notif1: JsonRpcMessage,
{
match self {
Dispatch::Request(request, responder) => {
let (new_request, new_responder) = map_request(request, responder);
Dispatch::Request(new_request, new_responder)
}
Dispatch::Notification(notification) => {
let new_notification = map_notification(notification);
Dispatch::Notification(new_notification)
}
Dispatch::Response(result, router) => Dispatch::Response(result, router),
}
}
pub fn respond_with_error<R: Role>(
self,
error: crate::Error,
cx: ConnectionTo<R>,
) -> Result<(), crate::Error> {
match self {
Dispatch::Request(_, responder) => responder.respond_with_error(error),
Dispatch::Notification(_) => cx.send_error_notification(error),
Dispatch::Response(_, responder) => responder.respond_with_error(error),
}
}
pub fn erase_to_json(self) -> Result<Dispatch, crate::Error> {
match self {
Dispatch::Request(response, responder) => Ok(Dispatch::Request(
response.to_untyped_message()?,
responder.erase_to_json(),
)),
Dispatch::Notification(notification) => {
Ok(Dispatch::Notification(notification.to_untyped_message()?))
}
Dispatch::Response(_, _) => Err(crate::util::internal_error(
"cannot erase Response variant to JSON",
)),
}
}
pub fn to_untyped_message(&self) -> Result<UntypedMessage, crate::Error> {
match self {
Dispatch::Request(request, _) => request.to_untyped_message(),
Dispatch::Notification(notification) => notification.to_untyped_message(),
Dispatch::Response(_, _) => Err(crate::util::internal_error(
"Response variant has no untyped message representation",
)),
}
}
pub fn into_untyped_dispatch(self) -> Result<Dispatch, crate::Error> {
match self {
Dispatch::Request(request, responder) => Ok(Dispatch::Request(
request.to_untyped_message()?,
responder.erase_to_json(),
)),
Dispatch::Notification(notification) => {
Ok(Dispatch::Notification(notification.to_untyped_message()?))
}
Dispatch::Response(_, _) => Err(crate::util::internal_error(
"cannot convert Response variant to untyped message context",
)),
}
}
pub fn id(&self) -> Option<serde_json::Value> {
match self {
Dispatch::Request(_, cx) => Some(cx.id()),
Dispatch::Notification(_) => None,
Dispatch::Response(_, cx) => Some(cx.id()),
}
}
pub fn method(&self) -> &str {
match self {
Dispatch::Request(msg, _) => msg.method(),
Dispatch::Notification(msg) => msg.method(),
Dispatch::Response(_, cx) => cx.method(),
}
}
}
impl Dispatch {
#[tracing::instrument(skip(self), fields(Request = ?std::any::type_name::<Req>(), Notif = ?std::any::type_name::<Notif>()), level = "trace", ret)]
pub(crate) fn into_typed_dispatch<Req: JsonRpcRequest, Notif: JsonRpcNotification>(
self,
) -> Result<Result<Dispatch<Req, Notif>, Dispatch>, crate::Error> {
tracing::debug!(
message = ?self,
"into_typed_dispatch"
);
match self {
Dispatch::Request(message, responder) => {
if Req::matches_method(&message.method) {
match Req::parse_message(&message.method, &message.params) {
Ok(req) => {
tracing::trace!(?req, "parsed ok");
Ok(Ok(Dispatch::Request(req, responder.cast())))
}
Err(err) => {
tracing::trace!(?err, "parse error");
Err(err)
}
}
} else {
tracing::trace!("method doesn't match");
Ok(Err(Dispatch::Request(message, responder)))
}
}
Dispatch::Notification(message) => {
if Notif::matches_method(&message.method) {
match Notif::parse_message(&message.method, &message.params) {
Ok(notif) => {
tracing::trace!(?notif, "parse ok");
Ok(Ok(Dispatch::Notification(notif)))
}
Err(err) => {
tracing::trace!(?err, "parse error");
Err(err)
}
}
} else {
tracing::trace!("method doesn't match");
Ok(Err(Dispatch::Notification(message)))
}
}
Dispatch::Response(result, cx) => {
let method = cx.method();
if Req::matches_method(method) {
let typed_result = match result {
Ok(value) => {
match <Req::Response as JsonRpcResponse>::from_value(method, value) {
Ok(parsed) => {
tracing::trace!(?parsed, "parse ok");
Ok(parsed)
}
Err(err) => {
tracing::trace!(?err, "parse error");
return Err(err);
}
}
}
Err(err) => {
tracing::trace!("error, passthrough");
Err(err)
}
};
Ok(Ok(Dispatch::Response(typed_result, cx.cast())))
} else {
tracing::trace!("method doesn't match");
Ok(Err(Dispatch::Response(result, cx)))
}
}
}
}
#[must_use]
pub fn has_field(&self, field_name: &str) -> bool {
self.message()
.and_then(|m| m.params().get(field_name))
.is_some()
}
pub(crate) fn has_session_id(&self) -> bool {
self.has_field("sessionId")
}
pub(crate) fn get_session_id(&self) -> Result<Option<SessionId>, crate::Error> {
let Some(message) = self.message() else {
return Ok(None);
};
let Some(value) = message.params().get("sessionId") else {
return Ok(None);
};
let session_id = serde_json::from_value(value.clone())?;
Ok(Some(session_id))
}
pub fn into_notification<N: JsonRpcNotification>(
self,
) -> Result<Result<N, Dispatch>, crate::Error> {
match self {
Dispatch::Notification(msg) => {
if !N::matches_method(&msg.method) {
return Ok(Err(Dispatch::Notification(msg)));
}
match N::parse_message(&msg.method, &msg.params) {
Ok(n) => Ok(Ok(n)),
Err(err) => Err(err),
}
}
Dispatch::Request(..) | Dispatch::Response(..) => Ok(Err(self)),
}
}
pub fn into_request<Req: JsonRpcRequest>(
self,
) -> Result<Result<(Req, Responder<Req::Response>), Dispatch>, crate::Error> {
match self {
Dispatch::Request(msg, responder) => {
if !Req::matches_method(&msg.method) {
return Ok(Err(Dispatch::Request(msg, responder)));
}
match Req::parse_message(&msg.method, &msg.params) {
Ok(req) => Ok(Ok((req, responder.cast()))),
Err(err) => Err(err),
}
}
Dispatch::Notification(..) | Dispatch::Response(..) => Ok(Err(self)),
}
}
}
impl<M: JsonRpcRequest + JsonRpcNotification> Dispatch<M, M> {
pub fn message(&self) -> Option<&M> {
match self {
Dispatch::Request(msg, _) | Dispatch::Notification(msg) => Some(msg),
Dispatch::Response(_, _) => None,
}
}
pub(crate) fn try_map_message(
self,
map_message: impl FnOnce(M) -> Result<M, crate::Error>,
) -> Result<Dispatch<M, M>, crate::Error> {
match self {
Dispatch::Request(request, cx) => Ok(Dispatch::Request(map_message(request)?, cx)),
Dispatch::Notification(notification) => {
Ok(Dispatch::<M, M>::Notification(map_message(notification)?))
}
Dispatch::Response(result, cx) => Ok(Dispatch::Response(result, cx)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct UntypedMessage {
pub method: String,
pub params: serde_json::Value,
}
impl UntypedMessage {
pub fn new(method: &str, params: impl Serialize) -> Result<Self, crate::Error> {
let params = serde_json::to_value(params)?;
Ok(Self {
method: method.to_string(),
params,
})
}
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
#[must_use]
pub fn params(&self) -> &serde_json::Value {
&self.params
}
#[must_use]
pub fn into_parts(self) -> (String, serde_json::Value) {
(self.method, self.params)
}
pub(crate) fn into_jsonrpc_msg(
self,
id: Option<jsonrpcmsg::Id>,
) -> Result<jsonrpcmsg::Request, crate::Error> {
let Self { method, params } = self;
Ok(jsonrpcmsg::Request::new_v2(method, json_cast(params)?, id))
}
}
impl JsonRpcMessage for UntypedMessage {
fn matches_method(_method: &str) -> bool {
true
}
fn method(&self) -> &str {
&self.method
}
fn to_untyped_message(&self) -> Result<UntypedMessage, crate::Error> {
Ok(self.clone())
}
fn parse_message(method: &str, params: &impl Serialize) -> Result<Self, crate::Error> {
UntypedMessage::new(method, params)
}
}
impl JsonRpcRequest for UntypedMessage {
type Response = serde_json::Value;
}
impl JsonRpcNotification for UntypedMessage {}
pub struct SentRequest<T> {
id: jsonrpcmsg::Id,
method: String,
task_tx: TaskTx,
response_rx: oneshot::Receiver<ResponsePayload>,
to_result: Box<dyn Fn(serde_json::Value) -> Result<T, crate::Error> + Send>,
}
impl<T: Debug> Debug for SentRequest<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SentRequest")
.field("id", &self.id)
.field("method", &self.method)
.field("task_tx", &self.task_tx)
.field("response_rx", &self.response_rx)
.finish_non_exhaustive()
}
}
impl SentRequest<serde_json::Value> {
fn new(
id: jsonrpcmsg::Id,
method: String,
task_tx: mpsc::UnboundedSender<Task>,
response_rx: oneshot::Receiver<ResponsePayload>,
) -> Self {
Self {
id,
method,
response_rx,
task_tx,
to_result: Box::new(Ok),
}
}
}
impl<T: JsonRpcResponse> SentRequest<T> {
#[must_use]
pub fn id(&self) -> serde_json::Value {
crate::util::id_to_json(&self.id)
}
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
pub fn map<U>(
self,
map_fn: impl Fn(T) -> Result<U, crate::Error> + 'static + Send,
) -> SentRequest<U> {
SentRequest {
id: self.id,
method: self.method,
response_rx: self.response_rx,
task_tx: self.task_tx,
to_result: Box::new(move |value| map_fn((self.to_result)(value)?)),
}
}
pub fn forward_response_to(self, responder: Responder<T>) -> Result<(), crate::Error>
where
T: Send,
{
self.on_receiving_result(async move |result| responder.respond_with_result(result))
}
pub async fn block_task(self) -> Result<T, crate::Error>
where
T: Send,
{
match self.response_rx.await {
Ok(ResponsePayload {
result: Ok(json_value),
ack_tx,
}) => {
if let Some(tx) = ack_tx {
let _ = tx.send(());
}
match (self.to_result)(json_value) {
Ok(value) => Ok(value),
Err(err) => Err(err),
}
}
Ok(ResponsePayload {
result: Err(err),
ack_tx,
}) => {
if let Some(tx) = ack_tx {
let _ = tx.send(());
}
Err(err)
}
Err(err) => Err(crate::util::internal_error(format!(
"response to `{}` never received: {}",
self.method, err
))),
}
}
#[track_caller]
pub fn on_receiving_ok_result<F>(
self,
responder: Responder<T>,
task: impl FnOnce(T, Responder<T>) -> F + 'static + Send,
) -> Result<(), crate::Error>
where
F: Future<Output = Result<(), crate::Error>> + 'static + Send,
T: Send,
{
self.on_receiving_result(async move |result| match result {
Ok(value) => task(value, responder).await,
Err(err) => responder.respond_with_error(err),
})
}
#[track_caller]
pub fn on_receiving_result<F>(
self,
task: impl FnOnce(Result<T, crate::Error>) -> F + 'static + Send,
) -> Result<(), crate::Error>
where
F: Future<Output = Result<(), crate::Error>> + 'static + Send,
T: Send,
{
let task_tx = self.task_tx.clone();
let method = self.method;
let response_rx = self.response_rx;
let to_result = self.to_result;
let location = Location::caller();
Task::new(location, async move {
match response_rx.await {
Ok(ResponsePayload { result, ack_tx }) => {
let typed_result = match result {
Ok(json_value) => to_result(json_value),
Err(err) => Err(err),
};
let outcome = task(typed_result).await;
if let Some(tx) = ack_tx {
let _ = tx.send(());
}
outcome
}
Err(err) => Err(crate::util::internal_error(format!(
"response to `{method}` never received: {err}"
))),
}
})
.spawn(&task_tx)
}
}
#[derive(Debug)]
pub struct Lines<OutgoingSink, IncomingStream> {
pub outgoing: OutgoingSink,
pub incoming: IncomingStream,
}
impl<OutgoingSink, IncomingStream> Lines<OutgoingSink, IncomingStream>
where
OutgoingSink: futures::Sink<String, Error = std::io::Error> + Send + 'static,
IncomingStream: futures::Stream<Item = std::io::Result<String>> + Send + 'static,
{
pub fn new(outgoing: OutgoingSink, incoming: IncomingStream) -> Self {
Self { outgoing, incoming }
}
}
impl<OutgoingSink, IncomingStream, R: Role> ConnectTo<R> for Lines<OutgoingSink, IncomingStream>
where
OutgoingSink: futures::Sink<String, Error = std::io::Error> + Send + 'static,
IncomingStream: futures::Stream<Item = std::io::Result<String>> + Send + 'static,
{
async fn connect_to(self, client: impl ConnectTo<R::Counterpart>) -> Result<(), crate::Error> {
let (channel, serve_self) = ConnectTo::<R>::into_channel_and_future(self);
match futures::future::select(Box::pin(client.connect_to(channel)), serve_self).await {
Either::Left((result, _)) | Either::Right((result, _)) => result,
}
}
fn into_channel_and_future(self) -> (Channel, BoxFuture<'static, Result<(), crate::Error>>) {
let Self { outgoing, incoming } = self;
let (channel_for_caller, channel_for_lines) = Channel::duplex();
let server_future = Box::pin(async move {
let Channel { rx, tx } = channel_for_lines;
let outgoing_future = transport_actor::transport_outgoing_lines_actor(rx, outgoing);
let incoming_future = transport_actor::transport_incoming_lines_actor(incoming, tx);
futures::try_join!(outgoing_future, incoming_future)?;
Ok(())
});
(channel_for_caller, server_future)
}
}
#[derive(Debug)]
pub struct ByteStreams<OB, IB> {
pub outgoing: OB,
pub incoming: IB,
}
impl<OB, IB> ByteStreams<OB, IB>
where
OB: AsyncWrite + Send + 'static,
IB: AsyncRead + Send + 'static,
{
pub fn new(outgoing: OB, incoming: IB) -> Self {
Self { outgoing, incoming }
}
}
impl<OB, IB, R: Role> ConnectTo<R> for ByteStreams<OB, IB>
where
OB: AsyncWrite + Send + 'static,
IB: AsyncRead + Send + 'static,
{
async fn connect_to(self, client: impl ConnectTo<R::Counterpart>) -> Result<(), crate::Error> {
let (channel, serve_self) = ConnectTo::<R>::into_channel_and_future(self);
match futures::future::select(pin!(client.connect_to(channel)), serve_self).await {
Either::Left((result, _)) | Either::Right((result, _)) => result,
}
}
fn into_channel_and_future(self) -> (Channel, BoxFuture<'static, Result<(), crate::Error>>) {
use futures::AsyncBufReadExt;
use futures::AsyncWriteExt;
use futures::io::BufReader;
let Self { outgoing, incoming } = self;
let incoming_lines = Box::pin(BufReader::new(incoming).lines());
let outgoing_sink =
futures::sink::unfold(Box::pin(outgoing), async move |mut writer, line: String| {
let mut bytes = line.into_bytes();
bytes.push(b'\n');
writer.write_all(&bytes).await?;
Ok::<_, std::io::Error>(writer)
});
ConnectTo::<R>::into_channel_and_future(Lines::new(outgoing_sink, incoming_lines))
}
}
#[derive(Debug)]
pub struct Channel {
pub rx: mpsc::UnboundedReceiver<Result<jsonrpcmsg::Message, crate::Error>>,
pub tx: mpsc::UnboundedSender<Result<jsonrpcmsg::Message, crate::Error>>,
}
impl Channel {
#[must_use]
pub fn duplex() -> (Self, Self) {
let (a_tx, b_rx) = mpsc::unbounded();
let (b_tx, a_rx) = mpsc::unbounded();
let channel_a = Self { rx: a_rx, tx: a_tx };
let channel_b = Self { rx: b_rx, tx: b_tx };
(channel_a, channel_b)
}
pub async fn copy(mut self) -> Result<(), crate::Error> {
while let Some(msg) = self.rx.next().await {
self.tx
.unbounded_send(msg)
.map_err(crate::util::internal_error)?;
}
Ok(())
}
}
impl<R: Role> ConnectTo<R> for Channel {
async fn connect_to(self, client: impl ConnectTo<R::Counterpart>) -> Result<(), crate::Error> {
let (client_channel, client_serve) = client.into_channel_and_future();
match futures::try_join!(
Channel {
rx: client_channel.rx,
tx: self.tx
}
.copy(),
Channel {
rx: self.rx,
tx: client_channel.tx
}
.copy(),
client_serve
) {
Ok(((), (), ())) => Ok(()),
Err(err) => Err(err),
}
}
fn into_channel_and_future(self) -> (Channel, BoxFuture<'static, Result<(), crate::Error>>) {
(self, Box::pin(future::ready(Ok(()))))
}
}