use crate::core::{
PageCount, PageIdx,
commit::SegmentFrameIdx,
page::{PAGESIZE, Page},
};
use bytes::{Bytes, BytesMut};
use smallvec::SmallVec;
use zstd::zstd_safe::{CCtx, CParameter, DCtx, InBuffer, OutBuffer, zstd_sys::ZSTD_EndDirective};
const FRAME_MAX_PAGES: PageCount = PageCount::new(64);
const ZSTD_COMPRESSION_LEVEL: i32 = 3;
pub struct SegmentBuilder {
frames: SmallVec<[SegmentFrameIdx; 1]>,
chunks: SmallVec<[Bytes; 1]>,
cctx: CCtx<'static>,
last_pageidx: Option<PageIdx>,
current_frame_pages: PageCount,
current_frame_bytes: usize,
chunk: Vec<u8>,
}
impl Default for SegmentBuilder {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl SegmentBuilder {
pub fn new() -> Self {
let mut cctx = CCtx::create();
cctx.set_parameter(CParameter::ContentSizeFlag(false))
.expect("BUG: failed to set content size flag");
cctx.set_parameter(CParameter::ChecksumFlag(true))
.expect("BUG: failed to set checksum flag");
cctx.set_parameter(CParameter::CompressionLevel(ZSTD_COMPRESSION_LEVEL))
.expect("BUG: failed to set compression level");
Self {
frames: Default::default(),
chunks: Default::default(),
cctx,
last_pageidx: None,
current_frame_pages: PageCount::ZERO,
current_frame_bytes: 0,
chunk: Vec::with_capacity(CCtx::out_size()),
}
}
fn flush_chunk(&mut self) {
let chunk = std::mem::replace(&mut self.chunk, Vec::with_capacity(CCtx::out_size()));
self.chunks.push(chunk.into());
}
pub fn write(&mut self, pageidx: PageIdx, page: &Page) {
if let Some(last_pageidx) = self.last_pageidx.replace(pageidx) {
assert!(pageidx > last_pageidx, "Pages must be pushed in order")
}
let mut in_buf = InBuffer::around(page.as_ref());
while in_buf.pos() < PAGESIZE {
let start_pos = self.chunk.len();
let mut out_buf = OutBuffer::around_pos(&mut self.chunk, start_pos);
let pending_flush = self
.cctx
.compress_stream2(
&mut out_buf,
&mut in_buf,
ZSTD_EndDirective::ZSTD_e_continue,
)
.expect("BUG: failed to compress frame");
self.current_frame_bytes += out_buf.pos() - start_pos;
if pending_flush > 0 && out_buf.pos() == out_buf.capacity() {
self.flush_chunk();
}
}
self.current_frame_pages = self.current_frame_pages.saturating_incr();
if self.current_frame_pages >= FRAME_MAX_PAGES {
self.end_frame();
}
}
fn end_frame(&mut self) {
let mut in_buf = InBuffer::around(&[]);
loop {
let start_pos = self.chunk.len();
let mut out_buf = OutBuffer::around_pos(&mut self.chunk, start_pos);
let pending_flush = self
.cctx
.compress_stream2(&mut out_buf, &mut in_buf, ZSTD_EndDirective::ZSTD_e_end)
.expect("BUG: failed to compress frame");
self.current_frame_bytes += out_buf.pos() - start_pos;
if pending_flush > 0 && out_buf.pos() == out_buf.capacity() {
self.flush_chunk();
} else if pending_flush == 0 {
break;
}
}
self.frames.push(SegmentFrameIdx::new(
self.current_frame_bytes,
self.last_pageidx.expect("BUG: flushing empty frame"),
));
self.current_frame_bytes = 0;
self.current_frame_pages = PageCount::ZERO;
self.cctx
.reset(zstd::zstd_safe::ResetDirective::SessionOnly)
.expect("BUG: failed to reset context");
}
pub fn finish(mut self) -> (SmallVec<[SegmentFrameIdx; 1]>, SmallVec<[Bytes; 1]>) {
if self.current_frame_pages > 0 {
self.end_frame();
}
let Self { mut chunks, chunk, frames, .. } = self;
if !chunk.is_empty() {
chunks.push(chunk.into());
}
(frames, chunks)
}
}
pub fn segment_frame_iter<'a>(
mut pages: impl Iterator<Item = PageIdx> + 'a,
frame: &'a [u8],
) -> impl Iterator<Item = (PageIdx, Page)> + 'a {
let mut dctx = DCtx::create();
let mut in_buf = InBuffer::around(frame);
std::iter::from_fn(move || {
if let Some(pageidx) = pages.next() {
let mut page = BytesMut::zeroed(PAGESIZE.as_usize());
let mut out_buf = OutBuffer::around(page.as_mut());
while out_buf.pos() < out_buf.capacity() {
let n = dctx
.decompress_stream(&mut out_buf, &mut in_buf)
.expect("BUG: failed to decompress segment frame");
assert!(
n > 0 || out_buf.pos() == out_buf.capacity(),
"BUG: reached end of frame before filling page"
);
}
Some((pageidx, Page::try_from(page).expect("BUG: invalid page")))
} else {
None
}
})
}
#[cfg(test)]
mod test {
use crate::pageidx;
use super::*;
#[test]
fn test_empty_segment() {
let segment = SegmentBuilder::new();
let (frames, chunks) = segment.finish();
assert_eq!(frames.len(), 0);
assert_eq!(chunks.len(), 0);
}
#[test]
fn test_segment() {
let mut segment = SegmentBuilder::new();
for i in 1..=96 {
segment.write(PageIdx::must_new(i), &Page::test_filled(i as u8));
}
let (frames, chunks) = segment.finish();
assert_eq!(frames.len(), 2);
assert_eq!(frames[0].last_pageidx(), pageidx!(64));
assert_eq!(frames[1].last_pageidx(), pageidx!(96));
assert_eq!(chunks.len(), 1);
assert_eq!(
chunks[0].len(),
frames[0].frame_size() + frames[1].frame_size()
);
}
}