use polaris_core_plugins::{IOError, IOMessage, IOProvider};
use tokio::sync::mpsc;
#[derive(Debug)]
pub struct HttpIOProvider {
input_rx: tokio::sync::Mutex<mpsc::Receiver<IOMessage>>,
output_tx: mpsc::UnboundedSender<IOMessage>,
}
impl HttpIOProvider {
#[must_use]
pub fn new(
input_buffer: usize,
) -> (
Self,
mpsc::Sender<IOMessage>,
mpsc::UnboundedReceiver<IOMessage>,
) {
let (input_tx, input_rx) = mpsc::channel(input_buffer);
let (output_tx, output_rx) = mpsc::unbounded_channel();
let provider = Self {
input_rx: tokio::sync::Mutex::new(input_rx),
output_tx,
};
(provider, input_tx, output_rx)
}
}
impl IOProvider for HttpIOProvider {
async fn send(&self, message: IOMessage) -> Result<(), IOError> {
self.output_tx.send(message).map_err(|_| IOError::Closed)
}
async fn receive(&self) -> Result<IOMessage, IOError> {
self.input_rx
.lock()
.await
.recv()
.await
.ok_or(IOError::Closed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use polaris_core_plugins::IOContent;
#[tokio::test]
async fn send_and_receive() {
let (provider, input_tx, mut output_rx) = HttpIOProvider::new(8);
input_tx.send(IOMessage::user_text("hello")).await.unwrap();
let msg = provider.receive().await.unwrap();
assert!(matches!(msg.content, IOContent::Text(ref t) if t == "hello"));
provider
.send(IOMessage::system_text("response"))
.await
.unwrap();
let resp = output_rx.recv().await.unwrap();
assert!(matches!(resp.content, IOContent::Text(ref t) if t == "response"));
}
#[tokio::test]
async fn receive_returns_closed_when_sender_dropped() {
let (provider, input_tx, _output_rx) = HttpIOProvider::new(8);
drop(input_tx);
let result = provider.receive().await;
assert!(matches!(result, Err(IOError::Closed)));
}
#[tokio::test]
async fn send_returns_closed_when_receiver_dropped() {
let (provider, _input_tx, output_rx) = HttpIOProvider::new(8);
drop(output_rx);
let result = provider.send(IOMessage::system_text("msg")).await;
assert!(matches!(result, Err(IOError::Closed)));
}
#[tokio::test]
async fn multiple_messages_in_order() {
let (provider, input_tx, mut output_rx) = HttpIOProvider::new(8);
input_tx.send(IOMessage::user_text("a")).await.unwrap();
input_tx.send(IOMessage::user_text("b")).await.unwrap();
let a = provider.receive().await.unwrap();
assert!(matches!(a.content, IOContent::Text(ref t) if t == "a"));
let b = provider.receive().await.unwrap();
assert!(matches!(b.content, IOContent::Text(ref t) if t == "b"));
provider.send(IOMessage::system_text("x")).await.unwrap();
provider.send(IOMessage::system_text("y")).await.unwrap();
let x = output_rx.recv().await.unwrap();
assert!(matches!(x.content, IOContent::Text(ref t) if t == "x"));
let y = output_rx.recv().await.unwrap();
assert!(matches!(y.content, IOContent::Text(ref t) if t == "y"));
}
}