craballoc/
slab.rs

1//! Slab allocators that run on the CPU.
2use core::sync::atomic::{AtomicUsize, Ordering};
3use crabslab::{Array, Id, SlabItem};
4use rustc_hash::{FxHashMap, FxHashSet};
5use snafu::prelude::*;
6use std::{
7    hash::Hash,
8    ops::Deref,
9    sync::{atomic::AtomicBool, Arc, RwLock},
10};
11
12use crate::{
13    range::{Range, RangeManager},
14    runtime::{IsRuntime, SlabUpdate},
15    value::{Hybrid, HybridArray, WeakGpuRef},
16};
17
18#[cfg(feature = "wgpu")]
19mod wgpu_slab;
20
21#[derive(Debug, Snafu)]
22#[snafu(visibility(pub(crate)))]
23pub enum SlabAllocatorError {
24    #[snafu(display(
25        "Slab has no internal buffer. Please call SlabAllocator::commit or \
26         SlabAllocator::get_updated_buffer first."
27    ))]
28    NoInternalBuffer,
29
30    #[snafu(display("Async recv error: {source}"))]
31    AsyncRecv { source: async_channel::RecvError },
32
33    #[cfg(feature = "wgpu")]
34    #[snafu(display("Async error: {source}"))]
35    Async { source: wgpu::BufferAsyncError },
36
37    #[cfg(feature = "wgpu")]
38    #[snafu(display("Poll error: {source}"))]
39    Poll { source: wgpu::PollError },
40
41    #[snafu(display("{source}"))]
42    Other { source: Box<dyn std::error::Error> },
43}
44
45/// A thin wrapper around a buffer `T` that provides the ability to tell
46/// if the buffer has been invalidated by the [`SlabAllocator`] that it
47/// originated from.
48///
49/// Invalidation happens when the slab resizes. For this reason it is
50/// important to create as many values as necessary _before_ calling
51/// [`SlabAllocator::commit`] to avoid unnecessary invalidation.
52pub struct SlabBuffer<T> {
53    // Id of the slab's last `commit` invocation.
54    slab_commit_invocation_k: Arc<AtomicUsize>,
55    // Id of the slab's last buffer invalidation.
56    slab_invalidation_k: Arc<AtomicUsize>,
57    // The slab's `slab_update_k` at the time of this buffer's creation.
58    buffer_creation_k: usize,
59    // The buffer created at `buffer_creation_k`
60    buffer: Arc<T>,
61    // The buffer the source slab is currently working with
62    source_slab_buffer: Arc<RwLock<Option<SlabBuffer<T>>>>,
63}
64
65impl<T> Clone for SlabBuffer<T> {
66    fn clone(&self) -> Self {
67        Self {
68            slab_commit_invocation_k: self.slab_commit_invocation_k.clone(),
69            slab_invalidation_k: self.slab_invalidation_k.clone(),
70            buffer_creation_k: self.buffer_creation_k,
71            buffer: self.buffer.clone(),
72            source_slab_buffer: self.source_slab_buffer.clone(),
73        }
74    }
75}
76
77impl<T> Deref for SlabBuffer<T> {
78    type Target = T;
79
80    fn deref(&self) -> &Self::Target {
81        &self.buffer
82    }
83}
84
85impl<T> SlabBuffer<T> {
86    fn new(
87        invalidation_k: Arc<AtomicUsize>,
88        invocation_k: Arc<AtomicUsize>,
89        buffer: T,
90        source_slab_buffer: Arc<RwLock<Option<SlabBuffer<T>>>>,
91    ) -> Self {
92        SlabBuffer {
93            buffer: buffer.into(),
94            buffer_creation_k: invalidation_k.load(std::sync::atomic::Ordering::Relaxed),
95            slab_invalidation_k: invalidation_k,
96            slab_commit_invocation_k: invocation_k,
97            source_slab_buffer,
98        }
99    }
100
101    pub(crate) fn invalidation_k(&self) -> usize {
102        self.slab_invalidation_k
103            .load(std::sync::atomic::Ordering::Relaxed)
104    }
105
106    pub(crate) fn invocation_k(&self) -> usize {
107        self.slab_commit_invocation_k
108            .load(std::sync::atomic::Ordering::Relaxed)
109    }
110
111    /// Returns the timestamp at which the internal buffer was created.
112    ///
113    /// The returned timestamp is not a unix timestamp. It is a
114    /// monotonically increasing count of buffer invalidations.
115    pub fn creation_time(&self) -> usize {
116        self.buffer_creation_k
117    }
118
119    /// Determines whether this buffer has been invalidated by the slab
120    /// it originated from.
121    pub fn is_invalid(&self) -> bool {
122        self.creation_time() < self.invalidation_k()
123    }
124
125    /// Determines whether this buffer has been invalidated by the slab
126    /// it originated from.
127    pub fn is_valid(&self) -> bool {
128        !self.is_invalid()
129    }
130
131    /// Returns `true` when the slab's internal buffer has been recreated, and this is that
132    /// newly created buffer.
133    ///
134    /// This will return false if [`SlabAllocator::commit`] has been called since the creation
135    /// of this buffer.
136    ///
137    /// Typically this function is used by structs that own the [`SlabAllocator`]. These owning
138    /// structs will call [`SlabAllocator::commit`] which returns a [`SlabBuffer`]. The callsite
139    /// can then call [`SlabBuffer::is_new_this_commit`] to determine if any
140    /// downstream resources (like bindgroups) need to be recreated.
141    ///
142    /// This pattern keeps the owning struct from having to also store the `SlabBuffer`.
143    pub fn is_new_this_commit(&self) -> bool {
144        self.invocation_k() == self.buffer_creation_k
145    }
146
147    #[deprecated(since = "0.1.5", note = "please use `is_new_this_commit` instead")]
148    pub fn is_new_this_upkeep(&self) -> bool {
149        self.is_new_this_commit()
150    }
151
152    /// Syncronize the buffer with the slab's internal buffer.
153    ///
154    /// This checks to ensure that the internal buffer is the one the slab is working with,
155    /// and updates it if the slab is working with a newer buffer.
156    ///
157    /// Returns `true` if the buffer was updated.
158    /// Returns `false` if the buffer remains the same.
159    ///
160    /// Use the result of this function to invalidate any bind groups or other downstream
161    /// resources.
162    ///
163    /// ## Note
164    /// Be cautious when using this function with multiple buffers to invalidate downstream
165    /// resources. Keep in mind that using the
166    /// [lazy boolean operators](https://doc.rust-lang.org/reference/expressions/operator-expr.html#lazy-boolean-operators)
167    /// might not have the effect you are expecting!
168    ///
169    /// For example:
170    ///
171    /// ```rust,no_run
172    /// use craballoc::prelude::*;
173    ///
174    /// let buffer_a: SlabBuffer<wgpu::Buffer> = todo!();
175    /// let buffer_b: SlabBuffer<wgpu::Buffer> = todo!();
176    /// let buffer_c: SlabBuffer<wgpu::Buffer> = todo!();
177    ///
178    /// let should_invalidate = buffer_a.update_if_invalid()
179    ///     || buffer_b.update_if_invalid()
180    ///     || buffer_c.update_if_invalid();
181    /// ```
182    ///
183    /// If `buffer_a` is invalid, neither `buffer_b` nor `buffer_c` will be synchronized, because
184    /// `||` is lazy in its parameter evaluation.
185    ///
186    /// Instead, we should write the following:
187    ///
188    /// ```rust,no_run
189    /// use craballoc::prelude::*;
190    ///
191    /// let buffer_a: SlabBuffer<wgpu::Buffer> = todo!();
192    /// let buffer_b: SlabBuffer<wgpu::Buffer> = todo!();
193    /// let buffer_c: SlabBuffer<wgpu::Buffer> = todo!();
194    ///
195    /// let buffer_a_invalid = buffer_a.update_if_invalid();
196    /// let buffer_b_invalid = buffer_b.update_if_invalid();
197    /// let buffer_c_invalid = buffer_c.update_if_invalid();
198    ///
199    /// let should_invalidate = buffer_a_invalid || buffer_b_invalid || buffer_c_invalid;
200    /// ```
201    pub fn update_if_invalid(&mut self) -> bool {
202        if self.is_invalid() {
203            // UNWRAP: Safe because it is an invariant of the system. Once the `SlabBuffer`
204            // is created, source_slab_buffer will always be Some.
205            let updated_buffer = {
206                let guard = self.source_slab_buffer.read().unwrap();
207                guard.as_ref().unwrap().clone()
208            };
209            debug_assert!(updated_buffer.is_valid());
210            *self = updated_buffer;
211            true
212        } else {
213            false
214        }
215    }
216
217    #[deprecated(since = "0.1.5", note = "please use `update_if_invalid` instead")]
218    pub fn synchronize(&mut self) -> bool {
219        self.update_if_invalid()
220    }
221}
222
223/// An identifier for a unique source of updates.
224#[derive(Clone, Copy, Debug)]
225pub struct SourceId {
226    pub key: usize,
227    /// This field is just for debugging.
228    pub type_is: &'static str,
229}
230
231impl core::fmt::Display for SourceId {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        f.write_fmt(format_args!("{}({})", self.type_is, self.key))
234    }
235}
236
237impl PartialEq for SourceId {
238    fn eq(&self, other: &Self) -> bool {
239        self.key.eq(&other.key)
240    }
241}
242
243impl Eq for SourceId {}
244
245impl PartialOrd for SourceId {
246    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
247        Some(self.key.cmp(&other.key))
248    }
249}
250
251impl Ord for SourceId {
252    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
253        self.key.cmp(&other.key)
254    }
255}
256
257impl Hash for SourceId {
258    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
259        self.key.hash(state)
260    }
261}
262
263/// Manages slab allocations and updates over a parameterised buffer.
264///
265/// Create a new instance using [`SlabAllocator::new`].
266///
267/// Upon creation you will need to call [`SlabAllocator::get_buffer`] or
268/// [`SlabAllocator::commit`] at least once before any data is written to the
269/// internal buffer.
270pub struct SlabAllocator<Runtime: IsRuntime> {
271    pub(crate) notifier: (
272        async_channel::Sender<SourceId>,
273        async_channel::Receiver<SourceId>,
274    ),
275    runtime: Runtime,
276    label: Arc<String>,
277    len: Arc<AtomicUsize>,
278    capacity: Arc<AtomicUsize>,
279    needs_expansion: Arc<AtomicBool>,
280    buffer: Arc<RwLock<Option<SlabBuffer<Runtime::Buffer>>>>,
281    buffer_usages: Runtime::BufferUsages,
282    // The value of invocation_k when the last buffer invalidation happened
283    invalidation_k: Arc<AtomicUsize>,
284    // The next monotonically increasing commit invocation identifier
285    invocation_k: Arc<AtomicUsize>,
286    // The next monotonically increasing update identifier
287    pub(crate) update_k: Arc<AtomicUsize>,
288    // Weak references to all values that can write updates into this slab
289    pub(crate) update_sources: Arc<RwLock<FxHashMap<SourceId, WeakGpuRef>>>,
290    // Set of ids of the update sources that have updates queued
291    update_queue: Arc<RwLock<FxHashSet<SourceId>>>,
292    // Recycled memory ranges
293    pub(crate) recycles: Arc<RwLock<RangeManager<Range>>>,
294}
295
296impl<R: IsRuntime> Clone for SlabAllocator<R> {
297    fn clone(&self) -> Self {
298        SlabAllocator {
299            runtime: self.runtime.clone(),
300            notifier: self.notifier.clone(),
301            label: self.label.clone(),
302            len: self.len.clone(),
303            capacity: self.capacity.clone(),
304            needs_expansion: self.needs_expansion.clone(),
305            buffer: self.buffer.clone(),
306            buffer_usages: self.buffer_usages.clone(),
307            invalidation_k: self.invalidation_k.clone(),
308            invocation_k: self.invocation_k.clone(),
309            update_k: self.update_k.clone(),
310            update_sources: self.update_sources.clone(),
311            update_queue: self.update_queue.clone(),
312            recycles: self.recycles.clone(),
313        }
314    }
315}
316
317impl<R: IsRuntime> SlabAllocator<R> {
318    pub fn new(
319        runtime: impl AsRef<R>,
320        name: impl AsRef<str>,
321        default_buffer_usages: R::BufferUsages,
322    ) -> Self {
323        let label = Arc::new(name.as_ref().to_owned());
324        Self {
325            runtime: runtime.as_ref().clone(),
326            label,
327            notifier: async_channel::unbounded(),
328            update_k: Default::default(),
329            update_sources: Default::default(),
330            update_queue: Default::default(),
331            recycles: Default::default(),
332            len: Default::default(),
333            // Start with size 1, because some of `wgpu`'s validation depends on it.
334            // See <https://github.com/gfx-rs/wgpu/issues/6414> for more info.
335            capacity: Arc::new(AtomicUsize::new(1)),
336            needs_expansion: Arc::new(true.into()),
337            buffer: Default::default(),
338            buffer_usages: default_buffer_usages,
339            invalidation_k: Default::default(),
340            invocation_k: Default::default(),
341        }
342    }
343
344    pub(crate) fn next_update_k(&self) -> usize {
345        self.update_k.fetch_add(1, Ordering::Relaxed)
346    }
347
348    pub(crate) fn insert_update_source(&self, id: SourceId, source: WeakGpuRef) {
349        log::trace!("{} insert_update_source {id}", self.label);
350        let _ = self.notifier.0.try_send(id);
351        // UNWRAP: panic on purpose
352        self.update_sources.write().unwrap().insert(id, source);
353    }
354
355    /// The length of the underlying buffer, in u32 slots.
356    ///
357    /// This does not include data that has not yet been committed.
358    pub fn len(&self) -> usize {
359        self.len.load(Ordering::Relaxed)
360    }
361
362    /// Whether the underlying buffer is empty.
363    ///
364    /// This does not include data that has not yet been committed.
365    pub fn is_empty(&self) -> bool {
366        self.len() == 0
367    }
368
369    pub(crate) fn allocate<T: SlabItem>(&self) -> Id<T> {
370        // UNWRAP: we want to panic
371        let may_range = self.recycles.write().unwrap().remove(T::SLAB_SIZE as u32);
372        if let Some(range) = may_range {
373            let id = Id::<T>::new(range.first_index);
374            log::trace!(
375                "slab allocate {}: dequeued {range:?} to {id:?}",
376                std::any::type_name::<T>()
377            );
378            debug_assert_eq!(
379                range.last_index,
380                range.first_index + T::SLAB_SIZE as u32 - 1
381            );
382            id
383        } else {
384            self.maybe_expand_to_fit::<T>(1);
385            let index = self.increment_len(T::SLAB_SIZE);
386            Id::from(index)
387        }
388    }
389
390    pub(crate) fn allocate_array<T: SlabItem>(&self, len: usize) -> Array<T> {
391        if len == 0 {
392            return Array::default();
393        }
394
395        // UNWRAP: we want to panic
396        let may_range = self
397            .recycles
398            .write()
399            .unwrap()
400            .remove((T::SLAB_SIZE * len) as u32);
401        if let Some(range) = may_range {
402            let array = Array::<T>::new(range.first_index, len as u32);
403            log::trace!(
404                "slab allocate_array {len}x{}: dequeued {range:?} to {array:?}",
405                std::any::type_name::<T>()
406            );
407            debug_assert_eq!(
408                range.last_index,
409                range.first_index + (T::SLAB_SIZE * len) as u32 - 1
410            );
411            array
412        } else {
413            self.maybe_expand_to_fit::<T>(len);
414            let index = self.increment_len(T::SLAB_SIZE * len);
415            Array::new(index as u32, len as u32)
416        }
417    }
418
419    fn capacity(&self) -> usize {
420        self.capacity.load(Ordering::Relaxed)
421    }
422
423    fn reserve_capacity(&self, capacity: usize) {
424        self.capacity.store(capacity, Ordering::Relaxed);
425        self.needs_expansion.store(true, Ordering::Relaxed);
426    }
427
428    fn increment_len(&self, n: usize) -> usize {
429        self.len.fetch_add(n, Ordering::Relaxed)
430    }
431
432    fn maybe_expand_to_fit<T: SlabItem>(&self, len: usize) {
433        let capacity = self.capacity();
434        // log::trace!(
435        //    "append_slice: {size} * {ts_len} + {len} ({}) >= {capacity}",
436        //    size * ts_len + len
437        //);
438        let capacity_needed = self.len() + T::SLAB_SIZE * len;
439        if capacity_needed > capacity {
440            let mut new_capacity = capacity * 2;
441            while new_capacity < capacity_needed {
442                new_capacity = (new_capacity * 2).max(2);
443            }
444            self.reserve_capacity(new_capacity);
445        }
446    }
447
448    /// Return the internal buffer used by this slab, if it has
449    /// been created.
450    pub fn get_buffer(&self) -> Option<SlabBuffer<R::Buffer>> {
451        self.buffer.read().unwrap().clone()
452    }
453
454    /// Recreate the internal buffer, writing the contents of the previous buffer (if it
455    /// exists) to the new one, then return the new buffer.
456    fn recreate_buffer(&self) -> SlabBuffer<R::Buffer> {
457        let new_buffer = self.runtime.buffer_create(
458            self.capacity(),
459            Some(self.label.as_ref()),
460            self.buffer_usages.clone(),
461        );
462        let mut guard = self.buffer.write().unwrap();
463        if let Some(old_buffer) = guard.take() {
464            self.runtime
465                .buffer_copy(&old_buffer, &new_buffer, Some(self.label.as_ref()));
466        }
467        let slab_buffer = SlabBuffer::new(
468            self.invalidation_k.clone(),
469            self.invocation_k.clone(),
470            new_buffer,
471            self.buffer.clone(),
472        );
473        *guard = Some(slab_buffer.clone());
474        slab_buffer
475    }
476
477    /// Stage a new value that lives on the GPU _and_ CPU.
478    pub fn new_value<T: SlabItem + Clone + Send + Sync + 'static>(&self, value: T) -> Hybrid<T> {
479        Hybrid::new(self, value)
480    }
481
482    /// Stage a contiguous array of new values that live on the GPU _and_ CPU.
483    pub fn new_array<T: SlabItem + Clone + Send + Sync + 'static>(
484        &self,
485        values: impl IntoIterator<Item = T>,
486    ) -> HybridArray<T> {
487        HybridArray::new(self, values)
488    }
489
490    /// Return the ids of all sources that require updating.
491    pub fn get_updated_source_ids(&self) -> FxHashSet<SourceId> {
492        // UNWRAP: panic on purpose
493        let mut update_set = self.update_queue.write().unwrap();
494        while let Ok(source_id) = self.notifier.1.try_recv() {
495            update_set.insert(source_id);
496        }
497        update_set.clone()
498    }
499
500    /// Build the set of sources that require updates, draining the source
501    /// notifier and resetting the stored `update_queue`.
502    ///
503    /// This also places recycled items into the recycle bin.
504    fn drain_updated_sources(&self) -> RangeManager<SlabUpdate> {
505        let update_set = self.get_updated_source_ids();
506        // UNWRAP: panic on purpose
507        *self.update_queue.write().unwrap() = Default::default();
508        // Prepare all of our GPU buffer writes
509        let mut writes = RangeManager::<SlabUpdate>::default();
510        {
511            // Recycle any update sources that are no longer needed, and collect the active
512            // sources' updates into `writes`.
513            let mut updates_guard = self.update_sources.write().unwrap();
514            let mut recycles_guard = self.recycles.write().unwrap();
515            for id in update_set {
516                let delete = if let Some(gpu_ref) = updates_guard.get_mut(&id) {
517                    let count = gpu_ref.weak.strong_count();
518                    if count == 0 {
519                        // recycle this allocation
520                        let array = gpu_ref.u32_array;
521                        log::debug!(
522                            "{} drain_updated_sources: recycling {id} {array:?}",
523                            self.label
524                        );
525                        if array.is_null() {
526                            log::debug!("  cannot recycle, null");
527                        } else if array.is_empty() {
528                            log::debug!("  cannot recycle, empty");
529                        } else {
530                            recycles_guard.add_range(gpu_ref.u32_array.into());
531                        }
532                        true
533                    } else {
534                        gpu_ref.get_update().into_iter().flatten().for_each(|u| {
535                            log::trace!("updating {id} {:?}", u.array);
536                            writes.add_range(u)
537                        });
538                        false
539                    }
540                } else {
541                    log::debug!("could not find {id}");
542                    false
543                };
544                if delete {
545                    let _ = updates_guard.remove(&id);
546                }
547            }
548            // Defrag the recycle ranges
549            let ranges = std::mem::take(&mut recycles_guard.ranges);
550            let num_ranges_to_defrag = ranges.len();
551            for range in ranges.into_iter() {
552                recycles_guard.add_range(range);
553            }
554            let num_ranges = recycles_guard.ranges.len();
555            if num_ranges < num_ranges_to_defrag {
556                log::trace!("{num_ranges_to_defrag} ranges before, {num_ranges} after");
557            }
558        }
559
560        writes
561    }
562
563    /// Returns whether any update sources, most likely from [`Hybrid`] or [`Gpu`](crate::value::Gpu) values,
564    /// have queued updates waiting to be committed.
565    pub fn has_queued_updates(&self) -> bool {
566        !self.notifier.1.is_empty() || !self.update_queue.read().unwrap().is_empty()
567    }
568
569    /// Perform upkeep on the slab, synchronizing changes to the internal buffer.
570    ///
571    /// Changes made to [`Hybrid`] and [`Gpu`](crate::value::Gpu) values created by this slab are not committed
572    /// until this function has been called.
573    ///
574    /// The internal buffer is not created until the first time this function is called.
575    ///
576    /// Returns a [`SlabBuffer`] wrapping the internal buffer that is currently in use by the allocator.
577    pub fn commit(&self) -> SlabBuffer<R::Buffer> {
578        let invocation_k = self
579            .invocation_k
580            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
581            + 1;
582        let buffer = if self.needs_expansion.swap(false, Ordering::Relaxed) {
583            self.invalidation_k
584                .store(invocation_k, std::sync::atomic::Ordering::Relaxed);
585            self.recreate_buffer()
586        } else {
587            // UNWRAP: Safe because we know it exists or else it would need expansion
588            self.get_buffer().unwrap()
589        };
590        let writes = self.drain_updated_sources();
591        if !writes.is_empty() {
592            self.runtime
593                .buffer_write(writes.ranges.into_iter(), &buffer);
594        }
595        buffer
596    }
597
598    #[deprecated(since = "0.1.5", note = "please use `commit` instead")]
599    pub fn upkeep(&self) -> SlabBuffer<R::Buffer> {
600        self.commit()
601    }
602
603    /// Defragments the internal "recycle" buffer.
604    pub fn defrag(&self) {
605        // UNWRAP: panic on purpose
606        let mut recycle_guard = self.recycles.write().unwrap();
607        for range in std::mem::take(&mut recycle_guard.ranges) {
608            recycle_guard.add_range(range);
609        }
610    }
611
612    pub fn runtime(&self) -> &R {
613        &self.runtime
614    }
615}