1#![no_std]
2extern crate alloc;
9#[cfg(feature = "test")]
10extern crate std;
11
12use alloc::sync::Arc;
13use core::cmp::Ordering;
14use core::ops::Range;
15use core::{iter, mem, slice};
16use smallvec::SmallVec;
17
18#[derive(Debug)]
19pub struct SnapBuf {
20 size: usize,
21 root_height: usize,
22 root: NodePointer,
23}
24
25const LEAF_SIZE: usize = if cfg!(feature = "test") { 32 } else { 4000 };
26const INNER_SIZE: usize = if cfg!(feature = "test") { 4 } else { 500 };
27
28#[cfg(feature = "test")]
29pub mod test;
30
31#[derive(Clone, Debug)]
32enum Node {
33 Inner([NodePointer; INNER_SIZE]),
34 Leaf([u8; LEAF_SIZE]),
35}
36
37#[derive(Clone, Debug)]
38struct NodePointer(Option<Arc<Node>>);
39
40macro_rules! deconstruct_range{
41 {$start:ident .. $end:ident = $range:expr,$height:expr} => {
42 let $start = $range.start;
43 let $end = $range.end;
44 debug_assert!($start < tree_size($height) as isize);
46 debug_assert!($end > 0);
47 }
48}
49
50fn range_all<T, const C: usize>(x: &[T; C], mut f: impl FnMut(&T) -> bool) -> bool {
51 let last = f(x.last().unwrap());
52 last && x[0..C - 1].iter().all(f)
53}
54
55impl NodePointer {
56 fn children(&self) -> Option<&[NodePointer; INNER_SIZE]> {
57 match &**(self.0.as_ref()?) {
58 Node::Inner(x) => Some(x),
59 Node::Leaf(_) => None,
60 }
61 }
62
63 fn get_mut(&mut self, height: usize) -> &mut Node {
64 let arc = self.0.get_or_insert_with(|| {
65 Arc::new({
66 if height == 0 {
67 Node::Leaf([0; LEAF_SIZE])
68 } else {
69 Node::Inner(array_init::array_init(|_| NodePointer(None)))
70 }
71 })
72 });
73 Arc::make_mut(arc)
74 }
75
76 fn set_range<const FREE_ZEROS: bool>(&mut self, height: usize, start: isize, values: &[u8]) {
77 deconstruct_range!(start..end = start .. start + values.len() as isize ,height);
78 match self.get_mut(height) {
79 Node::Inner(children) => {
80 for (child_offset, child) in
81 Self::affected_children(children, height - 1, start..end)
82 {
83 child.set_range::<FREE_ZEROS>(height - 1, start - child_offset, values);
84 }
85 if FREE_ZEROS && range_all(children, |c| c.0.is_none()) {
86 self.0 = None;
87 }
88 }
89 Node::Leaf(bytes) => {
90 let (src, dst) = if start < 0 {
91 (&values[-start as usize..], &mut bytes[..])
92 } else {
93 (values, &mut bytes[start as usize..])
94 };
95 let len = src.len().min(dst.len());
96 dst[..len].copy_from_slice(&src[..len]);
97 if FREE_ZEROS && range_all(bytes, |b| *b == 0) {
98 self.0 = None;
99 }
100 }
101 }
102 }
103
104 fn affected_children(
105 children: &mut [NodePointer; INNER_SIZE],
106 child_height: usize,
107 range: Range<isize>,
108 ) -> impl Iterator<Item = (isize, &mut NodePointer)> {
109 let start = range.start.max(0) as usize;
110 let child_size = tree_size(child_height);
111 children
112 .iter_mut()
113 .enumerate()
114 .skip(start / child_size)
115 .map(move |(i, c)| ((i * child_size) as isize, c))
116 .take_while(move |(offset, _)| (*offset) < range.end)
117 }
118
119 fn fill_range(&mut self, height: usize, range: Range<isize>, value: u8) {
120 deconstruct_range!(start..end=range,height);
121 match self.get_mut(height) {
122 Node::Inner(children) => {
123 for (child_offset, child) in
124 Self::affected_children(children, height - 1, range.clone())
125 {
126 child.fill_range(height - 1, start - child_offset..end - child_offset, value);
127 }
128 }
129 Node::Leaf(bytes) => {
130 let write_start = start.max(0) as usize;
131 let write_end = (end as usize).min(bytes.len());
132 bytes[write_start..write_end].fill(value);
133 }
134 }
135 }
136
137 fn clear_range(&mut self, height: usize, range: Range<isize>) {
138 deconstruct_range!(start..end = range,height);
139 if start <= 0 && end as usize >= tree_size(height) || self.0.is_none() {
140 self.0 = None;
141 return;
142 }
143 match self.get_mut(height) {
144 Node::Inner(children) => {
145 for (child_offset, child) in
146 Self::affected_children(children, height - 1, range.clone())
147 {
148 child.clear_range(height - 1, start - child_offset..end - child_offset);
149 }
150 if range_all(children, |c| c.0.is_none()) {
151 self.0 = None;
152 }
153 }
154 Node::Leaf(bytes) => {
155 let write_start = start.max(0) as usize;
156 let write_end = (end as usize).min(bytes.len());
157 bytes[write_start..write_end].fill(0);
158 if range_all(bytes, |b| *b == 0) {
159 self.0 = None;
160 }
161 }
162 }
163 }
164
165 fn put_leaf(&mut self, height: usize, offset: usize, leaf: NodePointer) {
166 match self.get_mut(height) {
167 Node::Inner(children) => {
168 let range = offset as isize..offset as isize + 1;
169 let (co, c) = Self::affected_children(children, height - 1, range)
170 .next()
171 .unwrap();
172 c.put_leaf(height - 1, offset - co as usize, leaf);
173 }
174 Node::Leaf(_) => {
175 debug_assert_eq!(offset, 0);
176 *self = leaf;
177 }
178 }
179 }
180
181 fn locate_leaf(
182 &mut self,
183 height: usize,
184 offset: usize,
185 ) -> Option<(usize, &mut [u8; LEAF_SIZE])> {
186 self.0.as_ref()?;
187 match self.get_mut(height) {
188 Node::Inner(children) => {
189 let range = offset as isize..offset as isize + 1;
190 let (co, c) = Self::affected_children(children, height - 1, range)
191 .next()
192 .unwrap();
193 c.locate_leaf(height - 1, offset - co as usize)
194 }
195 Node::Leaf(x) => Some((offset, x)),
196 }
197 }
198}
199
200const fn const_tree_size(height: usize) -> usize {
201 if height == 0 {
202 LEAF_SIZE
203 } else {
204 INNER_SIZE * const_tree_size(height - 1)
205 }
206}
207
208fn tree_size(height: usize) -> usize {
209 const_tree_size(height)
210}
211
212impl Default for SnapBuf {
213 fn default() -> Self {
214 Self::new()
215 }
216}
217
218impl SnapBuf {
219 pub fn new() -> Self {
221 Self {
222 root_height: 0,
223 size: 0,
224 root: NodePointer(None),
225 }
226 }
227
228 fn shrink(&mut self, new_len: usize) {
229 self.root.clear_range(
230 self.root_height,
231 new_len as isize..tree_size(self.root_height) as isize,
232 );
233 self.size = new_len;
234 }
235
236 fn grow_height_until(&mut self, min_size: usize) {
237 while tree_size(self.root_height) < min_size {
238 if self.root.0.is_some() {
239 let new_root = Arc::new(Node::Inner(array_init::array_init(|x| {
240 if x == 0 {
241 self.root.clone()
242 } else {
243 NodePointer(None)
244 }
245 })));
246 self.root = NodePointer(Some(new_root.clone()));
247 }
248 self.root_height += 1;
249 }
250 }
251
252 fn grow_zero(&mut self, new_len: usize) {
253 self.grow_height_until(new_len);
254 self.size = new_len;
255 }
256
257 #[inline]
262 pub fn resize(&mut self, new_len: usize, value: u8) {
263 match new_len.cmp(&self.size) {
264 Ordering::Less => {
265 self.shrink(new_len);
266 }
267 Ordering::Equal => {}
268 Ordering::Greater => {
269 let old_len = self.size;
270 self.grow_zero(new_len);
271 if value != 0 {
272 self.fill_range(old_len..new_len, value);
273 }
274 }
275 }
276 }
277
278 pub fn truncate(&mut self, new_len: usize) {
282 if new_len < self.size {
283 self.shrink(new_len);
284 }
285 }
286
287 pub fn fill_range(&mut self, range: Range<usize>, value: u8) {
293 if self.size < range.end {
294 self.grow_zero(range.end);
295 }
296 if range.is_empty() {
297 return;
298 }
299 let range = range.start as isize..range.end as isize;
300 self.root.fill_range(self.root_height, range, value);
301 }
302
303 pub fn write(&mut self, offset: usize, data: &[u8]) {
312 self.write_inner::<false>(offset, data);
313 }
314
315 pub fn write_with_zeros(&mut self, offset: usize, data: &[u8]) {
317 self.write_inner::<true>(offset, data);
318 }
319
320 fn write_inner<const FREE_ZEROS: bool>(&mut self, offset: usize, data: &[u8]) {
321 let write_end = offset + data.len();
322 if self.size < write_end {
323 self.resize(write_end, 0);
324 }
325 if data.is_empty() {
326 return;
327 }
328 self.root
329 .set_range::<FREE_ZEROS>(self.root_height, offset as isize, data);
330 }
331
332 pub fn is_empty(&self) -> bool {
334 self.len() == 0
335 }
336
337 pub fn len(&self) -> usize {
341 self.size
342 }
343
344 pub fn clear_range(&mut self, range: Range<usize>) {
349 assert!(range.end <= self.size);
350 if range.is_empty() {
351 return;
352 }
353 self.root
354 .clear_range(self.root_height, range.start as isize..range.end as isize);
355 }
356
357 pub fn clear(&mut self) {
359 *self = Self::new();
360 }
361
362 fn iter_nodes_pre_order(&self) -> impl Iterator<Item = (&NodePointer, usize)> {
363 struct IterStack<'a> {
364 stack_end_height: usize,
365 stack: SmallVec<[&'a [NodePointer]; 5]>,
366 }
367
368 #[allow(clippy::needless_lifetimes)]
369 fn split_first_in_place<'x, 's, T>(x: &'x mut &'s [T]) -> &'s T {
370 let (first, rest) = mem::take(x).split_first().unwrap();
371 *x = rest;
372 first
373 }
374
375 impl<'a> Iterator for IterStack<'a> {
376 type Item = (&'a NodePointer, usize);
377
378 fn next(&mut self) -> Option<Self::Item> {
379 let visit_now = loop {
380 let last_level = self.stack.last_mut()?;
381 if last_level.is_empty() {
382 self.stack.pop();
383 self.stack_end_height += 1;
384 } else {
385 break split_first_in_place(last_level);
386 }
387 };
388 let ret = (visit_now, self.stack_end_height);
389 if let Some(children) = visit_now.children() {
390 self.stack.push(children);
391 self.stack_end_height -= 1;
392 }
393 Some(ret)
394 }
395 }
396
397 let mut stack = SmallVec::new();
398 stack.push(slice::from_ref(&self.root));
399 IterStack {
400 stack_end_height: self.root_height,
401 stack,
402 }
403 }
404
405 pub fn chunks(&self) -> impl Iterator<Item = &[u8]> {
409 let mut emitted = 0;
410 self.iter_nodes_pre_order()
411 .flat_map(|(node, height)| {
412 let zero_leaf = &[0u8; LEAF_SIZE];
413 match node.0.as_deref() {
414 None => {
415 let leaf_count = INNER_SIZE.pow(height as u32);
416 iter::repeat_n(zero_leaf, leaf_count)
417 }
418 Some(Node::Inner(_)) => iter::repeat_n(zero_leaf, 0),
419 Some(Node::Leaf(b)) => iter::repeat_n(b, 1),
420 }
421 })
422 .map(move |x| {
423 let emit = (self.size - emitted).min(x.len());
424 emitted += emit;
425 &x[..emit]
426 })
427 .filter(|x| !x.is_empty())
428 }
429
430 pub fn iter(&self) -> impl Iterator<Item = &u8> {
432 self.chunks().flat_map(|x| x.iter())
433 }
434
435 #[doc(hidden)]
436 pub fn bytes(&self) -> impl Iterator<Item = u8> + '_ {
437 self.iter().copied()
438 }
439
440 pub fn extend_from_slice(&mut self, data: &[u8]) {
441 self.write_with_zeros(self.size, data)
442 }
443}
444
445impl Extend<u8> for SnapBuf {
446 fn extend<T: IntoIterator<Item = u8>>(&mut self, iter: T) {
447 fn generate_leaf(
448 start_at: usize,
449 iter: &mut impl Iterator<Item = u8>,
450 ) -> (usize, NodePointer) {
451 let mut consumed = start_at;
452 let first_non_zero = loop {
453 if let Some(x) = iter.next() {
454 consumed += 1;
455 if x != 0 {
456 break x;
457 }
458 } else {
459 return (consumed, NodePointer(None));
460 }
461 if consumed == LEAF_SIZE {
462 return (LEAF_SIZE, NodePointer(None));
463 }
464 };
465 let mut leaf = Arc::new(Node::Leaf([0u8; LEAF_SIZE]));
466 let leaf_mut = if let Node::Leaf(x) = Arc::get_mut(&mut leaf).unwrap() {
467 x
468 } else {
469 unreachable!()
470 };
471 leaf_mut[consumed - 1] = first_non_zero;
472 while consumed < LEAF_SIZE {
473 if let Some(x) = iter.next() {
474 leaf_mut[consumed] = x;
475 consumed += 1;
476 } else {
477 break;
478 }
479 }
480 (consumed, NodePointer(Some(leaf)))
481 }
482
483 let it = &mut iter.into_iter();
484 if self.size < tree_size(self.root_height) {
485 if let Some((offset, first_leaf)) = self.root.locate_leaf(self.root_height, self.size) {
486 for i in offset..LEAF_SIZE {
487 let Some(x) = it.next() else { return };
488 first_leaf[i] = x;
489 self.size += 1;
490 }
491 assert_eq!(self.size % LEAF_SIZE, 0);
492 }
493 } else {
494 assert_eq!(self.size % LEAF_SIZE, 0);
495 }
496 loop {
497 let in_leaf_offset = self.size % LEAF_SIZE;
498 let (consumed, leaf) = generate_leaf(in_leaf_offset, it);
499 let old_size = self.size;
500 self.size = old_size - in_leaf_offset + consumed;
501 self.grow_height_until(self.size);
502 if leaf.0.is_some() {
503 self.root
504 .put_leaf(self.root_height, old_size - in_leaf_offset, leaf);
505 }
506 if consumed < LEAF_SIZE {
507 return;
508 }
509 assert_eq!(self.size % LEAF_SIZE, 0);
510 }
511 }
512}
513
514impl FromIterator<u8> for SnapBuf {
515 fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
516 let mut iter = iter.into_iter();
517 let mut ret = Self::new();
518 ret.extend(&mut iter);
519 ret
520 }
521}