kronos_compute/implementation/
persistent_descriptors.rs

1//! Persistent descriptor management for optimal performance
2//! 
3//! Implements Set0 as persistent storage buffer descriptors that are:
4//! - Created once per buffer lifetime
5//! - Never updated in hot path
6//! - Parameters passed via push constants (≤128B)
7
8use std::collections::HashMap;
9use std::sync::Mutex;
10use crate::sys::*;
11use crate::core::*;
12use crate::ffi::*;
13use super::error::IcdError;
14
15/// Maximum push constant size (typical hardware limit)
16pub const MAX_PUSH_CONSTANT_SIZE: u32 = 128;
17
18/// Descriptor set 0 is reserved for persistent storage buffers
19pub const PERSISTENT_DESCRIPTOR_SET: u32 = 0;
20
21/// Persistent descriptor cache entry
22struct PersistentDescriptor {
23    descriptor_set: VkDescriptorSet,
24    buffers: Vec<VkBuffer>,
25    generation: u64,
26}
27
28/// Global persistent descriptor manager
29pub struct PersistentDescriptorManager {
30    /// Device -> Pool mapping
31    pools: HashMap<u64, VkDescriptorPool>,
32    
33    /// Layout for Set0 (storage buffers only)
34    set0_layout: HashMap<u64, VkDescriptorSetLayout>,
35    
36    /// Buffer -> Descriptor mapping
37    descriptors: HashMap<u64, PersistentDescriptor>,
38    
39    /// Generation counter for cache invalidation
40    generation: u64,
41}
42
43lazy_static::lazy_static! {
44    static ref DESCRIPTOR_MANAGER: Mutex<PersistentDescriptorManager> = Mutex::new(PersistentDescriptorManager {
45        pools: HashMap::new(),
46        set0_layout: HashMap::new(),
47        descriptors: HashMap::new(),
48        generation: 0,
49    });
50}
51
52/// Create Set0 layout for storage buffers
53///
54/// # Safety
55///
56/// This function is unsafe because:
57/// - The device must be a valid VkDevice handle
58/// - Calls vkCreateDescriptorSetLayout through ICD function pointer
59/// - The returned layout must be destroyed with vkDestroyDescriptorSetLayout
60/// - Invalid device handle will cause undefined behavior
61/// - The ICD must be properly initialized with valid function pointers
62pub unsafe fn create_persistent_layout(
63    device: VkDevice,
64    max_bindings: u32,
65) -> Result<VkDescriptorSetLayout, IcdError> {
66    let mut manager = DESCRIPTOR_MANAGER.lock()?;
67    let device_key = device.as_raw();
68    
69    // Return existing layout if already created
70    if let Some(&layout) = manager.set0_layout.get(&device_key) {
71        return Ok(layout);
72    }
73    
74    // Create bindings for storage buffers
75    let mut bindings = Vec::with_capacity(max_bindings as usize);
76    for i in 0..max_bindings {
77        bindings.push(VkDescriptorSetLayoutBinding {
78            binding: i,
79            descriptorType: VkDescriptorType::StorageBuffer,
80            descriptorCount: 1,
81            stageFlags: VkShaderStageFlags::COMPUTE,
82            pImmutableSamplers: std::ptr::null(),
83        });
84    }
85    
86    let create_info = VkDescriptorSetLayoutCreateInfo {
87        sType: VkStructureType::DescriptorSetLayoutCreateInfo,
88        pNext: std::ptr::null(),
89        flags: 0,
90        bindingCount: bindings.len() as u32,
91        pBindings: bindings.as_ptr(),
92    };
93    
94    // Forward to ICD
95    if let Some(icd) = super::icd_loader::get_icd() {
96        if let Some(create_fn) = icd.create_descriptor_set_layout {
97            let mut layout = VkDescriptorSetLayout::NULL;
98            let result = create_fn(device, &create_info, std::ptr::null(), &mut layout);
99            
100            if result == VkResult::Success {
101                manager.set0_layout.insert(device_key, layout);
102                return Ok(layout);
103            }
104            return Err(IcdError::VulkanError(result));
105        }
106    }
107    
108    Err(IcdError::MissingFunction("vkCreateDescriptorSetLayout"))
109}
110
111/// Create or get persistent descriptor pool
112///
113/// # Safety
114///
115/// This function is unsafe because:
116/// - The device must be a valid VkDevice handle
117/// - Calls vkCreateDescriptorPool through ICD function pointer
118/// - The returned pool must be destroyed with vkDestroyDescriptorPool
119/// - Pool limits (max_sets, max_descriptors) must not exceed device limits
120/// - Invalid device handle will cause undefined behavior
121/// - Thread safety relies on the Mutex protecting the global manager
122pub unsafe fn get_persistent_pool(
123    device: VkDevice,
124    max_sets: u32,
125    max_descriptors: u32,
126) -> Result<VkDescriptorPool, IcdError> {
127    let mut manager = DESCRIPTOR_MANAGER.lock()?;
128    let device_key = device.as_raw();
129    
130    // Return existing pool if already created
131    if let Some(&pool) = manager.pools.get(&device_key) {
132        return Ok(pool);
133    }
134    
135    // Create pool for storage buffer descriptors only
136    let pool_size = VkDescriptorPoolSize {
137        type_: VkDescriptorType::StorageBuffer,
138        descriptorCount: max_descriptors,
139    };
140    
141    let create_info = VkDescriptorPoolCreateInfo {
142        sType: VkStructureType::DescriptorPoolCreateInfo,
143        pNext: std::ptr::null(),
144        flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
145        maxSets: max_sets,
146        poolSizeCount: 1,
147        pPoolSizes: &pool_size,
148    };
149    
150    // Forward to ICD
151    if let Some(icd) = super::icd_loader::get_icd() {
152        if let Some(create_fn) = icd.create_descriptor_pool {
153            let mut pool = VkDescriptorPool::NULL;
154            let result = create_fn(device, &create_info, std::ptr::null(), &mut pool);
155            
156            if result == VkResult::Success {
157                manager.pools.insert(device_key, pool);
158                return Ok(pool);
159            }
160            return Err(IcdError::VulkanError(result));
161        }
162    }
163    
164    Err(IcdError::MissingFunction("vkCreateDescriptorPool"))
165}
166
167/// Get or create persistent descriptor set for buffers
168///
169/// # Safety
170///
171/// This function is unsafe because:
172/// - The device must be a valid VkDevice handle
173/// - All buffers in the array must be valid VkBuffer handles
174/// - Calls multiple Vulkan functions through ICD pointers
175/// - The descriptor set references the provided buffers
176/// - Buffers must remain valid for the lifetime of the descriptor set
177/// - Buffer usage must be compatible with STORAGE_BUFFER descriptor type
178pub unsafe fn get_persistent_descriptor_set(
179    device: VkDevice,
180    buffers: &[VkBuffer],
181) -> Result<VkDescriptorSet, IcdError> {
182    let mut manager = DESCRIPTOR_MANAGER.lock()?;
183    
184    // Create cache key from buffer handles
185    let cache_key = buffers.iter()
186        .map(|b| b.as_raw())
187        .fold(0u64, |acc, h| acc.wrapping_add(h).rotate_left(7));
188    
189    // Check if we already have this descriptor set
190    if let Some(descriptor) = manager.descriptors.get(&cache_key) {
191        if descriptor.buffers == buffers {
192            return Ok(descriptor.descriptor_set);
193        }
194    }
195    
196    // Get or create layout and pool
197    let layout = create_persistent_layout(device, buffers.len() as u32)?;
198    let pool = get_persistent_pool(device, 1000, 10000)?;
199    
200    // Allocate descriptor set
201    let alloc_info = VkDescriptorSetAllocateInfo {
202        sType: VkStructureType::DescriptorSetAllocateInfo,
203        pNext: std::ptr::null(),
204        descriptorPool: pool,
205        descriptorSetCount: 1,
206        pSetLayouts: &layout,
207    };
208    
209    let mut descriptor_set = VkDescriptorSet::NULL;
210    
211    if let Some(icd) = super::icd_loader::get_icd() {
212        if let Some(alloc_fn) = icd.allocate_descriptor_sets {
213            let result = alloc_fn(device, &alloc_info, &mut descriptor_set);
214            if result != VkResult::Success {
215                return Err(IcdError::VulkanError(result));
216            }
217        } else {
218            return Err(IcdError::MissingFunction("vkAllocateDescriptorSets"));
219        }
220    } else {
221        return Err(IcdError::NoIcdLoaded);
222    }
223    
224    // Write descriptor set with buffer bindings
225    let mut buffer_infos = Vec::with_capacity(buffers.len());
226    let mut writes = Vec::with_capacity(buffers.len());
227    
228    for (_i, &buffer) in buffers.iter().enumerate() {
229        buffer_infos.push(VkDescriptorBufferInfo {
230            buffer,
231            offset: 0,
232            range: VK_WHOLE_SIZE,
233        });
234    }
235    
236    for (i, buffer_info) in buffer_infos.iter().enumerate() {
237        writes.push(VkWriteDescriptorSet {
238            sType: VkStructureType::WriteDescriptorSet,
239            pNext: std::ptr::null(),
240            dstSet: descriptor_set,
241            dstBinding: i as u32,
242            dstArrayElement: 0,
243            descriptorCount: 1,
244            descriptorType: VkDescriptorType::StorageBuffer,
245            pImageInfo: std::ptr::null(),
246            pBufferInfo: buffer_info,
247            pTexelBufferView: std::ptr::null(),
248        });
249    }
250    
251    if let Some(icd) = super::icd_loader::get_icd() {
252        if let Some(update_fn) = icd.update_descriptor_sets {
253            update_fn(device, writes.len() as u32, writes.as_ptr(), 0, std::ptr::null());
254        }
255    }
256    
257    // Cache the descriptor
258    manager.generation += 1;
259    let generation = manager.generation;
260    manager.descriptors.insert(cache_key, PersistentDescriptor {
261        descriptor_set,
262        buffers: buffers.to_vec(),
263        generation,
264    });
265    
266    Ok(descriptor_set)
267}
268
269/// Create push constant range for parameters
270pub fn create_push_constant_range(size: u32) -> VkPushConstantRange {
271    assert!(size <= MAX_PUSH_CONSTANT_SIZE, "Push constant size {} exceeds limit {}", size, MAX_PUSH_CONSTANT_SIZE);
272    
273    VkPushConstantRange {
274        stageFlags: VkShaderStageFlags::COMPUTE,
275        offset: 0,
276        size,
277    }
278}
279
280/// Create optimized pipeline layout with Set0 + push constants
281///
282/// # Safety
283///
284/// This function is unsafe because:
285/// - The device must be a valid VkDevice handle
286/// - Calls vkCreatePipelineLayout through ICD function pointer
287/// - The returned layout must be destroyed with vkDestroyPipelineLayout
288/// - push_constant_size must not exceed MAX_PUSH_CONSTANT_SIZE (128 bytes)
289/// - set0_binding_count must not exceed device limits
290/// - Invalid parameters may cause device lost or undefined behavior
291pub unsafe fn create_compute_pipeline_layout(
292    device: VkDevice,
293    set0_binding_count: u32,
294    push_constant_size: u32,
295) -> Result<VkPipelineLayout, IcdError> {
296    let set0_layout = create_persistent_layout(device, set0_binding_count)?;
297    
298    let mut create_info = VkPipelineLayoutCreateInfo {
299        sType: VkStructureType::PipelineLayoutCreateInfo,
300        pNext: std::ptr::null(),
301        flags: 0,
302        setLayoutCount: 1,
303        pSetLayouts: &set0_layout,
304        pushConstantRangeCount: 0,
305        pPushConstantRanges: std::ptr::null(),
306    };
307    
308    let push_range = if push_constant_size > 0 {
309        Some(create_push_constant_range(push_constant_size))
310    } else {
311        None
312    };
313    
314    if let Some(ref range) = push_range {
315        create_info.pushConstantRangeCount = 1;
316        create_info.pPushConstantRanges = range;
317    }
318    
319    let mut layout = VkPipelineLayout::NULL;
320    
321    if let Some(icd) = super::icd_loader::get_icd() {
322        if let Some(create_fn) = icd.create_pipeline_layout {
323            let result = create_fn(device, &create_info, std::ptr::null(), &mut layout);
324            if result == VkResult::Success {
325                return Ok(layout);
326            }
327            return Err(IcdError::VulkanError(result));
328        }
329    }
330    
331    Err(IcdError::MissingFunction("vkCreatePipelineLayout"))
332}
333
334/// Cleanup persistent descriptors for a device
335///
336/// # Safety
337///
338/// This function is unsafe because:
339/// - The device must be a valid VkDevice handle
340/// - Calls vkDestroyDescriptorPool and vkDestroyDescriptorSetLayout
341/// - All descriptor sets allocated from the pool become invalid
342/// - Must be called before device destruction
343/// - Concurrent use of descriptors during cleanup causes undefined behavior
344/// - The global manager mutex provides thread safety for the cleanup
345pub unsafe fn cleanup_persistent_descriptors(device: VkDevice) -> Result<(), IcdError> {
346    let mut manager = DESCRIPTOR_MANAGER.lock()?;
347    let device_key = device.as_raw();
348    
349    // Clean up pool
350    if let Some(pool) = manager.pools.remove(&device_key) {
351        if let Some(icd) = super::icd_loader::get_icd() {
352            if let Some(destroy_fn) = icd.destroy_descriptor_pool {
353                destroy_fn(device, pool, std::ptr::null());
354            }
355        }
356    }
357    
358    // Clean up layout
359    if let Some(layout) = manager.set0_layout.remove(&device_key) {
360        if let Some(icd) = super::icd_loader::get_icd() {
361            if let Some(destroy_fn) = icd.destroy_descriptor_set_layout {
362                destroy_fn(device, layout, std::ptr::null());
363            }
364        }
365    }
366    
367    // Remove cached descriptors for this device
368    manager.descriptors.retain(|_, desc| {
369        // In a real implementation, we'd track which descriptors belong to which device
370        desc.generation > 0 // Placeholder
371    });
372    
373    Ok(())
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    
380    #[test]
381    fn test_push_constant_range() {
382        let range = create_push_constant_range(64);
383        assert_eq!(range.stageFlags, VkShaderStageFlags::COMPUTE);
384        assert_eq!(range.offset, 0);
385        assert_eq!(range.size, 64);
386    }
387    
388    #[test]
389    #[should_panic]
390    fn test_push_constant_size_limit() {
391        create_push_constant_range(256); // Exceeds 128 byte limit
392    }
393}