snap_buf/
lib.rs

1#![no_std]
2//! A [SnapBuf] is like a `Vec<u8>` with cheap snapshotting using copy on write.
3//!
4//! Internally, the data is broken up into segments that are organized in a tree structure.
5//! Only modified subtrees are cloned, so buffers with only little differences can share most of their memory.
6//! Moreover, subtrees which contain only zeros take up no memory.
7
8extern crate alloc;
9#[cfg(feature = "test")]
10extern crate std;
11
12use alloc::sync::Arc;
13use core::cmp::Ordering;
14use core::ops::Range;
15use core::{iter, mem, slice};
16use smallvec::SmallVec;
17
18#[derive(Debug)]
19pub struct SnapBuf {
20    size: usize,
21    root_height: usize,
22    root: NodePointer,
23}
24
25const LEAF_SIZE: usize = if cfg!(feature = "test") { 32 } else { 4000 };
26const INNER_SIZE: usize = if cfg!(feature = "test") { 4 } else { 500 };
27
28#[cfg(feature = "test")]
29pub mod test;
30
31#[derive(Clone, Debug)]
32enum Node {
33    Inner([NodePointer; INNER_SIZE]),
34    Leaf([u8; LEAF_SIZE]),
35}
36
37#[derive(Clone, Debug)]
38struct NodePointer(Option<Arc<Node>>);
39
40macro_rules! deconstruct_range{
41    {$start:ident .. $end:ident = $range:expr,$height:expr} => {
42         let $start = $range.start;
43        let $end = $range.end;
44         // assert range overlaps
45        debug_assert!($start < tree_size($height) as isize);
46        debug_assert!($end > 0);
47    }
48}
49
50fn range_all<T, const C: usize>(x: &[T; C], mut f: impl FnMut(&T) -> bool) -> bool {
51    let last = f(x.last().unwrap());
52    last && x[0..C - 1].iter().all(f)
53}
54
55impl NodePointer {
56    fn children(&self) -> Option<&[NodePointer; INNER_SIZE]> {
57        match &**(self.0.as_ref()?) {
58            Node::Inner(x) => Some(x),
59            Node::Leaf(_) => None,
60        }
61    }
62
63    fn get_mut(&mut self, height: usize) -> &mut Node {
64        let arc = self.0.get_or_insert_with(|| {
65            Arc::new({
66                if height == 0 {
67                    Node::Leaf([0; LEAF_SIZE])
68                } else {
69                    Node::Inner(array_init::array_init(|_| NodePointer(None)))
70                }
71            })
72        });
73        Arc::make_mut(arc)
74    }
75
76    fn set_range<const FREE_ZEROS: bool>(&mut self, height: usize, start: isize, values: &[u8]) {
77        deconstruct_range!(start..end = start .. start + values.len() as isize ,height);
78        match self.get_mut(height) {
79            Node::Inner(children) => {
80                for (child_offset, child) in
81                    Self::affected_children(children, height - 1, start..end)
82                {
83                    child.set_range::<FREE_ZEROS>(height - 1, start - child_offset, values);
84                }
85                if FREE_ZEROS && range_all(children, |c| c.0.is_none()) {
86                    self.0 = None;
87                }
88            }
89            Node::Leaf(bytes) => {
90                let (src, dst) = if start < 0 {
91                    (&values[-start as usize..], &mut bytes[..])
92                } else {
93                    (values, &mut bytes[start as usize..])
94                };
95                let len = src.len().min(dst.len());
96                dst[..len].copy_from_slice(&src[..len]);
97                if FREE_ZEROS && range_all(bytes, |b| *b == 0) {
98                    self.0 = None;
99                }
100            }
101        }
102    }
103
104    fn affected_children(
105        children: &mut [NodePointer; INNER_SIZE],
106        child_height: usize,
107        range: Range<isize>,
108    ) -> impl Iterator<Item = (isize, &mut NodePointer)> {
109        let start = range.start.max(0) as usize;
110        let child_size = tree_size(child_height);
111        children
112            .iter_mut()
113            .enumerate()
114            .skip(start / child_size)
115            .map(move |(i, c)| ((i * child_size) as isize, c))
116            .take_while(move |(offset, _)| (*offset) < range.end)
117    }
118
119    fn fill_range(&mut self, height: usize, range: Range<isize>, value: u8) {
120        deconstruct_range!(start..end=range,height);
121        match self.get_mut(height) {
122            Node::Inner(children) => {
123                for (child_offset, child) in
124                    Self::affected_children(children, height - 1, range.clone())
125                {
126                    child.fill_range(height - 1, start - child_offset..end - child_offset, value);
127                }
128            }
129            Node::Leaf(bytes) => {
130                let write_start = start.max(0) as usize;
131                let write_end = (end as usize).min(bytes.len());
132                bytes[write_start..write_end].fill(value);
133            }
134        }
135    }
136
137    fn clear_range(&mut self, height: usize, range: Range<isize>) {
138        deconstruct_range!(start..end = range,height);
139        if start <= 0 && end as usize >= tree_size(height) || self.0.is_none() {
140            self.0 = None;
141            return;
142        }
143        match self.get_mut(height) {
144            Node::Inner(children) => {
145                for (child_offset, child) in
146                    Self::affected_children(children, height - 1, range.clone())
147                {
148                    child.clear_range(height - 1, start - child_offset..end - child_offset);
149                }
150                if range_all(children, |c| c.0.is_none()) {
151                    self.0 = None;
152                }
153            }
154            Node::Leaf(bytes) => {
155                let write_start = start.max(0) as usize;
156                let write_end = (end as usize).min(bytes.len());
157                bytes[write_start..write_end].fill(0);
158                if range_all(bytes, |b| *b == 0) {
159                    self.0 = None;
160                }
161            }
162        }
163    }
164
165    fn put_leaf(&mut self, height: usize, offset: usize, leaf: NodePointer) {
166        match self.get_mut(height) {
167            Node::Inner(children) => {
168                let range = offset as isize..offset as isize + 1;
169                let (co, c) = Self::affected_children(children, height - 1, range)
170                    .next()
171                    .unwrap();
172                c.put_leaf(height - 1, offset - co as usize, leaf);
173            }
174            Node::Leaf(_) => {
175                debug_assert_eq!(offset, 0);
176                *self = leaf;
177            }
178        }
179    }
180
181    fn locate_leaf(
182        &mut self,
183        height: usize,
184        offset: usize,
185    ) -> Option<(usize, &mut [u8; LEAF_SIZE])> {
186        self.0.as_ref()?;
187        match self.get_mut(height) {
188            Node::Inner(children) => {
189                let range = offset as isize..offset as isize + 1;
190                let (co, c) = Self::affected_children(children, height - 1, range)
191                    .next()
192                    .unwrap();
193                c.locate_leaf(height - 1, offset - co as usize)
194            }
195            Node::Leaf(x) => Some((offset, x)),
196        }
197    }
198}
199
200const fn const_tree_size(height: usize) -> usize {
201    if height == 0 {
202        LEAF_SIZE
203    } else {
204        INNER_SIZE * const_tree_size(height - 1)
205    }
206}
207
208fn tree_size(height: usize) -> usize {
209    const_tree_size(height)
210}
211
212impl Default for SnapBuf {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218impl SnapBuf {
219    /// Creates an empty buffer.
220    pub fn new() -> Self {
221        Self {
222            root_height: 0,
223            size: 0,
224            root: NodePointer(None),
225        }
226    }
227
228    fn shrink(&mut self, new_len: usize) {
229        self.root.clear_range(
230            self.root_height,
231            new_len as isize..tree_size(self.root_height) as isize,
232        );
233        self.size = new_len;
234    }
235
236    fn grow_height_until(&mut self, min_size: usize) {
237        while tree_size(self.root_height) < min_size {
238            if self.root.0.is_some() {
239                let new_root = Arc::new(Node::Inner(array_init::array_init(|x| {
240                    if x == 0 {
241                        self.root.clone()
242                    } else {
243                        NodePointer(None)
244                    }
245                })));
246                self.root = NodePointer(Some(new_root.clone()));
247            }
248            self.root_height += 1;
249        }
250    }
251
252    fn grow_zero(&mut self, new_len: usize) {
253        self.grow_height_until(new_len);
254        self.size = new_len;
255    }
256
257    /// Resizes the buffer, to the given length.
258    ///
259    /// If `new_len` is greater than `len`, the new space in the buffer is filled with copies of value.
260    /// This is more efficient if `value == 0`.
261    #[inline]
262    pub fn resize(&mut self, new_len: usize, value: u8) {
263        match new_len.cmp(&self.size) {
264            Ordering::Less => {
265                self.shrink(new_len);
266            }
267            Ordering::Equal => {}
268            Ordering::Greater => {
269                let old_len = self.size;
270                self.grow_zero(new_len);
271                if value != 0 {
272                    self.fill_range(old_len..new_len, value);
273                }
274            }
275        }
276    }
277
278    /// Shortens the buffer, keeping the first `new_len` bytes and discarding the rest.
279    ///
280    /// If `new_len` is greater or equal to the buffer’s current length, this has no effect.
281    pub fn truncate(&mut self, new_len: usize) {
282        if new_len < self.size {
283            self.shrink(new_len);
284        }
285    }
286
287    /// Fill the given range with copies of value.
288    ///
289    /// This is equivalent to calling [write](Self::write) with a slice filled with value.
290    /// Calling this with `value = 0` is not guaranteed to free up the zeroed segments,
291    /// use [clear_range](Self::clear_range) if that is required.
292    pub fn fill_range(&mut self, range: Range<usize>, value: u8) {
293        if self.size < range.end {
294            self.grow_zero(range.end);
295        }
296        if range.is_empty() {
297            return;
298        }
299        let range = range.start as isize..range.end as isize;
300        self.root.fill_range(self.root_height, range, value);
301    }
302
303    /// Writes data at the specified offset.
304    ///
305    /// If this extends past the current end of the buffer, the buffer is automatically resized.
306    /// If offset is larger than the current buffer length, the space between the current buffer
307    /// end and the written region is filled with zeros.
308    ///
309    /// Zeroing parts of the buffer using this method is not guaranteed to free up the zeroed segments,
310    /// use [write_with_zeros](Self::write_with_zeros) if that is required.
311    pub fn write(&mut self, offset: usize, data: &[u8]) {
312        self.write_inner::<false>(offset, data);
313    }
314
315    /// Like [write](Self::write), but detects and frees newly zeroed segments in the buffer.
316    pub fn write_with_zeros(&mut self, offset: usize, data: &[u8]) {
317        self.write_inner::<true>(offset, data);
318    }
319
320    fn write_inner<const FREE_ZEROS: bool>(&mut self, offset: usize, data: &[u8]) {
321        let write_end = offset + data.len();
322        if self.size < write_end {
323            self.resize(write_end, 0);
324        }
325        if data.is_empty() {
326            return;
327        }
328        self.root
329            .set_range::<FREE_ZEROS>(self.root_height, offset as isize, data);
330    }
331
332    /// Returns `true` if the buffer length is zero.
333    pub fn is_empty(&self) -> bool {
334        self.len() == 0
335    }
336
337    /// Returns the length of the buffer, the number of bytes it contains.
338    ///
339    /// The memory footprint of the buffer may be much smaller than this due to omission of zero segments and sharing with other buffers.
340    pub fn len(&self) -> usize {
341        self.size
342    }
343
344    /// Fill a range with zeros, freeing memory if possible.
345    ///
346    /// # Panics
347    /// Panics if range end is `range.end` > `self.len()`.
348    pub fn clear_range(&mut self, range: Range<usize>) {
349        assert!(range.end <= self.size);
350        if range.is_empty() {
351            return;
352        }
353        self.root
354            .clear_range(self.root_height, range.start as isize..range.end as isize);
355    }
356
357    /// Clears all data and sets length to 0.
358    pub fn clear(&mut self) {
359        *self = Self::new();
360    }
361
362    fn iter_nodes_pre_order(&self) -> impl Iterator<Item = (&NodePointer, usize)> {
363        struct IterStack<'a> {
364            stack_end_height: usize,
365            stack: SmallVec<[&'a [NodePointer]; 5]>,
366        }
367
368        #[allow(clippy::needless_lifetimes)]
369        fn split_first_in_place<'x, 's, T>(x: &'x mut &'s [T]) -> &'s T {
370            let (first, rest) = mem::take(x).split_first().unwrap();
371            *x = rest;
372            first
373        }
374
375        impl<'a> Iterator for IterStack<'a> {
376            type Item = (&'a NodePointer, usize);
377
378            fn next(&mut self) -> Option<Self::Item> {
379                let visit_now = loop {
380                    let last_level = self.stack.last_mut()?;
381                    if last_level.is_empty() {
382                        self.stack.pop();
383                        self.stack_end_height += 1;
384                    } else {
385                        break split_first_in_place(last_level);
386                    }
387                };
388                let ret = (visit_now, self.stack_end_height);
389                if let Some(children) = visit_now.children() {
390                    self.stack.push(children);
391                    self.stack_end_height -= 1;
392                }
393                Some(ret)
394            }
395        }
396
397        let mut stack = SmallVec::new();
398        stack.push(slice::from_ref(&self.root));
399        IterStack {
400            stack_end_height: self.root_height,
401            stack,
402        }
403    }
404
405    /// Returns an iterator over the byte slices constituting the buffer.
406    ///
407    /// The returned slices may overlap.
408    pub fn chunks(&self) -> impl Iterator<Item = &[u8]> {
409        let mut emitted = 0;
410        self.iter_nodes_pre_order()
411            .flat_map(|(node, height)| {
412                let zero_leaf = &[0u8; LEAF_SIZE];
413                match node.0.as_deref() {
414                    None => {
415                        let leaf_count = INNER_SIZE.pow(height as u32);
416                        iter::repeat_n(zero_leaf, leaf_count)
417                    }
418                    Some(Node::Inner(_)) => iter::repeat_n(zero_leaf, 0),
419                    Some(Node::Leaf(b)) => iter::repeat_n(b, 1),
420                }
421            })
422            .map(move |x| {
423                let emit = (self.size - emitted).min(x.len());
424                emitted += emit;
425                &x[..emit]
426            })
427            .filter(|x| !x.is_empty())
428    }
429
430    /// Returns an iterator over the buffer.
431    pub fn iter(&self) -> impl Iterator<Item = &u8> {
432        self.chunks().flat_map(|x| x.iter())
433    }
434
435    #[doc(hidden)]
436    pub fn bytes(&self) -> impl Iterator<Item = u8> + '_ {
437        self.iter().copied()
438    }
439
440    pub fn extend_from_slice(&mut self, data: &[u8]) {
441        self.write_with_zeros(self.size, data)
442    }
443}
444
445impl Extend<u8> for SnapBuf {
446    fn extend<T: IntoIterator<Item = u8>>(&mut self, iter: T) {
447        fn generate_leaf(
448            start_at: usize,
449            iter: &mut impl Iterator<Item = u8>,
450        ) -> (usize, NodePointer) {
451            let mut consumed = start_at;
452            let first_non_zero = loop {
453                if let Some(x) = iter.next() {
454                    consumed += 1;
455                    if x != 0 {
456                        break x;
457                    }
458                } else {
459                    return (consumed, NodePointer(None));
460                }
461                if consumed == LEAF_SIZE {
462                    return (LEAF_SIZE, NodePointer(None));
463                }
464            };
465            let mut leaf = Arc::new(Node::Leaf([0u8; LEAF_SIZE]));
466            let leaf_mut = if let Node::Leaf(x) = Arc::get_mut(&mut leaf).unwrap() {
467                x
468            } else {
469                unreachable!()
470            };
471            leaf_mut[consumed - 1] = first_non_zero;
472            while consumed < LEAF_SIZE {
473                if let Some(x) = iter.next() {
474                    leaf_mut[consumed] = x;
475                    consumed += 1;
476                } else {
477                    break;
478                }
479            }
480            (consumed, NodePointer(Some(leaf)))
481        }
482
483        let it = &mut iter.into_iter();
484        if self.size < tree_size(self.root_height) {
485            if let Some((offset, first_leaf)) = self.root.locate_leaf(self.root_height, self.size) {
486                for i in offset..LEAF_SIZE {
487                    let Some(x) = it.next() else { return };
488                    first_leaf[i] = x;
489                    self.size += 1;
490                }
491                assert_eq!(self.size % LEAF_SIZE, 0);
492            }
493        } else {
494            assert_eq!(self.size % LEAF_SIZE, 0);
495        }
496        loop {
497            let in_leaf_offset = self.size % LEAF_SIZE;
498            let (consumed, leaf) = generate_leaf(in_leaf_offset, it);
499            let old_size = self.size;
500            self.size = old_size - in_leaf_offset + consumed;
501            self.grow_height_until(self.size);
502            if leaf.0.is_some() {
503                self.root
504                    .put_leaf(self.root_height, old_size - in_leaf_offset, leaf);
505            }
506            if consumed < LEAF_SIZE {
507                return;
508            }
509            assert_eq!(self.size % LEAF_SIZE, 0);
510        }
511    }
512}
513
514impl FromIterator<u8> for SnapBuf {
515    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
516        let mut iter = iter.into_iter();
517        let mut ret = Self::new();
518        ret.extend(&mut iter);
519        ret
520    }
521}