1use super::*;
4use crate::*; use std::marker::PhantomData;
9use std::ptr;
10use std::slice;
11
12#[derive(Debug, Clone, Copy)]
14pub struct BufferUsage {
15 flags: VkBufferUsageFlags,
16}
17
18impl BufferUsage {
19 pub const STORAGE: Self = Self { flags: VkBufferUsageFlags::STORAGE_BUFFER };
20 pub const TRANSFER_SRC: Self = Self { flags: VkBufferUsageFlags::TRANSFER_SRC };
21 pub const TRANSFER_DST: Self = Self { flags: VkBufferUsageFlags::TRANSFER_DST };
22
23 pub fn storage() -> Self {
24 Self::STORAGE
25 }
26
27 pub fn transfer_src() -> Self {
28 Self::TRANSFER_SRC
29 }
30
31 pub fn transfer_dst() -> Self {
32 Self::TRANSFER_DST
33 }
34}
35
36impl std::ops::BitOr for BufferUsage {
37 type Output = Self;
38
39 fn bitor(self, rhs: Self) -> Self::Output {
40 Self {
41 flags: VkBufferUsageFlags::from_bits_truncate(self.flags.bits() | rhs.flags.bits())
42 }
43 }
44}
45
46pub struct Buffer {
51 pub(super) context: ComputeContext,
52 pub(super) buffer: VkBuffer,
53 pub(super) memory: VkDeviceMemory,
54 pub(super) size: usize,
55 pub(super) usage: BufferUsage,
56 pub(super) _marker: PhantomData<*const u8>,
57}
58
59unsafe impl Send for Buffer {}
61unsafe impl Sync for Buffer {}
62
63impl Buffer {
64 pub fn size(&self) -> usize {
66 self.size
67 }
68
69 pub fn usage(&self) -> BufferUsage {
71 self.usage
72 }
73
74 pub fn raw(&self) -> VkBuffer {
76 self.buffer
77 }
78}
79
80impl ComputeContext {
81 pub fn create_buffer<T>(&self, data: &[T]) -> Result<Buffer>
83 where
84 T: Copy + 'static,
85 {
86 let size = std::mem::size_of_val(data);
87 let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST;
88
89 unsafe {
90 let buffer = self.create_buffer_raw(size, usage)?;
92
93 let staging_usage = BufferUsage::TRANSFER_SRC;
95 let staging = self.create_buffer_raw(size, staging_usage)?;
96
97 self.with_inner(|inner| {
99 let mut mapped_ptr = ptr::null_mut();
100 let result = vkMapMemory(
101 inner.device,
102 staging.memory,
103 0,
104 size as VkDeviceSize,
105 0,
106 &mut mapped_ptr,
107 );
108
109 if result != VkResult::Success {
110 return Err(KronosError::from(result));
111 }
112
113 ptr::copy_nonoverlapping(
114 data.as_ptr() as *const u8,
115 mapped_ptr as *mut u8,
116 size,
117 );
118
119 vkUnmapMemory(inner.device, staging.memory);
120 Ok(())
121 })?;
122
123 self.copy_buffer(&staging, &buffer, size)?;
125
126 Ok(buffer)
128 }
129 }
130
131 pub fn create_buffer_uninit(&self, size: usize) -> Result<Buffer> {
133 let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST | BufferUsage::TRANSFER_SRC;
134 unsafe { self.create_buffer_raw(size, usage) }
135 }
136
137 unsafe fn create_buffer_raw(&self, size: usize, usage: BufferUsage) -> Result<Buffer> {
148 self.with_inner(|inner| {
149 let buffer_info = VkBufferCreateInfo {
151 sType: VkStructureType::BufferCreateInfo,
152 pNext: ptr::null(),
153 flags: VkBufferCreateFlags::empty(),
154 size: size as VkDeviceSize,
155 usage: usage.flags,
156 sharingMode: VkSharingMode::Exclusive,
157 queueFamilyIndexCount: 0,
158 pQueueFamilyIndices: ptr::null(),
159 };
160
161 let mut buffer = VkBuffer::NULL;
162 log::debug!("API layer calling vkCreateBuffer for device {:?}", inner.device);
163 let result = vkCreateBuffer(inner.device, &buffer_info, ptr::null(), &mut buffer);
164
165 if result != VkResult::Success {
166 return Err(KronosError::BufferCreationFailed(format!("vkCreateBuffer failed: {:?}", result)));
167 }
168
169 let mut mem_requirements = VkMemoryRequirements::default();
171 vkGetBufferMemoryRequirements(inner.device, buffer, &mut mem_requirements);
172
173 let memory_type_index = Self::find_memory_type(
175 &inner.memory_properties,
176 mem_requirements.memoryTypeBits,
177 if usage.flags.contains(VkBufferUsageFlags::TRANSFER_SRC) {
178 VkMemoryPropertyFlags::HOST_VISIBLE | VkMemoryPropertyFlags::HOST_COHERENT
179 } else {
180 VkMemoryPropertyFlags::DEVICE_LOCAL
181 },
182 )?;
183
184 let alloc_info = VkMemoryAllocateInfo {
186 sType: VkStructureType::MemoryAllocateInfo,
187 pNext: ptr::null(),
188 allocationSize: mem_requirements.size,
189 memoryTypeIndex: memory_type_index,
190 };
191
192 let mut memory = VkDeviceMemory::NULL;
193 let result = vkAllocateMemory(inner.device, &alloc_info, ptr::null(), &mut memory);
194
195 if result != VkResult::Success {
196 vkDestroyBuffer(inner.device, buffer, ptr::null());
197 return Err(KronosError::BufferCreationFailed(format!("vkAllocateMemory failed: {:?}", result)));
198 }
199
200 let result = vkBindBufferMemory(inner.device, buffer, memory, 0);
202
203 if result != VkResult::Success {
204 vkFreeMemory(inner.device, memory, ptr::null());
205 vkDestroyBuffer(inner.device, buffer, ptr::null());
206 return Err(KronosError::BufferCreationFailed(format!("vkBindBufferMemory failed: {:?}", result)));
207 }
208
209 Ok(Buffer {
210 context: self.clone(),
211 buffer,
212 memory,
213 size,
214 usage,
215 _marker: std::marker::PhantomData,
216 })
217 })
218 }
219
220 fn find_memory_type(
222 memory_properties: &VkPhysicalDeviceMemoryProperties,
223 type_filter: u32,
224 properties: VkMemoryPropertyFlags,
225 ) -> Result<u32> {
226 for i in 0..memory_properties.memoryTypeCount {
227 if (type_filter & (1 << i)) != 0
228 && memory_properties.memoryTypes[i as usize].propertyFlags.contains(properties) {
229 return Ok(i);
230 }
231 }
232
233 Err(KronosError::BufferCreationFailed("No suitable memory type found".into()))
234 }
235
236 unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer, size: usize) -> Result<()> {
248 self.with_inner(|inner| {
249 let alloc_info = VkCommandBufferAllocateInfo {
251 sType: VkStructureType::CommandBufferAllocateInfo,
252 pNext: ptr::null(),
253 commandPool: inner.command_pool,
254 level: VkCommandBufferLevel::Primary,
255 commandBufferCount: 1,
256 };
257
258 let mut command_buffer = VkCommandBuffer::NULL;
259 vkAllocateCommandBuffers(inner.device, &alloc_info, &mut command_buffer);
260
261 let begin_info = VkCommandBufferBeginInfo {
263 sType: VkStructureType::CommandBufferBeginInfo,
264 pNext: ptr::null(),
265 flags: VkCommandBufferUsageFlags::ONE_TIME_SUBMIT,
266 pInheritanceInfo: ptr::null(),
267 };
268
269 vkBeginCommandBuffer(command_buffer, &begin_info);
270
271 let region = VkBufferCopy {
273 srcOffset: 0,
274 dstOffset: 0,
275 size: size as VkDeviceSize,
276 };
277
278 vkCmdCopyBuffer(command_buffer, src.buffer, dst.buffer, 1, ®ion);
279
280 vkEndCommandBuffer(command_buffer);
282
283 let submit_info = VkSubmitInfo {
285 sType: VkStructureType::SubmitInfo,
286 pNext: ptr::null(),
287 waitSemaphoreCount: 0,
288 pWaitSemaphores: ptr::null(),
289 pWaitDstStageMask: ptr::null(),
290 commandBufferCount: 1,
291 pCommandBuffers: &command_buffer,
292 signalSemaphoreCount: 0,
293 pSignalSemaphores: ptr::null(),
294 };
295
296 let result = vkQueueSubmit(inner.queue, 1, &submit_info, VkFence::NULL);
297 if result != VkResult::Success {
298 return Err(KronosError::from(result));
299 }
300
301 vkQueueWaitIdle(inner.queue);
303
304 vkFreeCommandBuffers(inner.device, inner.command_pool, 1, &command_buffer);
306
307 Ok(())
308 })
309 }
310}
311
312impl Buffer {
313 pub fn read<T>(&self) -> Result<Vec<T>>
315 where
316 T: Copy + 'static,
317 {
318 let element_size = std::mem::size_of::<T>();
319 let element_count = self.size / element_size;
320
321 if self.size % element_size != 0 {
322 return Err(KronosError::BufferCreationFailed(
323 format!("Buffer size {} is not a multiple of element size {}", self.size, element_size)
324 ));
325 }
326
327 unsafe {
328 let staging = self.context.create_buffer_uninit(self.size)?;
330
331 self.context.copy_buffer(self, &staging, self.size)?;
333
334 self.context.with_inner(|inner| {
336 let mut mapped_ptr = ptr::null_mut();
337 let result = vkMapMemory(
338 inner.device,
339 staging.memory,
340 0,
341 self.size as VkDeviceSize,
342 0,
343 &mut mapped_ptr,
344 );
345
346 if result != VkResult::Success {
347 return Err(KronosError::from(result));
348 }
349
350 let slice = slice::from_raw_parts(mapped_ptr as *const T, element_count);
351 let vec = slice.to_vec();
352
353 vkUnmapMemory(inner.device, staging.memory);
354
355 Ok(vec)
356 })
357 }
358 }
359}
360
361impl Drop for Buffer {
362 fn drop(&mut self) {
363 unsafe {
364 self.context.with_inner(|inner| {
365 vkFreeMemory(inner.device, self.memory, ptr::null());
366 vkDestroyBuffer(inner.device, self.buffer, ptr::null());
367 });
368 }
369 }
370}
371