1use super::*;
4use crate::*; use std::marker::PhantomData;
6use std::ptr;
7use std::slice;
8
9#[derive(Debug, Clone, Copy)]
11pub struct BufferUsage {
12 flags: VkBufferUsageFlags,
13}
14
15impl BufferUsage {
16 pub const STORAGE: Self = Self { flags: VkBufferUsageFlags::STORAGE_BUFFER };
17 pub const TRANSFER_SRC: Self = Self { flags: VkBufferUsageFlags::TRANSFER_SRC };
18 pub const TRANSFER_DST: Self = Self { flags: VkBufferUsageFlags::TRANSFER_DST };
19
20 pub fn storage() -> Self {
21 Self::STORAGE
22 }
23
24 pub fn transfer_src() -> Self {
25 Self::TRANSFER_SRC
26 }
27
28 pub fn transfer_dst() -> Self {
29 Self::TRANSFER_DST
30 }
31}
32
33impl std::ops::BitOr for BufferUsage {
34 type Output = Self;
35
36 fn bitor(self, rhs: Self) -> Self::Output {
37 Self {
38 flags: VkBufferUsageFlags::from_bits_truncate(self.flags.bits() | rhs.flags.bits())
39 }
40 }
41}
42
43pub struct Buffer {
48 pub(super) context: ComputeContext,
49 pub(super) buffer: VkBuffer,
50 pub(super) memory: VkDeviceMemory,
51 pub(super) size: usize,
52 pub(super) usage: BufferUsage,
53 pub(super) _marker: PhantomData<*const u8>,
54}
55
56unsafe impl Send for Buffer {}
58unsafe impl Sync for Buffer {}
59
60impl Buffer {
61 pub fn size(&self) -> usize {
63 self.size
64 }
65
66 pub fn usage(&self) -> BufferUsage {
68 self.usage
69 }
70
71 pub fn raw(&self) -> VkBuffer {
73 self.buffer
74 }
75}
76
77impl ComputeContext {
78 pub fn create_buffer<T>(&self, data: &[T]) -> Result<Buffer>
80 where
81 T: Copy + 'static,
82 {
83 let size = std::mem::size_of_val(data);
84 let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST;
85
86 unsafe {
87 let buffer = self.create_buffer_raw(size, usage)?;
89
90 let staging_usage = BufferUsage::TRANSFER_SRC;
92 let staging = self.create_buffer_raw(size, staging_usage)?;
93
94 self.with_inner(|inner| {
96 let mut mapped_ptr = ptr::null_mut();
97 let result = vkMapMemory(
98 inner.device,
99 staging.memory,
100 0,
101 size as VkDeviceSize,
102 0,
103 &mut mapped_ptr,
104 );
105
106 if result != VkResult::Success {
107 return Err(KronosError::from(result));
108 }
109
110 ptr::copy_nonoverlapping(
111 data.as_ptr() as *const u8,
112 mapped_ptr as *mut u8,
113 size,
114 );
115
116 vkUnmapMemory(inner.device, staging.memory);
117 Ok(())
118 })?;
119
120 self.copy_buffer(&staging, &buffer, size)?;
122
123 Ok(buffer)
125 }
126 }
127
128 pub fn create_buffer_uninit(&self, size: usize) -> Result<Buffer> {
130 let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST | BufferUsage::TRANSFER_SRC;
131 unsafe { self.create_buffer_raw(size, usage) }
132 }
133
134 unsafe fn create_buffer_raw(&self, size: usize, usage: BufferUsage) -> Result<Buffer> {
145 self.with_inner(|inner| {
146 let buffer_info = VkBufferCreateInfo {
148 sType: VkStructureType::BufferCreateInfo,
149 pNext: ptr::null(),
150 flags: VkBufferCreateFlags::empty(),
151 size: size as VkDeviceSize,
152 usage: usage.flags,
153 sharingMode: VkSharingMode::Exclusive,
154 queueFamilyIndexCount: 0,
155 pQueueFamilyIndices: ptr::null(),
156 };
157
158 let mut buffer = VkBuffer::NULL;
159 let result = vkCreateBuffer(inner.device, &buffer_info, ptr::null(), &mut buffer);
160
161 if result != VkResult::Success {
162 return Err(KronosError::BufferCreationFailed(format!("vkCreateBuffer failed: {:?}", result)));
163 }
164
165 let mut mem_requirements = VkMemoryRequirements::default();
167 vkGetBufferMemoryRequirements(inner.device, buffer, &mut mem_requirements);
168
169 let memory_type_index = Self::find_memory_type(
171 &inner.memory_properties,
172 mem_requirements.memoryTypeBits,
173 if usage.flags.contains(VkBufferUsageFlags::TRANSFER_SRC) {
174 VkMemoryPropertyFlags::HOST_VISIBLE | VkMemoryPropertyFlags::HOST_COHERENT
175 } else {
176 VkMemoryPropertyFlags::DEVICE_LOCAL
177 },
178 )?;
179
180 let alloc_info = VkMemoryAllocateInfo {
182 sType: VkStructureType::MemoryAllocateInfo,
183 pNext: ptr::null(),
184 allocationSize: mem_requirements.size,
185 memoryTypeIndex: memory_type_index,
186 };
187
188 let mut memory = VkDeviceMemory::NULL;
189 let result = vkAllocateMemory(inner.device, &alloc_info, ptr::null(), &mut memory);
190
191 if result != VkResult::Success {
192 vkDestroyBuffer(inner.device, buffer, ptr::null());
193 return Err(KronosError::BufferCreationFailed(format!("vkAllocateMemory failed: {:?}", result)));
194 }
195
196 let result = vkBindBufferMemory(inner.device, buffer, memory, 0);
198
199 if result != VkResult::Success {
200 vkFreeMemory(inner.device, memory, ptr::null());
201 vkDestroyBuffer(inner.device, buffer, ptr::null());
202 return Err(KronosError::BufferCreationFailed(format!("vkBindBufferMemory failed: {:?}", result)));
203 }
204
205 Ok(Buffer {
206 context: self.clone(),
207 buffer,
208 memory,
209 size,
210 usage,
211 _marker: std::marker::PhantomData,
212 })
213 })
214 }
215
216 fn find_memory_type(
218 memory_properties: &VkPhysicalDeviceMemoryProperties,
219 type_filter: u32,
220 properties: VkMemoryPropertyFlags,
221 ) -> Result<u32> {
222 for i in 0..memory_properties.memoryTypeCount {
223 if (type_filter & (1 << i)) != 0
224 && memory_properties.memoryTypes[i as usize].propertyFlags.contains(properties) {
225 return Ok(i);
226 }
227 }
228
229 Err(KronosError::BufferCreationFailed("No suitable memory type found".into()))
230 }
231
232 unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer, size: usize) -> Result<()> {
244 self.with_inner(|inner| {
245 let alloc_info = VkCommandBufferAllocateInfo {
247 sType: VkStructureType::CommandBufferAllocateInfo,
248 pNext: ptr::null(),
249 commandPool: inner.command_pool,
250 level: VkCommandBufferLevel::Primary,
251 commandBufferCount: 1,
252 };
253
254 let mut command_buffer = VkCommandBuffer::NULL;
255 vkAllocateCommandBuffers(inner.device, &alloc_info, &mut command_buffer);
256
257 let begin_info = VkCommandBufferBeginInfo {
259 sType: VkStructureType::CommandBufferBeginInfo,
260 pNext: ptr::null(),
261 flags: VkCommandBufferUsageFlags::ONE_TIME_SUBMIT,
262 pInheritanceInfo: ptr::null(),
263 };
264
265 vkBeginCommandBuffer(command_buffer, &begin_info);
266
267 let region = VkBufferCopy {
269 srcOffset: 0,
270 dstOffset: 0,
271 size: size as VkDeviceSize,
272 };
273
274 vkCmdCopyBuffer(command_buffer, src.buffer, dst.buffer, 1, ®ion);
275
276 vkEndCommandBuffer(command_buffer);
278
279 let submit_info = VkSubmitInfo {
281 sType: VkStructureType::SubmitInfo,
282 pNext: ptr::null(),
283 waitSemaphoreCount: 0,
284 pWaitSemaphores: ptr::null(),
285 pWaitDstStageMask: ptr::null(),
286 commandBufferCount: 1,
287 pCommandBuffers: &command_buffer,
288 signalSemaphoreCount: 0,
289 pSignalSemaphores: ptr::null(),
290 };
291
292 let result = vkQueueSubmit(inner.queue, 1, &submit_info, VkFence::NULL);
293 if result != VkResult::Success {
294 return Err(KronosError::from(result));
295 }
296
297 vkQueueWaitIdle(inner.queue);
299
300 vkFreeCommandBuffers(inner.device, inner.command_pool, 1, &command_buffer);
302
303 Ok(())
304 })
305 }
306}
307
308impl Buffer {
309 pub fn read<T>(&self) -> Result<Vec<T>>
311 where
312 T: Copy + 'static,
313 {
314 let element_size = std::mem::size_of::<T>();
315 let element_count = self.size / element_size;
316
317 if self.size % element_size != 0 {
318 return Err(KronosError::BufferCreationFailed(
319 format!("Buffer size {} is not a multiple of element size {}", self.size, element_size)
320 ));
321 }
322
323 unsafe {
324 let staging = self.context.create_buffer_uninit(self.size)?;
326
327 self.context.copy_buffer(self, &staging, self.size)?;
329
330 self.context.with_inner(|inner| {
332 let mut mapped_ptr = ptr::null_mut();
333 let result = vkMapMemory(
334 inner.device,
335 staging.memory,
336 0,
337 self.size as VkDeviceSize,
338 0,
339 &mut mapped_ptr,
340 );
341
342 if result != VkResult::Success {
343 return Err(KronosError::from(result));
344 }
345
346 let slice = slice::from_raw_parts(mapped_ptr as *const T, element_count);
347 let vec = slice.to_vec();
348
349 vkUnmapMemory(inner.device, staging.memory);
350
351 Ok(vec)
352 })
353 }
354 }
355}
356
357impl Drop for Buffer {
358 fn drop(&mut self) {
359 unsafe {
360 self.context.with_inner(|inner| {
361 vkFreeMemory(inner.device, self.memory, ptr::null());
362 vkDestroyBuffer(inner.device, self.buffer, ptr::null());
363 });
364 }
365 }
366}
367