use std::future::Future;
use std::pin::Pin;
use tokio::sync::mpsc;
use crate::agent::UiMessage;
#[derive(Debug)]
pub struct SendError(pub UiMessage);
impl std::fmt::Display for SendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to send event")
}
}
impl std::error::Error for SendError {}
#[allow(clippy::result_large_err)]
pub trait EventSink: Send + Sync + 'static {
fn send(&self, event: UiMessage) -> Result<(), SendError>;
fn send_async(
&self,
event: UiMessage,
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
Box::pin(async move { self.send(event) })
}
fn clone_box(&self) -> Box<dyn EventSink>;
}
impl EventSink for Box<dyn EventSink> {
fn send(&self, event: UiMessage) -> Result<(), SendError> {
(**self).send(event)
}
fn send_async(
&self,
event: UiMessage,
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
(**self).send_async(event)
}
fn clone_box(&self) -> Box<dyn EventSink> {
(**self).clone_box()
}
}
#[derive(Clone)]
pub struct ChannelEventSink {
tx: mpsc::Sender<UiMessage>,
}
impl ChannelEventSink {
pub fn new(tx: mpsc::Sender<UiMessage>) -> Self {
Self { tx }
}
}
impl EventSink for ChannelEventSink {
fn send(&self, event: UiMessage) -> Result<(), SendError> {
self.tx
.try_send(event)
.map_err(|e| SendError(e.into_inner()))
}
fn send_async(
&self,
event: UiMessage,
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
let tx = self.tx.clone();
Box::pin(async move { tx.send(event).await.map_err(|e| SendError(e.0)) })
}
fn clone_box(&self) -> Box<dyn EventSink> {
Box::new(self.clone())
}
}
#[derive(Clone, Default)]
pub struct SimpleEventSink;
impl SimpleEventSink {
pub fn new() -> Self {
Self
}
}
impl EventSink for SimpleEventSink {
fn send(&self, event: UiMessage) -> Result<(), SendError> {
use std::io::Write;
match &event {
UiMessage::TextChunk { text, .. } => {
print!("{}", text);
std::io::stdout().flush().ok();
}
UiMessage::Error { error, .. } => {
eprintln!("Error: {}", error);
}
UiMessage::Complete { .. } => {
println!();
}
UiMessage::ToolExecuting { display_name, .. } => {
println!("[Tool: {}]", display_name);
}
UiMessage::ToolCompleted {
error: Some(err), ..
} => {
eprintln!("[Tool error: {}]", err);
}
UiMessage::PermissionRequired { .. } => {
eprintln!(
"Warning: SimpleEventSink received permission request. Use AutoApprovePolicy to handle permissions automatically."
);
}
UiMessage::BatchPermissionRequired { .. } => {
eprintln!(
"Warning: SimpleEventSink received batch permission request. Use AutoApprovePolicy to handle permissions automatically."
);
}
UiMessage::UserInteractionRequired { .. } => {
eprintln!(
"Warning: SimpleEventSink received user interaction request. Use AutoApprovePolicy to auto-cancel interactions."
);
}
_ => {
}
}
Ok(())
}
fn clone_box(&self) -> Box<dyn EventSink> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_channel_event_sink_send() {
let (tx, mut rx) = mpsc::channel(10);
let sink = ChannelEventSink::new(tx);
let event = UiMessage::System {
session_id: 1,
message: "test".to_string(),
};
sink.send(event).unwrap();
let received = rx.recv().await.unwrap();
match received {
UiMessage::System {
session_id,
message,
} => {
assert_eq!(session_id, 1);
assert_eq!(message, "test");
}
_ => panic!("unexpected message type"),
}
}
#[tokio::test]
async fn test_channel_event_sink_send_async() {
let (tx, mut rx) = mpsc::channel(10);
let sink = ChannelEventSink::new(tx);
let event = UiMessage::System {
session_id: 2,
message: "async test".to_string(),
};
sink.send_async(event).await.unwrap();
let received = rx.recv().await.unwrap();
match received {
UiMessage::System {
session_id,
message,
} => {
assert_eq!(session_id, 2);
assert_eq!(message, "async test");
}
_ => panic!("unexpected message type"),
}
}
#[test]
fn test_channel_event_sink_full_channel() {
let (tx, _rx) = mpsc::channel(1);
let sink = ChannelEventSink::new(tx);
let event1 = UiMessage::System {
session_id: 1,
message: "first".to_string(),
};
sink.send(event1).unwrap();
let event2 = UiMessage::System {
session_id: 1,
message: "second".to_string(),
};
let result = sink.send(event2);
assert!(result.is_err());
}
#[test]
fn test_simple_event_sink_send() {
let sink = SimpleEventSink::new();
let events = vec![
UiMessage::TextChunk {
session_id: 1,
turn_id: None,
text: "hello".to_string(),
input_tokens: 0,
output_tokens: 0,
},
UiMessage::Complete {
session_id: 1,
turn_id: None,
input_tokens: 10,
output_tokens: 20,
stop_reason: None,
},
UiMessage::Error {
session_id: 1,
turn_id: None,
error: "test error".to_string(),
},
];
for event in events {
assert!(sink.send(event).is_ok());
}
}
#[test]
fn test_boxed_event_sink() {
let (tx, _rx) = mpsc::channel(10);
let sink: Box<dyn EventSink> = Box::new(ChannelEventSink::new(tx));
let event = UiMessage::System {
session_id: 1,
message: "boxed test".to_string(),
};
assert!(sink.send(event).is_ok());
let _cloned = sink.clone_box();
}
}