#![cfg_attr(coverage_nightly, coverage(off))]
use crate::transport::{TransportAdapter, TransportError};
use async_trait::async_trait;
use pmcp::transport::{Transport as PmcpTransport, TransportMessage};
use std::collections::VecDeque;
use std::fmt::Debug;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct MockTransport {
state: Arc<Mutex<MockState>>,
}
#[derive(Debug)]
struct MockState {
receive_queue: VecDeque<TransportMessage>,
sent_messages: Vec<TransportMessage>,
connected: bool,
next_error: Option<String>,
simulate_delay: bool,
delay_ms: u64,
}
impl MockTransport {
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(MockState {
receive_queue: VecDeque::new(),
sent_messages: Vec::new(),
connected: true,
next_error: None,
simulate_delay: false,
delay_ms: 10,
})),
}
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn queue_response(&mut self, message: TransportMessage) {
let mut state = self.state.lock().await;
state.receive_queue.push_back(message);
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn get_sent_messages(&self) -> Vec<TransportMessage> {
let state = self.state.lock().await;
state.sent_messages.clone()
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn inject_error(&mut self, error: impl Into<String>) {
let mut state = self.state.lock().await;
state.next_error = Some(error.into());
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn set_delay(&mut self, delay_ms: u64) {
let mut state = self.state.lock().await;
state.simulate_delay = true;
state.delay_ms = delay_ms;
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn disconnect(&mut self) {
let mut state = self.state.lock().await;
state.connected = false;
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn reset(&mut self) {
let mut state = self.state.lock().await;
state.receive_queue.clear();
state.sent_messages.clear();
state.next_error = None;
state.connected = true;
}
}
impl Default for MockTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PmcpTransport for MockTransport {
async fn send(&mut self, message: TransportMessage) -> pmcp::Result<()> {
let mut state = self.state.lock().await;
if let Some(error) = state.next_error.take() {
return Err(pmcp::Error::transport(error));
}
if !state.connected {
return Err(pmcp::Error::transport("Not connected"));
}
if state.simulate_delay {
tokio::time::sleep(tokio::time::Duration::from_millis(state.delay_ms)).await;
}
debug!("Mock transport sending message");
state.sent_messages.push(message);
Ok(())
}
async fn receive(&mut self) -> pmcp::Result<TransportMessage> {
let mut state = self.state.lock().await;
if let Some(error) = state.next_error.take() {
return Err(pmcp::Error::transport(error));
}
if !state.connected {
return Err(pmcp::Error::transport("Not connected"));
}
if state.simulate_delay {
tokio::time::sleep(tokio::time::Duration::from_millis(state.delay_ms)).await;
}
state.receive_queue
.pop_front()
.ok_or_else(|| pmcp::Error::transport("No messages in queue"))
}
async fn close(&mut self) -> pmcp::Result<()> {
let mut state = self.state.lock().await;
state.connected = false;
Ok(())
}
fn is_connected(&self) -> bool {
self.state.try_lock().map(|s| s.connected).unwrap_or(false)
}
fn transport_type(&self) -> &'static str {
"mock"
}
}
#[async_trait]
impl TransportAdapter for MockTransport {
async fn send(&mut self, message: TransportMessage) -> Result<(), TransportError> {
PmcpTransport::send(self, message)
.await
.map_err(|e| TransportError::Send(e.to_string()))
}
async fn receive(&mut self) -> Result<TransportMessage, TransportError> {
PmcpTransport::receive(self)
.await
.map_err(|e| TransportError::Receive(e.to_string()))
}
async fn close(&mut self) -> Result<(), TransportError> {
PmcpTransport::close(self)
.await
.map_err(|e| TransportError::Connection(e.to_string()))
}
fn is_connected(&self) -> bool {
PmcpTransport::is_connected(self)
}
fn transport_type(&self) -> &'static str {
PmcpTransport::transport_type(self)
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn test_message_ordering(messages in prop::collection::vec("\\PC+", 1..20)) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut transport = MockTransport::new();
for msg in &messages {
transport.queue_response(TransportMessage::text(msg)).await;
}
for expected in &messages {
let received = transport.receive().await.unwrap();
assert_eq!(received, TransportMessage::text(expected));
}
});
}
#[test]
fn test_sent_message_recording(messages in prop::collection::vec("\\PC+", 1..10)) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut transport = MockTransport::new();
for msg in &messages {
transport.send(TransportMessage::text(msg)).await.unwrap();
}
let sent = transport.get_sent_messages().await;
assert_eq!(sent.len(), messages.len());
for (sent_msg, expected) in sent.iter().zip(messages.iter()) {
assert_eq!(*sent_msg, TransportMessage::text(expected));
}
});
}
}
#[tokio::test]
async fn test_mock_transport_error_injection() {
let mut transport = MockTransport::new();
transport.inject_error("Test error").await;
let result = transport.send(TransportMessage::text("test")).await;
assert!(result.is_err());
let result = transport.send(TransportMessage::text("test")).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_mock_transport_connection_state() {
let mut transport = MockTransport::new();
assert!(transport.is_connected());
transport.disconnect().await;
assert!(!transport.is_connected());
let result = transport.send(TransportMessage::text("test")).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_transport_delay_simulation() {
let mut transport = MockTransport::new();
transport.set_delay(50).await;
let start = tokio::time::Instant::now();
transport.send(TransportMessage::text("test")).await.unwrap();
let elapsed = start.elapsed();
assert!(elapsed.as_millis() >= 50);
}
}