use crate::mocks::shared::{ContentWrapper, InnerWrapper};
use anyhow::{anyhow, bail};
use futures::{Sink, Stream, ready};
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::Mutex;
pub fn mock_streams<T>() -> (MockStream<T>, MockStream<T>)
where
T: Send,
{
let ch1 = MockStream::default();
let ch2 = ch1.make_connection();
(ch1, ch2)
}
pub struct MockStream<T: 'static> {
tx: InnerWrapper<VecDeque<T>>,
rx: InnerWrapper<VecDeque<T>>,
}
impl<T> MockStream<T> {
pub fn clone_tx_buffer(&self) -> Arc<Mutex<ContentWrapper<VecDeque<T>>>>
where
T: Send,
{
self.tx.clone_buffer()
}
pub fn clone_rx_buffer(&self) -> Arc<Mutex<ContentWrapper<VecDeque<T>>>>
where
T: Send,
{
self.rx.clone_buffer()
}
fn make_connection(&self) -> Self
where
T: Send,
{
MockStream {
tx: self.rx.cloned_buffer(),
rx: self.tx.cloned_buffer(),
}
}
}
impl<T> Default for MockStream<T> {
fn default() -> Self {
MockStream {
tx: InnerWrapper::default(),
rx: InnerWrapper::default(),
}
}
}
impl<T> Stream for MockStream<T>
where
T: Send,
{
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
ready!(Pin::new(&mut self.rx).poll_guard_ready(cx));
#[allow(clippy::unwrap_used)]
let guard = self.rx.guard().unwrap();
let Some(next) = guard.content.pop_front() else {
guard.waker = Some(cx.waker().clone());
self.rx.transition_to_idle();
return Poll::Pending;
};
if !guard.content.is_empty() {
cx.waker().wake_by_ref();
} else {
self.rx.transition_to_idle();
}
Poll::Ready(Some(next))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let Ok(guard) = self.rx.buffer.try_lock() else {
return (0, None);
};
let items = guard.content.len();
(items, Some(items))
}
}
impl<T> Sink<T> for MockStream<T>
where
T: Send,
{
type Error = anyhow::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(Pin::new(&mut self.tx).poll_guard_ready(cx));
Poll::Ready(Ok(()))
}
fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let Some(guard) = self.tx.guard() else {
bail!("invalid lock state to send messages");
};
guard.content.push_back(item);
Ok(())
}
fn poll_flush(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let Some(guard) = self.tx.guard() else {
return Poll::Ready(Err(anyhow!("invalid lock state to send/flush messages")));
};
if let Some(waker) = guard.waker.take() {
waker.wake();
}
self.tx.transition_to_idle();
Poll::Ready(Ok(()))
}
fn poll_close(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.tx.transition_to_idle();
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{SinkExt, StreamExt};
#[tokio::test]
async fn basic() {
let (mut stream1, mut stream2) = mock_streams();
stream1.send("foomp").await.unwrap();
let received = stream2.next().await.unwrap();
assert_eq!(received, "foomp");
stream2.send("bar").await.unwrap();
let received = stream1.next().await.unwrap();
assert_eq!(received, "bar");
}
}