1use super::*;
4use crate::*; #[cfg(feature = "implementation")]
9use crate::implementation::{
10 vkCreateBuffer, vkDestroyBuffer, vkGetBufferMemoryRequirements,
11 vkBindBufferMemory, vkAllocateMemory, vkFreeMemory,
12 vkMapMemory, vkUnmapMemory, vkCmdCopyBuffer,
13};
14
15use std::marker::PhantomData;
18use std::ptr;
19use std::slice;
20
21#[derive(Debug, Clone, Copy)]
23pub struct BufferUsage {
24 flags: VkBufferUsageFlags,
25}
26
27impl BufferUsage {
28 pub const STORAGE: Self = Self { flags: VkBufferUsageFlags::STORAGE_BUFFER };
29 pub const TRANSFER_SRC: Self = Self { flags: VkBufferUsageFlags::TRANSFER_SRC };
30 pub const TRANSFER_DST: Self = Self { flags: VkBufferUsageFlags::TRANSFER_DST };
31
32 pub fn storage() -> Self {
33 Self::STORAGE
34 }
35
36 pub fn transfer_src() -> Self {
37 Self::TRANSFER_SRC
38 }
39
40 pub fn transfer_dst() -> Self {
41 Self::TRANSFER_DST
42 }
43}
44
45impl std::ops::BitOr for BufferUsage {
46 type Output = Self;
47
48 fn bitor(self, rhs: Self) -> Self::Output {
49 Self {
50 flags: VkBufferUsageFlags::from_bits_truncate(self.flags.bits() | rhs.flags.bits())
51 }
52 }
53}
54
55pub struct Buffer {
60 pub(super) context: ComputeContext,
61 pub(super) buffer: VkBuffer,
62 pub(super) memory: VkDeviceMemory,
63 pub(super) size: usize,
64 pub(super) usage: BufferUsage,
65 pub(super) _marker: PhantomData<*const u8>,
66}
67
68unsafe impl Send for Buffer {}
70unsafe impl Sync for Buffer {}
71
72impl Buffer {
73 pub fn size(&self) -> usize {
75 self.size
76 }
77
78 pub fn usage(&self) -> BufferUsage {
80 self.usage
81 }
82
83 pub fn raw(&self) -> VkBuffer {
85 self.buffer
86 }
87}
88
89impl ComputeContext {
90 pub fn create_buffer<T>(&self, data: &[T]) -> Result<Buffer>
92 where
93 T: Copy + 'static,
94 {
95 let size = std::mem::size_of_val(data);
96 let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST;
97
98 unsafe {
99 let buffer = self.create_buffer_raw(size, usage)?;
101
102 let staging_usage = BufferUsage::TRANSFER_SRC;
104 let staging = self.create_buffer_raw(size, staging_usage)?;
105
106 self.with_inner(|inner| {
108 let mut mapped_ptr = ptr::null_mut();
109 let result = vkMapMemory(
110 inner.device,
111 staging.memory,
112 0,
113 size as VkDeviceSize,
114 0,
115 &mut mapped_ptr,
116 );
117
118 if result != VkResult::Success {
119 return Err(KronosError::from(result));
120 }
121
122 ptr::copy_nonoverlapping(
123 data.as_ptr() as *const u8,
124 mapped_ptr as *mut u8,
125 size,
126 );
127
128 vkUnmapMemory(inner.device, staging.memory);
129 Ok(())
130 })?;
131
132 self.copy_buffer(&staging, &buffer, size)?;
134
135 Ok(buffer)
137 }
138 }
139
140 pub fn create_buffer_uninit(&self, size: usize) -> Result<Buffer> {
142 let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST | BufferUsage::TRANSFER_SRC;
143 unsafe { self.create_buffer_raw(size, usage) }
144 }
145
146 unsafe fn create_buffer_raw(&self, size: usize, usage: BufferUsage) -> Result<Buffer> {
157 self.with_inner(|inner| {
158 let buffer_info = VkBufferCreateInfo {
160 sType: VkStructureType::BufferCreateInfo,
161 pNext: ptr::null(),
162 flags: VkBufferCreateFlags::empty(),
163 size: size as VkDeviceSize,
164 usage: usage.flags,
165 sharingMode: VkSharingMode::Exclusive,
166 queueFamilyIndexCount: 0,
167 pQueueFamilyIndices: ptr::null(),
168 };
169
170 let mut buffer = VkBuffer::NULL;
171 log::debug!("API layer calling vkCreateBuffer for device {:?}", inner.device);
172 let result = vkCreateBuffer(inner.device, &buffer_info, ptr::null(), &mut buffer);
173
174 if result != VkResult::Success {
175 return Err(KronosError::BufferCreationFailed(format!("vkCreateBuffer failed: {:?}", result)));
176 }
177
178 let mut mem_requirements = VkMemoryRequirements::default();
180 vkGetBufferMemoryRequirements(inner.device, buffer, &mut mem_requirements);
181
182 let memory_type_index = Self::find_memory_type(
184 &inner.memory_properties,
185 mem_requirements.memoryTypeBits,
186 if usage.flags.contains(VkBufferUsageFlags::TRANSFER_SRC) {
187 VkMemoryPropertyFlags::HOST_VISIBLE | VkMemoryPropertyFlags::HOST_COHERENT
188 } else {
189 VkMemoryPropertyFlags::DEVICE_LOCAL
190 },
191 )?;
192
193 let alloc_info = VkMemoryAllocateInfo {
195 sType: VkStructureType::MemoryAllocateInfo,
196 pNext: ptr::null(),
197 allocationSize: mem_requirements.size,
198 memoryTypeIndex: memory_type_index,
199 };
200
201 let mut memory = VkDeviceMemory::NULL;
202 let result = vkAllocateMemory(inner.device, &alloc_info, ptr::null(), &mut memory);
203
204 if result != VkResult::Success {
205 vkDestroyBuffer(inner.device, buffer, ptr::null());
206 return Err(KronosError::BufferCreationFailed(format!("vkAllocateMemory failed: {:?}", result)));
207 }
208
209 let result = vkBindBufferMemory(inner.device, buffer, memory, 0);
211
212 if result != VkResult::Success {
213 vkFreeMemory(inner.device, memory, ptr::null());
214 vkDestroyBuffer(inner.device, buffer, ptr::null());
215 return Err(KronosError::BufferCreationFailed(format!("vkBindBufferMemory failed: {:?}", result)));
216 }
217
218 Ok(Buffer {
219 context: self.clone(),
220 buffer,
221 memory,
222 size,
223 usage,
224 _marker: std::marker::PhantomData,
225 })
226 })
227 }
228
229 fn find_memory_type(
231 memory_properties: &VkPhysicalDeviceMemoryProperties,
232 type_filter: u32,
233 properties: VkMemoryPropertyFlags,
234 ) -> Result<u32> {
235 for i in 0..memory_properties.memoryTypeCount {
236 if (type_filter & (1 << i)) != 0
237 && memory_properties.memoryTypes[i as usize].propertyFlags.contains(properties) {
238 return Ok(i);
239 }
240 }
241
242 Err(KronosError::BufferCreationFailed("No suitable memory type found".into()))
243 }
244
245 unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer, size: usize) -> Result<()> {
257 self.with_inner(|inner| {
258 let alloc_info = VkCommandBufferAllocateInfo {
260 sType: VkStructureType::CommandBufferAllocateInfo,
261 pNext: ptr::null(),
262 commandPool: inner.command_pool,
263 level: VkCommandBufferLevel::Primary,
264 commandBufferCount: 1,
265 };
266
267 let mut command_buffer = VkCommandBuffer::NULL;
268 vkAllocateCommandBuffers(inner.device, &alloc_info, &mut command_buffer);
269
270 let begin_info = VkCommandBufferBeginInfo {
272 sType: VkStructureType::CommandBufferBeginInfo,
273 pNext: ptr::null(),
274 flags: VkCommandBufferUsageFlags::ONE_TIME_SUBMIT,
275 pInheritanceInfo: ptr::null(),
276 };
277
278 vkBeginCommandBuffer(command_buffer, &begin_info);
279
280 let region = VkBufferCopy {
282 srcOffset: 0,
283 dstOffset: 0,
284 size: size as VkDeviceSize,
285 };
286
287 vkCmdCopyBuffer(command_buffer, src.buffer, dst.buffer, 1, ®ion);
288
289 vkEndCommandBuffer(command_buffer);
291
292 let submit_info = VkSubmitInfo {
294 sType: VkStructureType::SubmitInfo,
295 pNext: ptr::null(),
296 waitSemaphoreCount: 0,
297 pWaitSemaphores: ptr::null(),
298 pWaitDstStageMask: ptr::null(),
299 commandBufferCount: 1,
300 pCommandBuffers: &command_buffer,
301 signalSemaphoreCount: 0,
302 pSignalSemaphores: ptr::null(),
303 };
304
305 let result = vkQueueSubmit(inner.queue, 1, &submit_info, VkFence::NULL);
306 if result != VkResult::Success {
307 return Err(KronosError::from(result));
308 }
309
310 vkQueueWaitIdle(inner.queue);
312
313 vkFreeCommandBuffers(inner.device, inner.command_pool, 1, &command_buffer);
315
316 Ok(())
317 })
318 }
319}
320
321impl Buffer {
322 pub fn read<T>(&self) -> Result<Vec<T>>
324 where
325 T: Copy + 'static,
326 {
327 let element_size = std::mem::size_of::<T>();
328 let element_count = self.size / element_size;
329
330 if self.size % element_size != 0 {
331 return Err(KronosError::BufferCreationFailed(
332 format!("Buffer size {} is not a multiple of element size {}", self.size, element_size)
333 ));
334 }
335
336 unsafe {
337 let staging = self.context.create_buffer_uninit(self.size)?;
339
340 self.context.copy_buffer(self, &staging, self.size)?;
342
343 self.context.with_inner(|inner| {
345 let mut mapped_ptr = ptr::null_mut();
346 let result = vkMapMemory(
347 inner.device,
348 staging.memory,
349 0,
350 self.size as VkDeviceSize,
351 0,
352 &mut mapped_ptr,
353 );
354
355 if result != VkResult::Success {
356 return Err(KronosError::from(result));
357 }
358
359 let slice = slice::from_raw_parts(mapped_ptr as *const T, element_count);
360 let vec = slice.to_vec();
361
362 vkUnmapMemory(inner.device, staging.memory);
363
364 Ok(vec)
365 })
366 }
367 }
368}
369
370impl Drop for Buffer {
371 fn drop(&mut self) {
372 unsafe {
373 self.context.with_inner(|inner| {
374 vkFreeMemory(inner.device, self.memory, ptr::null());
375 vkDestroyBuffer(inner.device, self.buffer, ptr::null());
376 });
377 }
378 }
379}
380