kronos_compute/api/
buffer.rs

1//! Safe buffer management with automatic memory allocation
2
3use super::*;
4use crate::*; // Import all functions from the crate root
5use std::marker::PhantomData;
6use std::ptr;
7use std::slice;
8
9/// Usage flags for buffers
10#[derive(Debug, Clone, Copy)]
11pub struct BufferUsage {
12    flags: VkBufferUsageFlags,
13}
14
15impl BufferUsage {
16    pub const STORAGE: Self = Self { flags: VkBufferUsageFlags::STORAGE_BUFFER };
17    pub const TRANSFER_SRC: Self = Self { flags: VkBufferUsageFlags::TRANSFER_SRC };
18    pub const TRANSFER_DST: Self = Self { flags: VkBufferUsageFlags::TRANSFER_DST };
19    
20    pub fn storage() -> Self {
21        Self::STORAGE
22    }
23    
24    pub fn transfer_src() -> Self {
25        Self::TRANSFER_SRC
26    }
27    
28    pub fn transfer_dst() -> Self {
29        Self::TRANSFER_DST
30    }
31}
32
33impl std::ops::BitOr for BufferUsage {
34    type Output = Self;
35    
36    fn bitor(self, rhs: Self) -> Self::Output {
37        Self {
38            flags: VkBufferUsageFlags::from_bits_truncate(self.flags.bits() | rhs.flags.bits())
39        }
40    }
41}
42
43/// A GPU buffer with automatic memory management
44/// 
45/// Buffers are automatically freed when dropped and use the
46/// pool allocator for efficient memory management.
47pub struct Buffer {
48    pub(super) context: ComputeContext,
49    pub(super) buffer: VkBuffer,
50    pub(super) memory: VkDeviceMemory,
51    pub(super) size: usize,
52    pub(super) usage: BufferUsage,
53    pub(super) _marker: PhantomData<*const u8>,
54}
55
56// Send + Sync for thread safety
57unsafe impl Send for Buffer {}
58unsafe impl Sync for Buffer {}
59
60impl Buffer {
61    /// Get the size of the buffer in bytes
62    pub fn size(&self) -> usize {
63        self.size
64    }
65    
66    /// Get the usage flags
67    pub fn usage(&self) -> BufferUsage {
68        self.usage
69    }
70    
71    /// Get the raw Vulkan buffer handle (for advanced usage)
72    pub fn raw(&self) -> VkBuffer {
73        self.buffer
74    }
75}
76
77impl ComputeContext {
78    /// Create a buffer with data
79    pub fn create_buffer<T>(&self, data: &[T]) -> Result<Buffer> 
80    where
81        T: Copy + 'static,
82    {
83        let size = std::mem::size_of_val(data);
84        let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST;
85        
86        unsafe {
87            // Create buffer
88            let buffer = self.create_buffer_raw(size, usage)?;
89            
90            // Create staging buffer
91            let staging_usage = BufferUsage::TRANSFER_SRC;
92            let staging = self.create_buffer_raw(size, staging_usage)?;
93            
94            // Map and copy data
95            self.with_inner(|inner| {
96                let mut mapped_ptr = ptr::null_mut();
97                let result = vkMapMemory(
98                    inner.device,
99                    staging.memory,
100                    0,
101                    size as VkDeviceSize,
102                    0,
103                    &mut mapped_ptr,
104                );
105                
106                if result != VkResult::Success {
107                    return Err(KronosError::from(result));
108                }
109                
110                ptr::copy_nonoverlapping(
111                    data.as_ptr() as *const u8,
112                    mapped_ptr as *mut u8,
113                    size,
114                );
115                
116                vkUnmapMemory(inner.device, staging.memory);
117                Ok(())
118            })?;
119            
120            // Copy staging to device buffer
121            self.copy_buffer(&staging, &buffer, size)?;
122            
123            // Staging buffer will be dropped automatically
124            Ok(buffer)
125        }
126    }
127    
128    /// Create an uninitialized buffer
129    pub fn create_buffer_uninit(&self, size: usize) -> Result<Buffer> {
130        let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST | BufferUsage::TRANSFER_SRC;
131        unsafe { self.create_buffer_raw(size, usage) }
132    }
133    
134    /// Internal: Create a raw buffer
135    ///
136    /// # Safety
137    ///
138    /// This function is unsafe because:
139    /// - It directly calls Vulkan functions that require valid device handles
140    /// - The caller must ensure `self` contains valid Vulkan handles (device, queue, etc.)
141    /// - The created buffer must be properly destroyed to avoid memory leaks
142    /// - Memory allocation may fail and must be handled appropriately
143    /// - The returned Buffer takes ownership of the Vulkan resources
144    unsafe fn create_buffer_raw(&self, size: usize, usage: BufferUsage) -> Result<Buffer> {
145        self.with_inner(|inner| {
146            // Create buffer
147            let buffer_info = VkBufferCreateInfo {
148                sType: VkStructureType::BufferCreateInfo,
149                pNext: ptr::null(),
150                flags: VkBufferCreateFlags::empty(),
151                size: size as VkDeviceSize,
152                usage: usage.flags,
153                sharingMode: VkSharingMode::Exclusive,
154                queueFamilyIndexCount: 0,
155                pQueueFamilyIndices: ptr::null(),
156            };
157            
158            let mut buffer = VkBuffer::NULL;
159            let result = vkCreateBuffer(inner.device, &buffer_info, ptr::null(), &mut buffer);
160            
161            if result != VkResult::Success {
162                return Err(KronosError::BufferCreationFailed(format!("vkCreateBuffer failed: {:?}", result)));
163            }
164            
165            // Get memory requirements
166            let mut mem_requirements = VkMemoryRequirements::default();
167            vkGetBufferMemoryRequirements(inner.device, buffer, &mut mem_requirements);
168            
169            // Find suitable memory type
170            let memory_type_index = Self::find_memory_type(
171                &inner.memory_properties,
172                mem_requirements.memoryTypeBits,
173                if usage.flags.contains(VkBufferUsageFlags::TRANSFER_SRC) {
174                    VkMemoryPropertyFlags::HOST_VISIBLE | VkMemoryPropertyFlags::HOST_COHERENT
175                } else {
176                    VkMemoryPropertyFlags::DEVICE_LOCAL
177                },
178            )?;
179            
180            // Allocate memory (this would use the pool allocator in the real implementation)
181            let alloc_info = VkMemoryAllocateInfo {
182                sType: VkStructureType::MemoryAllocateInfo,
183                pNext: ptr::null(),
184                allocationSize: mem_requirements.size,
185                memoryTypeIndex: memory_type_index,
186            };
187            
188            let mut memory = VkDeviceMemory::NULL;
189            let result = vkAllocateMemory(inner.device, &alloc_info, ptr::null(), &mut memory);
190            
191            if result != VkResult::Success {
192                vkDestroyBuffer(inner.device, buffer, ptr::null());
193                return Err(KronosError::BufferCreationFailed(format!("vkAllocateMemory failed: {:?}", result)));
194            }
195            
196            // Bind memory to buffer
197            let result = vkBindBufferMemory(inner.device, buffer, memory, 0);
198            
199            if result != VkResult::Success {
200                vkFreeMemory(inner.device, memory, ptr::null());
201                vkDestroyBuffer(inner.device, buffer, ptr::null());
202                return Err(KronosError::BufferCreationFailed(format!("vkBindBufferMemory failed: {:?}", result)));
203            }
204            
205            Ok(Buffer {
206                context: self.clone(),
207                buffer,
208                memory,
209                size,
210                usage,
211                _marker: std::marker::PhantomData,
212            })
213        })
214    }
215    
216    /// Find a suitable memory type
217    fn find_memory_type(
218        memory_properties: &VkPhysicalDeviceMemoryProperties,
219        type_filter: u32,
220        properties: VkMemoryPropertyFlags,
221    ) -> Result<u32> {
222        for i in 0..memory_properties.memoryTypeCount {
223            if (type_filter & (1 << i)) != 0 
224                && memory_properties.memoryTypes[i as usize].propertyFlags.contains(properties) {
225                return Ok(i);
226            }
227        }
228        
229        Err(KronosError::BufferCreationFailed("No suitable memory type found".into()))
230    }
231    
232    /// Copy data between buffers
233    ///
234    /// # Safety
235    ///
236    /// This function is unsafe because:
237    /// - It directly calls Vulkan functions that require valid handles
238    /// - The caller must ensure both `src` and `dst` buffers are valid
239    /// - The `size` must not exceed the size of either buffer
240    /// - Both buffers must have appropriate usage flags (TRANSFER_SRC for src, TRANSFER_DST for dst)
241    /// - The function submits commands to the GPU queue and waits for completion
242    /// - Concurrent access to the buffers during copy is undefined behavior
243    unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer, size: usize) -> Result<()> {
244        self.with_inner(|inner| {
245            // Allocate command buffer
246            let alloc_info = VkCommandBufferAllocateInfo {
247                sType: VkStructureType::CommandBufferAllocateInfo,
248                pNext: ptr::null(),
249                commandPool: inner.command_pool,
250                level: VkCommandBufferLevel::Primary,
251                commandBufferCount: 1,
252            };
253            
254            let mut command_buffer = VkCommandBuffer::NULL;
255            vkAllocateCommandBuffers(inner.device, &alloc_info, &mut command_buffer);
256            
257            // Begin recording
258            let begin_info = VkCommandBufferBeginInfo {
259                sType: VkStructureType::CommandBufferBeginInfo,
260                pNext: ptr::null(),
261                flags: VkCommandBufferUsageFlags::ONE_TIME_SUBMIT,
262                pInheritanceInfo: ptr::null(),
263            };
264            
265            vkBeginCommandBuffer(command_buffer, &begin_info);
266            
267            // Record copy command
268            let region = VkBufferCopy {
269                srcOffset: 0,
270                dstOffset: 0,
271                size: size as VkDeviceSize,
272            };
273            
274            vkCmdCopyBuffer(command_buffer, src.buffer, dst.buffer, 1, &region);
275            
276            // End recording
277            vkEndCommandBuffer(command_buffer);
278            
279            // Submit
280            let submit_info = VkSubmitInfo {
281                sType: VkStructureType::SubmitInfo,
282                pNext: ptr::null(),
283                waitSemaphoreCount: 0,
284                pWaitSemaphores: ptr::null(),
285                pWaitDstStageMask: ptr::null(),
286                commandBufferCount: 1,
287                pCommandBuffers: &command_buffer,
288                signalSemaphoreCount: 0,
289                pSignalSemaphores: ptr::null(),
290            };
291            
292            let result = vkQueueSubmit(inner.queue, 1, &submit_info, VkFence::NULL);
293            if result != VkResult::Success {
294                return Err(KronosError::from(result));
295            }
296            
297            // Wait for completion
298            vkQueueWaitIdle(inner.queue);
299            
300            // Free command buffer
301            vkFreeCommandBuffers(inner.device, inner.command_pool, 1, &command_buffer);
302            
303            Ok(())
304        })
305    }
306}
307
308impl Buffer {
309    /// Read data from the buffer
310    pub fn read<T>(&self) -> Result<Vec<T>>
311    where
312        T: Copy + 'static,
313    {
314        let element_size = std::mem::size_of::<T>();
315        let element_count = self.size / element_size;
316        
317        if self.size % element_size != 0 {
318            return Err(KronosError::BufferCreationFailed(
319                format!("Buffer size {} is not a multiple of element size {}", self.size, element_size)
320            ));
321        }
322        
323        unsafe {
324            // Create staging buffer
325            let staging = self.context.create_buffer_uninit(self.size)?;
326            
327            // Copy device to staging
328            self.context.copy_buffer(self, &staging, self.size)?;
329            
330            // Map and read
331            self.context.with_inner(|inner| {
332                let mut mapped_ptr = ptr::null_mut();
333                let result = vkMapMemory(
334                    inner.device,
335                    staging.memory,
336                    0,
337                    self.size as VkDeviceSize,
338                    0,
339                    &mut mapped_ptr,
340                );
341                
342                if result != VkResult::Success {
343                    return Err(KronosError::from(result));
344                }
345                
346                let slice = slice::from_raw_parts(mapped_ptr as *const T, element_count);
347                let vec = slice.to_vec();
348                
349                vkUnmapMemory(inner.device, staging.memory);
350                
351                Ok(vec)
352            })
353        }
354    }
355}
356
357impl Drop for Buffer {
358    fn drop(&mut self) {
359        unsafe {
360            self.context.with_inner(|inner| {
361                vkFreeMemory(inner.device, self.memory, ptr::null());
362                vkDestroyBuffer(inner.device, self.buffer, ptr::null());
363            });
364        }
365    }
366}
367