use anyhow::Result;
use tokio::sync::mpsc;
use crate::agent::streaming::{StreamConsumer, StreamEvent};
pub struct ChannelConsumer {
tx: mpsc::UnboundedSender<StreamEvent>,
}
impl ChannelConsumer {
pub fn new(tx: mpsc::UnboundedSender<StreamEvent>) -> Self {
Self { tx }
}
}
impl StreamConsumer for ChannelConsumer {
fn on_event(&self, event: &StreamEvent) -> Result<()> {
let _ = self.tx.send(event.clone());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_channel_consumer() {
let (tx, mut rx) = mpsc::unbounded_channel();
let consumer = ChannelConsumer::new(tx);
let event = StreamEvent::Content {
content: "test".to_string(),
};
consumer.on_event(&event).unwrap();
let received = rx.recv().await.unwrap();
match received {
StreamEvent::Content { content } => assert_eq!(content, "test"),
_ => panic!("Wrong event type"),
}
}
#[tokio::test]
async fn test_channel_consumer_multiple_events() {
let (tx, mut rx) = mpsc::unbounded_channel();
let consumer = ChannelConsumer::new(tx);
consumer
.on_event(&StreamEvent::SessionStarted {
session_id: "test".to_string(),
})
.unwrap();
consumer
.on_event(&StreamEvent::Content {
content: "hello".to_string(),
})
.unwrap();
consumer.on_event(&StreamEvent::Done).unwrap();
assert!(rx.recv().await.is_some());
assert!(rx.recv().await.is_some());
assert!(rx.recv().await.is_some());
}
#[test]
fn test_channel_consumer_closed_receiver() {
let (tx, rx) = mpsc::unbounded_channel();
let consumer = ChannelConsumer::new(tx);
drop(rx);
let result = consumer.on_event(&StreamEvent::Done);
assert!(result.is_ok());
}
}