use proptest::prelude::*;
use qail_pg::protocol::wire::{BackendMessage, FrontendMessage, PROTOCOL_VERSION_3_2};
fn arb_pg_string() -> impl Strategy<Value = String> {
"[a-zA-Z0-9_]{0,64}".prop_map(|s| s)
}
fn arb_param() -> impl Strategy<Value = Option<Vec<u8>>> {
prop_oneof![
3 => Just(None),
7 => proptest::collection::vec(any::<u8>(), 0..128).prop_map(Some),
]
}
fn arb_frontend_message() -> impl Strategy<Value = FrontendMessage> {
prop_oneof![
(arb_pg_string(), arb_pg_string()).prop_map(|(user, database)| FrontendMessage::Startup {
user,
database,
protocol_version: PROTOCOL_VERSION_3_2,
startup_params: Vec::new(),
}),
arb_pg_string().prop_map(FrontendMessage::Query),
(
arb_pg_string(),
arb_pg_string(),
proptest::collection::vec(any::<u32>(), 0..8),
)
.prop_map(|(name, query, param_types)| FrontendMessage::Parse {
name,
query,
param_types,
}),
(
arb_pg_string(),
arb_pg_string(),
proptest::collection::vec(arb_param(), 0..8),
)
.prop_map(|(portal, statement, params)| FrontendMessage::Bind {
portal,
statement,
params,
}),
(arb_pg_string(), 0i32..=i32::MAX)
.prop_map(|(portal, max_rows)| FrontendMessage::Execute { portal, max_rows }),
Just(FrontendMessage::Sync),
Just(FrontendMessage::Terminate),
arb_pg_string().prop_map(FrontendMessage::PasswordMessage),
(
arb_pg_string(),
proptest::collection::vec(any::<u8>(), 0..64)
)
.prop_map(|(mechanism, data)| FrontendMessage::SASLInitialResponse { mechanism, data }),
proptest::collection::vec(any::<u8>(), 0..64).prop_map(FrontendMessage::SASLResponse),
]
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn frontend_encode_valid_structure(msg in arb_frontend_message()) {
let bytes = msg
.encode_checked()
.expect("safe generated frontend message must encode");
match msg {
FrontendMessage::Startup { .. } => {
prop_assert!(bytes.len() >= 8, "Startup must be ≥8 bytes");
let declared_len = i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
prop_assert_eq!(bytes.len(), declared_len, "Startup length mismatch");
let version = i32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
prop_assert_eq!(version, PROTOCOL_VERSION_3_2, "Must use protocol version 3.2");
}
FrontendMessage::Sync | FrontendMessage::Terminate => {
prop_assert_eq!(bytes.len(), 5, "Sync/Terminate must be exactly 5 bytes");
let declared_len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
prop_assert_eq!(declared_len, 4, "Fixed messages have length = 4");
}
_ => {
prop_assert!(bytes.len() >= 5, "Messages must be ≥5 bytes");
let type_byte = bytes[0];
prop_assert!(
[b'Q', b'P', b'B', b'E', b'p', b'X', b'S'].contains(&type_byte),
"Unknown type byte: {}",
type_byte as char
);
let declared_len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
prop_assert_eq!(
bytes.len(),
declared_len + 1,
"Total length must be type_byte(1) + declared_length({})",
declared_len
);
}
}
}
#[test]
fn query_encode_null_terminated(sql in arb_pg_string()) {
let msg = FrontendMessage::Query(sql);
let bytes = msg.encode_checked().expect("query must encode");
prop_assert_eq!(*bytes.last().unwrap(), 0u8, "Query must be null-terminated");
}
#[test]
fn parse_encode_two_null_terminated_strings(
name in arb_pg_string(),
query in arb_pg_string(),
param_types in proptest::collection::vec(any::<u32>(), 0..4),
) {
let msg = FrontendMessage::Parse {
name,
query,
param_types: param_types.clone(),
};
let bytes = msg.encode_checked().expect("parse must encode");
let payload = &bytes[5..];
let null_positions: Vec<usize> = payload.iter().enumerate()
.filter(|&(_, b)| *b == 0)
.map(|(i, _)| i)
.collect();
prop_assert!(
null_positions.len() >= 2,
"Parse must have ≥2 null terminators, found {}",
null_positions.len()
);
let after_query_null = null_positions[1] + 1;
let remaining = &payload[after_query_null..];
prop_assert!(remaining.len() >= 2, "Must have param count after strings");
let param_count = i16::from_be_bytes([remaining[0], remaining[1]]) as usize;
prop_assert_eq!(param_count, param_types.len(), "Param count mismatch");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn backend_decode_never_panics(data in proptest::collection::vec(any::<u8>(), 0..256)) {
let _result = BackendMessage::decode(&data);
}
#[test]
fn backend_decode_rejects_short_buffers(data in proptest::collection::vec(any::<u8>(), 0..5)) {
let result = BackendMessage::decode(&data);
prop_assert!(result.is_err(), "Buffer < 5 bytes must be rejected");
}
#[test]
fn backend_decode_handles_truncated_messages(
msg_type in any::<u8>(),
extra_len in 10u32..1000u32,
payload in proptest::collection::vec(any::<u8>(), 0..8),
) {
let declared_len = (payload.len() as u32) + 4 + extra_len;
let mut buf = vec![msg_type];
buf.extend_from_slice(&(declared_len as i32).to_be_bytes());
buf.extend_from_slice(&payload);
let result = BackendMessage::decode(&buf);
prop_assert!(result.is_err(), "Truncated messages must return Err");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(200))]
#[test]
fn pg_encoder_query_valid_wire(sql in arb_pg_string()) {
use bytes::BytesMut;
use qail_pg::protocol::encoder::PgEncoder;
let buf: BytesMut = PgEncoder::try_encode_query_string(&sql).expect("safe sql must encode");
let bytes = &buf[..];
prop_assert_eq!(bytes[0], b'Q', "Query must start with 'Q'");
let declared_len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
prop_assert_eq!(bytes.len(), declared_len + 1, "Length must match actual size");
prop_assert_eq!(*bytes.last().unwrap(), 0u8, "Must be null-terminated");
}
#[test]
fn pg_encoder_extended_pipeline_valid(
sql in arb_pg_string(),
params in proptest::collection::vec(arb_param(), 0..4),
) {
use qail_pg::protocol::encoder::PgEncoder;
let result = PgEncoder::encode_extended_query(&sql, ¶ms);
match result {
Ok(buf) => {
let bytes = &buf[..];
prop_assert_eq!(bytes[0], b'P', "Extended query must start with Parse ('P')");
let mut pos = 0;
let mut msg_types = Vec::new();
while pos < bytes.len() {
let msg_type = bytes[pos];
prop_assert!(pos + 5 <= bytes.len(), "Truncated message at pos {}", pos);
let len = i32::from_be_bytes([
bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], bytes[pos + 4],
]) as usize;
prop_assert!(len >= 4, "Message length must be ≥4");
prop_assert!(
pos + 1 + len <= bytes.len(),
"Message at pos {} overflows buffer",
pos
);
msg_types.push(msg_type as char);
pos += 1 + len;
}
prop_assert_eq!(
msg_types,
vec!['P', 'B', 'E', 'S'],
"Extended query must be Parse+Bind+Execute+Sync"
);
}
Err(_e) => {
}
}
}
}