use std::io::Write;
use super::{InvalidFramingReason, StreamMagic, WalExportError, WalRecordSink, MAX_RECORD_BYTES};
#[cfg(test)]
use super::STREAM_HEADER_MAGIC;
pub struct BufferedWalSink<W: Write> {
writer: W,
buffer: Vec<u8>,
capacity: usize,
header_emitted: bool,
last_seq: Option<u64>,
}
static_assertions::assert_not_impl_any!(BufferedWalSink<Vec<u8>>: std::io::Seek);
static_assertions::assert_not_impl_any!(BufferedWalSink<std::fs::File>: std::io::Seek);
static_assertions::assert_not_impl_any!(BufferedWalSink<&mut Vec<u8>>: std::io::Seek);
static_assertions::assert_not_impl_any!(BufferedWalSink<std::io::BufWriter<Vec<u8>>>: std::io::Seek);
static_assertions::assert_not_impl_any!(BufferedWalSink<std::io::Cursor<Vec<u8>>>: std::io::Seek);
impl<W: Write> BufferedWalSink<W> {
pub const DEFAULT_CAPACITY: usize = 1 << 24;
pub fn new(writer: W) -> Self {
Self::with_capacity(writer, Self::DEFAULT_CAPACITY)
}
pub fn with_capacity(writer: W, capacity: usize) -> Self {
Self {
writer,
buffer: Vec::new(),
capacity,
header_emitted: false,
last_seq: None,
}
}
pub(super) fn extract_seq(record_bytes: &[u8]) -> Result<u64, WalExportError> {
let (seq, _rest): (u64, &[u8]) = postcard::take_from_bytes(record_bytes)
.map_err(|_| WalExportError::InvalidFraming(InvalidFramingReason::Truncated))?;
Ok(seq)
}
fn validate_length(record_bytes: &[u8]) -> Result<u64, WalExportError> {
let len = record_bytes.len() as u64;
if len == 0 {
return Err(WalExportError::InvalidFraming(
InvalidFramingReason::LengthZero,
));
}
if len > MAX_RECORD_BYTES {
return Err(WalExportError::InvalidFraming(
InvalidFramingReason::LengthExceedsMax {
prefix: len,
max: MAX_RECORD_BYTES,
},
));
}
Ok(len)
}
fn ensure_header(&mut self) {
if !self.header_emitted {
self.buffer.extend_from_slice(StreamMagic::V1.bytes());
self.header_emitted = true;
}
}
#[cfg(test)]
pub(crate) fn into_writer_for_test(self) -> W {
self.writer
}
}
impl<W: Write> WalRecordSink for BufferedWalSink<W> {
fn append_record(&mut self, record_bytes: &[u8]) -> Result<(), WalExportError> {
let len = Self::validate_length(record_bytes)?;
let seq = Self::extract_seq(record_bytes)?;
if let Some(prev) = self.last_seq {
let expected = match prev.checked_add(1) {
Some(e) => e,
None => return Err(WalExportError::SeqExhausted { last_seq: prev }),
};
if seq != expected {
return Err(WalExportError::AppendOnlyViolation {
expected_seq: expected,
got_seq: seq,
previous_seq: Some(prev),
});
}
}
self.ensure_header();
let frame_size = 8 + len as usize;
let current_buffer = self.buffer.len();
if current_buffer.saturating_add(frame_size) > self.capacity {
return Err(WalExportError::BufferOverflow {
capacity: self.capacity,
requested: frame_size,
current_buffer,
});
}
self.buffer.extend_from_slice(&len.to_be_bytes());
self.buffer.extend_from_slice(record_bytes);
self.last_seq = Some(seq);
Ok(())
}
fn flush(&mut self) -> Result<(), WalExportError> {
if self.buffer.is_empty() {
return Ok(());
}
self.writer.write_all(&self.buffer)?;
self.writer.flush()?;
self.buffer.clear();
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
fn synth_record(seq: u64, padding: usize) -> Vec<u8> {
let mut bytes = postcard::to_allocvec(&seq).unwrap();
bytes.extend(std::iter::repeat(0u8).take(padding));
bytes
}
fn build_records(n: u64) -> Vec<Vec<u8>> {
(1..=n).map(|i| synth_record(i, 16)).collect()
}
#[test]
fn fresh_sink_initial_state() {
let sink = BufferedWalSink::new(Vec::<u8>::new());
assert!(sink.buffer.is_empty());
assert!(!sink.header_emitted);
assert!(sink.last_seq.is_none());
}
#[test]
fn with_capacity_pins_buffer_capacity() {
let sink = BufferedWalSink::with_capacity(Vec::<u8>::new(), 1024);
assert_eq!(sink.capacity, 1024);
}
#[test]
fn first_append_emits_header_and_pins_seq() {
let records = build_records(1);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
sink.append_record(&records[0]).expect("first append OK");
assert!(sink.header_emitted);
assert_eq!(sink.last_seq, Some(1));
assert!(sink.buffer.starts_with(&STREAM_HEADER_MAGIC));
}
#[test]
fn multi_record_append_in_order_succeeds() {
let records = build_records(5);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
for (i, rec) in records.iter().enumerate() {
sink.append_record(rec)
.unwrap_or_else(|e| panic!("append #{i} failed: {e}"));
}
assert_eq!(sink.last_seq, Some(5));
}
#[test]
fn out_of_order_seq_rejected_with_append_only_violation() {
let records = build_records(3);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
sink.append_record(&records[0]).expect("seq 1 OK");
let result = sink.append_record(&records[2]);
match result {
Err(WalExportError::AppendOnlyViolation {
expected_seq,
got_seq,
previous_seq,
}) => {
assert_eq!(expected_seq, 2);
assert_eq!(got_seq, 3);
assert_eq!(previous_seq, Some(1));
}
other => panic!("expected AppendOnlyViolation, got: {other:?}"),
}
}
#[test]
fn duplicate_seq_rejected_with_append_only_violation() {
let records = build_records(2);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
sink.append_record(&records[0]).expect("seq 1 OK");
let result = sink.append_record(&records[0]);
assert!(matches!(
result,
Err(WalExportError::AppendOnlyViolation {
expected_seq: 2,
got_seq: 1,
previous_seq: Some(1),
})
));
}
#[test]
fn empty_record_rejected_with_length_zero() {
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
let result = sink.append_record(&[]);
assert!(matches!(
result,
Err(WalExportError::InvalidFraming(
InvalidFramingReason::LengthZero
))
));
}
#[test]
fn truncated_record_rejected_with_truncated() {
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
let result = sink.append_record(&[0xFF]);
assert!(matches!(
result,
Err(WalExportError::InvalidFraming(
InvalidFramingReason::Truncated
))
));
}
#[test]
fn buffer_overflow_rejected_with_capacity() {
let records = build_records(10);
let single_frame = 8 + records[0].len();
let cap = STREAM_HEADER_MAGIC.len() + single_frame;
let mut sink = BufferedWalSink::with_capacity(Vec::<u8>::new(), cap);
sink.append_record(&records[0]).expect("first record fits");
let result = sink.append_record(&records[1]);
match result {
Err(WalExportError::BufferOverflow {
capacity,
requested,
current_buffer,
}) => {
assert_eq!(capacity, cap);
assert_eq!(requested, 8 + records[1].len());
assert_eq!(current_buffer, STREAM_HEADER_MAGIC.len() + single_frame);
assert!(current_buffer + requested > capacity);
}
other => panic!("expected BufferOverflow, got: {other:?}"),
}
}
#[test]
fn flush_writes_buffer_to_writer_and_clears() {
let records = build_records(3);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
for rec in &records {
sink.append_record(rec).expect("append OK");
}
let buffered_len = sink.buffer.len();
sink.flush().expect("flush OK");
assert!(sink.buffer.is_empty(), "buffer cleared after flush");
assert_eq!(
sink.writer.len(),
buffered_len,
"writer received the buffered bytes verbatim"
);
assert!(sink.writer.starts_with(&STREAM_HEADER_MAGIC));
}
#[test]
fn empty_flush_is_idempotent_noop() {
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
sink.flush().expect("empty flush OK");
assert!(sink.writer.is_empty());
assert!(!sink.header_emitted);
}
#[test]
fn double_flush_is_idempotent() {
let records = build_records(2);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
sink.append_record(&records[0]).expect("append OK");
sink.append_record(&records[1]).expect("append OK");
sink.flush().expect("first flush OK");
let writer_len_after_first = sink.writer.len();
sink.flush().expect("second flush OK");
assert_eq!(
sink.writer.len(),
writer_len_after_first,
"second flush is a no-op"
);
}
#[test]
fn header_emitted_exactly_once() {
let records = build_records(3);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
for rec in &records {
sink.append_record(rec).expect("append OK");
}
let count = sink
.buffer
.windows(STREAM_HEADER_MAGIC.len())
.filter(|w| *w == STREAM_HEADER_MAGIC)
.count();
assert_eq!(count, 1, "header magic appears exactly once");
}
#[test]
fn seq_wraparound_at_u64_max_rejected_with_seq_exhausted() {
let records = build_records(1);
let mut sink = BufferedWalSink::new(Vec::<u8>::new());
sink.last_seq = Some(u64::MAX);
sink.header_emitted = true; let result = sink.append_record(&records[0]);
match result {
Err(WalExportError::SeqExhausted { last_seq }) => {
assert_eq!(last_seq, u64::MAX);
}
other => panic!("expected SeqExhausted, got: {other:?}"),
}
}
#[test]
fn length_exceeds_max_rejected() {
let oversized = vec![0u8; (MAX_RECORD_BYTES + 1) as usize];
let result = BufferedWalSink::<Vec<u8>>::validate_length(&oversized);
match result {
Err(WalExportError::InvalidFraming(InvalidFramingReason::LengthExceedsMax {
prefix,
max,
})) => {
assert_eq!(prefix, MAX_RECORD_BYTES + 1);
assert_eq!(max, MAX_RECORD_BYTES);
}
other => panic!("expected LengthExceedsMax, got: {other:?}"),
}
}
#[test]
fn extract_seq_round_trips_through_postcard() {
let records = build_records(7);
for (i, rec) in records.iter().enumerate() {
let seq = BufferedWalSink::<Vec<u8>>::extract_seq(rec).expect("decode OK");
assert_eq!(seq, (i as u64) + 1);
}
}
}