1use core::sync::atomic::{AtomicUsize, Ordering};
3use crabslab::{Array, Id, SlabItem};
4use rustc_hash::{FxHashMap, FxHashSet};
5use snafu::prelude::*;
6use std::{
7 hash::Hash,
8 ops::Deref,
9 sync::{atomic::AtomicBool, Arc, RwLock},
10};
11
12use crate::{
13 range::{Range, RangeManager},
14 runtime::{IsRuntime, SlabUpdate},
15 value::{Hybrid, HybridArray, WeakGpuRef},
16};
17
18#[cfg(feature = "wgpu")]
19mod wgpu_slab;
20
21#[derive(Debug, Snafu)]
22#[snafu(visibility(pub(crate)))]
23pub enum SlabAllocatorError {
24 #[snafu(display(
25 "Slab has no internal buffer. Please call SlabAllocator::commit or \
26 SlabAllocator::get_updated_buffer first."
27 ))]
28 NoInternalBuffer,
29
30 #[snafu(display("Async recv error: {source}"))]
31 AsyncRecv { source: async_channel::RecvError },
32
33 #[cfg(feature = "wgpu")]
34 #[snafu(display("Async error: {source}"))]
35 Async { source: wgpu::BufferAsyncError },
36
37 #[cfg(feature = "wgpu")]
38 #[snafu(display("Poll error: {source}"))]
39 Poll { source: wgpu::PollError },
40
41 #[snafu(display("{source}"))]
42 Other { source: Box<dyn std::error::Error> },
43}
44
45pub struct SlabBuffer<T> {
53 slab_commit_invocation_k: Arc<AtomicUsize>,
55 slab_invalidation_k: Arc<AtomicUsize>,
57 buffer_creation_k: usize,
59 buffer: Arc<T>,
61 source_slab_buffer: Arc<RwLock<Option<SlabBuffer<T>>>>,
63}
64
65impl<T> Clone for SlabBuffer<T> {
66 fn clone(&self) -> Self {
67 Self {
68 slab_commit_invocation_k: self.slab_commit_invocation_k.clone(),
69 slab_invalidation_k: self.slab_invalidation_k.clone(),
70 buffer_creation_k: self.buffer_creation_k,
71 buffer: self.buffer.clone(),
72 source_slab_buffer: self.source_slab_buffer.clone(),
73 }
74 }
75}
76
77impl<T> Deref for SlabBuffer<T> {
78 type Target = T;
79
80 fn deref(&self) -> &Self::Target {
81 &self.buffer
82 }
83}
84
85impl<T> SlabBuffer<T> {
86 fn new(
87 invalidation_k: Arc<AtomicUsize>,
88 invocation_k: Arc<AtomicUsize>,
89 buffer: T,
90 source_slab_buffer: Arc<RwLock<Option<SlabBuffer<T>>>>,
91 ) -> Self {
92 SlabBuffer {
93 buffer: buffer.into(),
94 buffer_creation_k: invalidation_k.load(std::sync::atomic::Ordering::Relaxed),
95 slab_invalidation_k: invalidation_k,
96 slab_commit_invocation_k: invocation_k,
97 source_slab_buffer,
98 }
99 }
100
101 pub(crate) fn invalidation_k(&self) -> usize {
102 self.slab_invalidation_k
103 .load(std::sync::atomic::Ordering::Relaxed)
104 }
105
106 pub(crate) fn invocation_k(&self) -> usize {
107 self.slab_commit_invocation_k
108 .load(std::sync::atomic::Ordering::Relaxed)
109 }
110
111 pub fn creation_time(&self) -> usize {
116 self.buffer_creation_k
117 }
118
119 pub fn is_invalid(&self) -> bool {
122 self.creation_time() < self.invalidation_k()
123 }
124
125 pub fn is_valid(&self) -> bool {
128 !self.is_invalid()
129 }
130
131 pub fn is_new_this_commit(&self) -> bool {
144 self.invocation_k() == self.buffer_creation_k
145 }
146
147 #[deprecated(since = "0.1.5", note = "please use `is_new_this_commit` instead")]
148 pub fn is_new_this_upkeep(&self) -> bool {
149 self.is_new_this_commit()
150 }
151
152 pub fn update_if_invalid(&mut self) -> bool {
202 if self.is_invalid() {
203 let updated_buffer = {
206 let guard = self.source_slab_buffer.read().unwrap();
207 guard.as_ref().unwrap().clone()
208 };
209 debug_assert!(updated_buffer.is_valid());
210 *self = updated_buffer;
211 true
212 } else {
213 false
214 }
215 }
216
217 #[deprecated(since = "0.1.5", note = "please use `update_if_invalid` instead")]
218 pub fn synchronize(&mut self) -> bool {
219 self.update_if_invalid()
220 }
221}
222
223#[derive(Clone, Copy, Debug)]
225pub struct SourceId {
226 pub key: usize,
227 pub type_is: &'static str,
229}
230
231impl core::fmt::Display for SourceId {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 f.write_fmt(format_args!("{}({})", self.type_is, self.key))
234 }
235}
236
237impl PartialEq for SourceId {
238 fn eq(&self, other: &Self) -> bool {
239 self.key.eq(&other.key)
240 }
241}
242
243impl Eq for SourceId {}
244
245impl PartialOrd for SourceId {
246 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
247 Some(self.key.cmp(&other.key))
248 }
249}
250
251impl Ord for SourceId {
252 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
253 self.key.cmp(&other.key)
254 }
255}
256
257impl Hash for SourceId {
258 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
259 self.key.hash(state)
260 }
261}
262
263pub struct SlabAllocator<Runtime: IsRuntime> {
271 pub(crate) notifier: (
272 async_channel::Sender<SourceId>,
273 async_channel::Receiver<SourceId>,
274 ),
275 runtime: Runtime,
276 label: Arc<String>,
277 len: Arc<AtomicUsize>,
278 capacity: Arc<AtomicUsize>,
279 needs_expansion: Arc<AtomicBool>,
280 buffer: Arc<RwLock<Option<SlabBuffer<Runtime::Buffer>>>>,
281 buffer_usages: Runtime::BufferUsages,
282 invalidation_k: Arc<AtomicUsize>,
284 invocation_k: Arc<AtomicUsize>,
286 pub(crate) update_k: Arc<AtomicUsize>,
288 pub(crate) update_sources: Arc<RwLock<FxHashMap<SourceId, WeakGpuRef>>>,
290 update_queue: Arc<RwLock<FxHashSet<SourceId>>>,
292 pub(crate) recycles: Arc<RwLock<RangeManager<Range>>>,
294}
295
296impl<R: IsRuntime> Clone for SlabAllocator<R> {
297 fn clone(&self) -> Self {
298 SlabAllocator {
299 runtime: self.runtime.clone(),
300 notifier: self.notifier.clone(),
301 label: self.label.clone(),
302 len: self.len.clone(),
303 capacity: self.capacity.clone(),
304 needs_expansion: self.needs_expansion.clone(),
305 buffer: self.buffer.clone(),
306 buffer_usages: self.buffer_usages.clone(),
307 invalidation_k: self.invalidation_k.clone(),
308 invocation_k: self.invocation_k.clone(),
309 update_k: self.update_k.clone(),
310 update_sources: self.update_sources.clone(),
311 update_queue: self.update_queue.clone(),
312 recycles: self.recycles.clone(),
313 }
314 }
315}
316
317impl<R: IsRuntime> SlabAllocator<R> {
318 pub fn new(
319 runtime: impl AsRef<R>,
320 name: impl AsRef<str>,
321 default_buffer_usages: R::BufferUsages,
322 ) -> Self {
323 let label = Arc::new(name.as_ref().to_owned());
324 Self {
325 runtime: runtime.as_ref().clone(),
326 label,
327 notifier: async_channel::unbounded(),
328 update_k: Default::default(),
329 update_sources: Default::default(),
330 update_queue: Default::default(),
331 recycles: Default::default(),
332 len: Default::default(),
333 capacity: Arc::new(AtomicUsize::new(1)),
336 needs_expansion: Arc::new(true.into()),
337 buffer: Default::default(),
338 buffer_usages: default_buffer_usages,
339 invalidation_k: Default::default(),
340 invocation_k: Default::default(),
341 }
342 }
343
344 pub(crate) fn next_update_k(&self) -> usize {
345 self.update_k.fetch_add(1, Ordering::Relaxed)
346 }
347
348 pub(crate) fn insert_update_source(&self, id: SourceId, source: WeakGpuRef) {
349 log::trace!("{} insert_update_source {id}", self.label);
350 let _ = self.notifier.0.try_send(id);
351 self.update_sources.write().unwrap().insert(id, source);
353 }
354
355 pub fn len(&self) -> usize {
359 self.len.load(Ordering::Relaxed)
360 }
361
362 pub fn is_empty(&self) -> bool {
366 self.len() == 0
367 }
368
369 pub(crate) fn allocate<T: SlabItem>(&self) -> Id<T> {
370 let may_range = self.recycles.write().unwrap().remove(T::SLAB_SIZE as u32);
372 if let Some(range) = may_range {
373 let id = Id::<T>::new(range.first_index);
374 log::trace!(
375 "slab allocate {}: dequeued {range:?} to {id:?}",
376 std::any::type_name::<T>()
377 );
378 debug_assert_eq!(
379 range.last_index,
380 range.first_index + T::SLAB_SIZE as u32 - 1
381 );
382 id
383 } else {
384 self.maybe_expand_to_fit::<T>(1);
385 let index = self.increment_len(T::SLAB_SIZE);
386 Id::from(index)
387 }
388 }
389
390 pub(crate) fn allocate_array<T: SlabItem>(&self, len: usize) -> Array<T> {
391 if len == 0 {
392 return Array::default();
393 }
394
395 let may_range = self
397 .recycles
398 .write()
399 .unwrap()
400 .remove((T::SLAB_SIZE * len) as u32);
401 if let Some(range) = may_range {
402 let array = Array::<T>::new(range.first_index, len as u32);
403 log::trace!(
404 "slab allocate_array {len}x{}: dequeued {range:?} to {array:?}",
405 std::any::type_name::<T>()
406 );
407 debug_assert_eq!(
408 range.last_index,
409 range.first_index + (T::SLAB_SIZE * len) as u32 - 1
410 );
411 array
412 } else {
413 self.maybe_expand_to_fit::<T>(len);
414 let index = self.increment_len(T::SLAB_SIZE * len);
415 Array::new(index as u32, len as u32)
416 }
417 }
418
419 fn capacity(&self) -> usize {
420 self.capacity.load(Ordering::Relaxed)
421 }
422
423 fn reserve_capacity(&self, capacity: usize) {
424 self.capacity.store(capacity, Ordering::Relaxed);
425 self.needs_expansion.store(true, Ordering::Relaxed);
426 }
427
428 fn increment_len(&self, n: usize) -> usize {
429 self.len.fetch_add(n, Ordering::Relaxed)
430 }
431
432 fn maybe_expand_to_fit<T: SlabItem>(&self, len: usize) {
433 let capacity = self.capacity();
434 let capacity_needed = self.len() + T::SLAB_SIZE * len;
439 if capacity_needed > capacity {
440 let mut new_capacity = capacity * 2;
441 while new_capacity < capacity_needed {
442 new_capacity = (new_capacity * 2).max(2);
443 }
444 self.reserve_capacity(new_capacity);
445 }
446 }
447
448 pub fn get_buffer(&self) -> Option<SlabBuffer<R::Buffer>> {
451 self.buffer.read().unwrap().clone()
452 }
453
454 fn recreate_buffer(&self) -> SlabBuffer<R::Buffer> {
457 let new_buffer = self.runtime.buffer_create(
458 self.capacity(),
459 Some(self.label.as_ref()),
460 self.buffer_usages.clone(),
461 );
462 let mut guard = self.buffer.write().unwrap();
463 if let Some(old_buffer) = guard.take() {
464 self.runtime
465 .buffer_copy(&old_buffer, &new_buffer, Some(self.label.as_ref()));
466 }
467 let slab_buffer = SlabBuffer::new(
468 self.invalidation_k.clone(),
469 self.invocation_k.clone(),
470 new_buffer,
471 self.buffer.clone(),
472 );
473 *guard = Some(slab_buffer.clone());
474 slab_buffer
475 }
476
477 pub fn new_value<T: SlabItem + Clone + Send + Sync + 'static>(&self, value: T) -> Hybrid<T> {
479 Hybrid::new(self, value)
480 }
481
482 pub fn new_array<T: SlabItem + Clone + Send + Sync + 'static>(
484 &self,
485 values: impl IntoIterator<Item = T>,
486 ) -> HybridArray<T> {
487 HybridArray::new(self, values)
488 }
489
490 pub fn get_updated_source_ids(&self) -> FxHashSet<SourceId> {
492 let mut update_set = self.update_queue.write().unwrap();
494 while let Ok(source_id) = self.notifier.1.try_recv() {
495 update_set.insert(source_id);
496 }
497 update_set.clone()
498 }
499
500 fn drain_updated_sources(&self) -> RangeManager<SlabUpdate> {
505 let update_set = self.get_updated_source_ids();
506 *self.update_queue.write().unwrap() = Default::default();
508 let mut writes = RangeManager::<SlabUpdate>::default();
510 {
511 let mut updates_guard = self.update_sources.write().unwrap();
514 let mut recycles_guard = self.recycles.write().unwrap();
515 for id in update_set {
516 let delete = if let Some(gpu_ref) = updates_guard.get_mut(&id) {
517 let count = gpu_ref.weak.strong_count();
518 if count == 0 {
519 let array = gpu_ref.u32_array;
521 log::debug!(
522 "{} drain_updated_sources: recycling {id} {array:?}",
523 self.label
524 );
525 if array.is_null() {
526 log::debug!(" cannot recycle, null");
527 } else if array.is_empty() {
528 log::debug!(" cannot recycle, empty");
529 } else {
530 recycles_guard.add_range(gpu_ref.u32_array.into());
531 }
532 true
533 } else {
534 gpu_ref.get_update().into_iter().flatten().for_each(|u| {
535 log::trace!("updating {id} {:?}", u.array);
536 writes.add_range(u)
537 });
538 false
539 }
540 } else {
541 log::debug!("could not find {id}");
542 false
543 };
544 if delete {
545 let _ = updates_guard.remove(&id);
546 }
547 }
548 let ranges = std::mem::take(&mut recycles_guard.ranges);
550 let num_ranges_to_defrag = ranges.len();
551 for range in ranges.into_iter() {
552 recycles_guard.add_range(range);
553 }
554 let num_ranges = recycles_guard.ranges.len();
555 if num_ranges < num_ranges_to_defrag {
556 log::trace!("{num_ranges_to_defrag} ranges before, {num_ranges} after");
557 }
558 }
559
560 writes
561 }
562
563 pub fn has_queued_updates(&self) -> bool {
566 !self.notifier.1.is_empty() || !self.update_queue.read().unwrap().is_empty()
567 }
568
569 pub fn commit(&self) -> SlabBuffer<R::Buffer> {
578 let invocation_k = self
579 .invocation_k
580 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
581 + 1;
582 let buffer = if self.needs_expansion.swap(false, Ordering::Relaxed) {
583 self.invalidation_k
584 .store(invocation_k, std::sync::atomic::Ordering::Relaxed);
585 self.recreate_buffer()
586 } else {
587 self.get_buffer().unwrap()
589 };
590 let writes = self.drain_updated_sources();
591 if !writes.is_empty() {
592 self.runtime
593 .buffer_write(writes.ranges.into_iter(), &buffer);
594 }
595 buffer
596 }
597
598 #[deprecated(since = "0.1.5", note = "please use `commit` instead")]
599 pub fn upkeep(&self) -> SlabBuffer<R::Buffer> {
600 self.commit()
601 }
602
603 pub fn defrag(&self) {
605 let mut recycle_guard = self.recycles.write().unwrap();
607 for range in std::mem::take(&mut recycle_guard.ranges) {
608 recycle_guard.add_range(range);
609 }
610 }
611
612 pub fn runtime(&self) -> &R {
613 &self.runtime
614 }
615}