use std::future::Future;
use std::pin::Pin;
use tokio::sync::mpsc;
use crate::controller::ControllerInputPayload;
pub trait InputSource: Send + 'static {
fn recv(&mut self)
-> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>>;
}
pub struct ChannelInputSource {
rx: mpsc::Receiver<ControllerInputPayload>,
}
impl ChannelInputSource {
pub fn new(rx: mpsc::Receiver<ControllerInputPayload>) -> Self {
Self { rx }
}
pub fn channel(buffer: usize) -> (mpsc::Sender<ControllerInputPayload>, Self) {
let (tx, rx) = mpsc::channel(buffer);
(tx, Self::new(rx))
}
}
impl InputSource for ChannelInputSource {
fn recv(
&mut self,
) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>> {
Box::pin(async move { self.rx.recv().await })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::controller::TurnId;
#[tokio::test]
async fn test_channel_input_source_recv() {
let (tx, mut source) = ChannelInputSource::channel(10);
let payload = ControllerInputPayload::data(1, "hello", TurnId::new_user_turn(1));
tx.send(payload).await.unwrap();
let received = source.recv().await.unwrap();
assert_eq!(received.session_id, 1);
assert_eq!(received.content, "hello");
}
#[tokio::test]
async fn test_channel_input_source_closed() {
let (tx, mut source) = ChannelInputSource::channel(10);
drop(tx);
let received = source.recv().await;
assert!(received.is_none());
}
#[tokio::test]
async fn test_channel_input_source_multiple() {
let (tx, mut source) = ChannelInputSource::channel(10);
for i in 0..3 {
let payload = ControllerInputPayload::data(
1,
format!("msg {}", i),
TurnId::new_user_turn(i as i64),
);
tx.send(payload).await.unwrap();
}
for i in 0..3 {
let received = source.recv().await.unwrap();
assert_eq!(received.content, format!("msg {}", i));
}
}
}