use crate::{ElicitCommunicator, ElicitResult, ElicitationContext, StyleContext, StyleMarker};
use std::time::Instant;
use tokio::sync::{mpsc, watch};
use tracing::instrument;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Participant {
Host,
Human,
Agent(Option<String>),
Custom(String),
}
impl std::fmt::Display for Participant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Participant::Host => write!(f, "Host"),
Participant::Human => write!(f, "Human"),
Participant::Agent(None) => write!(f, "Agent"),
Participant::Agent(Some(model)) => write!(f, "Agent({model})"),
Participant::Custom(name) => write!(f, "{name}"),
}
}
}
type ChatSink<M> = (
mpsc::UnboundedSender<M>,
Participant,
std::sync::Arc<dyn Fn(Participant, String) -> M + Send + Sync>,
);
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub sender: Participant,
pub content: String,
pub timestamp: Instant,
}
impl ChatMessage {
pub fn new(sender: Participant, content: impl Into<String>) -> Self {
Self {
sender,
content: content.into(),
timestamp: Instant::now(),
}
}
}
pub struct ObservableCommunicator<C, M = ChatMessage> {
inner: C,
prompt_tx: watch::Sender<Option<String>>,
chat: Option<ChatSink<M>>,
}
impl<C: Clone, M> Clone for ObservableCommunicator<C, M> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
prompt_tx: self.prompt_tx.clone(),
chat: self.chat.clone(),
}
}
}
impl<C, M> ObservableCommunicator<C, M> {
pub fn new(inner: C, prompt_tx: watch::Sender<Option<String>>) -> Self {
Self {
inner,
prompt_tx,
chat: None,
}
}
pub fn with_chat(
mut self,
chat_tx: mpsc::UnboundedSender<M>,
responder: Participant,
make_message: impl Fn(Participant, String) -> M + Send + Sync + 'static,
) -> Self {
self.chat = Some((chat_tx, responder, std::sync::Arc::new(make_message)));
self
}
}
impl<C: ElicitCommunicator, M: Send + 'static> ElicitCommunicator for ObservableCommunicator<C, M> {
#[instrument(skip(self), level = "debug", fields(prompt_len = prompt.len()))]
fn send_prompt(
&self,
prompt: &str,
) -> impl std::future::Future<Output = ElicitResult<String>> + Send {
let prompt_owned = prompt.to_string();
let watch_tx = self.prompt_tx.clone();
let chat = self.chat.clone();
let inner_future = self.inner.send_prompt(prompt);
async move {
watch_tx.send(Some(prompt_owned.clone())).ok();
if let Some((ref tx, _, ref make)) = chat {
tx.send(make(Participant::Host, prompt_owned)).ok();
}
let result = inner_future.await;
watch_tx.send(None).ok();
if let Some((ref tx, ref responder, ref make)) = chat
&& let Ok(ref response) = result
{
tx.send(make(responder.clone(), response.clone())).ok();
}
result
}
}
#[instrument(skip(self, params), level = "debug", fields(tool = %params.name))]
fn call_tool(
&self,
params: rmcp::model::CallToolRequestParams,
) -> impl std::future::Future<
Output = Result<rmcp::model::CallToolResult, rmcp::service::ServiceError>,
> + Send {
self.inner.call_tool(params)
}
fn style_context(&self) -> &StyleContext {
self.inner.style_context()
}
fn elicitation_context(&self) -> &ElicitationContext {
self.inner.elicitation_context()
}
fn with_style<T: 'static, S: StyleMarker + crate::style::ElicitationStyle + 'static>(
&self,
style: S,
) -> Self {
Self {
inner: self.inner.with_style::<T, S>(style),
prompt_tx: self.prompt_tx.clone(),
chat: self.chat.clone(),
}
}
}