use core::sync::atomic::{AtomicBool, Ordering};
use alloc::collections::VecDeque;
use alloc::vec::Vec;
use embedded_io::ErrorType;
use embedded_io_async::{Read, Write};
use crate::error::TransportError;
use crate::transport::MqttTransport;
#[derive(Debug)]
pub struct MockTransport {
incoming_data: VecDeque<u8>,
outgoing_data: Vec<u8>,
connected: AtomicBool,
read_errors: VecDeque<TransportError>,
write_errors: VecDeque<TransportError>,
local_addr: Option<String>,
remote_addr: Option<String>,
simulate_slow_read: bool,
simulate_connection_loss: bool,
}
impl Default for MockTransport {
fn default() -> Self {
Self::new()
}
}
impl MockTransport {
pub fn new() -> Self {
Self {
incoming_data: VecDeque::new(),
outgoing_data: Vec::new(),
connected: AtomicBool::new(true),
read_errors: VecDeque::new(),
write_errors: VecDeque::new(),
local_addr: None,
remote_addr: None,
simulate_slow_read: false,
simulate_connection_loss: false,
}
}
pub fn new_disconnected() -> Self {
let transport = Self::new();
transport.connected.store(false, Ordering::Relaxed);
transport
}
pub fn add_incoming_data(&mut self, data: &[u8]) {
self.incoming_data.extend(data.iter().copied());
}
pub fn add_incoming_packet(&mut self, packet_data: &[u8]) {
self.add_incoming_data(packet_data);
}
pub fn get_outgoing_data(&self) -> &[u8] {
&self.outgoing_data
}
pub fn clear_outgoing_data(&mut self) {
self.outgoing_data.clear();
}
pub fn get_last_outgoing_bytes(&self, n: usize) -> &[u8] {
let start = self.outgoing_data.len().saturating_sub(n);
&self.outgoing_data[start..]
}
pub fn outgoing_contains(&self, expected: &[u8]) -> bool {
self.outgoing_data
.windows(expected.len())
.any(|window| window == expected)
}
pub fn outgoing_ends_with(&self, expected: &[u8]) -> bool {
self.outgoing_data.ends_with(expected)
}
pub fn add_read_error(&mut self, error: TransportError) {
self.read_errors.push_back(error);
}
pub fn add_write_error(&mut self, error: TransportError) {
self.write_errors.push_back(error);
}
pub fn set_connected(&mut self, connected: bool) {
self.connected.store(connected, Ordering::Relaxed);
}
pub fn disconnect(&mut self) {
self.set_connected(false);
}
pub fn reconnect(&mut self) {
self.set_connected(true);
}
pub fn set_local_addr(&mut self, addr: String) {
self.local_addr = Some(addr);
}
pub fn set_remote_addr(&mut self, addr: String) {
self.remote_addr = Some(addr);
}
pub fn set_simulate_slow_read(&mut self, enable: bool) {
self.simulate_slow_read = enable;
}
pub fn set_simulate_connection_loss(&mut self, enable: bool) {
self.simulate_connection_loss = enable;
}
pub fn incoming_data_len(&self) -> usize {
self.incoming_data.len()
}
pub fn outgoing_data_len(&self) -> usize {
self.outgoing_data.len()
}
pub fn clear_incoming_data(&mut self) {
self.incoming_data.clear();
}
pub fn clear_errors(&mut self) {
self.read_errors.clear();
self.write_errors.clear();
}
pub fn reset(&mut self) {
self.incoming_data.clear();
self.outgoing_data.clear();
self.read_errors.clear();
self.write_errors.clear();
self.connected.store(true, Ordering::Relaxed);
self.simulate_slow_read = false;
self.simulate_connection_loss = false;
}
}
impl ErrorType for MockTransport {
type Error = TransportError;
}
impl MqttTransport for MockTransport {
async fn close(&mut self) -> Result<(), TransportError> {
self.connected.store(false, Ordering::Relaxed);
Ok(())
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Relaxed)
}
fn local_addr(&self) -> Option<&str> {
self.local_addr.as_deref()
}
fn remote_addr(&self) -> Option<&str> {
self.remote_addr.as_deref()
}
}
impl Read for MockTransport {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if self.simulate_connection_loss {
self.connected.store(false, Ordering::Relaxed);
return Err(TransportError::ConnectionLost);
}
if !self.is_connected() {
return Err(TransportError::ConnectionLost);
}
if let Some(error) = self.read_errors.pop_front() {
return Err(error);
}
if self.simulate_slow_read {
return Err(TransportError::Timeout);
}
let bytes_to_read = core::cmp::min(buf.len(), self.incoming_data.len());
if bytes_to_read == 0 {
return Ok(0); }
for i in 0..bytes_to_read {
buf[i] = self.incoming_data.pop_front().unwrap();
}
Ok(bytes_to_read)
}
}
impl Write for MockTransport {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if self.simulate_connection_loss {
self.connected.store(false, Ordering::Relaxed);
return Err(TransportError::ConnectionLost);
}
if !self.is_connected() {
return Err(TransportError::ConnectionLost);
}
if let Some(error) = self.write_errors.pop_front() {
return Err(error);
}
self.outgoing_data.extend_from_slice(buf);
Ok(buf.len())
}
async fn flush(&mut self) -> Result<(), Self::Error> {
if !self.is_connected() {
return Err(TransportError::ConnectionLost);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_transport_basic_operations() {
let transport = MockTransport::new();
assert!(transport.is_connected());
assert_eq!(transport.incoming_data_len(), 0);
assert_eq!(transport.outgoing_data_len(), 0);
}
#[test]
fn test_mock_transport_data_capture() {
let mut transport = MockTransport::new();
transport.add_incoming_data(b"hello");
assert_eq!(transport.incoming_data_len(), 5);
assert_eq!(transport.outgoing_data_len(), 0);
}
#[test]
fn test_mock_transport_connection_state() {
let mut transport = MockTransport::new();
assert!(transport.is_connected());
transport.disconnect();
assert!(!transport.is_connected());
transport.reconnect();
assert!(transport.is_connected());
}
#[test]
fn test_mock_transport_addresses() {
let mut transport = MockTransport::new();
assert!(transport.local_addr().is_none());
assert!(transport.remote_addr().is_none());
transport.set_local_addr("127.0.0.1:1234".to_string());
transport.set_remote_addr("broker.example.com:1883".to_string());
assert_eq!(transport.local_addr(), Some("127.0.0.1:1234"));
assert_eq!(transport.remote_addr(), Some("broker.example.com:1883"));
}
#[test]
fn test_mock_transport_reset() {
let mut transport = MockTransport::new();
transport.add_incoming_data(b"test");
transport.add_read_error(TransportError::Timeout);
transport.disconnect();
transport.reset();
assert!(transport.is_connected());
assert_eq!(transport.incoming_data_len(), 0);
assert_eq!(transport.read_errors.len(), 0);
}
}