1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use hashbrown::HashMap;
use std::{num::NonZeroU64, sync::Arc};

/// Buffer storage for wgpu.
pub struct WgpuStorage {
    memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
    deallocations: Vec<StorageId>,
    device: Arc<wgpu::Device>,
}

impl core::fmt::Debug for WgpuStorage {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
    }
}

/// The memory resource that can be allocated for wgpu.
#[derive(new, Debug)]
pub struct WgpuResource {
    /// The wgpu buffer.
    pub buffer: Arc<wgpu::Buffer>,
    /// How the resource is used.
    pub kind: WgpuResourceKind,
}

impl WgpuResource {
    /// Return the binding view of the buffer.
    pub fn as_binding(&self) -> wgpu::BindingResource {
        let binding = match &self.kind {
            WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(),
            WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding {
                buffer: &self.buffer,
                offset: *offs,
                size: Some(*size),
            },
        };
        wgpu::BindingResource::Buffer(binding)
    }

    /// Return the buffer size.
    pub fn size(&self) -> u64 {
        match self.kind {
            WgpuResourceKind::Full => self.buffer.size(),
            WgpuResourceKind::Slice(_, size) => size.get(),
        }
    }

    /// Return the buffer offset.
    pub fn offset(&self) -> u64 {
        match self.kind {
            WgpuResourceKind::Full => 0,
            WgpuResourceKind::Slice(offset, _) => offset,
        }
    }
}

/// How the resource is used, either as a slice or fully.
#[derive(Debug)]
pub enum WgpuResourceKind {
    /// Represents an entire buffer.
    Full,
    /// A slice over a buffer.
    Slice(wgpu::BufferAddress, wgpu::BufferSize),
}

/// Keeps actual wgpu buffer references in a hashmap with ids as key.
impl WgpuStorage {
    /// Create a new storage on the given [device](wgpu::Device).
    pub fn new(device: Arc<wgpu::Device>) -> Self {
        Self {
            memory: HashMap::new(),
            deallocations: Vec::new(),
            device,
        }
    }

    /// Actually deallocates buffers tagged to be deallocated.
    pub fn perform_deallocations(&mut self) {
        for id in self.deallocations.drain(..) {
            if let Some(buffer) = self.memory.remove(&id) {
                buffer.destroy()
            }
        }
    }
}

impl ComputeStorage for WgpuStorage {
    type Resource = WgpuResource;

    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
        let buffer = self.memory.get(&handle.id).unwrap();

        match handle.utilization {
            StorageUtilization::Full(_) => {
                WgpuResource::new(buffer.clone(), WgpuResourceKind::Full)
            }
            StorageUtilization::Slice(offset, size) => WgpuResource::new(
                buffer.clone(),
                WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()),
            ),
        }
    }

    fn alloc(&mut self, size: usize) -> StorageHandle {
        let id = StorageId::new();
        let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor {
            label: None,
            size: size as u64,
            usage: wgpu::BufferUsages::COPY_DST
                | wgpu::BufferUsages::STORAGE
                | wgpu::BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        }));

        self.memory.insert(id.clone(), buffer);

        StorageHandle::new(id, StorageUtilization::Full(size))
    }

    fn dealloc(&mut self, id: StorageId) {
        self.deallocations.push(id);
    }
}