#[cfg(test)]
use proptest::prelude::*;
use crate::bytes::BytesMut;
use crate::http::h2::{Header as HpackHeader, HpackDecoder, HpackEncoder};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MockTcpConnection {
pub local_addr: String,
pub remote_addr: String,
pub state: TcpState,
pub bytes_sent: u64,
pub bytes_received: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TcpState {
Closed,
Listen,
SynSent,
SynReceived,
Established,
FinWait1,
FinWait2,
CloseWait,
Closing,
LastAck,
TimeWait,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MockUdpSocket {
pub local_addr: String,
pub sent_packets: Vec<MockUdpPacket>,
pub received_packets: Vec<MockUdpPacket>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MockUdpPacket {
pub sequence: u64,
pub data: Vec<u8>,
pub timestamp: u64,
pub dest_addr: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MockWebSocketFrame {
pub opcode: u8,
pub payload: Vec<u8>,
pub mask: Option<[u8; 4]>,
pub fin: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MockDnsQuery {
pub domain: String,
pub record_type: DnsRecordType,
pub query_id: u16,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DnsRecordType {
A,
Aaaa,
Mx,
Txt,
Cname,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MockDnsResponse {
pub query_id: u16,
pub answers: Vec<String>,
pub status: DnsStatus,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DnsStatus {
Success,
NotFound,
ServerFailure,
Timeout,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MockTlsHandshake {
pub version: TlsVersion,
pub cipher_suite: String,
pub client_random: [u8; 32],
pub server_random: [u8; 32],
pub completed: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TlsVersion {
Tls12,
Tls13,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MockHttpMessage {
pub method: String,
pub path: String,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
pub version: HttpVersion,
}
#[derive(Debug, Clone, PartialEq)]
pub enum HttpVersion {
Http1_1,
Http2,
Http3,
}
#[derive(Debug)]
pub struct HpackRoundTripContext {
encoder: HpackEncoder,
decoder: HpackDecoder,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MockH2Frame {
pub frame_type: u8,
pub stream_id: u32,
pub flags: u8,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MockH3Frame {
pub frame_type: u64,
pub payload: Vec<u8>,
pub stream_id: u64,
}
impl MockTcpConnection {
pub fn new(local: &str, remote: &str) -> Self {
Self {
local_addr: local.to_string(),
remote_addr: remote.to_string(),
state: TcpState::Closed,
bytes_sent: 0,
bytes_received: 0,
}
}
pub fn connect(&mut self) -> bool {
match self.state {
TcpState::Closed => {
self.state = TcpState::SynSent;
true
}
_ => false,
}
}
pub fn establish(&mut self) -> bool {
match self.state {
TcpState::SynSent => {
self.state = TcpState::Established;
true
}
_ => false,
}
}
pub fn close(&mut self) -> bool {
match self.state {
TcpState::Established => {
self.state = TcpState::FinWait1;
true
}
TcpState::FinWait1 => {
self.state = TcpState::Closed;
true
}
_ => false,
}
}
pub fn is_closed(&self) -> bool {
self.state == TcpState::Closed
}
pub fn can_reestablish(&self) -> bool {
self.is_closed()
}
}
impl MockUdpSocket {
pub fn new(local_addr: &str) -> Self {
Self {
local_addr: local_addr.to_string(),
sent_packets: Vec::new(),
received_packets: Vec::new(),
}
}
pub fn send(&mut self, packet: MockUdpPacket) {
self.sent_packets.push(packet);
}
pub fn receive(&mut self, packet: MockUdpPacket) {
self.received_packets.push(packet);
}
pub fn ordering_preserved(&self) -> bool {
let received_sequences: Vec<u64> =
self.received_packets.iter().map(|p| p.sequence).collect();
received_sequences.windows(2).all(|w| w[0] <= w[1])
}
pub fn sent_count(&self) -> usize {
self.sent_packets.len()
}
pub fn received_count(&self) -> usize {
self.received_packets.len()
}
}
impl MockWebSocketFrame {
pub fn new(opcode: u8, payload: Vec<u8>, fin: bool) -> Self {
Self {
opcode,
payload,
mask: None,
fin,
}
}
pub fn apply_mask(&mut self, mask: [u8; 4]) {
self.mask = Some(mask);
for (i, byte) in self.payload.iter_mut().enumerate() {
*byte ^= mask[i % 4];
}
}
pub fn remove_mask(&mut self) {
if let Some(mask) = self.mask {
for (i, byte) in self.payload.iter_mut().enumerate() {
*byte ^= mask[i % 4];
}
self.mask = None;
}
}
pub fn mask_roundtrip_preserves_payload(&self, original_payload: &[u8]) -> bool {
self.mask.is_none() && self.payload == original_payload
}
}
impl MockDnsQuery {
pub fn lookup(&self) -> MockDnsResponse {
let answers = match self.domain.as_str() {
"example.com" => vec!["93.184.216.34".to_string()],
"localhost" => vec!["127.0.0.1".to_string()],
_ => vec![],
};
let status = if answers.is_empty() {
DnsStatus::NotFound
} else {
DnsStatus::Success
};
MockDnsResponse {
query_id: self.query_id,
answers,
status,
}
}
}
impl MockTlsHandshake {
pub fn new(version: TlsVersion, cipher_suite: &str) -> Self {
Self {
version,
cipher_suite: cipher_suite.to_string(),
client_random: [0u8; 32],
server_random: [0u8; 32],
completed: false,
}
}
pub fn complete(&mut self, client_random: [u8; 32], server_random: [u8; 32]) -> bool {
if !self.completed {
self.client_random = client_random;
self.server_random = server_random;
self.completed = true;
true
} else {
false
}
}
pub fn security_properties_maintained(&self) -> bool {
self.completed
&& self.client_random != [0u8; 32]
&& self.server_random != [0u8; 32]
&& !self.cipher_suite.is_empty()
}
}
impl MockHttpMessage {
pub fn new(method: &str, path: &str, version: HttpVersion) -> Self {
Self {
method: method.to_string(),
path: path.to_string(),
headers: Vec::new(),
body: Vec::new(),
version,
}
}
pub fn add_header(&mut self, name: &str, value: &str) {
self.headers.push((name.to_string(), value.to_string()));
}
pub fn set_body(&mut self, body: Vec<u8>) {
self.body = body;
}
pub fn encode_h1(&self) -> Vec<u8> {
let mut encoded = Vec::new();
encoded.extend(format!("{} {} HTTP/1.1\r\n", self.method, self.path).bytes());
for (name, value) in &self.headers {
encoded.extend(format!("{}: {}\r\n", name, value).bytes());
}
encoded.extend(b"\r\n");
encoded.extend(&self.body);
encoded
}
pub fn decode_h1(data: &[u8]) -> Option<Self> {
let text = String::from_utf8_lossy(data);
let lines: Vec<&str> = text.lines().collect();
if lines.is_empty() {
return None;
}
let parts: Vec<&str> = lines[0].split(' ').collect();
if parts.len() < 3 {
return None;
}
let mut message = Self::new(parts[0], parts[1], HttpVersion::Http1_1);
let mut i = 1;
while i < lines.len() && !lines[i].is_empty() {
if let Some(colon_pos) = lines[i].find(':') {
let name = lines[i][..colon_pos].trim();
let value = lines[i][colon_pos + 1..].trim();
message.add_header(name, value);
}
i += 1;
}
Some(message)
}
}
impl HpackRoundTripContext {
pub fn new(max_size: usize) -> Self {
Self {
encoder: HpackEncoder::with_max_size(max_size),
decoder: HpackDecoder::with_max_size(max_size),
}
}
pub fn encode_decode(&mut self, headers: &[HpackHeader]) -> Result<Vec<HpackHeader>, String> {
let mut encoded = BytesMut::new();
self.encoder.encode(headers, &mut encoded);
if encoded.is_empty() {
return Err("production HPACK encoder produced an empty header block".to_string());
}
let mut bytes = encoded.freeze();
self.decoder
.decode(&mut bytes)
.map_err(|err| format!("production HPACK decoder rejected encoded block: {err}"))
}
}
impl MockH2Frame {
pub fn new(frame_type: u8, stream_id: u32) -> Self {
Self {
frame_type,
stream_id,
flags: 0,
payload: Vec::new(),
}
}
pub fn set_payload(&mut self, payload: Vec<u8>) {
self.payload = payload;
}
pub fn serialize(&self) -> Vec<u8> {
let mut data = Vec::new();
data.extend(&(self.payload.len() as u32).to_be_bytes()[1..4]); data.push(self.frame_type);
data.push(self.flags);
data.extend(&self.stream_id.to_be_bytes());
data.extend(&self.payload);
data
}
pub fn deserialize(data: &[u8]) -> Option<Self> {
if data.len() < 9 {
return None;
}
let length = u32::from_be_bytes([0, data[0], data[1], data[2]]) as usize;
let frame_type = data[3];
let flags = data[4];
let stream_id = u32::from_be_bytes([data[5], data[6], data[7], data[8]]);
if data.len() < 9 + length {
return None;
}
let payload = data[9..9 + length].to_vec();
Some(Self {
frame_type,
stream_id,
flags,
payload,
})
}
}
impl MockH3Frame {
pub fn new(frame_type: u64, stream_id: u64) -> Self {
Self {
frame_type,
payload: Vec::new(),
stream_id,
}
}
pub fn set_payload(&mut self, payload: Vec<u8>) {
self.payload = payload;
}
pub fn serialize(&self) -> Vec<u8> {
let mut data = Vec::new();
data.extend(&self.frame_type.to_be_bytes());
data.extend(&(self.payload.len() as u64).to_be_bytes());
data.extend(&self.stream_id.to_be_bytes());
data.extend(&self.payload);
data
}
pub fn deserialize(data: &[u8]) -> Option<Self> {
if data.len() < 24 {
return None;
}
let frame_type = u64::from_be_bytes([
data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
]);
let length = u64::from_be_bytes([
data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
]) as usize;
let stream_id = u64::from_be_bytes([
data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
]);
if data.len() < 24 + length {
return None;
}
let payload = data[24..24 + length].to_vec();
Some(Self {
frame_type,
payload,
stream_id,
})
}
}
#[test]
fn test_mr_tcp_connect_close_round_trip() {
proptest!(|(
local_addr: String,
remote_addr: String,
cycles in 1u32..=5
)| {
for _ in 0..cycles {
let mut conn = MockTcpConnection::new(&local_addr, &remote_addr);
prop_assert!(conn.is_closed(), "Connection should start closed");
let connect_result = conn.connect();
prop_assert!(connect_result, "Connect should succeed from closed state");
let establish_result = conn.establish();
prop_assert!(establish_result, "Establish should succeed after connect");
let first_close = conn.close();
prop_assert!(first_close, "First close should succeed");
let second_close = conn.close();
prop_assert!(second_close, "Second close should complete transition");
prop_assert!(conn.can_reestablish(),
"Connection should be reestablishable after close cycle");
}
});
}
#[test]
fn test_mr_udp_send_recv_invariants() {
proptest!(|(
local_addr: String,
packets_data: Vec<(u64, Vec<u8>, String)>,
delivery_rate in 0.0f64..=1.0f64
)| {
let mut socket = MockUdpSocket::new(&local_addr);
let mut expected_sequences = Vec::new();
for (i, (sequence, data, dest)) in packets_data.iter().enumerate() {
let packet = MockUdpPacket {
sequence: *sequence,
data: data.clone(),
timestamp: i as u64,
dest_addr: dest.clone(),
};
socket.send(packet.clone());
expected_sequences.push(*sequence);
if (i as f64 / packets_data.len() as f64) < delivery_rate {
socket.receive(packet);
}
}
prop_assert!(
socket.received_count() <= socket.sent_count(),
"Received count {} exceeds sent count {}",
socket.received_count(), socket.sent_count()
);
if !socket.received_packets.is_empty() {
prop_assert!(
socket.ordering_preserved(),
"UDP packet ordering not preserved in received packets"
);
}
});
}
#[test]
fn test_mr_websocket_frame_mask_reversibility() {
proptest!(|(
opcode in 0u8..=15u8,
payload: Vec<u8>,
fin: bool,
mask_key: [u8; 4]
)| {
let original_payload = payload.clone();
let mut frame = MockWebSocketFrame::new(opcode, payload, fin);
frame.apply_mask(mask_key);
let masked_payload = frame.payload.clone();
if !original_payload.is_empty() && mask_key != [0, 0, 0, 0] {
let payloads_differ = masked_payload != original_payload;
prop_assert!(payloads_differ || original_payload.iter().enumerate().all(|(i, &b)| {
b == mask_key[i % 4]
}), "Masking should change payload (unless XOR cancellation)");
}
frame.remove_mask();
prop_assert!(
frame.mask_roundtrip_preserves_payload(&original_payload),
"WebSocket mask round-trip failed: original={:?}, final={:?}",
original_payload, frame.payload
);
prop_assert_eq!(frame.opcode, opcode, "Opcode should be preserved");
prop_assert_eq!(frame.fin, fin, "FIN flag should be preserved");
prop_assert_eq!(frame.mask, None, "Mask should be removed");
});
}
#[test]
fn test_mr_dns_lookup_determinism() {
proptest!(|(
domain: String,
record_type_idx in 0usize..5,
query_id1: u16,
query_id2: u16
)| {
let record_types = [
DnsRecordType::A, DnsRecordType::Aaaa, DnsRecordType::Mx,
DnsRecordType::Txt, DnsRecordType::Cname
];
let record_type = record_types[record_type_idx].clone();
let query1 = MockDnsQuery {
domain: domain.clone(),
record_type: record_type.clone(),
query_id: query_id1,
};
let query2 = MockDnsQuery {
domain,
record_type,
query_id: query_id2,
};
let response1 = query1.lookup();
let response2 = query2.lookup();
prop_assert_eq!(
response1.answers, response2.answers,
"DNS lookup determinism violated: same query produced different answers"
);
prop_assert_eq!(
response1.status, response2.status,
"DNS lookup determinism violated: same query produced different status"
);
prop_assert_eq!(response1.query_id, query_id1, "Query ID 1 should be preserved");
prop_assert_eq!(response2.query_id, query_id2, "Query ID 2 should be preserved");
});
}
#[test]
fn test_mr_tls_handshake_completion_property() {
proptest!(|(
version_idx in 0usize..2,
cipher_suite: String,
client_random: [u8; 32],
server_random: [u8; 32]
)| {
let versions = [TlsVersion::Tls12, TlsVersion::Tls13];
let version = versions[version_idx].clone();
let mut handshake = MockTlsHandshake::new(version.clone(), &cipher_suite);
if cipher_suite.is_empty() || client_random == [0u8; 32] || server_random == [0u8; 32] {
prop_assert!(
!handshake.security_properties_maintained(),
"Incomplete handshake should not maintain security properties"
);
}
let completion_result = handshake.complete(client_random, server_random);
if completion_result && !cipher_suite.is_empty() &&
client_random != [0u8; 32] && server_random != [0u8; 32] {
prop_assert!(
handshake.security_properties_maintained(),
"Completed TLS handshake should maintain security properties"
);
}
let second_completion = handshake.complete(client_random, server_random);
prop_assert!(
!second_completion,
"TLS handshake completion should be idempotent"
);
});
}
#[test]
fn test_mr_h1_codec_encode_decode_round_trip() {
proptest!(|(
method: String,
path: String,
headers: Vec<(String, String)>,
body: Vec<u8>
)| {
let mut original_message = MockHttpMessage::new(&method, &path, HttpVersion::Http1_1);
for (name, value) in &headers {
if !name.is_empty() && !name.contains(':') && !name.contains('\r') && !name.contains('\n') &&
!value.contains('\r') && !value.contains('\n') {
original_message.add_header(name, value);
}
}
original_message.set_body(body);
let encoded = original_message.encode_h1();
if let Some(decoded_message) = MockHttpMessage::decode_h1(&encoded) {
prop_assert_eq!(
decoded_message.method, original_message.method,
"HTTP method should be preserved in H1 codec round-trip"
);
prop_assert_eq!(
decoded_message.path, original_message.path,
"HTTP path should be preserved in H1 codec round-trip"
);
prop_assert_eq!(
decoded_message.headers.len(), original_message.headers.len(),
"HTTP headers count should be preserved in H1 codec round-trip"
);
for (orig_name, orig_value) in &original_message.headers {
let found = decoded_message.headers.iter()
.any(|(dec_name, dec_value)| dec_name == orig_name && dec_value == orig_value);
prop_assert!(found,
"Header {}:{} should be preserved in H1 codec round-trip", orig_name, orig_value);
}
}
});
}
#[test]
fn test_mr_h2_hpack_encode_decode_round_trip() {
proptest!(|(
headers: Vec<(String, String)>,
max_table_size in 1024usize..=8192usize
)| {
let valid_headers: Vec<HpackHeader> = headers
.iter()
.filter(|(name, value)| is_valid_hpack_header(name, value))
.take(16)
.map(|(name, value)| HpackHeader::new(name.clone(), value.clone()))
.collect();
if !valid_headers.is_empty() {
let mut hpack_context = HpackRoundTripContext::new(max_table_size);
let decoded_headers = hpack_context
.encode_decode(&valid_headers)
.map_err(TestCaseError::fail)?;
prop_assert_eq!(
decoded_headers,
valid_headers,
"production HPACK encode/decode should preserve valid headers"
);
}
});
}
fn is_valid_hpack_header(name: &str, value: &str) -> bool {
!name.is_empty()
&& name.len() <= 64
&& value.len() <= 256
&& name
.bytes()
.all(|byte| byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'-')
&& !value.bytes().any(|byte| matches!(byte, 0 | b'\r' | b'\n'))
}
#[test]
fn test_mr_h2_frame_serialization_round_trip() {
proptest!(|(
frame_type in 0u8..=10u8,
stream_id: u32,
flags in 0u8..=255u8,
payload: Vec<u8>
)| {
let payload = if payload.len() > 1024 { payload[..1024].to_vec() } else { payload };
let mut original_frame = MockH2Frame::new(frame_type, stream_id);
original_frame.flags = flags;
original_frame.set_payload(payload.clone());
let serialized = original_frame.serialize();
if let Some(deserialized_frame) = MockH2Frame::deserialize(&serialized) {
prop_assert_eq!(
deserialized_frame.frame_type, original_frame.frame_type,
"H2 frame type should be preserved"
);
prop_assert_eq!(
deserialized_frame.stream_id, original_frame.stream_id,
"H2 stream ID should be preserved"
);
prop_assert_eq!(
deserialized_frame.flags, original_frame.flags,
"H2 frame flags should be preserved"
);
prop_assert_eq!(
deserialized_frame.payload, original_frame.payload,
"H2 frame payload should be preserved"
);
} else {
prop_assert!(serialized.len() >= 9, "H2 frame serialization should produce at least 9 bytes");
}
});
}
#[test]
fn test_mr_h3_frame_round_trip() {
proptest!(|(
frame_type: u64,
stream_id: u64,
payload: Vec<u8>
)| {
let payload = if payload.len() > 512 { payload[..512].to_vec() } else { payload };
let mut original_frame = MockH3Frame::new(frame_type, stream_id);
original_frame.set_payload(payload.clone());
let serialized = original_frame.serialize();
if let Some(deserialized_frame) = MockH3Frame::deserialize(&serialized) {
prop_assert_eq!(
deserialized_frame.frame_type, original_frame.frame_type,
"H3 frame type should be preserved: {} != {}",
deserialized_frame.frame_type, original_frame.frame_type
);
prop_assert_eq!(
deserialized_frame.stream_id, original_frame.stream_id,
"H3 stream ID should be preserved: {} != {}",
deserialized_frame.stream_id, original_frame.stream_id
);
prop_assert_eq!(
deserialized_frame.payload.clone(), original_frame.payload.clone(),
"H3 frame payload should be preserved"
);
} else {
prop_assert!(serialized.len() >= 24, "H3 frame serialization should produce at least 24 bytes");
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_implementations() {
let mut conn = MockTcpConnection::new("127.0.0.1:8080", "127.0.0.1:9090");
assert!(conn.is_closed());
assert!(conn.connect());
assert!(conn.establish());
assert!(conn.close());
let mut frame = MockWebSocketFrame::new(1, vec![1, 2, 3, 4], true);
let original = frame.payload.clone();
frame.apply_mask([0xAA, 0xBB, 0xCC, 0xDD]);
frame.remove_mask();
assert_eq!(frame.payload, original);
let query = MockDnsQuery {
domain: "example.com".to_string(),
record_type: DnsRecordType::A,
query_id: 12345,
};
let response = query.lookup();
assert_eq!(response.query_id, 12345);
assert!(!response.answers.is_empty());
}
}