snap_buf/
lib.rs

1#![no_std]
2//! A [SnapBuf] is like a `Vec<u8>` with cheap snapshotting using copy on write.
3//!
4//! Internally, the data is broken up into segments that are organized in a tree structure.
5//! Only modified subtrees are cloned, so buffers with only little differences can share most of their memory.
6//! Moreover, subtrees which contain only zeros take up no memory.
7
8extern crate alloc;
9#[cfg(feature = "test")]
10extern crate std;
11
12use alloc::sync::Arc;
13use core::cmp::Ordering;
14use core::ops::Range;
15use core::{iter, mem, slice};
16use smallvec::SmallVec;
17
18#[derive(Debug)]
19pub struct SnapBuf {
20    size: usize,
21    root_height: usize,
22    root: NodePointer,
23}
24
25const LEAF_SIZE: usize = if cfg!(feature = "test") { 32 } else { 4000 };
26const INNER_SIZE: usize = if cfg!(feature = "test") { 4 } else { 500 };
27
28#[cfg(feature = "test")]
29pub mod test;
30
31#[derive(Clone, Debug)]
32enum Node {
33    Inner([NodePointer; INNER_SIZE]),
34    Leaf([u8; LEAF_SIZE]),
35}
36
37#[derive(Clone, Debug)]
38struct NodePointer(Option<Arc<Node>>);
39
40macro_rules! deconstruct_range{
41    {$start:ident .. $end:ident = $range:expr,$height:expr} => {
42         let $start = $range.start;
43        let $end = $range.end;
44         // assert range overlaps
45        debug_assert!($start < tree_size($height) as isize);
46        debug_assert!($end > 0);
47    }
48}
49
50impl NodePointer {
51    fn children(&self) -> Option<&[NodePointer; INNER_SIZE]> {
52        match &**(self.0.as_ref()?) {
53            Node::Inner(x) => Some(x),
54            Node::Leaf(_) => None,
55        }
56    }
57
58    fn get_mut(&mut self, height: usize) -> &mut Node {
59        let arc = self.0.get_or_insert_with(|| {
60            Arc::new({
61                if height == 0 {
62                    Node::Leaf([0; LEAF_SIZE])
63                } else {
64                    Node::Inner(array_init::array_init(|_| NodePointer(None)))
65                }
66            })
67        });
68        Arc::make_mut(arc)
69    }
70
71    fn set_range(&mut self, height: usize, start: isize, values: &[u8]) {
72        deconstruct_range!(start..end = start .. start + values.len() as isize ,height);
73        match self.get_mut(height) {
74            Node::Inner(children) => {
75                for (child_offset, child) in
76                    Self::affected_children(children, height - 1, start..end)
77                {
78                    child.set_range(height - 1, start - child_offset, values);
79                }
80            }
81            Node::Leaf(bytes) => {
82                let (src, dst) = if start < 0 {
83                    (&values[-start as usize..], &mut bytes[..])
84                } else {
85                    (values, &mut bytes[start as usize..])
86                };
87                let len = src.len().min(dst.len());
88                dst[..len].copy_from_slice(&src[..len]);
89            }
90        }
91    }
92
93    fn affected_children(
94        children: &mut [NodePointer; INNER_SIZE],
95        child_height: usize,
96        range: Range<isize>,
97    ) -> impl Iterator<Item = (isize, &mut NodePointer)> {
98        let start = range.start.max(0) as usize;
99        let child_size = tree_size(child_height);
100        children
101            .iter_mut()
102            .enumerate()
103            .skip(start / child_size)
104            .map(move |(i, c)| ((i * child_size) as isize, c))
105            .take_while(move |(offset, _)| (*offset) < range.end)
106    }
107
108    fn fill_range(&mut self, height: usize, range: Range<isize>, value: u8) {
109        deconstruct_range!(start..end=range,height);
110        match self.get_mut(height) {
111            Node::Inner(children) => {
112                for (child_offset, child) in
113                    Self::affected_children(children, height - 1, range.clone())
114                {
115                    child.fill_range(height - 1, start - child_offset..end - child_offset, value);
116                }
117            }
118            Node::Leaf(bytes) => {
119                let write_start = start.max(0) as usize;
120                let write_end = (end as usize).min(bytes.len());
121                bytes[write_start..write_end].fill(value);
122            }
123        }
124    }
125
126    fn clear_range(&mut self, height: usize, range: Range<isize>) {
127        fn range_all<T, const C: usize>(x: &[T; C], mut f: impl FnMut(&T) -> bool) -> bool {
128            let last = f(x.last().unwrap());
129            last && x[0..C - 1].iter().all(f)
130        }
131
132        deconstruct_range!(start..end = range,height);
133        if start <= 0 && end as usize >= tree_size(height) || self.0.is_none() {
134            self.0 = None;
135            return;
136        }
137        match self.get_mut(height) {
138            Node::Inner(children) => {
139                for (child_offset, child) in
140                    Self::affected_children(children, height - 1, range.clone())
141                {
142                    child.clear_range(height - 1, start - child_offset..end - child_offset);
143                }
144                if range_all(children, |c| c.0.is_none()) {
145                    self.0 = None;
146                }
147            }
148            Node::Leaf(bytes) => {
149                let write_start = start.max(0) as usize;
150                let write_end = (end as usize).min(bytes.len());
151                bytes[write_start..write_end].fill(0);
152                if range_all(bytes, |b| *b == 0) {
153                    self.0 = None;
154                }
155            }
156        }
157    }
158
159    fn put_leaf(&mut self, height: usize, offset: usize, leaf: NodePointer) {
160        match self.get_mut(height) {
161            Node::Inner(children) => {
162                let range = offset as isize..offset as isize + 1;
163                let (co, c) = Self::affected_children(children, height - 1, range)
164                    .next()
165                    .unwrap();
166                c.put_leaf(height - 1, offset - co as usize, leaf);
167            }
168            Node::Leaf(_) => {
169                debug_assert_eq!(offset, 0);
170                *self = leaf;
171            }
172        }
173    }
174
175    fn locate_leaf(
176        &mut self,
177        height: usize,
178        offset: usize,
179    ) -> Option<(usize, &mut [u8; LEAF_SIZE])> {
180        self.0.as_ref()?;
181        match self.get_mut(height) {
182            Node::Inner(children) => {
183                let range = offset as isize..offset as isize + 1;
184                let (co, c) = Self::affected_children(children, height - 1, range)
185                    .next()
186                    .unwrap();
187                c.locate_leaf(height - 1, offset - co as usize)
188            }
189            Node::Leaf(x) => Some((offset, x)),
190        }
191    }
192}
193
194const fn const_tree_size(height: usize) -> usize {
195    if height == 0 {
196        LEAF_SIZE
197    } else {
198        INNER_SIZE * const_tree_size(height - 1)
199    }
200}
201
202fn tree_size(height: usize) -> usize {
203    const_tree_size(height)
204}
205
206impl Default for SnapBuf {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl SnapBuf {
213    /// Creates an empty buffer.
214    pub fn new() -> Self {
215        Self {
216            root_height: 0,
217            size: 0,
218            root: NodePointer(None),
219        }
220    }
221
222    fn shrink(&mut self, new_len: usize) {
223        self.root.clear_range(
224            self.root_height,
225            new_len as isize..tree_size(self.root_height) as isize,
226        );
227        self.size = new_len;
228    }
229
230    fn grow_height_until(&mut self, min_size: usize) {
231        while tree_size(self.root_height) < min_size {
232            if self.root.0.is_some() {
233                let new_root = Arc::new(Node::Inner(array_init::array_init(|x| {
234                    if x == 0 {
235                        self.root.clone()
236                    } else {
237                        NodePointer(None)
238                    }
239                })));
240                self.root = NodePointer(Some(new_root.clone()));
241            }
242            self.root_height += 1;
243        }
244    }
245
246    fn grow_zero(&mut self, new_len: usize) {
247        self.grow_height_until(new_len);
248        self.size = new_len;
249    }
250
251    /// Resizes the buffer, to the given length.
252    ///
253    /// If `new_len` is greater than `len`, the new space in the buffer is filled with copies of value.
254    /// This is more efficient if `value == 0`.
255    #[inline]
256    pub fn resize(&mut self, new_len: usize, value: u8) {
257        match new_len.cmp(&self.size) {
258            Ordering::Less => {
259                self.shrink(new_len);
260            }
261            Ordering::Equal => {}
262            Ordering::Greater => {
263                let old_len = self.size;
264                self.grow_zero(new_len);
265                if value != 0 {
266                    self.fill_range(old_len..new_len, value);
267                }
268            }
269        }
270    }
271
272    /// Shortens the buffer, keeping the first `new_len` bytes and discarding the rest.
273    ///
274    /// If `new_len` is greater or equal to the buffer’s current length, this has no effect.
275    pub fn truncate(&mut self, new_len: usize) {
276        if new_len > self.size {
277            self.shrink(new_len);
278        }
279    }
280
281    /// Fill the given range with copies of value.
282    ///
283    /// This is equivalent to calling [write](Self::write) with a slice filled with value.
284    /// Calling this with `value = 0` is not guaranteed to free up the zeroed segments,
285    /// use [clear_range](Self::clear_range) if that is required.
286    pub fn fill_range(&mut self, range: Range<usize>, value: u8) {
287        if self.size < range.end {
288            self.grow_zero(range.end);
289        }
290        if range.is_empty() {
291            return;
292        }
293        let range = range.start as isize..range.end as isize;
294        self.root.fill_range(self.root_height, range, value);
295    }
296
297    /// Writes data at the specified offset.
298    ///
299    /// If this extends past the current end of the buffer, the buffer is automatically resized.
300    /// If offset is larger than the current buffer length, the space between the current buffer
301    /// end and the written region is filled with zeros.
302    ///
303    /// Zeroing parts of the buffer using this method is not guaranteed to free up the zeroed segments,
304    /// use [clear_range](Self::clear_range) if that is required.
305    pub fn write(&mut self, offset: usize, data: &[u8]) {
306        let write_end = offset + data.len();
307        if self.size < write_end {
308            self.resize(write_end, 0);
309        }
310        if data.is_empty() {
311            return;
312        }
313        self.root.set_range(self.root_height, offset as isize, data);
314    }
315
316    /// Returns `true` if the buffer length is zero.
317    pub fn is_empty(&self) -> bool {
318        self.len() == 0
319    }
320
321    /// Returns the length of the buffer, the number of bytes it contains.
322    ///
323    /// The memory footprint of the buffer may be much smaller than this due to omission of zero segments and sharing with other buffers.
324    pub fn len(&self) -> usize {
325        self.size
326    }
327
328    /// Fill a range with zeros, freeing memory if possible.
329    ///
330    /// # Panics
331    /// Panics if range end is `range.end` > `self.len()`.
332    pub fn clear_range(&mut self, range: Range<usize>) {
333        assert!(range.end <= self.size);
334        if range.is_empty() {
335            return;
336        }
337        self.root
338            .clear_range(self.root_height, range.start as isize..range.end as isize);
339    }
340
341    /// Clears all data and sets length to 0.
342    pub fn clear(&mut self) {
343        *self = Self::new();
344    }
345
346    fn iter_nodes_pre_order(&self) -> impl Iterator<Item = (&NodePointer, usize)> {
347        struct IterStack<'a> {
348            stack_end_height: usize,
349            stack: SmallVec<[&'a [NodePointer]; 5]>,
350        }
351
352        #[allow(clippy::needless_lifetimes)]
353        fn split_first_in_place<'x, 's, T>(x: &'x mut &'s [T]) -> &'s T {
354            let (first, rest) = mem::take(x).split_first().unwrap();
355            *x = rest;
356            first
357        }
358
359        impl<'a> Iterator for IterStack<'a> {
360            type Item = (&'a NodePointer, usize);
361
362            fn next(&mut self) -> Option<Self::Item> {
363                let visit_now = loop {
364                    let last_level = self.stack.last_mut()?;
365                    if last_level.is_empty() {
366                        self.stack.pop();
367                        self.stack_end_height += 1;
368                    } else {
369                        break split_first_in_place(last_level);
370                    }
371                };
372                let ret = (visit_now, self.stack_end_height);
373                if let Some(children) = visit_now.children() {
374                    self.stack.push(children);
375                    self.stack_end_height -= 1;
376                }
377                Some(ret)
378            }
379        }
380
381        let mut stack = SmallVec::new();
382        stack.push(slice::from_ref(&self.root));
383        IterStack {
384            stack_end_height: self.root_height,
385            stack,
386        }
387    }
388
389    /// Returns an iterator over the byte slices constituting the buffer.
390    ///
391    /// The returned slices may overlap.
392    pub fn chunks(&self) -> impl Iterator<Item = &[u8]> {
393        let mut emitted = 0;
394        self.iter_nodes_pre_order()
395            .flat_map(|(node, height)| {
396                let zero_leaf = &[0u8; LEAF_SIZE];
397                match node.0.as_deref() {
398                    None => {
399                        let leaf_count = INNER_SIZE.pow(height as u32);
400                        iter::repeat_n(zero_leaf, leaf_count)
401                    }
402                    Some(Node::Inner(_)) => iter::repeat_n(zero_leaf, 0),
403                    Some(Node::Leaf(b)) => iter::repeat_n(b, 1),
404                }
405            })
406            .map(move |x| {
407                let emit = (self.size - emitted).min(x.len());
408                emitted += emit;
409                &x[..emit]
410            })
411            .filter(|x| !x.is_empty())
412    }
413
414    /// Returns an iterator over the buffer.
415    pub fn iter(&self) -> impl Iterator<Item = &u8> {
416        self.chunks().flat_map(|x| x.iter())
417    }
418
419    #[doc(hidden)]
420    pub fn bytes(&self) -> impl Iterator<Item = u8> + '_ {
421        self.iter().copied()
422    }
423
424    pub fn extend_from_slice(&mut self, data: &[u8]) {
425        self.write(self.size, data)
426    }
427}
428
429impl Extend<u8> for SnapBuf {
430    fn extend<T: IntoIterator<Item = u8>>(&mut self, iter: T) {
431        fn generate_leaf(
432            start_at: usize,
433            iter: &mut impl Iterator<Item = u8>,
434        ) -> (usize, NodePointer) {
435            let mut consumed = start_at;
436            let first_non_zero = loop {
437                if let Some(x) = iter.next() {
438                    consumed += 1;
439                    if x != 0 {
440                        break x;
441                    }
442                } else {
443                    return (consumed, NodePointer(None));
444                }
445                if consumed == LEAF_SIZE {
446                    return (LEAF_SIZE, NodePointer(None));
447                }
448            };
449            let mut leaf = Arc::new(Node::Leaf([0u8; LEAF_SIZE]));
450            let leaf_mut = if let Node::Leaf(x) = Arc::get_mut(&mut leaf).unwrap() {
451                x
452            } else {
453                unreachable!()
454            };
455            leaf_mut[consumed - 1] = first_non_zero;
456            while consumed < LEAF_SIZE {
457                if let Some(x) = iter.next() {
458                    leaf_mut[consumed] = x;
459                    consumed += 1;
460                } else {
461                    break;
462                }
463            }
464            (consumed, NodePointer(Some(leaf)))
465        }
466
467        let it = &mut iter.into_iter();
468        if self.size < tree_size(self.root_height) {
469            if let Some((offset, first_leaf)) = self.root.locate_leaf(self.root_height, self.size) {
470                for i in offset..LEAF_SIZE {
471                    let Some(x) = it.next() else { return };
472                    first_leaf[i] = x;
473                    self.size += 1;
474                }
475                assert_eq!(self.size % LEAF_SIZE, 0);
476            }
477        } else {
478            assert_eq!(self.size % LEAF_SIZE, 0);
479        }
480        loop {
481            let in_leaf_offset = self.size % LEAF_SIZE;
482            let (consumed, leaf) = generate_leaf(in_leaf_offset, it);
483            let old_size = self.size;
484            self.size = old_size - in_leaf_offset + consumed;
485            self.grow_height_until(self.size);
486            if leaf.0.is_some() {
487                self.root
488                    .put_leaf(self.root_height, old_size - in_leaf_offset, leaf);
489            }
490            if consumed < LEAF_SIZE {
491                return;
492            }
493            assert_eq!(self.size % LEAF_SIZE, 0);
494        }
495    }
496}
497
498impl FromIterator<u8> for SnapBuf {
499    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
500        let mut iter = iter.into_iter();
501        let mut ret = Self::new();
502        ret.extend(&mut iter);
503        ret
504    }
505}