stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Mock transport for testing.

use async_trait::async_trait;
use std::collections::VecDeque;
use std::io::{self};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use super::traits::Transport;
use crate::error::ConnectError;

/// Mock transport for testing
///
/// Allows simulating various connection scenarios including failures,
/// latency, and pre-defined responses.
#[derive(Clone)]
pub struct MockTransport {
    inner: Arc<Mutex<MockTransportInner>>,
}

struct MockTransportInner {
    /// Number of times to fail before succeeding
    fail_count: usize,
    /// Current failure count
    current_failures: usize,
    /// Simulated latency
    latency: Option<Duration>,
    /// Pre-defined data to return on read
    read_data: VecDeque<Vec<u8>>,
    /// Captured write data
    write_data: Vec<u8>,
    /// Whether the connection is "closed"
    closed: bool,
    /// Custom error to return
    custom_error: Option<ConnectError>,
}

impl MockTransport {
    /// Create a new mock transport that always succeeds
    #[must_use]
    pub fn new() -> Self {
        Self {
            inner: Arc::new(Mutex::new(MockTransportInner {
                fail_count: 0,
                current_failures: 0,
                latency: None,
                read_data: VecDeque::new(),
                write_data: Vec::new(),
                closed: false,
                custom_error: None,
            })),
        }
    }

    /// Configure the transport to fail N times before succeeding
    ///
    /// # Panics
    ///
    /// Panics if the internal mutex is poisoned.
    #[must_use]
    pub fn fail_times(self, count: usize) -> Self {
        self.inner.lock().unwrap().fail_count = count;
        self
    }

    /// Configure simulated latency
    ///
    /// # Panics
    ///
    /// Panics if the internal mutex is poisoned.
    #[must_use]
    pub fn with_latency(self, latency: Duration) -> Self {
        self.inner.lock().unwrap().latency = Some(latency);
        self
    }

    /// Add data that will be returned on read
    ///
    /// # Panics
    ///
    /// Panics if the internal mutex is poisoned.
    #[must_use]
    pub fn with_read_data(self, data: impl Into<Vec<u8>>) -> Self {
        self.inner.lock().unwrap().read_data.push_back(data.into());
        self
    }

    /// Configure a custom error to return
    ///
    /// # Panics
    ///
    /// Panics if the internal mutex is poisoned.
    #[must_use]
    pub fn with_error(self, error: ConnectError) -> Self {
        self.inner.lock().unwrap().custom_error = Some(error);
        self
    }

    /// Get the data that was written to the mock
    ///
    /// # Panics
    ///
    /// Panics if the internal mutex is poisoned.
    #[must_use]
    pub fn get_written_data(&self) -> Vec<u8> {
        self.inner.lock().unwrap().write_data.clone()
    }

    /// Reset the mock state
    ///
    /// # Panics
    ///
    /// Panics if the internal mutex is poisoned.
    pub fn reset(&self) {
        let mut inner = self.inner.lock().unwrap();
        inner.current_failures = 0;
        inner.write_data.clear();
        inner.closed = false;
    }
}

impl Default for MockTransport {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl Transport for MockTransport {
    type Stream = MockStream;

    async fn connect(&self, host: &str, port: u16) -> Result<Self::Stream, ConnectError> {
        // Extract latency from lock scope first, then release lock before any await
        let latency = {
            let mut inner = self.inner.lock().unwrap();

            // Check for custom error
            if let Some(ref error) = inner.custom_error {
                return Err(error.clone());
            }

            // Check if we should fail
            if inner.current_failures < inner.fail_count {
                inner.current_failures += 1;
                return Err(ConnectError::Refused);
            }

            inner.latency
        }; // MutexGuard is dropped here

        // Simulate latency (no lock held)
        if let Some(lat) = latency {
            tokio::time::sleep(lat).await;
        }

        tracing::debug!(host = %host, port = %port, "Mock connection established");

        Ok(MockStream {
            inner: self.inner.clone(),
        })
    }

    fn name(&self) -> &'static str {
        "mock"
    }
}

/// Mock stream for testing
pub struct MockStream {
    inner: Arc<Mutex<MockTransportInner>>,
}

impl AsyncRead for MockStream {
    #[allow(clippy::significant_drop_tightening)]
    fn poll_read(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        let mut inner = self.inner.lock().unwrap();

        if inner.closed {
            return Poll::Ready(Ok(()));
        }

        if let Some(data) = inner.read_data.pop_front() {
            let len = std::cmp::min(buf.remaining(), data.len());
            buf.put_slice(&data[..len]);

            // Put back remaining data if any
            if len < data.len() {
                inner.read_data.push_front(data[len..].to_vec());
            }
        }

        Poll::Ready(Ok(()))
    }
}

impl AsyncWrite for MockStream {
    #[allow(clippy::significant_drop_tightening)]
    fn poll_write(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let mut inner = self.inner.lock().unwrap();

        if inner.closed {
            return Poll::Ready(Err(io::Error::new(
                io::ErrorKind::BrokenPipe,
                "Connection closed",
            )));
        }

        inner.write_data.extend_from_slice(buf);
        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    #[allow(clippy::significant_drop_tightening)]
    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        let mut inner = self.inner.lock().unwrap();
        inner.closed = true;
        Poll::Ready(Ok(()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_mock_transport_success() {
        let transport = MockTransport::new();
        let result = transport.connect("localhost", 8080).await;
        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn test_mock_transport_fail_then_succeed() {
        let transport = MockTransport::new().fail_times(2);

        // First two attempts fail
        assert!(transport.connect("localhost", 8080).await.is_err());
        assert!(transport.connect("localhost", 8080).await.is_err());

        // Third attempt succeeds
        assert!(transport.connect("localhost", 8080).await.is_ok());
    }

    #[tokio::test]
    async fn test_mock_transport_custom_error() {
        let transport = MockTransport::new().with_error(ConnectError::InvalidUri("test".into()));
        let result = transport.connect("localhost", 8080).await;
        assert!(matches!(result, Err(ConnectError::InvalidUri(_))));
    }
}