use std::{any::TypeId, fmt::Debug, future::Future, hash::Hash};
use serde::{Deserialize, Serialize};
use crate::schema::{METHOD_SUCCESSOR_MESSAGE, SuccessorMessage};
use crate::util::json_cast;
use crate::{Builder, ConnectionTo, Dispatch, Handled, JsonRpcMessage, UntypedMessage};
pub mod acp;
pub mod mcp;
pub trait Role: Debug + Clone + Send + Sync + 'static + Eq + Ord + Hash {
type Counterpart: Role<Counterpart = Self>;
fn builder(self) -> Builder<Self>
where
Self: Sized,
{
Builder::new(self)
}
fn role_id(&self) -> RoleId;
fn default_handle_dispatch_from(
&self,
message: Dispatch,
connection: ConnectionTo<Self>,
) -> impl Future<Output = Result<Handled<Dispatch>, crate::Error>> + Send;
fn counterpart(&self) -> Self::Counterpart;
}
pub trait HasPeer<Peer: Role>: Role {
fn remote_style(&self, peer: Peer) -> RemoteStyle;
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum RemoteStyle {
Counterpart,
Predecessor,
Successor,
}
impl RemoteStyle {
pub(crate) fn transform_outgoing_message<M: JsonRpcMessage>(
&self,
msg: M,
) -> Result<UntypedMessage, crate::Error> {
match self {
RemoteStyle::Counterpart | RemoteStyle::Predecessor => msg.to_untyped_message(),
RemoteStyle::Successor => SuccessorMessage {
message: msg,
meta: None,
}
.to_untyped_message(),
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
#[non_exhaustive]
pub enum RoleId {
Singleton(&'static str, TypeId),
}
impl RoleId {
pub fn from_singleton<R>(_role: &R) -> RoleId
where
R: Role + Default,
{
RoleId::Singleton(std::any::type_name::<R>(), TypeId::of::<R>())
}
}
pub(crate) async fn handle_incoming_dispatch<Counterpart, Peer>(
counterpart: Counterpart,
peer: Peer,
dispatch: Dispatch,
connection: ConnectionTo<Counterpart>,
handle_dispatch: impl AsyncFnOnce(
Dispatch,
ConnectionTo<Counterpart>,
) -> Result<Handled<Dispatch>, crate::Error>,
) -> Result<Handled<Dispatch>, crate::Error>
where
Counterpart: Role + HasPeer<Peer>,
Peer: Role,
{
tracing::trace!(
method = %dispatch.method(),
?counterpart,
?peer,
?dispatch,
"handle_incoming_dispatch: enter"
);
if let Dispatch::Response(_, router) = &dispatch {
tracing::trace!(
response_role_id = ?router.role_id(),
peer_role_id = ?peer.role_id(),
"handle_incoming_dispatch: response"
);
if router.role_id() == peer.role_id() {
return handle_dispatch(dispatch, connection).await;
}
return Ok(Handled::No {
message: dispatch,
retry: false,
});
}
let method = dispatch.method();
match counterpart.remote_style(peer) {
RemoteStyle::Counterpart => {
tracing::trace!("handle_incoming_dispatch: Counterpart style, passing through");
handle_dispatch(dispatch, connection).await
}
RemoteStyle::Predecessor => {
tracing::trace!("handle_incoming_dispatch: Predecessor style, passing through");
if method == METHOD_SUCCESSOR_MESSAGE {
Ok(Handled::No {
message: dispatch,
retry: false,
})
} else {
handle_dispatch(dispatch, connection).await
}
}
RemoteStyle::Successor => {
if method != METHOD_SUCCESSOR_MESSAGE {
tracing::trace!(
method,
expected = METHOD_SUCCESSOR_MESSAGE,
"handle_incoming_dispatch: Successor style but method doesn't match, returning Handled::No"
);
return Ok(Handled::No {
message: dispatch,
retry: false,
});
}
tracing::trace!(
"handle_incoming_dispatch: Successor style, unwrapping SuccessorMessage"
);
let untyped_message = dispatch.message().ok_or_else(|| {
crate::util::internal_error(
"Response variant cannot be unwrapped as SuccessorMessage",
)
})?;
let SuccessorMessage { message, meta } = json_cast(untyped_message.params())?;
let successor_dispatch = dispatch.try_map_message(|_| Ok(message))?;
tracing::trace!(
unwrapped_method = %successor_dispatch.method(),
"handle_incoming_dispatch: unwrapped to inner message"
);
match handle_dispatch(successor_dispatch, connection).await? {
Handled::Yes => {
tracing::trace!(
"handle_incoming_dispatch: inner handler returned Handled::Yes"
);
Ok(Handled::Yes)
}
Handled::No {
message: successor_dispatch,
retry,
} => {
tracing::trace!(
"handle_incoming_dispatch: inner handler returned Handled::No, re-wrapping"
);
Ok(Handled::No {
message: successor_dispatch.try_map_message(|message| {
SuccessorMessage { message, meta }.to_untyped_message()
})?,
retry,
})
}
}
}
}
}
#[derive(
Copy, Clone, Default, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
pub struct UntypedRole;
impl UntypedRole {
pub fn builder(self) -> Builder<Self> {
Builder::new(self)
}
}
impl Role for UntypedRole {
type Counterpart = UntypedRole;
fn role_id(&self) -> RoleId {
RoleId::from_singleton(self)
}
async fn default_handle_dispatch_from(
&self,
message: Dispatch,
_connection: ConnectionTo<Self>,
) -> Result<Handled<Dispatch>, crate::Error> {
Ok(Handled::No {
message,
retry: false,
})
}
fn counterpart(&self) -> Self::Counterpart {
*self
}
}
impl HasPeer<UntypedRole> for UntypedRole {
fn remote_style(&self, _peer: UntypedRole) -> RemoteStyle {
RemoteStyle::Counterpart
}
}