use std::collections::HashMap;
use std::fmt;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use fsqlite_types::ObjectId;
use fsqlite_types::encoding::{
append_u16_be, append_u32_be, append_u32_le, append_u64_be, append_u64_le, read_u16_be,
read_u32_be, read_u32_le, read_u64_be, read_u64_le,
};
use fsqlite_types::sync_primitives::Mutex;
pub const FRAME_MIN_LEN_BE: u32 = 12;
pub const FRAME_MAX_LEN_BE: u32 = 4 * 1024 * 1024;
pub const PROTOCOL_VERSION: u16 = 1;
pub const MAX_OUTSTANDING_PERMITS: usize = 16;
const CONSUMED_PERMIT_GC_MULTIPLIER: usize = 8;
pub const WIRE_WRITE_SET_MAX_BYTES: usize = 1024 * 1024;
pub const WIRE_WITNESS_EDGE_MAX: usize = 65_536;
const FRAME_HEADER_WIRE_BYTES: usize = 16;
const WIRE_TXN_TOKEN_BYTES: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MessageKind {
Reserve,
SubmitNativePublish,
SubmitWalCommit,
RowidReserve,
Response,
Ping,
Pong,
}
impl MessageKind {
#[must_use]
pub const fn to_u16(self) -> u16 {
match self {
Self::Reserve => 1,
Self::SubmitNativePublish => 2,
Self::SubmitWalCommit => 3,
Self::RowidReserve => 4,
Self::Response => 5,
Self::Ping => 6,
Self::Pong => 7,
}
}
#[must_use]
pub const fn from_u16(v: u16) -> Option<Self> {
match v {
1 => Some(Self::Reserve),
2 => Some(Self::SubmitNativePublish),
3 => Some(Self::SubmitWalCommit),
4 => Some(Self::RowidReserve),
5 => Some(Self::Response),
6 => Some(Self::Ping),
7 => Some(Self::Pong),
_ => None,
}
}
}
impl fmt::Display for MessageKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label = match self {
Self::Reserve => "RESERVE",
Self::SubmitNativePublish => "SUBMIT_NATIVE_PUBLISH",
Self::SubmitWalCommit => "SUBMIT_WAL_COMMIT",
Self::RowidReserve => "ROWID_RESERVE",
Self::Response => "RESPONSE",
Self::Ping => "PING",
Self::Pong => "PONG",
};
f.write_str(label)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FrameError {
TooShort,
LenTooSmall(u32),
LenTooLarge(u32),
UnknownVersion(u16),
UnknownKind(u16),
PayloadTruncated { expected: u32, actual: usize },
}
impl fmt::Display for FrameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooShort => f.write_str("frame buffer too short for header"),
Self::LenTooSmall(v) => write!(f, "len_be {v} below minimum {FRAME_MIN_LEN_BE}"),
Self::LenTooLarge(v) => write!(f, "len_be {v} exceeds cap {FRAME_MAX_LEN_BE}"),
Self::UnknownVersion(v) => write!(f, "unknown protocol version {v}"),
Self::UnknownKind(v) => write!(f, "unknown message kind {v}"),
Self::PayloadTruncated { expected, actual } => {
write!(
f,
"payload truncated: expected {expected} bytes, got {actual}"
)
}
}
}
}
impl std::error::Error for FrameError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Frame {
pub kind: MessageKind,
pub request_id: u64,
pub payload: Vec<u8>,
}
impl Frame {
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let payload_len = self.payload.len();
let max_payload_len = (FRAME_MAX_LEN_BE - FRAME_MIN_LEN_BE) as usize;
assert!(
payload_len <= max_payload_len,
"frame payload length {payload_len} exceeds max {max_payload_len}"
);
let len_be = FRAME_MIN_LEN_BE + u32::try_from(payload_len).expect("payload_len fits u32");
let mut buf = Vec::with_capacity(FRAME_HEADER_WIRE_BYTES + payload_len);
append_u32_be(&mut buf, len_be);
append_u16_be(&mut buf, PROTOCOL_VERSION);
append_u16_be(&mut buf, self.kind.to_u16());
append_u64_be(&mut buf, self.request_id);
buf.extend_from_slice(&self.payload);
buf
}
pub fn decode(buf: &[u8]) -> Result<Self, FrameError> {
if buf.len() < FRAME_HEADER_WIRE_BYTES {
return Err(FrameError::TooShort);
}
let len_be = read_u32_be(&buf[0..4]).ok_or(FrameError::TooShort)?;
if len_be < FRAME_MIN_LEN_BE {
return Err(FrameError::LenTooSmall(len_be));
}
if len_be > FRAME_MAX_LEN_BE {
return Err(FrameError::LenTooLarge(len_be));
}
let version = read_u16_be(&buf[4..6]).ok_or(FrameError::TooShort)?;
if version != PROTOCOL_VERSION {
return Err(FrameError::UnknownVersion(version));
}
let kind_raw = read_u16_be(&buf[6..8]).ok_or(FrameError::TooShort)?;
let kind = MessageKind::from_u16(kind_raw).ok_or(FrameError::UnknownKind(kind_raw))?;
let request_id = read_u64_be(&buf[8..16]).ok_or(FrameError::TooShort)?;
let payload_len = (len_be - FRAME_MIN_LEN_BE) as usize;
let remaining = &buf[FRAME_HEADER_WIRE_BYTES..];
if remaining.len() < payload_len {
return Err(FrameError::PayloadTruncated {
expected: len_be - FRAME_MIN_LEN_BE,
actual: remaining.len(),
});
}
let payload = remaining[..payload_len].to_vec();
Ok(Self {
kind,
request_id,
payload,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct WireTxnToken {
pub txn_id: u64,
pub txn_epoch: u32,
}
impl WireTxnToken {
#[must_use]
pub fn to_bytes(self) -> [u8; WIRE_TXN_TOKEN_BYTES] {
let mut buf = [0u8; WIRE_TXN_TOKEN_BYTES];
buf[..8].copy_from_slice(&self.txn_id.to_le_bytes());
buf[8..12].copy_from_slice(&self.txn_epoch.to_le_bytes());
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
let txn_id = read_u64_le(src.get(..8)?)?;
let txn_epoch = read_u32_le(src.get(8..12)?)?;
Some(Self { txn_id, txn_epoch })
}
#[must_use]
pub const fn idempotency_key(self) -> (u64, u32) {
(self.txn_id, self.txn_epoch)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReservePayload {
pub purpose: u8,
pub txn: WireTxnToken,
}
impl ReservePayload {
const WIRE_BYTES: usize = 24;
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(Self::WIRE_BYTES);
buf.push(self.purpose);
buf.extend_from_slice(&[0u8; 7]); buf.extend_from_slice(&self.txn.to_bytes());
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
if src.len() < Self::WIRE_BYTES {
return None;
}
let purpose = src[0];
let txn = WireTxnToken::from_bytes(&src[8..])?;
Some(Self { purpose, txn })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReserveResponse {
Ok { permit_id: u64 },
Busy { retry_after_ms: u32 },
Err { code: u32 },
}
impl ReserveResponse {
const TAG_OK: u8 = 0;
const TAG_BUSY: u8 = 1;
const TAG_ERR: u8 = 2;
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(16);
match self {
Self::Ok { permit_id } => {
buf.push(Self::TAG_OK);
buf.extend_from_slice(&[0u8; 7]); append_u64_le(&mut buf, *permit_id);
}
Self::Busy { retry_after_ms } => {
buf.push(Self::TAG_BUSY);
buf.extend_from_slice(&[0u8; 7]); append_u32_le(&mut buf, *retry_after_ms);
append_u32_le(&mut buf, 0); }
Self::Err { code } => {
buf.push(Self::TAG_ERR);
buf.extend_from_slice(&[0u8; 7]); append_u32_le(&mut buf, *code);
}
}
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
let tag = *src.first()?;
match tag {
Self::TAG_OK => {
let permit_id = read_u64_le(src.get(8..16)?)?;
Some(Self::Ok { permit_id })
}
Self::TAG_BUSY => {
let retry = read_u32_le(src.get(8..12)?)?;
Some(Self::Busy {
retry_after_ms: retry,
})
}
Self::TAG_ERR => {
let code = read_u32_le(src.get(8..12)?)?;
Some(Self::Err { code })
}
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RowidReservePayload {
pub txn: WireTxnToken,
pub schema_epoch: u64,
pub table_id: u32,
pub count: u32,
}
impl RowidReservePayload {
const WIRE_BYTES: usize = 32;
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(Self::WIRE_BYTES);
buf.extend_from_slice(&self.txn.to_bytes());
append_u64_le(&mut buf, self.schema_epoch);
append_u32_le(&mut buf, self.table_id);
append_u32_le(&mut buf, self.count);
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
if src.len() < Self::WIRE_BYTES {
return None;
}
let txn = WireTxnToken::from_bytes(src)?;
let schema_epoch = read_u64_le(&src[16..24])?;
let table_id = read_u32_le(&src[24..28])?;
let count = read_u32_le(&src[28..32])?;
Some(Self {
txn,
schema_epoch,
table_id,
count,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RowidReserveResponse {
Ok { start_rowid: u64, count: u32 },
Err { code: u32 },
}
impl RowidReserveResponse {
const TAG_OK: u8 = 0;
const TAG_ERR: u8 = 1;
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(24);
match self {
Self::Ok { start_rowid, count } => {
buf.push(Self::TAG_OK);
buf.extend_from_slice(&[0u8; 7]); append_u64_le(&mut buf, *start_rowid);
append_u32_le(&mut buf, *count);
append_u32_le(&mut buf, 0); }
Self::Err { code } => {
buf.push(Self::TAG_ERR);
buf.extend_from_slice(&[0u8; 7]); append_u32_le(&mut buf, *code);
}
}
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
let tag = *src.first()?;
match tag {
Self::TAG_OK => {
let start_rowid = read_u64_le(src.get(8..16)?)?;
let count = read_u32_le(src.get(16..20)?)?;
read_u32_le(src.get(20..24)?)?; Some(Self::Ok { start_rowid, count })
}
Self::TAG_ERR => {
let code = read_u32_le(src.get(8..12)?)?;
Some(Self::Err { code })
}
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SpillPageEntry {
pub pgno: u32,
pub offset: u64,
pub len: u32,
pub xxh3_64: u64,
}
impl SpillPageEntry {
const WIRE_BYTES: usize = 32;
#[must_use]
pub fn to_bytes(self) -> [u8; Self::WIRE_BYTES] {
let mut buf = [0u8; Self::WIRE_BYTES];
buf[..4].copy_from_slice(&self.pgno.to_le_bytes());
buf[8..16].copy_from_slice(&self.offset.to_le_bytes());
buf[16..20].copy_from_slice(&self.len.to_le_bytes());
buf[24..32].copy_from_slice(&self.xxh3_64.to_le_bytes());
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
if src.len() < Self::WIRE_BYTES {
return None;
}
Some(Self {
pgno: read_u32_le(&src[..4])?,
offset: read_u64_le(&src[8..16])?,
len: read_u32_le(&src[16..20])?,
xxh3_64: read_u64_le(&src[24..32])?,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SubmitNativePayload {
pub permit_id: u64,
pub txn: WireTxnToken,
pub begin_seq: u64,
pub capsule_object_id: ObjectId,
pub capsule_digest_32: [u8; 32],
pub write_set_summary: Vec<u32>,
pub read_witness_refs: Vec<ObjectId>,
pub write_witness_refs: Vec<ObjectId>,
pub edge_refs: Vec<ObjectId>,
pub merge_refs: Vec<ObjectId>,
pub abort_policy: u8,
}
impl SubmitNativePayload {
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn to_bytes(&self) -> Vec<u8> {
assert!(
validate_write_set_summary(&self.write_set_summary),
"write_set_summary exceeds wire cap"
);
assert!(
is_canonical_pages(&self.write_set_summary),
"write_set_summary must be sorted ascending with no duplicates"
);
assert!(
validate_witness_edge_counts(
self.read_witness_refs.len(),
self.write_witness_refs.len(),
self.edge_refs.len(),
self.merge_refs.len()
),
"witness/edge counts exceed wire cap"
);
assert!(
is_canonical_object_ids(&self.read_witness_refs),
"read_witness_refs must be sorted with no duplicates"
);
assert!(
is_canonical_object_ids(&self.write_witness_refs),
"write_witness_refs must be sorted with no duplicates"
);
assert!(
is_canonical_object_ids(&self.edge_refs),
"edge_refs must be sorted with no duplicates"
);
assert!(
is_canonical_object_ids(&self.merge_refs),
"merge_refs must be sorted with no duplicates"
);
let mut buf = Vec::with_capacity(256);
append_u64_le(&mut buf, self.permit_id);
buf.extend_from_slice(&self.txn.to_bytes());
append_u64_le(&mut buf, self.begin_seq);
buf.extend_from_slice(self.capsule_object_id.as_bytes());
buf.extend_from_slice(&self.capsule_digest_32);
let ws_count = u32::try_from(self.write_set_summary.len())
.expect("write_set_summary length must fit u32");
append_u32_le(&mut buf, ws_count);
for &pgno in &self.write_set_summary {
append_u32_le(&mut buf, pgno);
}
encode_object_id_array(&mut buf, &self.read_witness_refs);
encode_object_id_array(&mut buf, &self.write_witness_refs);
encode_object_id_array(&mut buf, &self.edge_refs);
encode_object_id_array(&mut buf, &self.merge_refs);
buf.push(self.abort_policy);
buf
}
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
let mut pos = 0usize;
let permit_id = read_u64_le(src.get(pos..pos + 8)?)?;
pos += 8;
let txn = WireTxnToken::from_bytes(src.get(pos..)?)?;
pos += WIRE_TXN_TOKEN_BYTES;
let begin_seq = read_u64_le(src.get(pos..pos + 8)?)?;
pos += 8;
let capsule_object_id = ObjectId::from_bytes(src.get(pos..pos + 16)?.try_into().ok()?);
pos += 16;
let capsule_digest_32: [u8; 32] = src.get(pos..pos + 32)?.try_into().ok()?;
pos += 32;
let ws_count = read_u32_le(src.get(pos..pos + 4)?)? as usize;
pos += 4;
if ws_count > (WIRE_WRITE_SET_MAX_BYTES / 4) {
return None;
}
let ws_bytes = ws_count.checked_mul(4)?;
if src.len() < pos.checked_add(ws_bytes)? {
return None;
}
let mut write_set_summary = Vec::with_capacity(ws_count);
for _ in 0..ws_count {
write_set_summary.push(read_u32_le(src.get(pos..pos + 4)?)?);
pos += 4;
}
let mut remaining = WIRE_WITNESS_EDGE_MAX;
let (read_witness_refs, new_pos) = decode_object_id_array(src, pos, remaining)?;
pos = new_pos;
if !is_canonical_object_ids(&read_witness_refs) {
return None;
}
remaining = remaining.saturating_sub(read_witness_refs.len());
let (write_witness_refs, new_pos) = decode_object_id_array(src, pos, remaining)?;
pos = new_pos;
if !is_canonical_object_ids(&write_witness_refs) {
return None;
}
remaining = remaining.saturating_sub(write_witness_refs.len());
let (edge_refs, new_pos) = decode_object_id_array(src, pos, remaining)?;
pos = new_pos;
if !is_canonical_object_ids(&edge_refs) {
return None;
}
remaining = remaining.saturating_sub(edge_refs.len());
let (merge_refs, new_pos) = decode_object_id_array(src, pos, remaining)?;
pos = new_pos;
if !is_canonical_object_ids(&merge_refs) {
return None;
}
if !is_canonical_pages(&write_set_summary) {
return None;
}
if !validate_write_set_summary(&write_set_summary) {
return None;
}
if !validate_witness_edge_counts(
read_witness_refs.len(),
write_witness_refs.len(),
edge_refs.len(),
merge_refs.len(),
) {
return None;
}
let abort_policy = *src.get(pos)?;
Some(Self {
permit_id,
txn,
begin_seq,
capsule_object_id,
capsule_digest_32,
write_set_summary,
read_witness_refs,
write_witness_refs,
edge_refs,
merge_refs,
abort_policy,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SubmitWalPayload {
pub permit_id: u64,
pub txn: WireTxnToken,
pub mode: u8,
pub snapshot_high: u64,
pub schema_epoch: u64,
pub has_in_rw: bool,
pub has_out_rw: bool,
pub wal_fec_r: u8,
pub spill_pages: Vec<SpillPageEntry>,
pub read_witness_refs: Vec<ObjectId>,
pub write_witness_refs: Vec<ObjectId>,
pub edge_refs: Vec<ObjectId>,
pub merge_refs: Vec<ObjectId>,
}
impl SubmitWalPayload {
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn to_bytes(&self) -> Vec<u8> {
assert!(
validate_witness_edge_counts(
self.read_witness_refs.len(),
self.write_witness_refs.len(),
self.edge_refs.len(),
self.merge_refs.len()
),
"witness/edge counts exceed wire cap"
);
assert!(
is_canonical_spill_pages(&self.spill_pages),
"spill_pages must be sorted by pgno with no duplicates"
);
assert!(
is_canonical_object_ids(&self.read_witness_refs),
"read_witness_refs must be sorted with no duplicates"
);
assert!(
is_canonical_object_ids(&self.write_witness_refs),
"write_witness_refs must be sorted with no duplicates"
);
assert!(
is_canonical_object_ids(&self.edge_refs),
"edge_refs must be sorted with no duplicates"
);
assert!(
is_canonical_object_ids(&self.merge_refs),
"merge_refs must be sorted with no duplicates"
);
let mut buf = Vec::with_capacity(256);
append_u64_le(&mut buf, self.permit_id);
buf.extend_from_slice(&self.txn.to_bytes());
buf.push(self.mode);
buf.extend_from_slice(&[0u8; 7]); append_u64_le(&mut buf, self.snapshot_high);
append_u64_le(&mut buf, self.schema_epoch);
buf.push(u8::from(self.has_in_rw));
buf.push(u8::from(self.has_out_rw));
buf.push(self.wal_fec_r);
buf.extend_from_slice(&[0u8; 5]); let sp_count =
u32::try_from(self.spill_pages.len()).expect("spill_pages length must fit u32");
append_u32_le(&mut buf, sp_count);
for sp in &self.spill_pages {
buf.extend_from_slice(&sp.to_bytes());
}
encode_object_id_array(&mut buf, &self.read_witness_refs);
encode_object_id_array(&mut buf, &self.write_witness_refs);
encode_object_id_array(&mut buf, &self.edge_refs);
encode_object_id_array(&mut buf, &self.merge_refs);
buf
}
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
let mut pos = 0usize;
let permit_id = read_u64_le(src.get(pos..pos + 8)?)?;
pos += 8;
let txn = WireTxnToken::from_bytes(src.get(pos..)?)?;
pos += WIRE_TXN_TOKEN_BYTES;
let mode = *src.get(pos)?;
pos += 8; let snapshot_high = read_u64_le(src.get(pos..pos + 8)?)?;
pos += 8;
let schema_epoch = read_u64_le(src.get(pos..pos + 8)?)?;
pos += 8;
let has_in_rw = *src.get(pos)? != 0;
pos += 1;
let has_out_rw = *src.get(pos)? != 0;
pos += 1;
let wal_fec_r = *src.get(pos)?;
pos += 6; let sp_count = read_u32_le(src.get(pos..pos + 4)?)? as usize;
pos += 4;
let min_tail = 4usize * 4;
let available = src.len().checked_sub(pos)?;
if available < min_tail {
return None;
}
let max_spill = available
.checked_sub(min_tail)?
.checked_div(SpillPageEntry::WIRE_BYTES)?;
if sp_count > max_spill {
return None;
}
let mut spill_pages = Vec::with_capacity(sp_count);
for _ in 0..sp_count {
spill_pages.push(SpillPageEntry::from_bytes(src.get(pos..)?)?);
pos += SpillPageEntry::WIRE_BYTES;
}
if !is_canonical_spill_pages(&spill_pages) {
return None;
}
let mut remaining = WIRE_WITNESS_EDGE_MAX;
let (read_witness_refs, new_pos) = decode_object_id_array(src, pos, remaining)?;
pos = new_pos;
if !is_canonical_object_ids(&read_witness_refs) {
return None;
}
remaining = remaining.saturating_sub(read_witness_refs.len());
let (write_witness_refs, new_pos) = decode_object_id_array(src, pos, remaining)?;
pos = new_pos;
if !is_canonical_object_ids(&write_witness_refs) {
return None;
}
remaining = remaining.saturating_sub(write_witness_refs.len());
let (edge_refs, new_pos) = decode_object_id_array(src, pos, remaining)?;
pos = new_pos;
if !is_canonical_object_ids(&edge_refs) {
return None;
}
remaining = remaining.saturating_sub(edge_refs.len());
let (merge_refs, _) = decode_object_id_array(src, pos, remaining)?;
if !is_canonical_object_ids(&merge_refs) {
return None;
}
if !validate_witness_edge_counts(
read_witness_refs.len(),
write_witness_refs.len(),
edge_refs.len(),
merge_refs.len(),
) {
return None;
}
Some(Self {
permit_id,
txn,
mode,
snapshot_high,
schema_epoch,
has_in_rw,
has_out_rw,
wal_fec_r,
spill_pages,
read_witness_refs,
write_witness_refs,
edge_refs,
merge_refs,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NativePublishResponse {
Ok { commit_seq: u64 },
Conflict { pages: Vec<u32>, reason: u8 },
Aborted { code: u32 },
Err { code: u32 },
}
impl NativePublishResponse {
const TAG_OK: u8 = 0;
const TAG_CONFLICT: u8 = 1;
const TAG_ABORTED: u8 = 2;
const TAG_ERR: u8 = 3;
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(32);
match self {
Self::Ok { commit_seq } => {
buf.push(Self::TAG_OK);
buf.extend_from_slice(&[0u8; 7]);
append_u64_le(&mut buf, *commit_seq);
}
Self::Conflict { pages, reason } => {
buf.push(Self::TAG_CONFLICT);
buf.push(*reason);
buf.extend_from_slice(&[0u8; 2]); let count = u32::try_from(pages.len()).expect("conflict pages length must fit u32");
append_u32_le(&mut buf, count);
for &p in pages {
append_u32_le(&mut buf, p);
}
}
Self::Aborted { code } => {
buf.push(Self::TAG_ABORTED);
buf.extend_from_slice(&[0u8; 3]);
append_u32_le(&mut buf, *code);
}
Self::Err { code } => {
buf.push(Self::TAG_ERR);
buf.extend_from_slice(&[0u8; 3]);
append_u32_le(&mut buf, *code);
}
}
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
let tag = *src.first()?;
match tag {
Self::TAG_OK => {
let commit_seq = read_u64_le(src.get(8..16)?)?;
Some(Self::Ok { commit_seq })
}
Self::TAG_CONFLICT => {
let reason = *src.get(1)?;
let count = read_u32_le(src.get(4..8)?)? as usize;
let max_pages = src.len().saturating_sub(8) / 4;
if count > max_pages {
return None;
}
let mut pages = Vec::with_capacity(count);
for i in 0..count {
let off = 8 + i * 4;
pages.push(read_u32_le(src.get(off..off + 4)?)?);
}
Some(Self::Conflict { pages, reason })
}
Self::TAG_ABORTED => {
let code = read_u32_le(src.get(4..8)?)?;
Some(Self::Aborted { code })
}
Self::TAG_ERR => {
let code = read_u32_le(src.get(4..8)?)?;
Some(Self::Err { code })
}
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WalCommitResponse {
Ok { commit_seq: u64 },
Conflict { pages: Vec<u32>, reason: u8 },
IoError { code: u32 },
Err { code: u32 },
}
impl WalCommitResponse {
const TAG_OK: u8 = 0;
const TAG_CONFLICT: u8 = 1;
const TAG_IO_ERROR: u8 = 2;
const TAG_ERR: u8 = 3;
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(32);
match self {
Self::Ok { commit_seq } => {
buf.push(Self::TAG_OK);
buf.extend_from_slice(&[0u8; 7]);
append_u64_le(&mut buf, *commit_seq);
}
Self::Conflict { pages, reason } => {
buf.push(Self::TAG_CONFLICT);
buf.push(*reason);
buf.extend_from_slice(&[0u8; 2]);
let count = u32::try_from(pages.len()).expect("conflict pages length must fit u32");
append_u32_le(&mut buf, count);
for &p in pages {
append_u32_le(&mut buf, p);
}
}
Self::IoError { code } => {
buf.push(Self::TAG_IO_ERROR);
buf.extend_from_slice(&[0u8; 3]);
append_u32_le(&mut buf, *code);
}
Self::Err { code } => {
buf.push(Self::TAG_ERR);
buf.extend_from_slice(&[0u8; 3]);
append_u32_le(&mut buf, *code);
}
}
buf
}
#[must_use]
pub fn from_bytes(src: &[u8]) -> Option<Self> {
let tag = *src.first()?;
match tag {
Self::TAG_OK => {
let commit_seq = read_u64_le(src.get(8..16)?)?;
Some(Self::Ok { commit_seq })
}
Self::TAG_CONFLICT => {
let reason = *src.get(1)?;
let count = read_u32_le(src.get(4..8)?)? as usize;
let max_pages = src.len().saturating_sub(8) / 4;
if count > max_pages {
return None;
}
let mut pages = Vec::with_capacity(count);
for i in 0..count {
let off = 8 + i * 4;
pages.push(read_u32_le(src.get(off..off + 4)?)?);
}
Some(Self::Conflict { pages, reason })
}
Self::TAG_IO_ERROR => {
let code = read_u32_le(src.get(4..8)?)?;
Some(Self::IoError { code })
}
Self::TAG_ERR => {
let code = read_u32_le(src.get(4..8)?)?;
Some(Self::Err { code })
}
_ => None,
}
}
}
fn encode_object_id_array(buf: &mut Vec<u8>, ids: &[ObjectId]) {
let count = u32::try_from(ids.len()).expect("ObjectId array length must fit u32");
append_u32_le(buf, count);
for id in ids {
buf.extend_from_slice(id.as_bytes());
}
}
fn decode_object_id_array(
src: &[u8],
pos: usize,
count_cap: usize,
) -> Option<(Vec<ObjectId>, usize)> {
let count = read_u32_le(src.get(pos..pos + 4)?)? as usize;
if count > count_cap {
return None;
}
let bytes_needed = count.checked_mul(16)?.checked_add(4)?;
let end = pos.checked_add(bytes_needed)?;
if end > src.len() {
return None;
}
let mut cur = pos + 4;
let mut ids = Vec::with_capacity(count);
for _ in 0..count {
let bytes: [u8; 16] = src.get(cur..cur + 16)?.try_into().ok()?;
ids.push(ObjectId::from_bytes(bytes));
cur += 16;
}
Some((ids, cur))
}
#[must_use]
pub fn is_canonical_pages(pages: &[u32]) -> bool {
pages.windows(2).all(|w| w[0] < w[1])
}
#[must_use]
pub fn is_canonical_spill_pages(spill_pages: &[SpillPageEntry]) -> bool {
spill_pages.windows(2).all(|w| w[0].pgno < w[1].pgno)
}
#[must_use]
pub fn is_canonical_object_ids(ids: &[ObjectId]) -> bool {
ids.windows(2).all(|w| w[0].as_bytes() < w[1].as_bytes())
}
#[must_use]
pub fn validate_write_set_summary(pages: &[u32]) -> bool {
let byte_len = pages.len().saturating_mul(4);
byte_len <= WIRE_WRITE_SET_MAX_BYTES && byte_len % 4 == 0
}
#[must_use]
pub fn validate_write_set_summary_raw_len(byte_len: usize) -> bool {
byte_len % 4 == 0 && byte_len <= WIRE_WRITE_SET_MAX_BYTES
}
#[must_use]
pub fn validate_witness_edge_counts(
read_w: usize,
write_w: usize,
edges: usize,
merges: usize,
) -> bool {
read_w
.saturating_add(write_w)
.saturating_add(edges)
.saturating_add(merges)
<= WIRE_WITNESS_EDGE_MAX
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PermitError {
Busy,
NotFound(u64),
AlreadyConsumed(u64),
}
impl fmt::Display for PermitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Busy => f.write_str("max outstanding permits reached"),
Self::NotFound(id) => write!(f, "permit {id} not found"),
Self::AlreadyConsumed(id) => write!(f, "permit {id} already consumed"),
}
}
}
impl std::error::Error for PermitError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PermitState {
Reserved,
Consumed,
}
pub struct PermitManager {
max_permits: usize,
next_id: AtomicU64,
reserved_count: AtomicUsize,
active: Mutex<HashMap<u64, PermitState>>,
}
impl PermitManager {
#[must_use]
pub fn new(max_permits: usize) -> Self {
Self {
max_permits,
next_id: AtomicU64::new(1),
reserved_count: AtomicUsize::new(0),
active: Mutex::new(HashMap::new()),
}
}
pub fn reserve(&self) -> Result<u64, PermitError> {
let mut active = self.active.lock();
if self.reserved_count.load(Ordering::Relaxed) >= self.max_permits {
drop(active);
return Err(PermitError::Busy);
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
active.insert(id, PermitState::Reserved);
self.reserved_count.fetch_add(1, Ordering::Relaxed);
drop(active);
Ok(id)
}
pub fn consume(&self, permit_id: u64) -> Result<(), PermitError> {
let mut active = self.active.lock();
let result = match active.get_mut(&permit_id) {
None => Err(PermitError::NotFound(permit_id)),
Some(PermitState::Consumed) => Err(PermitError::AlreadyConsumed(permit_id)),
Some(state @ PermitState::Reserved) => {
*state = PermitState::Consumed;
self.reserved_count.fetch_sub(1, Ordering::Relaxed);
let reserved = self.reserved_count.load(Ordering::Relaxed);
let consumed = active.len().saturating_sub(reserved);
let max_consumed_before_gc = self
.max_permits
.saturating_mul(CONSUMED_PERMIT_GC_MULTIPLIER)
.max(self.max_permits);
if consumed > max_consumed_before_gc {
active.retain(|_, permit_state| *permit_state == PermitState::Reserved);
}
Ok(())
}
};
drop(active);
result
}
pub fn release(&self, permit_id: u64) {
let mut active = self.active.lock();
if let Some(permit_state) = active.remove(&permit_id) {
if permit_state == PermitState::Reserved {
self.reserved_count.fetch_sub(1, Ordering::Relaxed);
}
}
}
#[must_use]
pub fn outstanding(&self) -> usize {
self.reserved_count.load(Ordering::Relaxed)
}
pub fn gc_consumed(&self) {
let mut active = self.active.lock();
active.retain(|_, s| *s == PermitState::Reserved);
self.reserved_count.store(active.len(), Ordering::Relaxed);
}
}
pub struct IdempotencyCache {
inner: Mutex<HashMap<(u64, u32), Vec<u8>>>,
}
impl IdempotencyCache {
#[must_use]
pub fn new() -> Self {
Self {
inner: Mutex::new(HashMap::new()),
}
}
#[must_use]
pub fn get(&self, txn_id: u64, txn_epoch: u32) -> Option<Vec<u8>> {
let cache = self.inner.lock();
cache.get(&(txn_id, txn_epoch)).cloned()
}
pub fn insert(&self, txn_id: u64, txn_epoch: u32, response: Vec<u8>) {
let mut cache = self.inner.lock();
cache.insert((txn_id, txn_epoch), response);
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.lock().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.lock().is_empty()
}
}
impl Default for IdempotencyCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PeerAuthError {
NoCreds,
UidMismatch { expected: u32, actual: u32 },
}
impl fmt::Display for PeerAuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoCreds => f.write_str("could not retrieve peer credentials"),
Self::UidMismatch { expected, actual } => {
write!(f, "UID mismatch: expected {expected}, got {actual}")
}
}
}
}
impl std::error::Error for PeerAuthError {}
#[cfg(target_os = "linux")]
pub fn authenticate_peer(
stream: &std::os::unix::net::UnixStream,
expected_uid: u32,
) -> Result<(), PeerAuthError> {
use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
let cred = getsockopt(stream, PeerCredentials).map_err(|_| PeerAuthError::NoCreds)?;
let actual_uid = cred.uid();
if actual_uid != expected_uid {
return Err(PeerAuthError::UidMismatch {
expected: expected_uid,
actual: actual_uid,
});
}
Ok(())
}
#[cfg(target_os = "linux")]
#[derive(Debug)]
pub struct ReceivedFd(std::os::unix::io::RawFd);
#[cfg(target_os = "linux")]
impl ReceivedFd {
#[must_use]
pub fn raw_fd(&self) -> std::os::unix::io::RawFd {
self.0
}
}
#[cfg(target_os = "linux")]
impl Drop for ReceivedFd {
fn drop(&mut self) {
let _ = nix::unistd::close(self.0);
}
}
#[cfg(target_os = "linux")]
pub fn send_with_fd(
stream: &std::os::unix::net::UnixStream,
data: &[u8],
fd: std::os::unix::io::RawFd,
) -> std::io::Result<usize> {
use nix::sys::socket::{ControlMessage, MsgFlags, sendmsg};
use std::io::IoSlice;
use std::io::Write as _;
use std::os::unix::io::AsRawFd;
if data.is_empty() {
return Err(std::io::Error::other(
"cannot send an fd with an empty data payload",
));
}
let iov = [IoSlice::new(data)];
let fds = [fd];
let cmsg = [ControlMessage::ScmRights(&fds)];
let mut sent = sendmsg::<()>(stream.as_raw_fd(), &iov, &cmsg, MsgFlags::empty(), None)
.map_err(std::io::Error::other)?;
if sent == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"sendmsg wrote 0 bytes",
));
}
if sent < data.len() {
(&*stream).write_all(&data[sent..])?;
sent = data.len();
}
Ok(sent)
}
#[cfg(target_os = "linux")]
pub fn recv_with_fd(
stream: &std::os::unix::net::UnixStream,
buf: &mut [u8],
) -> std::io::Result<(usize, Option<ReceivedFd>)> {
use nix::cmsg_space;
use nix::sys::socket::{MsgFlags, recvmsg};
use std::io::IoSliceMut;
use std::os::unix::io::AsRawFd;
let mut cmsg_buf = cmsg_space!(std::os::unix::io::RawFd);
let mut iov = [IoSliceMut::new(buf)];
let msg = recvmsg::<()>(
stream.as_raw_fd(),
&mut iov,
Some(&mut cmsg_buf),
MsgFlags::empty(),
)
.map_err(std::io::Error::other)?;
let n = msg.bytes;
if msg.flags.contains(MsgFlags::MSG_CTRUNC) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"ancillary data truncated",
));
}
let mut fds = Vec::<std::os::unix::io::RawFd>::new();
let cmsgs = msg
.cmsgs()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
for cmsg in cmsgs {
if let nix::sys::socket::ControlMessageOwned::ScmRights(scm_fds) = cmsg {
fds.extend(scm_fds);
}
}
match fds.len() {
0 => Ok((n, None)),
1 => Ok((n, Some(ReceivedFd(fds[0])))),
_ => {
for fd in fds {
let _ = nix::unistd::close(fd);
}
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"received more than one fd",
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(target_os = "linux")]
const MULTIPROC_CHILD_MODE_ENV: &str = "FSQLITE_COORD_IPC_CHILD_MODE";
#[cfg(target_os = "linux")]
const MULTIPROC_SOCKET_PATH_ENV: &str = "FSQLITE_COORD_IPC_SOCKET_PATH";
#[cfg(target_os = "linux")]
const MULTIPROC_SEED_ENV: &str = "FSQLITE_COORD_IPC_SEED";
#[cfg(target_os = "linux")]
const MULTIPROC_ROUNDS_ENV: &str = "FSQLITE_COORD_IPC_ROUNDS";
#[cfg(target_os = "linux")]
const MULTIPROC_SUMMARY_PREFIX: &str = "IPC_MULTIPROC_SUMMARY";
#[cfg(target_os = "linux")]
fn recv_frame_with_optional_fd(
stream: &std::os::unix::net::UnixStream,
) -> std::io::Result<(Frame, Option<ReceivedFd>)> {
use std::io::ErrorKind;
let mut header = [0u8; FRAME_HEADER_WIRE_BYTES];
let mut filled = 0usize;
let mut fd: Option<ReceivedFd> = None;
while filled < FRAME_HEADER_WIRE_BYTES {
let (n, maybe_fd) = recv_with_fd(stream, &mut header[filled..])?;
if n == 0 {
return Err(std::io::Error::new(
ErrorKind::UnexpectedEof,
"eof while reading frame header",
));
}
filled += n;
if let Some(new_fd) = maybe_fd {
if fd.is_some() {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
"received multiple fds for a single frame",
));
}
fd = Some(new_fd);
}
}
let len_be = read_u32_be(&header[..4])
.ok_or_else(|| std::io::Error::new(ErrorKind::InvalidData, "missing frame length"))?;
if len_be < FRAME_MIN_LEN_BE {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!("len_be {len_be} below minimum {FRAME_MIN_LEN_BE}"),
));
}
if len_be > FRAME_MAX_LEN_BE {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!("len_be {len_be} exceeds cap {FRAME_MAX_LEN_BE}"),
));
}
let total_len = 4usize + len_be as usize;
let mut wire = vec![0u8; total_len];
wire[..FRAME_HEADER_WIRE_BYTES].copy_from_slice(&header);
let mut pos = FRAME_HEADER_WIRE_BYTES;
while pos < total_len {
let (n, maybe_fd) = recv_with_fd(stream, &mut wire[pos..])?;
if n == 0 {
return Err(std::io::Error::new(
ErrorKind::UnexpectedEof,
"eof while reading frame payload",
));
}
pos += n;
if let Some(new_fd) = maybe_fd {
if fd.is_some() {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
"received multiple fds for a single frame",
));
}
fd = Some(new_fd);
}
}
let frame = Frame::decode(&wire)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
Ok((frame, fd))
}
#[cfg(target_os = "linux")]
fn deterministic_txn(seed: u64, round: u32) -> WireTxnToken {
let mixed = seed
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(u64::from(round) + 1)
^ (u64::from(round) << 32);
WireTxnToken {
txn_id: mixed.max(1),
txn_epoch: round.saturating_add(1),
}
}
#[cfg(target_os = "linux")]
fn fold_commit_digest(acc: u64, commit_seq: u64) -> u64 {
acc.rotate_left(9) ^ commit_seq.wrapping_mul(0x9E37_79B1_85EB_CA87)
}
#[cfg(target_os = "linux")]
fn parse_child_summary(stdout: &str) -> Option<(u32, u32, u64)> {
for line in stdout.lines() {
let Some(start_idx) = line.find(MULTIPROC_SUMMARY_PREFIX) else {
continue;
};
let summary = &line[start_idx..];
let mut ok: Option<u32> = None;
let mut busy: Option<u32> = None;
let mut digest: Option<u64> = None;
for token in summary.split_whitespace().skip(1) {
let (k, v) = token.split_once('=')?;
match k {
"ok" => ok = v.parse::<u32>().ok(),
"busy" => busy = v.parse::<u32>().ok(),
"digest" => digest = v.parse::<u64>().ok(),
_ => {}
}
}
return Some((ok?, busy?, digest?));
}
None
}
#[cfg(target_os = "linux")]
fn connect_with_retry(
path: &std::path::Path,
) -> std::io::Result<std::os::unix::net::UnixStream> {
use std::io::ErrorKind;
use std::os::unix::net::UnixStream;
use std::thread;
use std::time::Duration;
let mut last_err: Option<std::io::Error> = None;
for _ in 0..100_u32 {
match UnixStream::connect(path) {
Ok(stream) => return Ok(stream),
Err(err)
if matches!(
err.kind(),
ErrorKind::NotFound
| ErrorKind::ConnectionRefused
| ErrorKind::ConnectionReset
) =>
{
last_err = Some(err);
thread::sleep(Duration::from_millis(10));
}
Err(err) => return Err(err),
}
}
Err(last_err.unwrap_or_else(|| {
std::io::Error::new(
ErrorKind::TimedOut,
format!("timed out connecting to {}", path.display()),
)
}))
}
#[cfg(target_os = "linux")]
#[allow(clippy::too_many_lines)]
fn run_multiprocess_child_client_script(
socket_path: &std::path::Path,
seed: u64,
rounds: u32,
) -> std::io::Result<(u32, u32, u64)> {
use std::io::ErrorKind;
use std::io::Write;
use std::os::fd::AsRawFd;
let stream = connect_with_retry(socket_path)?;
let mut request_id = 1_u64;
let mut ok_count = 0_u32;
let mut busy_count = 0_u32;
let mut digest = 0_u64;
for round in 0..rounds {
let txn = deterministic_txn(seed, round);
let reserve_frame = Frame {
kind: MessageKind::Reserve,
request_id,
payload: ReservePayload { purpose: 1, txn }.to_bytes(),
};
request_id = request_id.saturating_add(1);
(&stream).write_all(&reserve_frame.encode())?;
let (reserve_resp_frame, reserve_resp_fd) = recv_frame_with_optional_fd(&stream)?;
if reserve_resp_fd.is_some() {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
"reserve response carried unexpected fd",
));
}
if reserve_resp_frame.kind != MessageKind::Response {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!(
"expected reserve response frame, got {:?}",
reserve_resp_frame.kind
),
));
}
let reserve_resp = ReserveResponse::from_bytes(&reserve_resp_frame.payload)
.ok_or_else(|| {
std::io::Error::new(ErrorKind::InvalidData, "invalid reserve response payload")
})?;
let permit_id = match reserve_resp {
ReserveResponse::Ok { permit_id } => permit_id,
ReserveResponse::Busy { .. } => {
busy_count = busy_count.saturating_add(1);
continue;
}
ReserveResponse::Err { code } => {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!("unexpected reserve error code {code}"),
));
}
};
let wal = SubmitWalPayload {
permit_id,
txn,
mode: 0,
snapshot_high: 10 + u64::from(round),
schema_epoch: 1,
has_in_rw: false,
has_out_rw: false,
wal_fec_r: 0,
spill_pages: vec![SpillPageEntry {
pgno: 1 + round,
offset: 0,
len: 4096,
xxh3_64: seed ^ u64::from(round),
}],
read_witness_refs: vec![],
write_witness_refs: vec![],
edge_refs: vec![],
merge_refs: vec![],
};
let (_spill_r, spill_w) = std::io::pipe()?;
let submit_frame = Frame {
kind: MessageKind::SubmitWalCommit,
request_id,
payload: wal.to_bytes(),
};
request_id = request_id.saturating_add(1);
send_with_fd(&stream, &submit_frame.encode(), spill_w.as_raw_fd())?;
let (submit_resp_frame, submit_resp_fd) = recv_frame_with_optional_fd(&stream)?;
if submit_resp_fd.is_some() {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
"submit response carried unexpected fd",
));
}
if submit_resp_frame.kind != MessageKind::Response {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!(
"expected submit response frame, got {:?}",
submit_resp_frame.kind
),
));
}
let commit_resp = WalCommitResponse::from_bytes(&submit_resp_frame.payload)
.ok_or_else(|| {
std::io::Error::new(ErrorKind::InvalidData, "invalid submit response payload")
})?;
let commit_seq = match commit_resp {
WalCommitResponse::Ok { commit_seq } => commit_seq,
WalCommitResponse::Conflict { .. }
| WalCommitResponse::Err { .. }
| WalCommitResponse::IoError { .. } => {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!("unexpected submit response: {commit_resp:?}"),
));
}
};
let (_dup_spill_r, dup_spill_w) = std::io::pipe()?;
let dup_submit = Frame {
kind: MessageKind::SubmitWalCommit,
request_id,
payload: wal.to_bytes(),
};
request_id = request_id.saturating_add(1);
send_with_fd(&stream, &dup_submit.encode(), dup_spill_w.as_raw_fd())?;
let (dup_resp_frame, dup_resp_fd) = recv_frame_with_optional_fd(&stream)?;
if dup_resp_fd.is_some() {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
"duplicate submit response carried unexpected fd",
));
}
let dup_commit =
WalCommitResponse::from_bytes(&dup_resp_frame.payload).ok_or_else(|| {
std::io::Error::new(
ErrorKind::InvalidData,
"invalid duplicate submit response payload",
)
})?;
if dup_commit != (WalCommitResponse::Ok { commit_seq }) {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!("idempotency mismatch: first={commit_seq} duplicate={dup_commit:?}"),
));
}
ok_count = ok_count.saturating_add(1);
digest = fold_commit_digest(digest, commit_seq);
}
let ping = Frame {
kind: MessageKind::Ping,
request_id,
payload: vec![],
};
(&stream).write_all(&ping.encode())?;
let (pong, pong_fd) = recv_frame_with_optional_fd(&stream)?;
if pong_fd.is_some() {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
"pong carried unexpected fd",
));
}
if pong.kind != MessageKind::Pong {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!("expected pong, got {:?}", pong.kind),
));
}
Ok((ok_count, busy_count, digest))
}
#[test]
fn test_frame_round_trip() {
for kind in [
MessageKind::Reserve,
MessageKind::SubmitNativePublish,
MessageKind::SubmitWalCommit,
MessageKind::RowidReserve,
MessageKind::Response,
MessageKind::Ping,
MessageKind::Pong,
] {
let original = Frame {
kind,
request_id: 0xDEAD_BEEF_CAFE_BABE,
payload: vec![1, 2, 3, 4, 5],
};
let wire = original.encode();
let decoded = Frame::decode(&wire).expect("decode must succeed");
assert_eq!(decoded.kind, original.kind, "kind mismatch for {kind}");
assert_eq!(
decoded.request_id, original.request_id,
"request_id mismatch for {kind}"
);
assert_eq!(
decoded.payload, original.payload,
"payload mismatch for {kind}"
);
}
let ping = Frame {
kind: MessageKind::Ping,
request_id: 42,
payload: vec![],
};
let wire = ping.encode();
assert_eq!(wire.len(), FRAME_HEADER_WIRE_BYTES); let decoded = Frame::decode(&wire).expect("decode empty payload");
assert_eq!(decoded.kind, MessageKind::Ping);
assert!(decoded.payload.is_empty());
let reserve = ReservePayload {
purpose: 0,
txn: WireTxnToken {
txn_id: 100,
txn_epoch: 3,
},
};
let frame = Frame {
kind: MessageKind::Reserve,
request_id: 7,
payload: reserve.to_bytes(),
};
let wire = frame.encode();
let decoded = Frame::decode(&wire).expect("decode reserve frame");
let parsed = ReservePayload::from_bytes(&decoded.payload).expect("parse reserve payload");
assert_eq!(parsed, reserve);
}
#[test]
fn test_frame_validation() {
assert_eq!(Frame::decode(&[0u8; 4]), Err(FrameError::TooShort));
let mut buf = [0u8; 16];
buf[..4].copy_from_slice(&5_u32.to_be_bytes()); assert_eq!(Frame::decode(&buf), Err(FrameError::LenTooSmall(5)));
buf[..4].copy_from_slice(&(5_000_000_u32).to_be_bytes());
assert_eq!(Frame::decode(&buf), Err(FrameError::LenTooLarge(5_000_000)));
let bad_version = Frame {
kind: MessageKind::Ping,
request_id: 0,
payload: vec![],
};
let mut wire = bad_version.encode();
wire[4..6].copy_from_slice(&99_u16.to_be_bytes()); assert_eq!(Frame::decode(&wire), Err(FrameError::UnknownVersion(99)));
let mut wire = bad_version.encode();
wire[6..8].copy_from_slice(&255_u16.to_be_bytes()); assert_eq!(Frame::decode(&wire), Err(FrameError::UnknownKind(255)));
let mut wire = vec![0u8; 20]; wire[..4].copy_from_slice(&20_u32.to_be_bytes()); wire[4..6].copy_from_slice(&1_u16.to_be_bytes()); wire[6..8].copy_from_slice(&6_u16.to_be_bytes()); assert_eq!(
Frame::decode(&wire),
Err(FrameError::PayloadTruncated {
expected: 8,
actual: 4
})
);
}
#[test]
fn test_reserve_submit_discipline() {
let pm = PermitManager::new(MAX_OUTSTANDING_PERMITS);
let p1 = pm.reserve().expect("first reserve");
assert_eq!(pm.outstanding(), 1);
pm.consume(p1).expect("consume p1");
assert_eq!(pm.outstanding(), 0);
let p2 = pm.reserve().expect("second reserve");
assert_eq!(pm.outstanding(), 1);
pm.release(p2);
assert_eq!(pm.outstanding(), 0);
let p3 = pm.reserve().expect("third reserve");
pm.consume(p3).expect("consume p3");
pm.gc_consumed();
assert_eq!(pm.outstanding(), 0);
}
#[test]
fn test_permit_single_use() {
let pm = PermitManager::new(MAX_OUTSTANDING_PERMITS);
let p = pm.reserve().expect("reserve");
pm.consume(p).expect("first consume");
assert_eq!(pm.consume(p), Err(PermitError::AlreadyConsumed(p)));
assert_eq!(pm.consume(999), Err(PermitError::NotFound(999)));
}
#[test]
fn test_idempotency() {
let cache = IdempotencyCache::new();
let txn = WireTxnToken {
txn_id: 42,
txn_epoch: 1,
};
assert!(cache.get(txn.txn_id, txn.txn_epoch).is_none());
let response = ReserveResponse::Ok { permit_id: 77 }.to_bytes();
cache.insert(txn.txn_id, txn.txn_epoch, response.clone());
let cached = cache.get(txn.txn_id, txn.txn_epoch).expect("cache hit");
assert_eq!(cached, response);
assert!(cache.get(txn.txn_id, txn.txn_epoch + 1).is_none());
let resp2 = ReserveResponse::Busy {
retry_after_ms: 100,
}
.to_bytes();
cache.insert(99, 2, resp2.clone());
assert_eq!(cache.len(), 2);
assert_eq!(cache.get(99, 2).expect("second hit"), resp2);
}
#[cfg(target_os = "linux")]
#[test]
fn test_peer_auth_rejects_wrong_uid() {
use std::os::unix::net::UnixStream;
let (a, _b) = UnixStream::pair().expect("socketpair");
use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
let actual_uid = getsockopt(&a, PeerCredentials).expect("peer_cred").uid();
authenticate_peer(&a, actual_uid).expect("peer auth ok");
let wrong_uid = actual_uid ^ 1;
assert_eq!(
authenticate_peer(&a, wrong_uid),
Err(PeerAuthError::UidMismatch {
expected: wrong_uid,
actual: actual_uid,
})
);
}
#[cfg(target_os = "linux")]
#[test]
fn test_scm_rights_fd_passing() {
use std::io::Write;
use std::io::pipe;
use std::os::fd::AsRawFd;
use std::os::unix::net::UnixStream;
let (sender, receiver) = UnixStream::pair().expect("socketpair");
let (pipe_r, mut pipe_w) = pipe().expect("pipe");
let data = b"fd-marker";
let sent = send_with_fd(&sender, data, pipe_r.as_raw_fd()).expect("send_with_fd");
assert_eq!(sent, data.len());
drop(pipe_r);
let mut buf = [0u8; 64];
let (n, maybe_fd) = recv_with_fd(&receiver, &mut buf).expect("recv_with_fd");
assert_eq!(&buf[..n], data);
let recv_fd = maybe_fd.expect("fd must be attached");
let payload = b"pipe-data";
pipe_w.write_all(payload).expect("write into pipe");
let mut out = [0u8; 64];
let nr = nix::unistd::read(recv_fd.raw_fd(), &mut out).expect("read from received fd");
assert_eq!(&out[..nr], payload);
}
#[test]
fn test_canonical_ordering() {
assert!(is_canonical_pages(&[]));
assert!(is_canonical_pages(&[1]));
assert!(is_canonical_pages(&[1, 2, 3]));
assert!(!is_canonical_pages(&[2, 1])); assert!(!is_canonical_pages(&[1, 1]));
let a = ObjectId::from_bytes([0u8; 16]);
let b = ObjectId::from_bytes([1u8; 16]);
let c = ObjectId::from_bytes([2u8; 16]);
assert!(is_canonical_object_ids(&[]));
assert!(is_canonical_object_ids(&[a]));
assert!(is_canonical_object_ids(&[a, b, c]));
assert!(!is_canonical_object_ids(&[b, a])); assert!(!is_canonical_object_ids(&[a, a]));
let mut mixed_low = [0u8; 16];
mixed_low[0] = 0;
mixed_low[1] = 255;
let mut mixed_high = [0u8; 16];
mixed_high[0] = 1;
mixed_high[1] = 0;
assert!(is_canonical_object_ids(&[
ObjectId::from_bytes(mixed_low),
ObjectId::from_bytes(mixed_high),
]));
assert!(validate_write_set_summary(&[1, 2, 3])); assert!(validate_write_set_summary(&[]));
assert!(validate_witness_edge_counts(10_000, 10_000, 10_000, 10_000));
assert!(!validate_witness_edge_counts(30_000, 30_000, 5_000, 537));
}
#[test]
fn test_backpressure_busy() {
let pm = PermitManager::new(MAX_OUTSTANDING_PERMITS);
let mut permits = Vec::with_capacity(MAX_OUTSTANDING_PERMITS);
for i in 0..MAX_OUTSTANDING_PERMITS {
permits.push(
pm.reserve()
.unwrap_or_else(|_| unreachable!("reserve #{i}")),
);
}
assert_eq!(pm.outstanding(), MAX_OUTSTANDING_PERMITS);
assert_eq!(pm.reserve(), Err(PermitError::Busy));
pm.release(permits[0]);
assert_eq!(pm.outstanding(), MAX_OUTSTANDING_PERMITS - 1);
let p17 = pm.reserve().expect("reserve after release");
assert_eq!(pm.outstanding(), MAX_OUTSTANDING_PERMITS);
pm.consume(p17).expect("consume p17");
}
#[cfg(target_os = "linux")]
#[test]
#[allow(clippy::too_many_lines)]
fn test_e2e_bd_1m07() {
use std::io::Write;
use std::os::fd::AsRawFd;
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Barrier};
let pm = Arc::new(PermitManager::new(MAX_OUTSTANDING_PERMITS));
let cache = Arc::new(IdempotencyCache::new());
let barrier = Arc::new(Barrier::new(2));
let (client_sock, server_sock) = UnixStream::pair().expect("socketpair");
use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
let expected_uid = getsockopt(&server_sock, PeerCredentials)
.expect("peer_cred")
.uid();
authenticate_peer(&server_sock, expected_uid).expect("E2E peer auth");
let pm_server = Arc::clone(&pm);
let cache_server = Arc::clone(&cache);
let barrier_server = Arc::clone(&barrier);
let server = std::thread::spawn(move || {
barrier_server.wait();
let (frame, maybe_fd) =
recv_frame_with_optional_fd(&server_sock).expect("server recv reserve");
assert!(maybe_fd.is_none(), "reserve must not carry an fd");
assert_eq!(frame.kind, MessageKind::Reserve);
let _payload =
ReservePayload::from_bytes(&frame.payload).expect("parse reserve payload");
let permit_id = pm_server.reserve().expect("server reserve");
let resp = ReserveResponse::Ok { permit_id };
let resp_frame = Frame {
kind: MessageKind::Response,
request_id: frame.request_id,
payload: resp.to_bytes(),
};
(&server_sock)
.write_all(&resp_frame.encode())
.expect("server write reserve response");
let (frame, maybe_fd) =
recv_frame_with_optional_fd(&server_sock).expect("server recv submit");
let _spill_fd = maybe_fd.expect("submit must carry spill fd");
assert_eq!(frame.kind, MessageKind::SubmitWalCommit);
let wal_payload =
SubmitWalPayload::from_bytes(&frame.payload).expect("parse wal payload");
assert_eq!(wal_payload.permit_id, permit_id);
let key = wal_payload.txn.idempotency_key();
if let Some(cached) = cache_server.get(key.0, key.1) {
let resp_frame = Frame {
kind: MessageKind::Response,
request_id: frame.request_id,
payload: cached,
};
(&server_sock)
.write_all(&resp_frame.encode())
.expect("server write cached");
} else {
pm_server.consume(permit_id).expect("consume permit");
let commit_resp = WalCommitResponse::Ok { commit_seq: 42 };
let resp_bytes = commit_resp.to_bytes();
cache_server.insert(key.0, key.1, resp_bytes.clone());
let resp_frame = Frame {
kind: MessageKind::Response,
request_id: frame.request_id,
payload: resp_bytes,
};
(&server_sock)
.write_all(&resp_frame.encode())
.expect("server write commit response");
}
let (frame, maybe_fd) =
recv_frame_with_optional_fd(&server_sock).expect("server recv dup");
let _spill_fd = maybe_fd.expect("dup submit must carry spill fd");
let wal_payload =
SubmitWalPayload::from_bytes(&frame.payload).expect("parse dup payload");
let key = wal_payload.txn.idempotency_key();
let cached = cache_server
.get(key.0, key.1)
.expect("idempotency cache must hit");
let resp_frame = Frame {
kind: MessageKind::Response,
request_id: frame.request_id,
payload: cached,
};
(&server_sock)
.write_all(&resp_frame.encode())
.expect("server write dup response");
let (frame, maybe_fd) =
recv_frame_with_optional_fd(&server_sock).expect("server recv ping");
assert!(maybe_fd.is_none(), "ping must not carry an fd");
assert_eq!(frame.kind, MessageKind::Ping);
let pong = Frame {
kind: MessageKind::Pong,
request_id: frame.request_id,
payload: vec![],
};
(&server_sock)
.write_all(&pong.encode())
.expect("server write pong");
});
barrier.wait();
let txn = WireTxnToken {
txn_id: 1,
txn_epoch: 1,
};
let reserve = Frame {
kind: MessageKind::Reserve,
request_id: 1,
payload: ReservePayload { purpose: 0, txn }.to_bytes(),
};
(&client_sock)
.write_all(&reserve.encode())
.expect("client write reserve");
let (resp, resp_fd) =
recv_frame_with_optional_fd(&client_sock).expect("client read reserve resp");
assert!(resp_fd.is_none(), "reserve response must not carry an fd");
assert_eq!(resp.kind, MessageKind::Response);
let reserve_resp =
ReserveResponse::from_bytes(&resp.payload).expect("parse reserve response");
let permit_id = match reserve_resp {
ReserveResponse::Ok { permit_id } => permit_id,
other => unreachable!("expected Ok, got {other:?}"),
};
let wal = SubmitWalPayload {
permit_id,
txn,
mode: 0,
snapshot_high: 10,
schema_epoch: 1,
has_in_rw: false,
has_out_rw: false,
wal_fec_r: 0,
spill_pages: vec![SpillPageEntry {
pgno: 1,
offset: 0,
len: 4096,
xxh3_64: 0xABCD,
}],
read_witness_refs: vec![],
write_witness_refs: vec![],
edge_refs: vec![],
merge_refs: vec![],
};
let (_spill_r, spill_w) = std::io::pipe().expect("spill pipe");
let submit = Frame {
kind: MessageKind::SubmitWalCommit,
request_id: 2,
payload: wal.to_bytes(),
};
send_with_fd(&client_sock, &submit.encode(), spill_w.as_raw_fd())
.expect("client send submit");
let (resp, resp_fd) =
recv_frame_with_optional_fd(&client_sock).expect("client read commit resp");
assert!(resp_fd.is_none(), "commit response must not carry an fd");
let commit_resp =
WalCommitResponse::from_bytes(&resp.payload).expect("parse commit response");
assert_eq!(commit_resp, WalCommitResponse::Ok { commit_seq: 42 });
let dup_submit = Frame {
kind: MessageKind::SubmitWalCommit,
request_id: 3,
payload: wal.to_bytes(),
};
send_with_fd(&client_sock, &dup_submit.encode(), spill_w.as_raw_fd())
.expect("client send dup submit");
let (resp, resp_fd) =
recv_frame_with_optional_fd(&client_sock).expect("client read dup resp");
assert!(resp_fd.is_none(), "dup response must not carry an fd");
let dup_resp =
WalCommitResponse::from_bytes(&resp.payload).expect("parse dup commit response");
assert_eq!(
dup_resp,
WalCommitResponse::Ok { commit_seq: 42 },
"idempotent response must match"
);
let ping = Frame {
kind: MessageKind::Ping,
request_id: 4,
payload: vec![],
};
(&client_sock)
.write_all(&ping.encode())
.expect("client write ping");
let (resp, resp_fd) = recv_frame_with_optional_fd(&client_sock).expect("client read pong");
assert!(resp_fd.is_none(), "pong must not carry an fd");
assert_eq!(resp.kind, MessageKind::Pong);
assert_eq!(resp.request_id, 4);
server.join().expect("server thread");
}
#[cfg(target_os = "linux")]
#[test]
#[allow(clippy::too_many_lines)]
fn test_e2e_bd_1m07_multiprocess_seeded_stress() {
use std::io::{ErrorKind, Write};
use std::os::unix::net::UnixListener;
use std::process::Command;
use std::sync::Arc;
if std::env::var(MULTIPROC_CHILD_MODE_ENV).ok().as_deref() == Some("client") {
let socket_path = std::path::PathBuf::from(
std::env::var(MULTIPROC_SOCKET_PATH_ENV).expect("child socket path"),
);
let seed = std::env::var(MULTIPROC_SEED_ENV)
.expect("child seed")
.parse::<u64>()
.expect("parse child seed");
let rounds = std::env::var(MULTIPROC_ROUNDS_ENV)
.expect("child rounds")
.parse::<u32>()
.expect("parse child rounds");
let (ok, busy, digest) =
run_multiprocess_child_client_script(&socket_path, seed, rounds)
.expect("child client script");
println!("{MULTIPROC_SUMMARY_PREFIX} ok={ok} busy={busy} digest={digest}");
return;
}
let seed = 0xC0FF_EE11_AAA5_5501_u64;
let rounds = 96_u32;
let mut socket_path = std::env::temp_dir();
let unique = format!(
"coord-ipc-seeded-{}-{}.sock",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0_u128, |duration| duration.as_nanos())
);
socket_path.push(unique);
if socket_path.exists() {
let _ = std::fs::remove_file(&socket_path);
}
let listener = UnixListener::bind(&socket_path).expect("bind unix listener");
let pm = Arc::new(PermitManager::new(MAX_OUTSTANDING_PERMITS));
let cache = Arc::new(IdempotencyCache::new());
let pm_server = Arc::clone(&pm);
let cache_server = Arc::clone(&cache);
let server = std::thread::spawn(move || -> (u32, u32, u64) {
let (server_sock, _) = listener.accept().expect("accept child connection");
let mut next_commit_seq = 1_000_u64;
let mut ok_count = 0_u32;
let mut busy_count = 0_u32;
let mut digest = 0_u64;
loop {
let (frame, maybe_fd) = match recv_frame_with_optional_fd(&server_sock) {
Ok(frame) => frame,
Err(err) if err.kind() == ErrorKind::UnexpectedEof => break,
Err(err) => panic!("server recv frame failed: {err}"),
};
match frame.kind {
MessageKind::Reserve => {
let _reserve = ReservePayload::from_bytes(&frame.payload)
.expect("parse reserve payload");
let response_payload = match pm_server.reserve() {
Ok(permit_id) => ReserveResponse::Ok { permit_id }.to_bytes(),
Err(PermitError::Busy) => {
busy_count = busy_count.saturating_add(1);
ReserveResponse::Busy { retry_after_ms: 1 }.to_bytes()
}
Err(err) => panic!("reserve failed unexpectedly: {err:?}"),
};
let response = Frame {
kind: MessageKind::Response,
request_id: frame.request_id,
payload: response_payload,
};
(&server_sock)
.write_all(&response.encode())
.expect("server write reserve response");
}
MessageKind::SubmitWalCommit => {
let _spill_fd = maybe_fd.expect("submit must carry spill fd");
let wal = SubmitWalPayload::from_bytes(&frame.payload)
.expect("parse submit wal payload");
let key = wal.txn.idempotency_key();
let response_payload = if let Some(cached) = cache_server.get(key.0, key.1)
{
cached
} else {
pm_server
.consume(wal.permit_id)
.expect("consume permit for submit");
next_commit_seq = next_commit_seq.saturating_add(1);
let response = WalCommitResponse::Ok {
commit_seq: next_commit_seq,
}
.to_bytes();
cache_server.insert(key.0, key.1, response.clone());
ok_count = ok_count.saturating_add(1);
digest = fold_commit_digest(digest, next_commit_seq);
response
};
let response = Frame {
kind: MessageKind::Response,
request_id: frame.request_id,
payload: response_payload,
};
(&server_sock)
.write_all(&response.encode())
.expect("server write submit response");
}
MessageKind::Ping => {
if let Some(fd) = maybe_fd {
panic!("ping must not carry fd: {}", fd.raw_fd());
}
let pong = Frame {
kind: MessageKind::Pong,
request_id: frame.request_id,
payload: vec![],
};
(&server_sock)
.write_all(&pong.encode())
.expect("server write pong");
}
other => panic!("unexpected frame kind in stress server: {other:?}"),
}
}
assert_eq!(
pm_server.outstanding(),
0,
"all permits must be consumed or released at end of stress run"
);
(ok_count, busy_count, digest)
});
let child_output = Command::new(std::env::current_exe().expect("current test exe"))
.arg("--exact")
.arg("coordinator_ipc::tests::test_e2e_bd_1m07_multiprocess_seeded_stress")
.arg("--nocapture")
.arg("--test-threads=1")
.env(MULTIPROC_CHILD_MODE_ENV, "client")
.env(
MULTIPROC_SOCKET_PATH_ENV,
socket_path
.to_str()
.expect("socket path must be valid UTF-8 for env transport"),
)
.env(MULTIPROC_SEED_ENV, seed.to_string())
.env(MULTIPROC_ROUNDS_ENV, rounds.to_string())
.output()
.expect("spawn child stress process");
let child_stdout = String::from_utf8_lossy(&child_output.stdout).into_owned();
let child_stderr = String::from_utf8_lossy(&child_output.stderr).into_owned();
assert!(
child_output.status.success(),
"child process failed\nstatus: {:?}\nstdout:\n{}\nstderr:\n{}",
child_output.status.code(),
child_stdout,
child_stderr
);
let (child_ok, child_busy, child_digest) = parse_child_summary(&child_stdout)
.unwrap_or_else(|| {
panic!(
"missing child summary line `{MULTIPROC_SUMMARY_PREFIX}`\nstdout:\n{child_stdout}\nstderr:\n{child_stderr}"
)
});
let (server_ok, server_busy, server_digest) = server.join().expect("server thread join");
assert_eq!(
child_ok + child_busy,
rounds,
"child must account for all scripted rounds"
);
assert_eq!(server_ok + server_busy, rounds);
assert_eq!(
child_ok, server_ok,
"ok-count mismatch between child/server"
);
assert_eq!(
child_busy, server_busy,
"busy-count mismatch between child/server"
);
assert_eq!(
child_digest, server_digest,
"commit digest mismatch between child and server"
);
let cache_len = u32::try_from(cache.len()).expect("cache len must fit in u32");
assert_eq!(
cache_len, server_ok,
"cache cardinality must match committed txns"
);
let _ = std::fs::remove_file(&socket_path);
}
#[test]
fn test_reserve_v1_roundtrip() {
let original = ReservePayload {
purpose: 0, txn: WireTxnToken {
txn_id: 0xDEAD_BEEF_0000_0001,
txn_epoch: 7,
},
};
let bytes = original.to_bytes();
assert_eq!(bytes.len(), 24, "ReserveV1 must be exactly 24 bytes");
assert_eq!(&bytes[1..8], &[0u8; 7], "pad0 must be zero");
let decoded = ReservePayload::from_bytes(&bytes).expect("decode must succeed");
assert_eq!(decoded, original, "round-trip mismatch");
let wal_reserve = ReservePayload {
purpose: 1,
txn: WireTxnToken {
txn_id: 42,
txn_epoch: 0,
},
};
let bytes2 = wal_reserve.to_bytes();
assert_eq!(bytes2.len(), 24);
assert_eq!(bytes2[0], 1, "purpose must be 1");
let decoded2 = ReservePayload::from_bytes(&bytes2).expect("decode purpose=1");
assert_eq!(decoded2, wal_reserve);
}
#[test]
fn test_reserve_resp_tagged_union_variants() {
let ok = ReserveResponse::Ok { permit_id: 0xCAFE };
let ok_bytes = ok.to_bytes();
assert_eq!(ok_bytes[0], 0, "Ok tag = 0");
assert_eq!(&ok_bytes[1..8], &[0u8; 7], "pad0 must be zero");
let ok_rt = ReserveResponse::from_bytes(&ok_bytes).expect("decode Ok");
assert_eq!(ok_rt, ok);
let busy = ReserveResponse::Busy {
retry_after_ms: 500,
};
let busy_bytes = busy.to_bytes();
assert_eq!(busy_bytes[0], 1, "Busy tag = 1");
assert_eq!(&busy_bytes[1..8], &[0u8; 7], "pad0 must be zero");
let busy_rt = ReserveResponse::from_bytes(&busy_bytes).expect("decode Busy");
assert_eq!(busy_rt, busy);
let err = ReserveResponse::Err { code: 0x07 };
let err_bytes = err.to_bytes();
assert_eq!(err_bytes[0], 2, "Err tag = 2");
assert_eq!(&err_bytes[1..8], &[0u8; 7], "pad0 must be zero");
let err_rt = ReserveResponse::from_bytes(&err_bytes).expect("decode Err");
assert_eq!(err_rt, err);
let mut bad = ok_bytes;
bad[0] = 99;
assert!(
ReserveResponse::from_bytes(&bad).is_none(),
"unknown tag must be rejected"
);
}
#[test]
fn test_write_set_summary_canonical_encoding() {
let mut pages = vec![5_u32, 1, 100, 3];
pages.sort_unstable();
assert!(is_canonical_pages(&pages), "sorted pages must be canonical");
assert_eq!(pages, [1, 3, 5, 100]);
let mut bytes = Vec::new();
for &p in &pages {
append_u32_le(&mut bytes, p);
}
assert_eq!(bytes.len(), 16, "4 pages × 4 bytes = 16");
assert!(
validate_write_set_summary_raw_len(bytes.len()),
"16 bytes must be valid"
);
let mut decoded = Vec::new();
for i in 0..4 {
let off = i * 4;
decoded.push(read_u32_le(&bytes[off..off + 4]).unwrap());
}
assert_eq!(decoded, pages, "round-trip mismatch");
}
#[test]
fn test_write_set_summary_len_not_multiple_of_4_rejected() {
assert!(
!validate_write_set_summary_raw_len(7),
"7 bytes is not a multiple of 4 — must be rejected"
);
assert!(
!validate_write_set_summary_raw_len(1),
"1 byte is not a multiple of 4"
);
assert!(
!validate_write_set_summary_raw_len(5),
"5 bytes is not a multiple of 4"
);
assert!(validate_write_set_summary_raw_len(0), "0 is valid");
assert!(validate_write_set_summary_raw_len(4), "4 is valid");
assert!(validate_write_set_summary_raw_len(8), "8 is valid");
}
#[test]
fn test_native_publish_conflict_response_page_list() {
let conflict = NativePublishResponse::Conflict {
pages: vec![10, 42, 99],
reason: 1,
};
let bytes = conflict.to_bytes();
assert_eq!(bytes[0], 1, "Conflict tag = 1");
let decoded = NativePublishResponse::from_bytes(&bytes).expect("decode Conflict response");
assert_eq!(decoded, conflict, "round-trip mismatch");
if let NativePublishResponse::Conflict { pages, reason } = &decoded {
assert_eq!(pages.len(), 3);
assert_eq!(pages[0], 10);
assert_eq!(pages[1], 42);
assert_eq!(pages[2], 99);
assert_eq!(*reason, 1);
} else {
unreachable!("must be Conflict variant");
}
}
#[test]
fn test_wire_size_cap_write_set_summary_exceeds_1mib() {
let max_pages: Vec<u32> = (0..262_144).collect();
assert!(
validate_write_set_summary(&max_pages),
"262,144 pages (1 MiB) must be accepted"
);
let over_pages: Vec<u32> = (0..262_145).collect();
assert!(
!validate_write_set_summary(&over_pages),
"262,145 pages (> 1 MiB) must be rejected"
);
assert!(
validate_write_set_summary_raw_len(WIRE_WRITE_SET_MAX_BYTES),
"exactly 1 MiB must pass"
);
assert!(
!validate_write_set_summary_raw_len(WIRE_WRITE_SET_MAX_BYTES + 4),
"1 MiB + 4 must fail"
);
}
#[test]
fn test_wire_size_cap_total_witness_count_exceeds_65536() {
assert!(
validate_witness_edge_counts(16_384, 16_384, 16_384, 16_384),
"65,536 total must be accepted"
);
assert!(
!validate_witness_edge_counts(16_384, 16_384, 16_384, 16_385),
"65,537 total must be rejected"
);
assert!(validate_witness_edge_counts(65_536, 0, 0, 0));
assert!(!validate_witness_edge_counts(65_537, 0, 0, 0));
}
#[test]
fn test_wal_commit_spill_page_encoding() {
let sp1 = SpillPageEntry {
pgno: 1,
offset: 4096,
len: 4096,
xxh3_64: 0x1234_5678_ABCD_EF01,
};
let sp2 = SpillPageEntry {
pgno: 5,
offset: 8192,
len: 4096,
xxh3_64: 0xFEDC_BA98_7654_3210,
};
let sp3 = SpillPageEntry {
pgno: 10,
offset: 12288,
len: 4096,
xxh3_64: 0,
};
assert_eq!(
sp1.to_bytes().len(),
32,
"SpillPageV1 must be exactly 32 bytes"
);
let sp1_rt = SpillPageEntry::from_bytes(&sp1.to_bytes()).unwrap();
assert_eq!(sp1_rt, sp1);
let sp2_rt = SpillPageEntry::from_bytes(&sp2.to_bytes()).unwrap();
assert_eq!(sp2_rt, sp2);
let sp3_rt = SpillPageEntry::from_bytes(&sp3.to_bytes()).unwrap();
assert_eq!(sp3_rt, sp3);
let wal = SubmitWalPayload {
permit_id: 42,
txn: WireTxnToken {
txn_id: 100,
txn_epoch: 1,
},
mode: 1, snapshot_high: 999,
schema_epoch: 5,
has_in_rw: true,
has_out_rw: false,
wal_fec_r: 3,
spill_pages: vec![sp1, sp2, sp3],
read_witness_refs: vec![],
write_witness_refs: vec![],
edge_refs: vec![],
merge_refs: vec![],
};
let bytes = wal.to_bytes();
let decoded = SubmitWalPayload::from_bytes(&bytes).expect("decode WAL payload");
assert_eq!(decoded.spill_pages.len(), 3);
assert_eq!(decoded.spill_pages[0], sp1);
assert_eq!(decoded.spill_pages[1], sp2);
assert_eq!(decoded.spill_pages[2], sp3);
assert_eq!(decoded.schema_epoch, 5);
assert!(decoded.has_in_rw);
assert!(!decoded.has_out_rw);
}
#[test]
fn test_rowid_reserve_roundtrip() {
let original = RowidReservePayload {
txn: WireTxnToken {
txn_id: 0xAAAA_BBBB_CCCC_DDDD,
txn_epoch: 42,
},
schema_epoch: 7,
table_id: 100,
count: 64,
};
let bytes = original.to_bytes();
assert_eq!(bytes.len(), 32, "RowIdReserveV1 must be exactly 32 bytes");
let decoded = RowidReservePayload::from_bytes(&bytes).expect("decode must succeed");
assert_eq!(decoded, original, "round-trip mismatch");
assert_eq!(
read_u64_le(&bytes[0..8]).unwrap(),
original.txn.txn_id,
"txn_id at offset 0"
);
assert_eq!(
read_u32_le(&bytes[8..12]).unwrap(),
original.txn.txn_epoch,
"txn_epoch at offset 8"
);
assert_eq!(
read_u64_le(&bytes[16..24]).unwrap(),
original.schema_epoch,
"schema_epoch at offset 16"
);
assert_eq!(
read_u32_le(&bytes[24..28]).unwrap(),
original.table_id,
"table_id at offset 24"
);
assert_eq!(
read_u32_le(&bytes[28..32]).unwrap(),
original.count,
"count at offset 28"
);
}
#[test]
fn test_rowid_reserve_response_ok_layout_and_roundtrip() {
let original = RowidReserveResponse::Ok {
start_rowid: 0x0123_4567_89AB_CDEF,
count: 64,
};
let bytes = original.to_bytes();
assert_eq!(
bytes.len(),
24,
"ROWID_RESERVE ok response must be 24 bytes"
);
assert_eq!(bytes[0], RowidReserveResponse::TAG_OK, "tag at offset 0");
assert!(
bytes[1..8].iter().all(|b| *b == 0),
"bytes 1..8 are reserved padding"
);
assert_eq!(
read_u64_le(&bytes[8..16]).unwrap(),
0x0123_4567_89AB_CDEF,
"start_rowid at offset 8"
);
assert_eq!(
read_u32_le(&bytes[16..20]).unwrap(),
64,
"count at offset 16"
);
assert_eq!(read_u32_le(&bytes[20..24]).unwrap(), 0, "pad1 at offset 20");
let decoded = RowidReserveResponse::from_bytes(&bytes).expect("decode ok response");
assert_eq!(decoded, original);
}
#[test]
fn test_rowid_reserve_response_err_layout_and_roundtrip() {
let original = RowidReserveResponse::Err { code: 0xDEAD_BEEF };
let bytes = original.to_bytes();
assert_eq!(
bytes.len(),
12,
"ROWID_RESERVE err response must be 12 bytes"
);
assert_eq!(bytes[0], RowidReserveResponse::TAG_ERR, "tag at offset 0");
assert!(
bytes[1..8].iter().all(|b| *b == 0),
"bytes 1..8 are reserved padding"
);
assert_eq!(
read_u32_le(&bytes[8..12]).unwrap(),
0xDEAD_BEEF,
"error code at offset 8"
);
let decoded = RowidReserveResponse::from_bytes(&bytes).expect("decode err response");
assert_eq!(decoded, original);
}
}