rivi_loader/
lib.rs

1mod lib_test;
2
3use std::{error::Error, fmt, sync::RwLock};
4
5use ash::vk;
6use gpu_allocator::vulkan::{Allocation, AllocationCreateDesc, Allocator, AllocatorCreateDesc};
7use rayon::prelude::*;
8
9const LAYER_VALIDATION: *const std::os::raw::c_char = concat!("VK_LAYER_KHRONOS_validation", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
10const LAYER_DEBUG: *const std::os::raw::c_char = concat!("VK_LAYER_LUNARG_api_dump", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
11
12const EXT_VARIABLE_POINTERS: *const std::os::raw::c_char = concat!("VK_KHR_variable_pointers", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
13const EXT_GET_MEMORY_REQUIREMENTS2: *const std::os::raw::c_char = concat!("VK_KHR_get_memory_requirements2", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
14const EXT_DEDICATED_ALLOCATION: *const std::os::raw::c_char = concat!("VK_KHR_dedicated_allocation", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
15const EXT_PORTABILITY_SUBSET: *const std::os::raw::c_char = concat!("VK_KHR_portability_subset", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
16
17pub fn new(
18    debug: DebugOption
19) -> Result<Vulkan, Box<dyn Error>> {
20    let vk = Vulkan::new(debug)?;
21    Ok(vk)
22}
23
24pub struct Vulkan {
25    entry: ash::Entry, // Needs to outlive Instance and Devices.
26    instance: ash::Instance, // Needs to outlive Devices.
27    debug_layer: Option<DebugLayer>,
28    compute: Option<Vec<Compute>>,
29}
30
31impl Vulkan {
32
33    fn new(
34        debug: DebugOption
35    ) -> Result<Self, Box<dyn Error>> {
36
37        let vk_layers = match debug {
38            DebugOption::None => vec![],
39            DebugOption::Validation => vec![LAYER_VALIDATION],
40            DebugOption::Debug => vec![LAYER_VALIDATION, LAYER_DEBUG],
41        };
42
43        let mut info = match debug {
44            DebugOption::None => vk::DebugUtilsMessengerCreateInfoEXT {
45                ..Default::default()
46            },
47            _ => {
48                vk::DebugUtilsMessengerCreateInfoEXT {
49                    message_severity: vk::DebugUtilsMessageSeverityFlagsEXT::WARNING
50                        | vk::DebugUtilsMessageSeverityFlagsEXT::VERBOSE
51                        | vk::DebugUtilsMessageSeverityFlagsEXT::INFO
52                        | vk::DebugUtilsMessageSeverityFlagsEXT::ERROR,
53                    message_type: vk::DebugUtilsMessageTypeFlagsEXT::GENERAL
54                        | vk::DebugUtilsMessageTypeFlagsEXT::PERFORMANCE
55                        | vk::DebugUtilsMessageTypeFlagsEXT::VALIDATION,
56                    pfn_user_callback: Some(DebugLayer::callback),
57                    ..Default::default()
58                }
59            }
60        };
61
62        let entry = unsafe { ash::Entry::load()? };
63
64        let instance = unsafe {
65            entry.create_instance(&vk::InstanceCreateInfo::builder()
66                .push_next(&mut info)
67                .application_info(&vk::ApplicationInfo {
68                    api_version: vk::make_api_version(0, 1, 2, 0),
69                    engine_version: 0,
70                    ..Default::default()
71                })
72                .enabled_layer_names(&vk_layers)
73                .enabled_extension_names(&[ash::extensions::ext::DebugUtils::name().as_ptr()])
74            , None)?
75        };
76
77        let debug_layer = match debug {
78            DebugOption::None => None,
79            _ => {
80                let loader = ash::extensions::ext::DebugUtils::new(&entry, &instance);
81                let messenger = unsafe { loader.create_debug_utils_messenger(&info, None)? };
82                Some(DebugLayer{loader, messenger})
83            },
84        };
85
86        let computes = Self::logical_devices(&instance)?;
87        let compute = match computes.len() {
88            0 => None,
89            _ => Some(computes),
90        };
91
92        Ok(Self{entry, instance, debug_layer, compute})
93    }
94
95    pub fn version(
96        &self
97    ) -> Result<(u32, u32, u32), Box<dyn Error>>  {
98        match self.entry.try_enumerate_instance_version()? {
99            Some(v) => Ok((vk::api_version_major(v), vk::api_version_minor(v), vk::api_version_patch(v))),
100            None => Ok((vk::api_version_major(1), vk::api_version_minor(0), vk::api_version_patch(0))),
101        }
102    }
103
104    fn logical_devices(
105        instance: &ash::Instance,
106    ) -> Result<Vec<Compute>, Box<dyn Error>> {
107        let pdevices = unsafe { instance.enumerate_physical_devices()? };
108        Ok(pdevices.into_iter()
109            .filter(|pdevice| {
110                let (_, properties) = Self::device_name(instance, *pdevice);
111                let sp = Self::subgroup_properties(instance, *pdevice);
112                properties.device_type.ne(&vk::PhysicalDeviceType::CPU)
113                && sp.supported_stages.contains(vk::ShaderStageFlags::COMPUTE)
114            })
115            .map(|pdevice| {
116                let device = Self::create_device(instance, pdevice)?;
117                let queue_infos = unsafe { Self::queue_infos(instance, pdevice) };
118                let fences = Self::create_fences(&device, queue_infos)?;
119                let allocator = Allocator::new(&AllocatorCreateDesc {
120                    physical_device: pdevice,
121                    device: device.clone(),
122                    instance: instance.clone(),
123                    debug_settings: Default::default(),
124                    buffer_device_address: false,
125                })?;
126                let memory = unsafe { instance.get_physical_device_memory_properties(pdevice) };
127
128                Ok(Compute { device, allocator: Some(RwLock::new(allocator)), fences, memory})
129
130            })
131            .collect::<Result<Vec<Compute>, Box<dyn Error>>>()?.into_iter()
132            .filter(|c| !c.fences.is_empty())
133            .collect::<Vec<Compute>>())
134    }
135
136    unsafe fn queue_infos(
137        instance: &ash::Instance,
138        pdevice: vk::PhysicalDevice,
139    ) -> Vec<(usize, Vec<f32>)> {
140        instance.get_physical_device_queue_family_properties(pdevice).iter().enumerate()
141            .filter(|(_, prop)| prop.queue_flags.contains(vk::QueueFlags::COMPUTE))
142            .map(|(idx, prop)| (idx, vec![1.0_f32; prop.queue_count as usize]))
143            .collect()
144    }
145
146    fn device_name(
147        instance: &ash::Instance,
148        pdevice: vk::PhysicalDevice,
149    ) -> (String, vk::PhysicalDeviceProperties)  {
150        let mut dp2 = vk::PhysicalDeviceProperties2::builder().build();
151        unsafe { instance.fp_v1_1().get_physical_device_properties2(pdevice, &mut dp2) };
152        let device_name = dp2.properties.device_name.iter()
153            .filter_map(|f| match *f as u8 {
154                0 => None,
155                _ => Some(*f as u8 as char),
156            })
157            .collect::<String>();
158        (device_name, dp2.properties)
159    }
160
161    fn subgroup_properties(
162        instance: &ash::Instance,
163        pdevice: vk::PhysicalDevice,
164    ) -> vk::PhysicalDeviceSubgroupProperties {
165        // Retrieving Subgroup operations will segfault a Mac
166        // https://www.khronos.org/blog/vulkan-subgroup-tutorial
167        let mut sp = vk::PhysicalDeviceSubgroupProperties::builder();
168        let mut dp2 = vk::PhysicalDeviceProperties2::builder().push_next(&mut sp).build();
169        unsafe { instance.fp_v1_1().get_physical_device_properties2(pdevice, &mut dp2); }
170        sp.build()
171    }
172
173    fn create_fences(
174        device: &ash::Device,
175        queue_infos: Vec<(usize, Vec<f32>)>,
176    ) -> Result<Vec<Fence>, Box<dyn Error>> {
177        Ok(queue_infos.into_iter().flat_map(|(phy_index, queue_priorities)| {
178            (0..queue_priorities.len()).into_iter().map(|queue_index| {
179                let vk_fence = unsafe { device.create_fence(&vk::FenceCreateInfo::default(), None)? };
180                let present_queue = unsafe { device.get_device_queue(phy_index as u32, queue_index as u32) };
181                Ok(Fence{ fence: vk_fence, present_queue, phy_index: phy_index as u32 })
182            })
183            .collect::<Result<Vec<Fence>, Box<dyn Error>>>().into_iter()
184            .flatten()
185        })
186        .collect())
187    }
188
189    fn create_device(
190        instance: &ash::Instance,
191        pdevice: vk::PhysicalDevice,
192    ) -> Result<ash::Device, vk::Result> {
193
194        let features = vk::PhysicalDeviceFeatures {
195            ..Default::default()
196        };
197
198        let mut variable_pointers = vk::PhysicalDeviceVariablePointersFeatures::builder()
199            .variable_pointers(true)
200            .variable_pointers_storage_buffer(true)
201            .build();
202
203        let mut ext_names: Vec<_> = vec![
204            EXT_VARIABLE_POINTERS,
205            EXT_GET_MEMORY_REQUIREMENTS2,
206            EXT_DEDICATED_ALLOCATION,
207        ];
208
209        if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") {
210            ext_names.push(EXT_PORTABILITY_SUBSET);
211        }
212
213        // See: https://github.com/MaikKlein/ash/issues/539
214        let priorities = unsafe { Self::queue_infos(instance, pdevice) }.into_iter().map(|f| f.1).collect::<Vec<_>>();
215        let queue_create_infos = unsafe { Self::queue_infos(instance, pdevice) }.into_iter().enumerate().map(|(idx, (phy_index, _))| {
216            vk::DeviceQueueCreateInfo::builder()
217                .queue_family_index(phy_index as u32)
218                .queue_priorities(&priorities[idx])
219                .build()
220        })
221        .collect::<Vec<vk::DeviceQueueCreateInfo>>();
222
223        let device_info = vk::DeviceCreateInfo::builder()
224            .queue_create_infos(&queue_create_infos)
225            .enabled_extension_names(&ext_names)
226            .enabled_features(&features)
227            .push_next(&mut variable_pointers);
228
229        unsafe { instance.create_device(pdevice, &device_info, None) }
230    }
231
232    pub fn load_shader<R: std::io::Read + std::io::Seek>(
233        &self,
234        x: &mut R,
235        specializations: Option<Vec<Vec<u8>>>,
236    ) -> Result<Shader<'_>, Box<dyn Error>> {
237        let binary = ash::util::read_spv(x)?;
238        let bindings = Shader::module(&binary).map(|module| Shader::descriptor_set_layout_bindings(Shader::binding_count(&module)))?;
239        match &self.compute {
240            Some(c) => {
241                let shaders = c.iter()
242                    .enumerate()
243                    .map(|(idx, f)| {
244                        match &specializations {
245                            Some(specs) => {
246                                let maps = (0..specs.len())
247                                    .into_iter()
248                                    .map(|id| {
249                                        vk::SpecializationMapEntry::builder()
250                                            .constant_id(id as u32)
251                                            .offset(0)
252                                            .size(1)
253                                            .build()
254                                    })
255                                    .collect::<Vec<_>>();
256                                let spec = vk::SpecializationInfo::builder()
257                                    .data(specs.get(idx).unwrap())
258                                    .map_entries(&maps);
259                                Shader::create(&f.device, &bindings, &binary, spec)
260                            },
261                            None => {
262                                let spec = vk::SpecializationInfo::builder();
263                                Shader::create(&f.device, &bindings, &binary, spec)
264                            },
265                        }
266                    })
267                    .collect::<Result<Vec<Shader<'_>>, Box<dyn Error>>>()?;
268                match shaders.into_iter().next() {
269                    Some(s) => Ok(s),
270                    None => Err("No compute capable devices".to_string().into()),
271                }
272            }
273            None => Err("No compute capable devices".to_string().into()),
274        }
275    }
276
277    pub fn compute<T: std::marker::Sync>(
278        &self,
279        input: &[Vec<Vec<T>>],
280        output: &mut [T],
281        shader: &Shader<'_>,
282    ) -> Result<(), Box<dyn Error + Send + Sync>> {
283        match &self.compute {
284            Some(c) => c.first().unwrap().execute(input, output, shader),
285            None => Err("No compute capable devices".to_string().into()),
286        }
287    }
288
289    pub fn threads(
290        &self
291    ) -> usize {
292        match &self.compute {
293            Some(c) => c.iter().map(|d| d.fences.len()).sum::<usize>(),
294            None => 0,
295        }
296    }
297}
298
299impl fmt::Display for Vulkan {
300    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
301        writeln!(f, "cpu_logical_cores: {}", std::thread::available_parallelism().unwrap().get());
302        let pdevices = unsafe { self.instance.enumerate_physical_devices().unwrap() };
303        writeln!(f, "f32_size: {}", std::mem::size_of::<f32>());
304        writeln!(f, "gpu_device_count: {}", pdevices.len());
305        pdevices.into_iter()
306            .filter(|pdevice| {
307                let (_, properties) = Self::device_name(&self.instance, *pdevice);
308                let sp = Self::subgroup_properties(&self.instance, *pdevice);
309                properties.device_type.ne(&vk::PhysicalDeviceType::CPU)
310                && sp.supported_stages.contains(vk::ShaderStageFlags::COMPUTE)
311            })
312            .for_each(|pdevice| {
313
314                let (device_name, properties) = Self::device_name(&self.instance, pdevice);
315                writeln!(f, "name: {}", device_name);
316                writeln!(f, "type: {:?}", properties.device_type);
317
318                let queue_infos = unsafe { Self::queue_infos(&self.instance, pdevice) };
319                writeln!(f, "queue_size: {:?}", queue_infos.len());
320                let queue_str = queue_infos.iter()
321                    .map(|f| {
322                        format!("{} {}", f.0, f.1.len())
323                    })
324                    .collect::<Vec<String>>();
325                writeln!(f, "queues: {:?}", queue_str);
326                let memory = unsafe { self.instance.get_physical_device_memory_properties(pdevice) };
327
328                let sp = Self::subgroup_properties(&self.instance, pdevice);
329                writeln!(f, "subgroup_size: {:?}", sp.subgroup_size);
330                writeln!(f, "subgroup_operations: {:?}", sp.supported_operations);
331
332                writeln!(f, "memory_heap_count: {}", memory.memory_heap_count);
333                let mem_str = memory.memory_heaps.iter()
334                    .filter(|mh| mh.size.ne(&0))
335                    .enumerate()
336                    .map(|(idx, mh)| {
337                        format!("{} {}", idx, mh.size / 1_073_741_824)
338                    })
339                    .collect::<Vec<String>>();
340                writeln!(f, "memory_heaps: {:?}", mem_str);
341
342            });
343        write!(f, "")
344    }
345}
346
347impl Drop for Vulkan {
348    fn drop(&mut self) {
349        self.compute = None;
350        self.debug_layer = None;
351        unsafe { self.instance.destroy_instance(None) }
352    }
353}
354
355pub enum DebugOption {
356    None,
357    Validation,
358    Debug,
359}
360
361struct DebugLayer {
362    loader: ash::extensions::ext::DebugUtils,
363    messenger: vk::DebugUtilsMessengerEXT,
364}
365
366impl DebugLayer {
367    extern "system" fn callback(
368        message_severity: vk::DebugUtilsMessageSeverityFlagsEXT,
369        message_type: vk::DebugUtilsMessageTypeFlagsEXT,
370        p_callback_data: *const vk::DebugUtilsMessengerCallbackDataEXT,
371        _p_user_data: *mut std::ffi::c_void,
372    ) -> vk::Bool32 {
373        let message = unsafe { std::ffi::CStr::from_ptr((*p_callback_data).p_message) };
374        let severity = format!("{:?}", message_severity).to_lowercase();
375        let ty = format!("{:?}", message_type).to_lowercase();
376        println!("[Debug][{}][{}] {:?}", severity, ty, message);
377        vk::FALSE
378    }
379}
380
381impl Drop for DebugLayer {
382    fn drop(&mut self) {
383        unsafe { self.loader.destroy_debug_utils_messenger(self.messenger, None) }
384    }
385}
386
387pub struct Shader<'a> {
388    module: vk::ShaderModule,
389    pipeline_layout: vk::PipelineLayout,
390    pipeline: vk::Pipeline,
391    set_layouts: Vec<vk::DescriptorSetLayout>,
392    binding_count: u32,
393
394    device: &'a ash::Device,
395}
396
397impl <'a> Shader<'_> {
398
399    fn module(
400        binary: &[u32]
401    ) -> Result<rspirv::dr::Module, Box<dyn Error>> {
402        let mut loader = rspirv::dr::Loader::new();
403        rspirv::binary::parse_words(binary, &mut loader)?;
404        Ok(loader.module())
405    }
406
407    fn binding_count(
408        module: &rspirv::dr::Module
409    ) -> usize {
410        module.annotations.iter()
411            .flat_map(|f| f.operands.clone())
412            .filter(|op| op.eq(&rspirv::dr::Operand::Decoration(rspirv::spirv::Decoration::Binding)))
413            .count()
414    }
415
416    fn specialization(
417        module: &rspirv::dr::Module
418    ) -> usize {
419        module.annotations.iter()
420            .flat_map(|f| f.operands.clone())
421            .filter(|op| op.eq(&rspirv::dr::Operand::Decoration(rspirv::spirv::Decoration::SpecId)))
422            .count()
423    }
424
425    fn descriptor_set_layout_bindings(
426        binding_count: usize
427    ) -> Vec<vk::DescriptorSetLayoutBinding> {
428        (0..binding_count).into_iter().map(|i|
429            vk::DescriptorSetLayoutBinding::builder()
430                .binding(i as u32)
431                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
432                .descriptor_count(1)
433                .stage_flags(vk::ShaderStageFlags::COMPUTE)
434                .build()
435        )
436        .collect()
437    }
438
439    fn create(
440        device: &'a ash::Device,
441        bindings: &[vk::DescriptorSetLayoutBinding],
442        binary: &[u32],
443        spec: vk::SpecializationInfoBuilder,
444    ) -> Result<Shader<'a>, Box<dyn Error>> {
445        let set_layouts = unsafe { device.create_descriptor_set_layout(
446            &vk::DescriptorSetLayoutCreateInfo::builder().bindings(bindings),
447            None,
448        ) }.map(|set_layout| vec![set_layout])?;
449
450        let pipeline_layout = unsafe { device.create_pipeline_layout(
451            &vk::PipelineLayoutCreateInfo::builder().set_layouts(&set_layouts),
452            None,
453        )? };
454
455        let module = unsafe { device.create_shader_module(&vk::ShaderModuleCreateInfo::builder().code(binary), None)? };
456        let stage = vk::PipelineShaderStageCreateInfo::builder()
457            // According to https://raphlinus.github.io/gpu/2020/04/30/prefix-sum.html
458            // "Another problem is querying the subgroup size from inside the kernel, which has a
459            // surprising gotcha. Unless the VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT_EXT
460            // flag is set at pipeline creation time, the gl_SubgroupSize variable is defined to have
461            // the value from VkPhysicalDeviceSubgroupProperties, which in my experiment is always 32 on
462            // Intel no matter the actual subgroup size. But setting that flag makes it give the value expected."
463            .flags(vk::PipelineShaderStageCreateFlags::ALLOW_VARYING_SUBGROUP_SIZE_EXT|vk::PipelineShaderStageCreateFlags::REQUIRE_FULL_SUBGROUPS_EXT)
464            .module(module)
465            .name(std::ffi::CStr::from_bytes_with_nul(b"main\0")?)
466            .specialization_info(&spec)
467            .stage(vk::ShaderStageFlags::COMPUTE);
468
469        let pipeline = unsafe { device.create_compute_pipelines(
470            vk::PipelineCache::null(),
471            &[vk::ComputePipelineCreateInfo::builder().stage(stage.build()).layout(pipeline_layout).build()],
472            None,
473        ) }.map(|pipelines| pipelines[0]).map_err(|(_, err)| err)?;
474
475        Ok(Shader{module, pipeline_layout, pipeline, set_layouts, device, binding_count: bindings.len() as u32})
476    }
477}
478
479impl <'a> Drop for Shader<'a> {
480    fn drop(&mut self) {
481        unsafe { self.device.destroy_pipeline_layout(self.pipeline_layout, None) };
482        unsafe { self.device.destroy_shader_module(self.module, None) };
483        for set_layout in self.set_layouts.iter().copied() {
484            unsafe { self.device.destroy_descriptor_set_layout(set_layout, None) };
485        }
486        unsafe { self.device.destroy_pipeline(self.pipeline, None) };
487    }
488}
489
490struct Fence {
491    fence: vk::Fence,
492    present_queue: vk::Queue,
493    phy_index: u32,
494}
495
496struct Compute {
497    device: ash::Device,
498    allocator: Option<RwLock<Allocator>>,
499    fences: Vec<Fence>,
500    memory: vk::PhysicalDeviceMemoryProperties,
501}
502
503impl fmt::Debug for Compute {
504    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505
506        println!("Memory types: {}", self.memory.memory_type_count);
507        self.memory.memory_types.iter()
508            .filter(|mt| !mt.property_flags.is_empty())
509            .enumerate()
510            .for_each(|(idx, mt)| {
511                println!("Index {} {:?} (heap {})", idx, mt.property_flags, mt.heap_index);
512            });
513
514        println!("Memory heaps: {}", self.memory.memory_heap_count);
515        self.memory.memory_heaps.iter()
516            .filter(|mh| mh.size.ne(&0))
517            .enumerate()
518            .for_each(|(idx, mh)| {
519                println!("{:?} GiB {:?} (heap {})", mh.size / 1_073_741_824, mh.flags, idx);
520            });
521
522        f.write_fmt(format_args!("  Found {} compute core(s) with {} total of thread(s)", self.cores().len(), self.fences.len()))
523    }
524}
525
526impl Compute {
527
528    fn cores(
529        &self
530    ) -> Vec<u32> {
531        self.fences.iter().fold(vec![], |mut acc, f| {
532            if !acc.contains(&f.phy_index) {
533                acc.push(f.phy_index);
534            }
535            acc
536        })
537    }
538
539    // task must return a result vector to avoid
540    // rust ownership system to delete it before
541    // it is used by vulkan
542    fn task<T>(
543        &self,
544        descriptor_set: vk::DescriptorSet,
545        command_buffer: vk::CommandBuffer,
546        shader: &Shader<'_>,
547        output: vk::Buffer,
548        input: &[Vec<T>],
549        memory_mapping: (vk::DeviceSize, vk::DeviceSize),
550    ) -> Result<Vec<Buffer<'_, '_>>, Box<dyn Error + Send + Sync>> {
551
552        let input_buffers = input.iter().map(|data| {
553            let buffer = Buffer::new(
554                "cpu input",
555                &self.device,
556                &self.allocator,
557                (data.len() * std::mem::size_of::<T>()) as vk::DeviceSize,
558                vk::BufferUsageFlags::TRANSFER_DST | vk::BufferUsageFlags::TRANSFER_SRC | vk::BufferUsageFlags::STORAGE_BUFFER,
559                gpu_allocator::MemoryLocation::CpuToGpu,
560                &self.cores(),
561            )?.fill(data)?;
562            Ok(buffer)
563        })
564        .collect::<Result<Vec<Buffer<'_, '_>>, Box<dyn Error + Send + Sync>>>()?;
565
566        let buffer_infos = (0..=input_buffers.len()).into_iter()
567            .map(|f| match f {
568                0 => [vk::DescriptorBufferInfo::builder()
569                    .buffer(output)
570                    .offset(memory_mapping.0)
571                    .range(memory_mapping.1)
572                    .build()],
573                _ => [vk::DescriptorBufferInfo::builder()
574                    .buffer(input_buffers[f-1].buffer)
575                    .offset(0)
576                    .range(vk::WHOLE_SIZE)
577                    .build()],
578            })
579            .collect::<Vec<[vk::DescriptorBufferInfo; 1]>>();
580
581        let wds = buffer_infos.iter().enumerate()
582            .map(|(i, buf)| {
583                vk::WriteDescriptorSet::builder()
584                    .dst_set(descriptor_set)
585                    .dst_binding(i as u32)
586                    .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
587                    .buffer_info(buf)
588                    .build()
589            })
590            .collect::<Vec<vk::WriteDescriptorSet>>();
591
592        unsafe {
593            self.device.update_descriptor_sets(&wds, &[]);
594            self.device.begin_command_buffer(command_buffer, &vk::CommandBufferBeginInfo::builder().flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT))?;
595            self.device.cmd_bind_pipeline(command_buffer, vk::PipelineBindPoint::COMPUTE, shader.pipeline);
596            self.device.cmd_bind_descriptor_sets(command_buffer, vk::PipelineBindPoint::COMPUTE, shader.pipeline_layout, 0, &[descriptor_set], &[]);
597            self.device.cmd_dispatch(command_buffer, 1024, 1, 1);
598            self.device.end_command_buffer(command_buffer)?;
599        }
600
601        Ok(input_buffers)
602    }
603
604    fn execute<T: std::marker::Sync>(
605        &self,
606        input: &[Vec<Vec<T>>],
607        output: &mut [T],
608        shader: &Shader<'_>,
609    ) -> Result<(), Box<dyn Error + Send + Sync>> {
610
611        let output_buffer = Buffer::new(
612            "output buffer",
613            &self.device,
614            &self.allocator,
615            (output.len() * std::mem::size_of::<T>()) as vk::DeviceSize,
616            vk::BufferUsageFlags::TRANSFER_DST | vk::BufferUsageFlags::TRANSFER_SRC | vk::BufferUsageFlags::STORAGE_BUFFER,
617            gpu_allocator::MemoryLocation::GpuToCpu,
618            &self.cores(),
619        )?;
620        let output_chunk_size = output_buffer.device_size / input.len() as vk::DeviceSize;
621
622        let command = Command::new(
623            0,
624            shader.binding_count,
625            &shader.set_layouts,
626            input.len() as u32,
627            &self.device,
628        )?;
629
630        command.descriptor_sets.iter()
631            .zip(command.command_buffers.iter())
632            .zip(input.iter())
633            .enumerate().map(|(idx, cmd)| {
634                self.task(
635                    *cmd.0.0,
636                    *cmd.0.1,
637                    shader,
638                    output_buffer.buffer,
639                    cmd.1,
640                    (output_chunk_size * idx as vk::DeviceSize, output_chunk_size),
641                )
642        })
643        .collect::<Result<Vec<_>, _>>()
644        .and_then(|_| unsafe {
645            let submits = [vk::SubmitInfo::builder().command_buffers(&command.command_buffers).build()];
646            self.device.queue_submit(self.fences[0].present_queue, &submits, self.fences[0].fence)?;
647            self.device.wait_for_fences(&[self.fences[0].fence], true, u64::MAX)?;
648            self.device.reset_fences(&[self.fences[0].fence])?;
649            Ok(())
650        })?;
651
652        let data_ptr = output_buffer.c_ptr.as_ptr().cast::<T>();
653        unsafe { data_ptr.copy_to_nonoverlapping(output.as_mut_ptr(), output.len()) };
654
655        Ok(())
656    }
657}
658
659impl Drop for Compute {
660    fn drop(&mut self) {
661        unsafe { self.device.device_wait_idle().unwrap() }
662        for fence in &self.fences {
663            unsafe { self.device.destroy_fence(fence.fence, None) }
664        }
665        self.allocator = None;
666        unsafe { self.device.destroy_device(None) }
667    }
668}
669
670struct Command<'a> {
671    descriptor_pool: vk::DescriptorPool,
672    command_pool: vk::CommandPool,
673    command_buffers: Vec<vk::CommandBuffer>,
674    descriptor_sets: Vec<vk::DescriptorSet>,
675
676    device: &'a ash::Device,
677}
678
679impl <'a> Command<'_> {
680
681    fn descriptor_pool(
682        device: &ash::Device,
683        descriptor_count: u32,
684        max_sets: u32,
685    ) -> Result<vk::DescriptorPool, vk::Result> {
686        let descriptor_pool_size = [vk::DescriptorPoolSize::builder()
687            .ty(vk::DescriptorType::STORAGE_BUFFER)
688            .descriptor_count(descriptor_count)
689            .build()];
690        let descriptor_pool_info = vk::DescriptorPoolCreateInfo::builder()
691            .max_sets(max_sets)
692            .pool_sizes(&descriptor_pool_size);
693        unsafe { device.create_descriptor_pool(&descriptor_pool_info, None) }
694    }
695
696    fn command_pool(
697        device: &ash::Device,
698        queue_family_index: u32,
699    ) -> Result<vk::CommandPool, vk::Result> {
700        let command_pool_info = vk::CommandPoolCreateInfo::builder()
701            .queue_family_index(queue_family_index);
702        unsafe { device.create_command_pool(&command_pool_info, None) }
703    }
704
705    fn allocate_command_buffers(
706        device: &ash::Device,
707        command_pool: vk::CommandPool,
708        command_buffer_count: u32,
709    ) -> Result<Vec<vk::CommandBuffer>, vk::Result> {
710        let command_buffers_info = vk::CommandBufferAllocateInfo::builder()
711            .command_buffer_count(command_buffer_count)
712            .command_pool(command_pool);
713        unsafe { device.allocate_command_buffers(&command_buffers_info) }
714    }
715
716    fn new(
717        queue_family_index: u32,
718        descriptor_count: u32,
719        set_layouts: &[vk::DescriptorSetLayout],
720        command_buffer_count: u32,
721        device: &'a ash::Device,
722    ) -> Result<Command<'a>, Box<dyn Error + Send + Sync>> {
723
724        let descriptor_pool = Command::descriptor_pool(device, descriptor_count, command_buffer_count)?;
725
726        let descriptor_set_info = vk::DescriptorSetAllocateInfo::builder()
727            .descriptor_pool(descriptor_pool)
728            .set_layouts(set_layouts);
729
730        let descriptor_sets = (0..command_buffer_count).into_iter().flat_map(|_| {
731            unsafe { device.allocate_descriptor_sets(&descriptor_set_info) }.map(|sets| sets[0])
732        }).collect();
733
734        let command_pool = Command::command_pool(device, queue_family_index)?;
735        let command_buffers = Command::allocate_command_buffers(device, command_pool, command_buffer_count)?;
736
737        Ok(Command { descriptor_pool, command_pool, command_buffers, descriptor_sets, device })
738    }
739}
740
741impl <'a> Drop for Command<'a> {
742    fn drop(&mut self) {
743        unsafe { self.device.destroy_command_pool(self.command_pool, None) };
744        unsafe { self.device.destroy_descriptor_pool(self.descriptor_pool, None) };
745    }
746}
747
748struct Buffer<'a, 'b>  {
749    buffer: vk::Buffer,
750    allocation: Option<Allocation>,
751    device_size: vk::DeviceSize,
752    c_ptr: std::ptr::NonNull<std::ffi::c_void>,
753
754    device: &'a ash::Device,
755    allocator: &'b Option<RwLock<Allocator>>,
756}
757
758impl <'a, 'b> Buffer<'_, '_> {
759
760    fn new(
761        name: &str,
762        device: &'a ash::Device,
763        allocator: &'b Option<RwLock<Allocator>>,
764        device_size: vk::DeviceSize,
765        usage: vk::BufferUsageFlags,
766        location: gpu_allocator::MemoryLocation,
767        queue_family_indices: &[u32],
768    ) -> Result<Buffer<'a, 'b>, Box<dyn Error + Send + Sync>> {
769        let create_info = vk::BufferCreateInfo::builder()
770            .size(device_size)
771            .usage(usage)
772            .sharing_mode(match queue_family_indices.len() {
773                1 => vk::SharingMode::EXCLUSIVE,
774                _ => vk::SharingMode::CONCURRENT,
775            })
776            .queue_family_indices(queue_family_indices);
777        let buffer = unsafe { device.create_buffer(&create_info, None)? };
778        let requirements = unsafe { device.get_buffer_memory_requirements(buffer) };
779        let mut malloc = allocator.as_ref().unwrap().write().unwrap();
780        let allocation = malloc.allocate(&AllocationCreateDesc {
781            name,
782            requirements,
783            location,
784            linear: true,
785        })?;
786        let c_ptr = allocation.mapped_ptr().unwrap();
787        unsafe { device.bind_buffer_memory(buffer, allocation.memory(), allocation.offset())? };
788        Ok(Buffer { buffer, allocation: Some(allocation), c_ptr, device_size, device, allocator })
789    }
790
791    fn fill<T: Sized>(
792        self,
793        data: &[T],
794    ) -> Result<Self, Box<dyn Error + Send + Sync>> {
795        let data_ptr = self.c_ptr.as_ptr().cast::<T>();
796        unsafe { data_ptr.copy_from_nonoverlapping(data.as_ptr(), data.len()) };
797        Ok(self)
798    }
799}
800
801impl <'a, 'b> Drop for Buffer<'a, 'b> {
802    fn drop(&mut self) {
803        let lock = self.allocator.as_ref().unwrap();
804        let mut malloc = lock.write().unwrap();
805        malloc.free(self.allocation.take().unwrap()).unwrap();
806        unsafe { self.device.destroy_buffer(self.buffer, None) };
807    }
808}