graft_kernel/remote/
segment.rs

1/*
2* Segments are sequences of compressed ZStd frames. All pages in a Segment is stored
3* in order by `PageIdx`.
4*/
5
6use bytes::{Bytes, BytesMut};
7use graft_core::{
8    PageCount, PageIdx,
9    commit::SegmentFrameIdx,
10    page::{PAGESIZE, Page},
11};
12use smallvec::SmallVec;
13use zstd::zstd_safe::{CCtx, CParameter, DCtx, InBuffer, OutBuffer, zstd_sys::ZSTD_EndDirective};
14
15/// The maximum number of pages per Frame.
16/// At 4k per page this is 256k
17const FRAME_MAX_PAGES: PageCount = PageCount::new(64);
18
19/// The ZSTD compression level
20const ZSTD_COMPRESSION_LEVEL: i32 = 3;
21
22pub struct SegmentBuilder {
23    /// index of compressed frames
24    frames: SmallVec<[SegmentFrameIdx; 1]>,
25
26    /// chunks of the resulting segment. each chunk represents a portion of the
27    /// compressed stream of frames
28    chunks: SmallVec<[Bytes; 1]>,
29
30    /// the compression context
31    cctx: CCtx<'static>,
32
33    /// the last pageidx; used to ensure pages are pushed in order and to build
34    /// the frame index
35    last_pageidx: Option<PageIdx>,
36
37    /// the number of pages written to the current frame
38    current_frame_pages: PageCount,
39
40    /// the compressed size of current frame
41    current_frame_bytes: usize,
42
43    /// the active chunk
44    chunk: Vec<u8>,
45}
46
47impl Default for SegmentBuilder {
48    #[inline]
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl SegmentBuilder {
55    pub fn new() -> Self {
56        let mut cctx = CCtx::create();
57        cctx.set_parameter(CParameter::ContentSizeFlag(false))
58            .expect("BUG: failed to set content size flag");
59        cctx.set_parameter(CParameter::ChecksumFlag(true))
60            .expect("BUG: failed to set checksum flag");
61        cctx.set_parameter(CParameter::CompressionLevel(ZSTD_COMPRESSION_LEVEL))
62            .expect("BUG: failed to set compression level");
63        Self {
64            frames: Default::default(),
65            chunks: Default::default(),
66            cctx,
67            last_pageidx: None,
68            current_frame_pages: PageCount::ZERO,
69            current_frame_bytes: 0,
70            chunk: Vec::with_capacity(CCtx::out_size()),
71        }
72    }
73
74    fn flush_chunk(&mut self) {
75        let chunk = std::mem::replace(&mut self.chunk, Vec::with_capacity(CCtx::out_size()));
76        self.chunks.push(chunk.into());
77    }
78
79    pub fn write(&mut self, pageidx: PageIdx, page: &Page) {
80        if let Some(last_pageidx) = self.last_pageidx.replace(pageidx) {
81            assert!(pageidx > last_pageidx, "Pages must be pushed in order")
82        }
83
84        let mut in_buf = InBuffer::around(page.as_ref());
85
86        while in_buf.pos() < PAGESIZE {
87            let start_pos = self.chunk.len();
88            let mut out_buf = OutBuffer::around_pos(&mut self.chunk, start_pos);
89
90            let pending_flush = self
91                .cctx
92                .compress_stream2(
93                    &mut out_buf,
94                    &mut in_buf,
95                    ZSTD_EndDirective::ZSTD_e_continue,
96                )
97                .expect("BUG: failed to compress frame");
98
99            self.current_frame_bytes += out_buf.pos() - start_pos;
100
101            if pending_flush > 0 && out_buf.pos() == out_buf.capacity() {
102                // output buffer is full, swap chunks
103                self.flush_chunk();
104            }
105        }
106
107        self.current_frame_pages = self.current_frame_pages.saturating_incr();
108
109        if self.current_frame_pages >= FRAME_MAX_PAGES {
110            self.end_frame();
111        }
112    }
113
114    fn end_frame(&mut self) {
115        let mut in_buf = InBuffer::around(&[]);
116        loop {
117            let start_pos = self.chunk.len();
118            let mut out_buf = OutBuffer::around_pos(&mut self.chunk, start_pos);
119
120            let pending_flush = self
121                .cctx
122                .compress_stream2(&mut out_buf, &mut in_buf, ZSTD_EndDirective::ZSTD_e_end)
123                .expect("BUG: failed to compress frame");
124
125            self.current_frame_bytes += out_buf.pos() - start_pos;
126
127            if pending_flush > 0 && out_buf.pos() == out_buf.capacity() {
128                // output buffer is full, swap chunks
129                self.flush_chunk();
130            } else if pending_flush == 0 {
131                break;
132            }
133        }
134
135        // record the frame
136        self.frames.push(SegmentFrameIdx::new(
137            self.current_frame_bytes,
138            self.last_pageidx.expect("BUG: flushing empty frame"),
139        ));
140
141        // reset current frame vars
142        self.current_frame_bytes = 0;
143        self.current_frame_pages = PageCount::ZERO;
144        self.cctx
145            .reset(zstd::zstd_safe::ResetDirective::SessionOnly)
146            .expect("BUG: failed to reset context");
147    }
148
149    pub fn finish(mut self) -> (SmallVec<[SegmentFrameIdx; 1]>, SmallVec<[Bytes; 1]>) {
150        // flush the last frame if needed
151        if self.current_frame_pages > 0 {
152            self.end_frame();
153        }
154
155        let Self { mut chunks, chunk, frames, .. } = self;
156
157        // flush the last chunk if it's non-empty
158        if !chunk.is_empty() {
159            chunks.push(chunk.into());
160        }
161
162        (frames, chunks)
163    }
164}
165
166pub fn segment_frame_iter<'a>(
167    mut pages: impl Iterator<Item = PageIdx> + 'a,
168    frame: &'a [u8],
169) -> impl Iterator<Item = (PageIdx, Page)> + 'a {
170    let mut dctx = DCtx::create();
171    let mut in_buf = InBuffer::around(frame);
172
173    std::iter::from_fn(move || {
174        if let Some(pageidx) = pages.next() {
175            let mut page = BytesMut::zeroed(PAGESIZE.as_usize());
176            let mut out_buf = OutBuffer::around(page.as_mut());
177
178            while out_buf.pos() < out_buf.capacity() {
179                let n = dctx
180                    .decompress_stream(&mut out_buf, &mut in_buf)
181                    .expect("BUG: failed to decompress segment frame");
182                assert!(
183                    n > 0 || out_buf.pos() == out_buf.capacity(),
184                    "BUG: reached end of frame before filling page"
185                );
186            }
187
188            Some((pageidx, Page::try_from(page).expect("BUG: invalid page")))
189        } else {
190            None
191        }
192    })
193}
194
195#[cfg(test)]
196mod test {
197    use graft_core::pageidx;
198
199    use super::*;
200
201    #[test]
202    fn test_empty_segment() {
203        let segment = SegmentBuilder::new();
204        let (frames, chunks) = segment.finish();
205        assert_eq!(frames.len(), 0);
206        assert_eq!(chunks.len(), 0);
207    }
208
209    #[test]
210    fn test_segment() {
211        let mut segment = SegmentBuilder::new();
212
213        // Push 1.5 frames worth of pages
214        for i in 1..=96 {
215            segment.write(PageIdx::must_new(i), &Page::test_filled(i as u8));
216        }
217
218        // Finish the segment
219        let (frames, chunks) = segment.finish();
220
221        // Check the frames and chunks
222        assert_eq!(frames.len(), 2);
223        assert_eq!(frames[0].last_pageidx(), pageidx!(64));
224        assert_eq!(frames[1].last_pageidx(), pageidx!(96));
225        assert_eq!(chunks.len(), 1);
226        assert_eq!(
227            chunks[0].len(),
228            frames[0].frame_size() + frames[1].frame_size()
229        );
230    }
231}