use async_trait::async_trait;
use std::collections::VecDeque;
use std::io::{self};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::traits::Transport;
use crate::error::ConnectError;
#[derive(Clone)]
pub struct MockTransport {
inner: Arc<Mutex<MockTransportInner>>,
}
struct MockTransportInner {
fail_count: usize,
current_failures: usize,
latency: Option<Duration>,
read_data: VecDeque<Vec<u8>>,
write_data: Vec<u8>,
closed: bool,
custom_error: Option<ConnectError>,
}
impl MockTransport {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(MockTransportInner {
fail_count: 0,
current_failures: 0,
latency: None,
read_data: VecDeque::new(),
write_data: Vec::new(),
closed: false,
custom_error: None,
})),
}
}
#[must_use]
pub fn fail_times(self, count: usize) -> Self {
self.inner.lock().unwrap().fail_count = count;
self
}
#[must_use]
pub fn with_latency(self, latency: Duration) -> Self {
self.inner.lock().unwrap().latency = Some(latency);
self
}
#[must_use]
pub fn with_read_data(self, data: impl Into<Vec<u8>>) -> Self {
self.inner.lock().unwrap().read_data.push_back(data.into());
self
}
#[must_use]
pub fn with_error(self, error: ConnectError) -> Self {
self.inner.lock().unwrap().custom_error = Some(error);
self
}
#[must_use]
pub fn get_written_data(&self) -> Vec<u8> {
self.inner.lock().unwrap().write_data.clone()
}
pub fn reset(&self) {
let mut inner = self.inner.lock().unwrap();
inner.current_failures = 0;
inner.write_data.clear();
inner.closed = false;
}
}
impl Default for MockTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Transport for MockTransport {
type Stream = MockStream;
async fn connect(&self, host: &str, port: u16) -> Result<Self::Stream, ConnectError> {
let latency = {
let mut inner = self.inner.lock().unwrap();
if let Some(ref error) = inner.custom_error {
return Err(error.clone());
}
if inner.current_failures < inner.fail_count {
inner.current_failures += 1;
return Err(ConnectError::Refused);
}
inner.latency
};
if let Some(lat) = latency {
tokio::time::sleep(lat).await;
}
tracing::debug!(host = %host, port = %port, "Mock connection established");
Ok(MockStream {
inner: self.inner.clone(),
})
}
fn name(&self) -> &'static str {
"mock"
}
}
pub struct MockStream {
inner: Arc<Mutex<MockTransportInner>>,
}
impl AsyncRead for MockStream {
#[allow(clippy::significant_drop_tightening)]
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut inner = self.inner.lock().unwrap();
if inner.closed {
return Poll::Ready(Ok(()));
}
if let Some(data) = inner.read_data.pop_front() {
let len = std::cmp::min(buf.remaining(), data.len());
buf.put_slice(&data[..len]);
if len < data.len() {
inner.read_data.push_front(data[len..].to_vec());
}
}
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for MockStream {
#[allow(clippy::significant_drop_tightening)]
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut inner = self.inner.lock().unwrap();
if inner.closed {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Connection closed",
)));
}
inner.write_data.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[allow(clippy::significant_drop_tightening)]
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut inner = self.inner.lock().unwrap();
inner.closed = true;
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_transport_success() {
let transport = MockTransport::new();
let result = transport.connect("localhost", 8080).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_mock_transport_fail_then_succeed() {
let transport = MockTransport::new().fail_times(2);
assert!(transport.connect("localhost", 8080).await.is_err());
assert!(transport.connect("localhost", 8080).await.is_err());
assert!(transport.connect("localhost", 8080).await.is_ok());
}
#[tokio::test]
async fn test_mock_transport_custom_error() {
let transport = MockTransport::new().with_error(ConnectError::InvalidUri("test".into()));
let result = transport.connect("localhost", 8080).await;
assert!(matches!(result, Err(ConnectError::InvalidUri(_))));
}
}