#![cfg_attr(coverage_nightly, coverage(off))]
use crate::transport::{PmcpTransportWrapper, TransportAdapter, TransportError};
use pmcp::transport::WebSocketTransport;
use std::fmt::Debug;
use tracing::{debug, info};
#[derive(Debug)]
pub struct WebSocketTransportAdapter {
wrapper: PmcpTransportWrapper<WebSocketTransport>,
}
impl WebSocketTransportAdapter {
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn connect(url: &str) -> Result<Self, TransportError> {
info!("Connecting to WebSocket at {}", url);
let inner = WebSocketTransport::connect(url)
.await
.map_err(|e| TransportError::Connection(format!("WebSocket connection failed: {}", e)))?;
let wrapper = PmcpTransportWrapper::new(inner);
debug!("WebSocket connection established");
Ok(Self { wrapper })
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn from_stream(stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Self {
debug!("Creating WebSocket transport from accepted stream");
let inner = WebSocketTransport::from_stream(stream);
let wrapper = PmcpTransportWrapper::new(inner);
Self { wrapper }
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn serve(addr: &str) -> Result<WebSocketServer, TransportError> {
info!("Starting WebSocket server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| TransportError::Connection(format!("Failed to bind: {}", e)))?;
Ok(WebSocketServer { listener })
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn boxed(url: &str) -> Result<Box<dyn TransportAdapter>, TransportError> {
Ok(Box::new(Self::connect(url).await?))
}
}
#[async_trait::async_trait]
impl TransportAdapter for WebSocketTransportAdapter {
async fn send(&mut self, message: pmcp::transport::TransportMessage) -> Result<(), TransportError> {
self.wrapper.send(message).await
}
async fn receive(&mut self) -> Result<pmcp::transport::TransportMessage, TransportError> {
self.wrapper.receive().await
}
async fn close(&mut self) -> Result<(), TransportError> {
self.wrapper.close().await
}
fn is_connected(&self) -> bool {
self.wrapper.is_connected()
}
fn transport_type(&self) -> &'static str {
"websocket"
}
}
pub struct WebSocketServer {
listener: tokio::net::TcpListener,
}
impl WebSocketServer {
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn accept(&mut self) -> Result<WebSocketTransportAdapter, TransportError> {
let (stream, addr) = self.listener
.accept()
.await
.map_err(|e| TransportError::Connection(format!("Accept failed: {}", e)))?;
info!("Accepting WebSocket connection from {}", addr);
let ws_stream = tokio_tungstenite::accept_async(stream)
.await
.map_err(|e| TransportError::Connection(format!("WebSocket handshake failed: {}", e)))?;
Ok(WebSocketTransportAdapter::from_stream(ws_stream))
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn test_websocket_frame_fragmentation(data in prop::collection::vec(0u8..255, 1..10000)) {
prop_assert!(!data.is_empty());
}
#[test]
fn test_websocket_url_validation(
scheme in prop::sample::select(vec!["ws", "wss", "http", "https", "ftp"]),
host in "[a-z]{1,10}",
port in 1u16..65535
) {
let url = format!("{}://{}:{}", scheme, host, port);
let should_be_valid = scheme == "ws" || scheme == "wss";
prop_assert_eq!(url.starts_with("ws://") || url.starts_with("wss://"), should_be_valid);
}
}
#[tokio::test]
async fn test_websocket_server_bind() {
let result = WebSocketTransportAdapter::serve("127.0.0.1:0").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_websocket_connection_drop_recovery() {
assert!(true);
}
#[test]
fn test_websocket_transport_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<WebSocketTransportAdapter>();
}
#[test]
fn test_websocket_url_schemes() {
let valid_ws = "ws://localhost:8080";
let valid_wss = "wss://localhost:8080";
assert!(valid_ws.starts_with("ws://"));
assert!(valid_wss.starts_with("wss://"));
let invalid_http = "http://localhost:8080";
assert!(!invalid_http.starts_with("ws://") && !invalid_http.starts_with("wss://"));
}
#[test]
fn test_websocket_port_ranges() {
let standard_port = 80;
let secure_port = 443;
let common_dev_port = 8080;
assert!(standard_port < 65536);
assert!(secure_port < 65536);
assert!(common_dev_port < 65536);
assert!(standard_port > 0);
}
#[tokio::test]
async fn test_server_bind_ephemeral_port() {
let server = WebSocketTransportAdapter::serve("127.0.0.1:0").await;
assert!(server.is_ok());
}
#[tokio::test]
async fn test_server_bind_invalid_address() {
let result = WebSocketTransportAdapter::serve("invalid-host:8080").await;
assert!(result.is_err());
if let Err(TransportError::Connection(msg)) = result {
assert!(msg.contains("bind"));
}
}
#[test]
fn test_transport_error_variants() {
let conn_err = TransportError::Connection("test".to_string());
assert!(matches!(conn_err, TransportError::Connection(_)));
let send_err = TransportError::Send("test".to_string());
assert!(matches!(send_err, TransportError::Send(_)));
let recv_err = TransportError::Receive("test".to_string());
assert!(matches!(recv_err, TransportError::Receive(_)));
}
}