1use encase::{
2 self,
3 private::{AlignmentValue, Writer},
4};
5use wgpu;
6
7#[cfg(feature = "burn-torch")]
8use crate::error::CudaInteropError;
9#[cfg(feature = "burn-torch")]
10use cust_raw;
11
12#[cfg(feature = "burn-torch")]
13use std::sync::Arc;
14#[cfg(feature = "burn-torch")]
15use tch::Tensor;
16#[cfg(feature = "burn-torch")]
17use wgpu_cuda_interop::{vulkan_wgpu_interop::WgpuBufferCudaMem, AllocSize};
18
19#[cfg(feature = "burn-torch")]
20use log::debug;
21
22#[derive(Clone)]
25pub struct Buffer {
26 pub buffer: wgpu::Buffer,
27 pub size_bytes: usize,
28 cpu_byte_buffer: Vec<u8>,
29 offset: usize,
30 alignment: AlignmentValue,
31
32 #[cfg(feature = "burn-torch")]
34 pub staging_buffer_backed_by_cuda_mem: Option<Arc<WgpuBufferCudaMem>>,
35}
36
37impl Buffer {
38 pub fn new_empty(device: &wgpu::Device, usage: wgpu::BufferUsages, label: wgpu::Label, size_bytes: usize) -> Self {
39 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
40 label,
41 size: size_bytes as u64,
42 usage,
43 mapped_at_creation: false,
44 });
45
46 let cpu_byte_buffer = Vec::new();
47
48 Self {
49 buffer,
50 size_bytes,
51 cpu_byte_buffer,
53 offset: 0,
54 alignment: AlignmentValue::new(256),
55 #[cfg(feature = "burn-torch")]
56 staging_buffer_backed_by_cuda_mem: None,
57 }
58 }
59
60 pub fn new_from_buffer(buffer: wgpu::Buffer) -> Self {
61 let size_bytes = usize::try_from(buffer.size()).unwrap();
62 let cpu_byte_buffer = Vec::new();
63
64 Self {
65 buffer,
66 size_bytes,
67 cpu_byte_buffer,
69 offset: 0,
70 alignment: AlignmentValue::new(256),
71 #[cfg(feature = "burn-torch")]
72 staging_buffer_backed_by_cuda_mem: None,
73 }
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.size_bytes == 0
78 }
79
80 pub fn write_buffer(&mut self, queue: &wgpu::Queue, offset_bytes: usize, data: &[u8]) {
81 queue.write_buffer(&self.buffer, offset_bytes as u64, bytemuck::cast_slice(data));
82 }
83
84 pub fn push_cpu_chunk_aligned<T: encase::ShaderType + encase::internal::WriteInto>(&mut self, chunk: &T) -> u32 {
90 let offset = self.offset;
91 let mut writer = Writer::new(chunk, &mut self.cpu_byte_buffer, offset).unwrap();
92 chunk.write_into(&mut writer);
93 self.offset += usize::try_from(self.alignment.round_up(chunk.size().get())).unwrap();
95 u32::try_from(offset).unwrap()
96 }
97
98 pub fn push_cpu_chunk_packed<T: encase::ShaderType + encase::internal::WriteInto>(&mut self, chunk: &T) {
105 let offset = self.offset;
106 let mut writer = Writer::new(chunk, &mut self.cpu_byte_buffer, offset).unwrap();
107 chunk.write_into(&mut writer);
108 self.offset += usize::try_from(chunk.size().get()).unwrap();
109 }
110
111 pub fn upload_from_cpu_chunks(&mut self, queue: &wgpu::Queue) {
113 queue.write_buffer(&self.buffer, 0, self.cpu_byte_buffer.as_slice());
115 }
116
117 pub fn reset_chunks_offset(&mut self) {
120 self.offset = 0;
121 }
122
123 pub fn reset_chunks_offset_if_necessary(&mut self) {
128 if self.offset > self.size_bytes / 2 {
130 self.offset = 0;
131 }
132 }
133
134 pub fn offset(&self) -> usize {
135 self.offset
136 }
137
138 #[cfg(feature = "burn-torch")]
139 pub fn new_from_tensor(
140 tensor: &Tensor,
141 device: &wgpu::Device,
142 queue: &wgpu::Queue,
143 adapter: &wgpu::Adapter,
144 usage: wgpu::BufferUsages,
145 label: wgpu::Label,
146 ) -> Self {
147 let mut buffer = Self::new_empty(device, usage, label, 4);
149
150 buffer
152 .copy_from_tensor(tensor, device, queue, adapter)
153 .expect("Failed to copy from tensor");
154
155 buffer
156 }
157
158 #[cfg(feature = "burn-torch")]
159 pub fn copy_from_tensor(
160 &mut self,
161 tensor: &Tensor,
162 device: &wgpu::Device,
163 queue: &wgpu::Queue,
164 adapter: &wgpu::Adapter,
165 ) -> Result<(), CudaInteropError> {
166 if tensor.size().is_empty() || tensor.size()[0] != 1 {
168 return Err(CudaInteropError::InvalidBatchSize(tensor.size().get(0).copied().unwrap_or(0) as usize));
169 }
170
171 let shape = tensor.size();
173 let elem_size = match tensor.kind() {
174 tch::Kind::Float => std::mem::size_of::<f32>(),
175 tch::Kind::Int => std::mem::size_of::<i32>(),
176 _ => {
177 return Err(CudaInteropError::InvalidTensorType(tensor.kind()));
178 }
179 };
180 let num_elements: i64 = shape.iter().skip(1).product(); let num_elements = usize::try_from(num_elements).unwrap();
182 let buf_size = AllocSize {
183 height: 1,
184 width: 1,
185 stride: num_elements * elem_size,
186 };
187
188 if !tensor.is_contiguous() {
190 return Err(CudaInteropError::InvalidNonContiguous);
191 }
192
193 if self.staging_buffer_backed_by_cuda_mem.is_none()
195 || self.staging_buffer_backed_by_cuda_mem.as_ref().unwrap().cuda_mem.alloc_size != buf_size
196 {
197 debug!("staging_buffer_backed_by_cuda_mem creating because it is none or the size is different");
198 let wgpu_cuda = wgpu_cuda_interop::interop::create_wgpu_cuda_buffer(device, adapter, buf_size, wgpu::BufferUsages::COPY_SRC);
199 self.staging_buffer_backed_by_cuda_mem = Some(Arc::new(wgpu_cuda));
200 }
201
202 if self.size_bytes != buf_size.stride {
204 debug!("recreating the wgpu buffer because the size is different");
205 self.size_bytes = buf_size.stride;
206 self.buffer = device.create_buffer(&wgpu::BufferDescriptor {
207 label: Some("Buffer::from_tensor wgpu buffer"),
208 size: self.size_bytes as u64,
209 usage: self.buffer.usage(),
210 mapped_at_creation: false,
211 });
212 }
213
214 let source_ptr = tensor.data_ptr() as cust_raw::CUdeviceptr;
216 if let Some(staging_buffer) = self.staging_buffer_backed_by_cuda_mem.as_ref() {
217 wgpu_cuda_interop::interop::cuda_buffer_to_wgpu(source_ptr, buf_size, staging_buffer, &self.buffer, device, queue);
219 }
220
221 Ok(())
222 }
223}