kronos_compute/api/
buffer.rs

1//! Safe buffer management with automatic memory allocation
2
3use super::*;
4use crate::*; // Need all the type definitions
5
6// Explicitly import Vulkan functions from implementation when available
7// This ensures we use Kronos's multi-ICD aware implementation
8#[cfg(feature = "implementation")]
9use crate::implementation::{
10    vkCreateBuffer, vkDestroyBuffer, vkGetBufferMemoryRequirements, 
11    vkBindBufferMemory, vkAllocateMemory, vkFreeMemory,
12    vkMapMemory, vkUnmapMemory, vkCmdCopyBuffer,
13};
14
15// If implementation feature is not enabled, these functions must come from
16// linking to an external Vulkan library
17use std::marker::PhantomData;
18use std::ptr;
19use std::slice;
20
21/// Usage flags for buffers
22#[derive(Debug, Clone, Copy)]
23pub struct BufferUsage {
24    flags: VkBufferUsageFlags,
25}
26
27impl BufferUsage {
28    pub const STORAGE: Self = Self { flags: VkBufferUsageFlags::STORAGE_BUFFER };
29    pub const TRANSFER_SRC: Self = Self { flags: VkBufferUsageFlags::TRANSFER_SRC };
30    pub const TRANSFER_DST: Self = Self { flags: VkBufferUsageFlags::TRANSFER_DST };
31    
32    pub fn storage() -> Self {
33        Self::STORAGE
34    }
35    
36    pub fn transfer_src() -> Self {
37        Self::TRANSFER_SRC
38    }
39    
40    pub fn transfer_dst() -> Self {
41        Self::TRANSFER_DST
42    }
43}
44
45impl std::ops::BitOr for BufferUsage {
46    type Output = Self;
47    
48    fn bitor(self, rhs: Self) -> Self::Output {
49        Self {
50            flags: VkBufferUsageFlags::from_bits_truncate(self.flags.bits() | rhs.flags.bits())
51        }
52    }
53}
54
55/// A GPU buffer with automatic memory management
56/// 
57/// Buffers are automatically freed when dropped and use the
58/// pool allocator for efficient memory management.
59pub struct Buffer {
60    pub(super) context: ComputeContext,
61    pub(super) buffer: VkBuffer,
62    pub(super) memory: VkDeviceMemory,
63    pub(super) size: usize,
64    pub(super) usage: BufferUsage,
65    pub(super) _marker: PhantomData<*const u8>,
66}
67
68// Send + Sync for thread safety
69unsafe impl Send for Buffer {}
70unsafe impl Sync for Buffer {}
71
72impl Buffer {
73    /// Get the size of the buffer in bytes
74    pub fn size(&self) -> usize {
75        self.size
76    }
77    
78    /// Get the usage flags
79    pub fn usage(&self) -> BufferUsage {
80        self.usage
81    }
82    
83    /// Get the raw Vulkan buffer handle (for advanced usage)
84    pub fn raw(&self) -> VkBuffer {
85        self.buffer
86    }
87}
88
89impl ComputeContext {
90    /// Create a buffer with data
91    pub fn create_buffer<T>(&self, data: &[T]) -> Result<Buffer> 
92    where
93        T: Copy + 'static,
94    {
95        let size = std::mem::size_of_val(data);
96        let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST;
97        
98        unsafe {
99            // Create buffer
100            let buffer = self.create_buffer_raw(size, usage)?;
101            
102            // Create staging buffer
103            let staging_usage = BufferUsage::TRANSFER_SRC;
104            let staging = self.create_buffer_raw(size, staging_usage)?;
105            
106            // Map and copy data
107            self.with_inner(|inner| {
108                let mut mapped_ptr = ptr::null_mut();
109                let result = vkMapMemory(
110                    inner.device,
111                    staging.memory,
112                    0,
113                    size as VkDeviceSize,
114                    0,
115                    &mut mapped_ptr,
116                );
117                
118                if result != VkResult::Success {
119                    return Err(KronosError::from(result));
120                }
121                
122                ptr::copy_nonoverlapping(
123                    data.as_ptr() as *const u8,
124                    mapped_ptr as *mut u8,
125                    size,
126                );
127                
128                vkUnmapMemory(inner.device, staging.memory);
129                Ok(())
130            })?;
131            
132            // Copy staging to device buffer
133            self.copy_buffer(&staging, &buffer, size)?;
134            
135            // Staging buffer will be dropped automatically
136            Ok(buffer)
137        }
138    }
139    
140    /// Create an uninitialized buffer
141    pub fn create_buffer_uninit(&self, size: usize) -> Result<Buffer> {
142        let usage = BufferUsage::STORAGE | BufferUsage::TRANSFER_DST | BufferUsage::TRANSFER_SRC;
143        unsafe { self.create_buffer_raw(size, usage) }
144    }
145    
146    /// Internal: Create a raw buffer
147    ///
148    /// # Safety
149    ///
150    /// This function is unsafe because:
151    /// - It directly calls Vulkan functions that require valid device handles
152    /// - The caller must ensure `self` contains valid Vulkan handles (device, queue, etc.)
153    /// - The created buffer must be properly destroyed to avoid memory leaks
154    /// - Memory allocation may fail and must be handled appropriately
155    /// - The returned Buffer takes ownership of the Vulkan resources
156    unsafe fn create_buffer_raw(&self, size: usize, usage: BufferUsage) -> Result<Buffer> {
157        self.with_inner(|inner| {
158            // Create buffer
159            let buffer_info = VkBufferCreateInfo {
160                sType: VkStructureType::BufferCreateInfo,
161                pNext: ptr::null(),
162                flags: VkBufferCreateFlags::empty(),
163                size: size as VkDeviceSize,
164                usage: usage.flags,
165                sharingMode: VkSharingMode::Exclusive,
166                queueFamilyIndexCount: 0,
167                pQueueFamilyIndices: ptr::null(),
168            };
169            
170            let mut buffer = VkBuffer::NULL;
171            log::debug!("API layer calling vkCreateBuffer for device {:?}", inner.device);
172            let result = vkCreateBuffer(inner.device, &buffer_info, ptr::null(), &mut buffer);
173            
174            if result != VkResult::Success {
175                return Err(KronosError::BufferCreationFailed(format!("vkCreateBuffer failed: {:?}", result)));
176            }
177            
178            // Get memory requirements
179            let mut mem_requirements = VkMemoryRequirements::default();
180            vkGetBufferMemoryRequirements(inner.device, buffer, &mut mem_requirements);
181            
182            // Find suitable memory type
183            let memory_type_index = Self::find_memory_type(
184                &inner.memory_properties,
185                mem_requirements.memoryTypeBits,
186                if usage.flags.contains(VkBufferUsageFlags::TRANSFER_SRC) {
187                    VkMemoryPropertyFlags::HOST_VISIBLE | VkMemoryPropertyFlags::HOST_COHERENT
188                } else {
189                    VkMemoryPropertyFlags::DEVICE_LOCAL
190                },
191            )?;
192            
193            // Allocate memory (this would use the pool allocator in the real implementation)
194            let alloc_info = VkMemoryAllocateInfo {
195                sType: VkStructureType::MemoryAllocateInfo,
196                pNext: ptr::null(),
197                allocationSize: mem_requirements.size,
198                memoryTypeIndex: memory_type_index,
199            };
200            
201            let mut memory = VkDeviceMemory::NULL;
202            let result = vkAllocateMemory(inner.device, &alloc_info, ptr::null(), &mut memory);
203            
204            if result != VkResult::Success {
205                vkDestroyBuffer(inner.device, buffer, ptr::null());
206                return Err(KronosError::BufferCreationFailed(format!("vkAllocateMemory failed: {:?}", result)));
207            }
208            
209            // Bind memory to buffer
210            let result = vkBindBufferMemory(inner.device, buffer, memory, 0);
211            
212            if result != VkResult::Success {
213                vkFreeMemory(inner.device, memory, ptr::null());
214                vkDestroyBuffer(inner.device, buffer, ptr::null());
215                return Err(KronosError::BufferCreationFailed(format!("vkBindBufferMemory failed: {:?}", result)));
216            }
217            
218            Ok(Buffer {
219                context: self.clone(),
220                buffer,
221                memory,
222                size,
223                usage,
224                _marker: std::marker::PhantomData,
225            })
226        })
227    }
228    
229    /// Find a suitable memory type
230    fn find_memory_type(
231        memory_properties: &VkPhysicalDeviceMemoryProperties,
232        type_filter: u32,
233        properties: VkMemoryPropertyFlags,
234    ) -> Result<u32> {
235        for i in 0..memory_properties.memoryTypeCount {
236            if (type_filter & (1 << i)) != 0 
237                && memory_properties.memoryTypes[i as usize].propertyFlags.contains(properties) {
238                return Ok(i);
239            }
240        }
241        
242        Err(KronosError::BufferCreationFailed("No suitable memory type found".into()))
243    }
244    
245    /// Copy data between buffers
246    ///
247    /// # Safety
248    ///
249    /// This function is unsafe because:
250    /// - It directly calls Vulkan functions that require valid handles
251    /// - The caller must ensure both `src` and `dst` buffers are valid
252    /// - The `size` must not exceed the size of either buffer
253    /// - Both buffers must have appropriate usage flags (TRANSFER_SRC for src, TRANSFER_DST for dst)
254    /// - The function submits commands to the GPU queue and waits for completion
255    /// - Concurrent access to the buffers during copy is undefined behavior
256    unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer, size: usize) -> Result<()> {
257        self.with_inner(|inner| {
258            // Allocate command buffer
259            let alloc_info = VkCommandBufferAllocateInfo {
260                sType: VkStructureType::CommandBufferAllocateInfo,
261                pNext: ptr::null(),
262                commandPool: inner.command_pool,
263                level: VkCommandBufferLevel::Primary,
264                commandBufferCount: 1,
265            };
266            
267            let mut command_buffer = VkCommandBuffer::NULL;
268            vkAllocateCommandBuffers(inner.device, &alloc_info, &mut command_buffer);
269            
270            // Begin recording
271            let begin_info = VkCommandBufferBeginInfo {
272                sType: VkStructureType::CommandBufferBeginInfo,
273                pNext: ptr::null(),
274                flags: VkCommandBufferUsageFlags::ONE_TIME_SUBMIT,
275                pInheritanceInfo: ptr::null(),
276            };
277            
278            vkBeginCommandBuffer(command_buffer, &begin_info);
279            
280            // Record copy command
281            let region = VkBufferCopy {
282                srcOffset: 0,
283                dstOffset: 0,
284                size: size as VkDeviceSize,
285            };
286            
287            vkCmdCopyBuffer(command_buffer, src.buffer, dst.buffer, 1, &region);
288            
289            // End recording
290            vkEndCommandBuffer(command_buffer);
291            
292            // Submit
293            let submit_info = VkSubmitInfo {
294                sType: VkStructureType::SubmitInfo,
295                pNext: ptr::null(),
296                waitSemaphoreCount: 0,
297                pWaitSemaphores: ptr::null(),
298                pWaitDstStageMask: ptr::null(),
299                commandBufferCount: 1,
300                pCommandBuffers: &command_buffer,
301                signalSemaphoreCount: 0,
302                pSignalSemaphores: ptr::null(),
303            };
304            
305            let result = vkQueueSubmit(inner.queue, 1, &submit_info, VkFence::NULL);
306            if result != VkResult::Success {
307                return Err(KronosError::from(result));
308            }
309            
310            // Wait for completion
311            vkQueueWaitIdle(inner.queue);
312            
313            // Free command buffer
314            vkFreeCommandBuffers(inner.device, inner.command_pool, 1, &command_buffer);
315            
316            Ok(())
317        })
318    }
319}
320
321impl Buffer {
322    /// Read data from the buffer
323    pub fn read<T>(&self) -> Result<Vec<T>>
324    where
325        T: Copy + 'static,
326    {
327        let element_size = std::mem::size_of::<T>();
328        let element_count = self.size / element_size;
329        
330        if self.size % element_size != 0 {
331            return Err(KronosError::BufferCreationFailed(
332                format!("Buffer size {} is not a multiple of element size {}", self.size, element_size)
333            ));
334        }
335        
336        unsafe {
337            // Create staging buffer
338            let staging = self.context.create_buffer_uninit(self.size)?;
339            
340            // Copy device to staging
341            self.context.copy_buffer(self, &staging, self.size)?;
342            
343            // Map and read
344            self.context.with_inner(|inner| {
345                let mut mapped_ptr = ptr::null_mut();
346                let result = vkMapMemory(
347                    inner.device,
348                    staging.memory,
349                    0,
350                    self.size as VkDeviceSize,
351                    0,
352                    &mut mapped_ptr,
353                );
354                
355                if result != VkResult::Success {
356                    return Err(KronosError::from(result));
357                }
358                
359                let slice = slice::from_raw_parts(mapped_ptr as *const T, element_count);
360                let vec = slice.to_vec();
361                
362                vkUnmapMemory(inner.device, staging.memory);
363                
364                Ok(vec)
365            })
366        }
367    }
368}
369
370impl Drop for Buffer {
371    fn drop(&mut self) {
372        unsafe {
373            self.context.with_inner(|inner| {
374                vkFreeMemory(inner.device, self.memory, ptr::null());
375                vkDestroyBuffer(inner.device, self.buffer, ptr::null());
376            });
377        }
378    }
379}
380