use std::{fmt::Debug, hash::Hash};
use agent_client_protocol_schema::{NewSessionRequest, NewSessionResponse, SessionId};
use crate::jsonrpc::{Builder, handlers::NullHandler, run::NullRun};
use crate::role::{HasPeer, RemoteStyle};
use crate::schema::{InitializeProxyRequest, InitializeRequest, METHOD_INITIALIZE_PROXY};
use crate::util::MatchDispatchFrom;
use crate::{ConnectTo, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, Role, RoleId};
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Client;
impl Role for Client {
type Counterpart = Agent;
async fn default_handle_dispatch_from(
&self,
message: Dispatch,
_connection: ConnectionTo<Client>,
) -> Result<Handled<Dispatch>, crate::Error> {
Ok(Handled::No {
message,
retry: false,
})
}
fn role_id(&self) -> RoleId {
RoleId::from_singleton(self)
}
fn counterpart(&self) -> Self::Counterpart {
Agent
}
}
impl Client {
pub fn builder(self) -> Builder<Client, NullHandler, NullRun> {
Builder::new(self)
}
pub async fn connect_with<R>(
self,
agent: impl ConnectTo<Client>,
main_fn: impl AsyncFnOnce(ConnectionTo<Agent>) -> Result<R, crate::Error>,
) -> Result<R, crate::Error> {
self.builder().connect_with(agent, main_fn).await
}
}
impl HasPeer<Client> for Client {
fn remote_style(&self, _peer: Client) -> RemoteStyle {
RemoteStyle::Counterpart
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Agent;
impl Role for Agent {
type Counterpart = Client;
fn role_id(&self) -> RoleId {
RoleId::from_singleton(self)
}
fn counterpart(&self) -> Self::Counterpart {
Client
}
async fn default_handle_dispatch_from(
&self,
message: Dispatch,
connection: ConnectionTo<Agent>,
) -> Result<Handled<Dispatch>, crate::Error> {
MatchDispatchFrom::new(message, &connection)
.if_message_from(Agent, async |message: Dispatch| {
let retry = message.has_session_id();
Ok(Handled::No { message, retry })
})
.await
.done()
}
}
impl Agent {
pub fn builder(self) -> Builder<Agent, NullHandler, NullRun> {
Builder::new(self)
}
}
impl HasPeer<Agent> for Agent {
fn remote_style(&self, _peer: Agent) -> RemoteStyle {
RemoteStyle::Counterpart
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Proxy;
impl Role for Proxy {
type Counterpart = Conductor;
async fn default_handle_dispatch_from(
&self,
message: crate::Dispatch,
_connection: crate::ConnectionTo<Self>,
) -> Result<crate::Handled<crate::Dispatch>, crate::Error> {
Ok(Handled::No {
message,
retry: false,
})
}
fn role_id(&self) -> RoleId {
RoleId::from_singleton(self)
}
fn counterpart(&self) -> Self::Counterpart {
Conductor
}
}
impl Proxy {
pub fn builder(self) -> Builder<Proxy, NullHandler, NullRun> {
Builder::new(self)
}
}
impl HasPeer<Proxy> for Proxy {
fn remote_style(&self, _peer: Proxy) -> RemoteStyle {
RemoteStyle::Counterpart
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Conductor;
impl Role for Conductor {
type Counterpart = Proxy;
fn role_id(&self) -> RoleId {
RoleId::from_singleton(self)
}
fn counterpart(&self) -> Self::Counterpart {
Proxy
}
async fn default_handle_dispatch_from(
&self,
message: Dispatch,
cx: ConnectionTo<Conductor>,
) -> Result<Handled<Dispatch>, crate::Error> {
MatchDispatchFrom::new(message, &cx)
.if_request_from(Client, async |_req: InitializeRequest, responder| {
responder.respond_with_error(crate::Error::invalid_request().data(format!(
"proxies must be initialized with `{METHOD_INITIALIZE_PROXY}`"
)))
})
.await
.if_request_from(
Client,
async |request: InitializeProxyRequest, responder| {
let InitializeProxyRequest { initialize } = request;
cx.send_request_to(Agent, initialize)
.forward_response_to(responder)
},
)
.await
.if_request_from(Client, async |request: NewSessionRequest, responder| {
cx.send_request_to(Agent, request).on_receiving_result({
let cx = cx.clone();
async move |result| {
if let Ok(NewSessionResponse { session_id, .. }) = &result {
cx.add_dynamic_handler(ProxySessionMessages::new(session_id.clone()))?
.run_indefinitely();
}
responder.respond_with_result(result)
}
})
})
.await
.if_message_from(Client, async |message: Dispatch| {
cx.send_proxied_message_to(Agent, message)
})
.await
.if_message_from(Agent, async |message: Dispatch| {
cx.send_proxied_message_to(Client, message)
})
.await
.done()
}
}
impl Conductor {
pub fn builder(self) -> Builder<Conductor, NullHandler, NullRun> {
Builder::new(self)
}
}
impl HasPeer<Client> for Conductor {
fn remote_style(&self, _peer: Client) -> RemoteStyle {
RemoteStyle::Predecessor
}
}
impl HasPeer<Agent> for Conductor {
fn remote_style(&self, _peer: Agent) -> RemoteStyle {
RemoteStyle::Successor
}
}
pub(crate) struct ProxySessionMessages {
session_id: SessionId,
}
impl ProxySessionMessages {
pub fn new(session_id: SessionId) -> Self {
Self { session_id }
}
}
impl<Counterpart: Role> HandleDispatchFrom<Counterpart> for ProxySessionMessages
where
Counterpart: HasPeer<Agent> + HasPeer<Client>,
{
async fn handle_dispatch_from(
&mut self,
message: Dispatch,
connection: ConnectionTo<Counterpart>,
) -> Result<Handled<Dispatch>, crate::Error> {
MatchDispatchFrom::new(message, &connection)
.if_message_from(Agent, async |message| {
if let Some(session_id) = message.get_session_id()?
&& session_id == self.session_id
{
connection.send_proxied_message_to(Client, message)?;
return Ok(Handled::Yes);
}
Ok(Handled::No {
message,
retry: false,
})
})
.await
.done()
}
fn describe_chain(&self) -> impl std::fmt::Debug {
format!("ProxySessionMessages({})", self.session_id)
}
}