use super::Transport;
use async_trait::async_trait;
use parking_lot::Mutex;
use std::collections::VecDeque;
pub struct MockTransport {
expected_sends: Mutex<VecDeque<Vec<u8>>>,
queued_responses: Mutex<VecDeque<Vec<u8>>>,
queued_interrupts: Mutex<VecDeque<Vec<u8>>>,
actual_sends: Mutex<Vec<Vec<u8>>>,
}
impl MockTransport {
#[must_use]
pub fn new() -> Self {
Self {
expected_sends: Mutex::new(VecDeque::new()),
queued_responses: Mutex::new(VecDeque::new()),
queued_interrupts: Mutex::new(VecDeque::new()),
actual_sends: Mutex::new(Vec::new()),
}
}
pub fn expect_send(&self, data: Vec<u8>) {
self.expected_sends.lock().push_back(data);
}
pub fn queue_response(&self, data: Vec<u8>) {
self.queued_responses.lock().push_back(data);
}
pub fn queue_interrupt(&self, data: Vec<u8>) {
self.queued_interrupts.lock().push_back(data);
}
pub fn verify(&self) -> Result<(), String> {
let expected = self.expected_sends.lock();
let responses = self.queued_responses.lock();
let interrupts = self.queued_interrupts.lock();
let mut errors = Vec::new();
if !expected.is_empty() {
errors.push(format!(
"{} expected send(s) were not received",
expected.len()
));
}
if !responses.is_empty() {
errors.push(format!(
"{} queued response(s) were not consumed",
responses.len()
));
}
if !interrupts.is_empty() {
errors.push(format!(
"{} queued interrupt(s) were not consumed",
interrupts.len()
));
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.join("; "))
}
}
pub fn get_sends(&self) -> Vec<Vec<u8>> {
self.actual_sends.lock().clone()
}
pub fn reset(&self) {
self.expected_sends.lock().clear();
self.queued_responses.lock().clear();
self.queued_interrupts.lock().clear();
self.actual_sends.lock().clear();
}
}
impl Default for MockTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Transport for MockTransport {
async fn send_bulk(&self, data: &[u8]) -> Result<(), crate::Error> {
self.actual_sends.lock().push(data.to_vec());
let expected = self.expected_sends.lock().pop_front();
if let Some(expected_data) = expected {
if data != expected_data.as_slice() {
return Err(crate::Error::invalid_data(format!(
"send mismatch: expected {:?}, got {:?}",
expected_data, data
)));
}
}
Ok(())
}
async fn receive_bulk(&self, _max_size: usize) -> Result<Vec<u8>, crate::Error> {
self.queued_responses
.lock()
.pop_front()
.ok_or(crate::Error::NoDevice)
}
async fn receive_interrupt(&self) -> Result<Vec<u8>, crate::Error> {
self.queued_interrupts
.lock()
.pop_front()
.ok_or(crate::Error::NoDevice)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_send_receive() {
let mock = MockTransport::new();
mock.queue_response(vec![1, 2, 3]);
mock.send_bulk(&[4, 5, 6]).await.unwrap();
let response = mock.receive_bulk(100).await.unwrap();
assert_eq!(response, vec![1, 2, 3]);
assert_eq!(mock.get_sends(), vec![vec![4, 5, 6]]);
}
#[tokio::test]
async fn test_expected_send_matches() {
let mock = MockTransport::new();
mock.expect_send(vec![1, 2, 3]);
mock.send_bulk(&[1, 2, 3]).await.unwrap();
mock.verify().unwrap();
}
#[tokio::test]
async fn test_expected_send_mismatch() {
let mock = MockTransport::new();
mock.expect_send(vec![1, 2, 3]);
let result = mock.send_bulk(&[1, 2, 4]).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_verify_fails_with_unconsumed_expectations() {
let mock = MockTransport::new();
mock.expect_send(vec![1, 2, 3]);
let result = mock.verify();
assert!(result.is_err());
}
#[tokio::test]
async fn test_verify_fails_with_unconsumed_responses() {
let mock = MockTransport::new();
mock.queue_response(vec![1, 2, 3]);
let result = mock.verify();
assert!(result.is_err());
}
#[tokio::test]
async fn test_receive_bulk_empty_queue_returns_error() {
let mock = MockTransport::new();
let result = mock.receive_bulk(100).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_receive_interrupt() {
let mock = MockTransport::new();
mock.queue_interrupt(vec![10, 20, 30]);
let result = mock.receive_interrupt().await.unwrap();
assert_eq!(result, vec![10, 20, 30]);
}
#[tokio::test]
async fn test_receive_interrupt_empty_queue_returns_error() {
let mock = MockTransport::new();
let result = mock.receive_interrupt().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reset_clears_all_state() {
let mock = MockTransport::new();
mock.expect_send(vec![1, 2, 3]);
mock.queue_response(vec![4, 5, 6]);
mock.queue_interrupt(vec![7, 8, 9]);
mock.send_bulk(&[10, 11, 12]).await.ok();
mock.reset();
assert!(mock.get_sends().is_empty());
mock.verify().unwrap(); }
#[tokio::test]
async fn test_multiple_sends_and_responses() {
let mock = MockTransport::new();
mock.expect_send(vec![1, 2]);
mock.expect_send(vec![3, 4]);
mock.queue_response(vec![5, 6]);
mock.queue_response(vec![7, 8]);
mock.send_bulk(&[1, 2]).await.unwrap();
mock.send_bulk(&[3, 4]).await.unwrap();
let r1 = mock.receive_bulk(100).await.unwrap();
let r2 = mock.receive_bulk(100).await.unwrap();
assert_eq!(r1, vec![5, 6]);
assert_eq!(r2, vec![7, 8]);
mock.verify().unwrap();
}
#[tokio::test]
async fn test_default_impl() {
let mock = MockTransport::default();
mock.queue_response(vec![1, 2, 3]);
let response = mock.receive_bulk(100).await.unwrap();
assert_eq!(response, vec![1, 2, 3]);
}
}