use std::collections::HashMap;
use futures::StreamExt as _;
use futures::channel::mpsc;
use futures::channel::oneshot;
use futures_concurrency::stream::StreamExt as _;
use rustc_hash::FxHashMap;
use uuid::Uuid;
use crate::Dispatch;
use crate::RoleId;
use crate::UntypedMessage;
use crate::jsonrpc::ConnectionTo;
use crate::jsonrpc::HandleDispatchFrom;
use crate::jsonrpc::OutgoingMessage;
use crate::jsonrpc::RawJsonRpcMessage;
use crate::jsonrpc::RawJsonRpcParams;
use crate::jsonrpc::ReplyMessage;
use crate::jsonrpc::Responder;
use crate::jsonrpc::ResponseRouter;
use crate::jsonrpc::dynamic_handler::DynHandleDispatchFrom;
use crate::jsonrpc::dynamic_handler::DynamicHandlerMessage;
use crate::jsonrpc::outgoing_actor::send_raw_message;
use crate::jsonrpc::protocol_compat::ProtocolCompat;
use crate::role::Role;
use crate::schema::v1::{RequestId, Response};
use super::Handled;
struct PendingReply {
method: String,
role_id: RoleId,
sender: oneshot::Sender<crate::jsonrpc::ResponsePayload>,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm: super::SentRequestCancellationDisarm,
}
pub(super) async fn incoming_protocol_actor<Counterpart: Role>(
counterpart: Counterpart,
connection: &ConnectionTo<Counterpart>,
transport_rx: mpsc::UnboundedReceiver<Result<RawJsonRpcMessage, crate::Error>>,
dynamic_handler_rx: mpsc::UnboundedReceiver<DynamicHandlerMessage<Counterpart>>,
reply_rx: mpsc::UnboundedReceiver<ReplyMessage>,
mut handler: impl HandleDispatchFrom<Counterpart>,
protocol_compat: ProtocolCompat,
) -> Result<(), crate::Error> {
let mut my_rx = transport_rx
.map(IncomingProtocolMsg::Transport)
.merge(dynamic_handler_rx.map(IncomingProtocolMsg::DynamicHandler))
.merge(reply_rx.map(IncomingProtocolMsg::Reply));
let mut dynamic_handlers: FxHashMap<Uuid, Box<dyn DynHandleDispatchFrom<Counterpart>>> =
FxHashMap::default();
let mut pending_messages: Vec<Dispatch> = vec![];
let request_cancellations = super::RequestCancellationRegistry::new();
let mut pending_replies: HashMap<RequestId, PendingReply> = HashMap::new();
while let Some(message_result) = my_rx.next().await {
tracing::trace!(message = ?message_result, actor = "incoming_protocol_actor");
match message_result {
IncomingProtocolMsg::Reply(message) => match message {
ReplyMessage::Subscribe {
id,
role_id,
method,
sender,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm,
} => {
tracing::trace!(?id, %method, "incoming_actor: subscribing to response");
pending_replies.insert(
id,
PendingReply {
method,
role_id,
sender,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm,
},
);
}
},
IncomingProtocolMsg::DynamicHandler(message) => match message {
DynamicHandlerMessage::AddDynamicHandler(uuid, mut handler) => {
let mut new_pending_messages = vec![];
for pending_message in pending_messages {
tracing::trace!(method = pending_message.method(), handler = ?handler.dyn_describe_chain(), "Retrying message");
let id = pending_message.id();
let method = pending_message.method().to_string();
match handler
.dyn_handle_dispatch_from(pending_message, connection.clone())
.await
{
Ok(Handled::Yes) => {
tracing::trace!("Message handled");
}
Ok(Handled::No {
message: m,
retry: _,
}) => {
tracing::trace!(method = m.method(), handler = ?handler.dyn_describe_chain(), "Message not handled");
new_pending_messages.push(m);
}
Err(err) => {
tracing::warn!(?err, handler = ?handler.dyn_describe_chain(), "Dynamic handler errored on pending message, reporting back");
report_handler_error(connection, id, method, err)?;
}
}
}
pending_messages = new_pending_messages;
dynamic_handlers.insert(uuid, handler);
}
DynamicHandlerMessage::RemoveDynamicHandler(uuid) => {
dynamic_handlers.remove(&uuid);
}
},
IncomingProtocolMsg::Transport(message) => match message {
Ok(message) => match message {
RawJsonRpcMessage::Request(request) => {
tracing::trace!(method = %request.method, id = ?request.id, "Handling request");
let request_method = request.method.to_string();
let request_id = request.id.clone();
match dispatch_from_message(
connection,
request.method,
request.params,
Some(request.id),
&protocol_compat,
&request_cancellations,
) {
Ok(dispatches) => {
for dispatch in dispatches {
dispatch_dispatch(
counterpart.clone(),
connection,
dispatch,
&mut dynamic_handlers,
&mut handler,
&mut pending_messages,
&request_cancellations,
)
.await?;
}
}
Err(error) => {
report_handler_error(
connection,
Some(
serde_json::to_value(request_id)
.expect("RequestId serializes infallibly"),
),
request_method,
error,
)?;
}
}
}
RawJsonRpcMessage::Notification(notification) => {
tracing::trace!(method = %notification.method, "Handling notification");
let request_method = notification.method.to_string();
match dispatch_from_message(
connection,
notification.method,
notification.params,
None,
&protocol_compat,
&request_cancellations,
) {
Ok(dispatches) => {
for dispatch in dispatches {
dispatch_dispatch(
counterpart.clone(),
connection,
dispatch,
&mut dynamic_handlers,
&mut handler,
&mut pending_messages,
&request_cancellations,
)
.await?;
}
}
Err(error) => {
report_handler_error(connection, None, request_method, error)?;
}
}
}
RawJsonRpcMessage::Response(response) => {
let (id, result) = match response {
Response::Result { id, result } => (id, Ok(result)),
Response::Error { id, error } => (id, Err(error)),
};
tracing::trace!(?id, "Handling response");
if let Some(pending_reply) = pending_replies.remove(&id) {
let result =
protocol_compat.incoming_response(&pending_reply.method, result);
let dispatch = dispatch_from_response(id, pending_reply, result);
dispatch_dispatch(
counterpart.clone(),
connection,
dispatch,
&mut dynamic_handlers,
&mut handler,
&mut pending_messages,
&request_cancellations,
)
.await?;
} else {
tracing::warn!(
?id,
"incoming_actor: received response for unknown id, no subscriber found"
);
}
}
},
Err(error) => {
tracing::warn!(?error, "Transport parse error, sending error notification");
connection.send_error_notification(error)?;
}
},
}
}
Ok(())
}
#[derive(Debug)]
enum IncomingProtocolMsg<Counterpart: Role> {
Transport(Result<RawJsonRpcMessage, crate::Error>),
DynamicHandler(DynamicHandlerMessage<Counterpart>),
Reply(ReplyMessage),
}
fn dispatch_from_message<Counterpart: Role>(
connection: &ConnectionTo<Counterpart>,
method: std::sync::Arc<str>,
params: Option<RawJsonRpcParams>,
id: Option<RequestId>,
protocol_compat: &ProtocolCompat,
request_cancellations: &super::RequestCancellationRegistry,
) -> Result<Vec<Dispatch>, crate::Error> {
let message = UntypedMessage::new(&method, crate::jsonrpc::params_from_transport(params))
.expect("well-formed JSON");
match id {
Some(id) => {
let message = protocol_compat.incoming_message(message)?;
Ok(vec![Dispatch::Request(
message,
Responder::new(
connection.message_tx.clone(),
method.to_string(),
id,
request_cancellations,
),
)])
}
None => Ok(protocol_compat
.incoming_notification(message)?
.into_iter()
.map(Dispatch::Notification)
.collect()),
}
}
fn dispatch_from_response(
id: RequestId,
pending_reply: PendingReply,
result: Result<serde_json::Value, crate::Error>,
) -> Dispatch {
let PendingReply {
method,
role_id,
sender,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm,
} = pending_reply;
let router = ResponseRouter::new(
method.clone(),
id.clone(),
role_id,
sender,
#[cfg(feature = "unstable_cancel_request")]
cancellation_disarm,
);
Dispatch::Response(result, router)
}
#[tracing::instrument(
skip(connection, dispatch, dynamic_handlers, handler, pending_messages),
fields(method = dispatch.method()),
level = "trace",
)]
async fn dispatch_dispatch<Counterpart: Role>(
counterpart: Counterpart,
connection: &ConnectionTo<Counterpart>,
mut dispatch: Dispatch,
dynamic_handlers: &mut FxHashMap<Uuid, Box<dyn DynHandleDispatchFrom<Counterpart>>>,
handler: &mut impl HandleDispatchFrom<Counterpart>,
pending_messages: &mut Vec<Dispatch>,
request_cancellations: &super::RequestCancellationRegistry,
) -> Result<(), crate::Error> {
tracing::trace!(?dispatch, "dispatch_dispatch");
let mut retry_any = false;
let id = dispatch.id();
let method = dispatch.method().to_string();
match request_cancellations.cancel_if_requested(&dispatch) {
Ok(true) => {
tracing::debug!(?method, "Marked request as cancelled");
}
Ok(false) => {}
Err(err) => {
tracing::warn!(
?method,
?id,
?err,
"Request cancellation notification errored"
);
return report_handler_error(connection, id, method, err);
}
}
tracing::trace!(handler = ?handler.describe_chain(), "Attempting handler chain");
match handler
.handle_dispatch_from(dispatch, connection.clone())
.await
{
Ok(Handled::Yes) => {
tracing::trace!(?method, ?id, handler = ?handler.describe_chain(), "Handler accepted message");
return Ok(());
}
Ok(Handled::No { message: m, retry }) => {
tracing::trace!(?method, ?id, handler = ?handler.describe_chain(), "Handler declined message");
dispatch = m;
retry_any |= retry;
}
Err(err) => {
tracing::warn!(?method, ?id, ?err, handler = ?handler.describe_chain(), "Handler errored, reporting back to remote");
return report_handler_error(connection, id, method, err);
}
}
for dynamic_handler in dynamic_handlers.values_mut() {
tracing::trace!(handler = ?dynamic_handler.dyn_describe_chain(), "Attempting dynamic handler");
match dynamic_handler
.dyn_handle_dispatch_from(dispatch, connection.clone())
.await
{
Ok(Handled::Yes) => {
tracing::trace!(?method, ?id, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler accepted message");
return Ok(());
}
Ok(Handled::No { message: m, retry }) => {
tracing::trace!(?method, ?id, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler declined message");
retry_any |= retry;
dispatch = m;
}
Err(err) => {
tracing::warn!(?method, ?id, ?err, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler errored, reporting back to remote");
return report_handler_error(connection, id, method, err);
}
}
}
tracing::trace!(role = ?counterpart, "Attempting default handler");
match counterpart
.default_handle_dispatch_from(dispatch, connection.clone())
.await
{
Ok(Handled::Yes) => {
tracing::trace!(?method, handler = "default", "Role accepted message");
return Ok(());
}
Ok(Handled::No { message: m, retry }) => {
tracing::trace!(?method, handler = "default", "Role declined message");
dispatch = m;
retry_any |= retry;
}
Err(err) => {
tracing::warn!(
?method,
?id,
?err,
handler = "default",
"Default handler errored, reporting back to remote"
);
return report_handler_error(connection, id, method, err);
}
}
if retry_any {
tracing::debug!(
?method,
"Retrying message as new dynamic handlers are added"
);
pending_messages.push(dispatch);
Ok(())
} else if super::is_protocol_level_notification(&dispatch) {
tracing::debug!(?method, "Ignoring unhandled protocol-level notification");
Ok(())
} else {
match dispatch {
Dispatch::Request(..) | Dispatch::Notification(_) => {
tracing::info!(?method, "Rejecting message with error, no handler");
let method = dispatch.method().to_string();
dispatch.respond_with_error(
crate::Error::method_not_found().data(method),
connection.clone(),
)
}
Dispatch::Response(result, router) => {
tracing::trace!(?method, "Forwarding response");
router.respond_with_result(result)
}
}
}
}
fn report_handler_error<Counterpart: Role>(
connection: &ConnectionTo<Counterpart>,
id: Option<serde_json::Value>,
method: String,
error: crate::Error,
) -> Result<(), crate::Error> {
match id {
Some(id) => {
let jsonrpc_id = serde_json::from_value(id).unwrap_or(RequestId::Null);
send_raw_message(
&connection.message_tx,
OutgoingMessage::Response {
id: jsonrpc_id,
method,
response: Err(error),
},
)
}
None => {
connection.send_error_notification(error)
}
}
}