kronos_compute/implementation/
persistent_descriptors.rs1use std::collections::HashMap;
9use std::sync::Mutex;
10use crate::sys::*;
11use crate::core::*;
12use crate::ffi::*;
13use super::error::IcdError;
14
15pub const MAX_PUSH_CONSTANT_SIZE: u32 = 128;
17
18pub const PERSISTENT_DESCRIPTOR_SET: u32 = 0;
20
21struct PersistentDescriptor {
23 descriptor_set: VkDescriptorSet,
24 buffers: Vec<VkBuffer>,
25 generation: u64,
26}
27
28pub struct PersistentDescriptorManager {
30 pools: HashMap<u64, VkDescriptorPool>,
32
33 set0_layout: HashMap<u64, VkDescriptorSetLayout>,
35
36 descriptors: HashMap<u64, PersistentDescriptor>,
38
39 generation: u64,
41}
42
43lazy_static::lazy_static! {
44 static ref DESCRIPTOR_MANAGER: Mutex<PersistentDescriptorManager> = Mutex::new(PersistentDescriptorManager {
45 pools: HashMap::new(),
46 set0_layout: HashMap::new(),
47 descriptors: HashMap::new(),
48 generation: 0,
49 });
50}
51
52pub unsafe fn create_persistent_layout(
63 device: VkDevice,
64 max_bindings: u32,
65) -> Result<VkDescriptorSetLayout, IcdError> {
66 let mut manager = DESCRIPTOR_MANAGER.lock()?;
67 let device_key = device.as_raw();
68
69 if let Some(&layout) = manager.set0_layout.get(&device_key) {
71 return Ok(layout);
72 }
73
74 let mut bindings = Vec::with_capacity(max_bindings as usize);
76 for i in 0..max_bindings {
77 bindings.push(VkDescriptorSetLayoutBinding {
78 binding: i,
79 descriptorType: VkDescriptorType::StorageBuffer,
80 descriptorCount: 1,
81 stageFlags: VkShaderStageFlags::COMPUTE,
82 pImmutableSamplers: std::ptr::null(),
83 });
84 }
85
86 let create_info = VkDescriptorSetLayoutCreateInfo {
87 sType: VkStructureType::DescriptorSetLayoutCreateInfo,
88 pNext: std::ptr::null(),
89 flags: 0,
90 bindingCount: bindings.len() as u32,
91 pBindings: bindings.as_ptr(),
92 };
93
94 if let Some(icd) = super::icd_loader::get_icd() {
96 if let Some(create_fn) = icd.create_descriptor_set_layout {
97 let mut layout = VkDescriptorSetLayout::NULL;
98 let result = create_fn(device, &create_info, std::ptr::null(), &mut layout);
99
100 if result == VkResult::Success {
101 manager.set0_layout.insert(device_key, layout);
102 return Ok(layout);
103 }
104 return Err(IcdError::VulkanError(result));
105 }
106 }
107
108 Err(IcdError::MissingFunction("vkCreateDescriptorSetLayout"))
109}
110
111pub unsafe fn get_persistent_pool(
123 device: VkDevice,
124 max_sets: u32,
125 max_descriptors: u32,
126) -> Result<VkDescriptorPool, IcdError> {
127 let mut manager = DESCRIPTOR_MANAGER.lock()?;
128 let device_key = device.as_raw();
129
130 if let Some(&pool) = manager.pools.get(&device_key) {
132 return Ok(pool);
133 }
134
135 let pool_size = VkDescriptorPoolSize {
137 type_: VkDescriptorType::StorageBuffer,
138 descriptorCount: max_descriptors,
139 };
140
141 let create_info = VkDescriptorPoolCreateInfo {
142 sType: VkStructureType::DescriptorPoolCreateInfo,
143 pNext: std::ptr::null(),
144 flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
145 maxSets: max_sets,
146 poolSizeCount: 1,
147 pPoolSizes: &pool_size,
148 };
149
150 if let Some(icd) = super::icd_loader::get_icd() {
152 if let Some(create_fn) = icd.create_descriptor_pool {
153 let mut pool = VkDescriptorPool::NULL;
154 let result = create_fn(device, &create_info, std::ptr::null(), &mut pool);
155
156 if result == VkResult::Success {
157 manager.pools.insert(device_key, pool);
158 return Ok(pool);
159 }
160 return Err(IcdError::VulkanError(result));
161 }
162 }
163
164 Err(IcdError::MissingFunction("vkCreateDescriptorPool"))
165}
166
167pub unsafe fn get_persistent_descriptor_set(
179 device: VkDevice,
180 buffers: &[VkBuffer],
181) -> Result<VkDescriptorSet, IcdError> {
182 let mut manager = DESCRIPTOR_MANAGER.lock()?;
183
184 let cache_key = buffers.iter()
186 .map(|b| b.as_raw())
187 .fold(0u64, |acc, h| acc.wrapping_add(h).rotate_left(7));
188
189 if let Some(descriptor) = manager.descriptors.get(&cache_key) {
191 if descriptor.buffers == buffers {
192 return Ok(descriptor.descriptor_set);
193 }
194 }
195
196 let layout = create_persistent_layout(device, buffers.len() as u32)?;
198 let pool = get_persistent_pool(device, 1000, 10000)?;
199
200 let alloc_info = VkDescriptorSetAllocateInfo {
202 sType: VkStructureType::DescriptorSetAllocateInfo,
203 pNext: std::ptr::null(),
204 descriptorPool: pool,
205 descriptorSetCount: 1,
206 pSetLayouts: &layout,
207 };
208
209 let mut descriptor_set = VkDescriptorSet::NULL;
210
211 if let Some(icd) = super::icd_loader::get_icd() {
212 if let Some(alloc_fn) = icd.allocate_descriptor_sets {
213 let result = alloc_fn(device, &alloc_info, &mut descriptor_set);
214 if result != VkResult::Success {
215 return Err(IcdError::VulkanError(result));
216 }
217 } else {
218 return Err(IcdError::MissingFunction("vkAllocateDescriptorSets"));
219 }
220 } else {
221 return Err(IcdError::NoIcdLoaded);
222 }
223
224 let mut buffer_infos = Vec::with_capacity(buffers.len());
226 let mut writes = Vec::with_capacity(buffers.len());
227
228 for (_i, &buffer) in buffers.iter().enumerate() {
229 buffer_infos.push(VkDescriptorBufferInfo {
230 buffer,
231 offset: 0,
232 range: VK_WHOLE_SIZE,
233 });
234 }
235
236 for (i, buffer_info) in buffer_infos.iter().enumerate() {
237 writes.push(VkWriteDescriptorSet {
238 sType: VkStructureType::WriteDescriptorSet,
239 pNext: std::ptr::null(),
240 dstSet: descriptor_set,
241 dstBinding: i as u32,
242 dstArrayElement: 0,
243 descriptorCount: 1,
244 descriptorType: VkDescriptorType::StorageBuffer,
245 pImageInfo: std::ptr::null(),
246 pBufferInfo: buffer_info,
247 pTexelBufferView: std::ptr::null(),
248 });
249 }
250
251 if let Some(icd) = super::icd_loader::get_icd() {
252 if let Some(update_fn) = icd.update_descriptor_sets {
253 update_fn(device, writes.len() as u32, writes.as_ptr(), 0, std::ptr::null());
254 }
255 }
256
257 manager.generation += 1;
259 let generation = manager.generation;
260 manager.descriptors.insert(cache_key, PersistentDescriptor {
261 descriptor_set,
262 buffers: buffers.to_vec(),
263 generation,
264 });
265
266 Ok(descriptor_set)
267}
268
269pub fn create_push_constant_range(size: u32) -> VkPushConstantRange {
271 assert!(size <= MAX_PUSH_CONSTANT_SIZE, "Push constant size {} exceeds limit {}", size, MAX_PUSH_CONSTANT_SIZE);
272
273 VkPushConstantRange {
274 stageFlags: VkShaderStageFlags::COMPUTE,
275 offset: 0,
276 size,
277 }
278}
279
280pub unsafe fn create_compute_pipeline_layout(
292 device: VkDevice,
293 set0_binding_count: u32,
294 push_constant_size: u32,
295) -> Result<VkPipelineLayout, IcdError> {
296 let set0_layout = create_persistent_layout(device, set0_binding_count)?;
297
298 let mut create_info = VkPipelineLayoutCreateInfo {
299 sType: VkStructureType::PipelineLayoutCreateInfo,
300 pNext: std::ptr::null(),
301 flags: 0,
302 setLayoutCount: 1,
303 pSetLayouts: &set0_layout,
304 pushConstantRangeCount: 0,
305 pPushConstantRanges: std::ptr::null(),
306 };
307
308 let push_range = if push_constant_size > 0 {
309 Some(create_push_constant_range(push_constant_size))
310 } else {
311 None
312 };
313
314 if let Some(ref range) = push_range {
315 create_info.pushConstantRangeCount = 1;
316 create_info.pPushConstantRanges = range;
317 }
318
319 let mut layout = VkPipelineLayout::NULL;
320
321 if let Some(icd) = super::icd_loader::get_icd() {
322 if let Some(create_fn) = icd.create_pipeline_layout {
323 let result = create_fn(device, &create_info, std::ptr::null(), &mut layout);
324 if result == VkResult::Success {
325 return Ok(layout);
326 }
327 return Err(IcdError::VulkanError(result));
328 }
329 }
330
331 Err(IcdError::MissingFunction("vkCreatePipelineLayout"))
332}
333
334pub unsafe fn cleanup_persistent_descriptors(device: VkDevice) -> Result<(), IcdError> {
346 let mut manager = DESCRIPTOR_MANAGER.lock()?;
347 let device_key = device.as_raw();
348
349 if let Some(pool) = manager.pools.remove(&device_key) {
351 if let Some(icd) = super::icd_loader::get_icd() {
352 if let Some(destroy_fn) = icd.destroy_descriptor_pool {
353 destroy_fn(device, pool, std::ptr::null());
354 }
355 }
356 }
357
358 if let Some(layout) = manager.set0_layout.remove(&device_key) {
360 if let Some(icd) = super::icd_loader::get_icd() {
361 if let Some(destroy_fn) = icd.destroy_descriptor_set_layout {
362 destroy_fn(device, layout, std::ptr::null());
363 }
364 }
365 }
366
367 manager.descriptors.retain(|_, desc| {
369 desc.generation > 0 });
372
373 Ok(())
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_push_constant_range() {
382 let range = create_push_constant_range(64);
383 assert_eq!(range.stageFlags, VkShaderStageFlags::COMPUTE);
384 assert_eq!(range.offset, 0);
385 assert_eq!(range.size, 64);
386 }
387
388 #[test]
389 #[should_panic]
390 fn test_push_constant_size_limit() {
391 create_push_constant_range(256); }
393}