use std::collections::VecDeque;
use std::io;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::Notify;
use super::io::{AsyncIo, AsyncReconnect, AsyncStream};
use crate::errors::Error;
#[derive(Default)]
struct Inner {
inbound: VecDeque<Vec<u8>>,
outbound: Vec<u8>,
closed: bool,
reconnect_failures: usize,
}
#[derive(Clone, Default)]
pub(crate) struct MemoryStream {
inner: Arc<Mutex<Inner>>,
notify: Arc<Notify>,
}
impl MemoryStream {
pub fn push_inbound(&self, body: Vec<u8>) {
self.inner.lock().unwrap().inbound.push_back(body);
self.notify.notify_one();
}
pub fn captured(&self) -> Vec<u8> {
self.inner.lock().unwrap().outbound.clone()
}
pub fn close(&self) {
self.inner.lock().unwrap().closed = true;
self.notify.notify_waiters();
}
pub fn set_reconnect_failures(&self, count: usize) {
self.inner.lock().unwrap().reconnect_failures = count;
}
}
impl std::fmt::Debug for MemoryStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryStream").finish_non_exhaustive()
}
}
#[async_trait]
impl AsyncIo for MemoryStream {
async fn read_message(&self) -> Result<Vec<u8>, Error> {
loop {
let notified = self.notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
{
let mut inner = self.inner.lock().unwrap();
if let Some(body) = inner.inbound.pop_front() {
return Ok(body);
}
if inner.closed {
return Err(Error::Io(io::Error::new(io::ErrorKind::UnexpectedEof, "MemoryStream closed")));
}
}
notified.await;
}
}
async fn write_all(&self, buf: &[u8]) -> Result<(), Error> {
self.inner.lock().unwrap().outbound.extend_from_slice(buf);
Ok(())
}
}
#[async_trait]
impl AsyncReconnect for MemoryStream {
async fn reconnect(&self) -> Result<(), Error> {
let should_fail = {
let mut inner = self.inner.lock().unwrap();
if inner.reconnect_failures > 0 {
inner.reconnect_failures -= 1;
true
} else {
false
}
};
if should_fail {
Err(Error::Simple("simulated reconnect failure".into()))
} else {
Ok(())
}
}
async fn sleep(&self, _duration: Duration) {}
}
impl AsyncStream for MemoryStream {}
#[cfg(test)]
#[path = "memory_tests.rs"]
mod tests;