1mod lib_test;
2
3use std::{error::Error, fmt, sync::RwLock};
4
5use ash::vk;
6use gpu_allocator::vulkan::{Allocation, AllocationCreateDesc, Allocator, AllocatorCreateDesc};
7use rayon::prelude::*;
8
9const LAYER_VALIDATION: *const std::os::raw::c_char = concat!("VK_LAYER_KHRONOS_validation", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
10const LAYER_DEBUG: *const std::os::raw::c_char = concat!("VK_LAYER_LUNARG_api_dump", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
11
12const EXT_VARIABLE_POINTERS: *const std::os::raw::c_char = concat!("VK_KHR_variable_pointers", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
13const EXT_GET_MEMORY_REQUIREMENTS2: *const std::os::raw::c_char = concat!("VK_KHR_get_memory_requirements2", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
14const EXT_DEDICATED_ALLOCATION: *const std::os::raw::c_char = concat!("VK_KHR_dedicated_allocation", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
15const EXT_PORTABILITY_SUBSET: *const std::os::raw::c_char = concat!("VK_KHR_portability_subset", "\0") as *const str as *const [std::os::raw::c_char] as *const std::os::raw::c_char;
16
17pub fn new(
18 debug: DebugOption
19) -> Result<Vulkan, Box<dyn Error>> {
20 let vk = Vulkan::new(debug)?;
21 Ok(vk)
22}
23
24pub struct Vulkan {
25 entry: ash::Entry, instance: ash::Instance, debug_layer: Option<DebugLayer>,
28 compute: Option<Vec<Compute>>,
29}
30
31impl Vulkan {
32
33 fn new(
34 debug: DebugOption
35 ) -> Result<Self, Box<dyn Error>> {
36
37 let vk_layers = match debug {
38 DebugOption::None => vec![],
39 DebugOption::Validation => vec![LAYER_VALIDATION],
40 DebugOption::Debug => vec![LAYER_VALIDATION, LAYER_DEBUG],
41 };
42
43 let mut info = match debug {
44 DebugOption::None => vk::DebugUtilsMessengerCreateInfoEXT {
45 ..Default::default()
46 },
47 _ => {
48 vk::DebugUtilsMessengerCreateInfoEXT {
49 message_severity: vk::DebugUtilsMessageSeverityFlagsEXT::WARNING
50 | vk::DebugUtilsMessageSeverityFlagsEXT::VERBOSE
51 | vk::DebugUtilsMessageSeverityFlagsEXT::INFO
52 | vk::DebugUtilsMessageSeverityFlagsEXT::ERROR,
53 message_type: vk::DebugUtilsMessageTypeFlagsEXT::GENERAL
54 | vk::DebugUtilsMessageTypeFlagsEXT::PERFORMANCE
55 | vk::DebugUtilsMessageTypeFlagsEXT::VALIDATION,
56 pfn_user_callback: Some(DebugLayer::callback),
57 ..Default::default()
58 }
59 }
60 };
61
62 let entry = unsafe { ash::Entry::load()? };
63
64 let instance = unsafe {
65 entry.create_instance(&vk::InstanceCreateInfo::builder()
66 .push_next(&mut info)
67 .application_info(&vk::ApplicationInfo {
68 api_version: vk::make_api_version(0, 1, 2, 0),
69 engine_version: 0,
70 ..Default::default()
71 })
72 .enabled_layer_names(&vk_layers)
73 .enabled_extension_names(&[ash::extensions::ext::DebugUtils::name().as_ptr()])
74 , None)?
75 };
76
77 let debug_layer = match debug {
78 DebugOption::None => None,
79 _ => {
80 let loader = ash::extensions::ext::DebugUtils::new(&entry, &instance);
81 let messenger = unsafe { loader.create_debug_utils_messenger(&info, None)? };
82 Some(DebugLayer{loader, messenger})
83 },
84 };
85
86 let computes = Self::logical_devices(&instance)?;
87 let compute = match computes.len() {
88 0 => None,
89 _ => Some(computes),
90 };
91
92 Ok(Self{entry, instance, debug_layer, compute})
93 }
94
95 pub fn version(
96 &self
97 ) -> Result<(u32, u32, u32), Box<dyn Error>> {
98 match self.entry.try_enumerate_instance_version()? {
99 Some(v) => Ok((vk::api_version_major(v), vk::api_version_minor(v), vk::api_version_patch(v))),
100 None => Ok((vk::api_version_major(1), vk::api_version_minor(0), vk::api_version_patch(0))),
101 }
102 }
103
104 fn logical_devices(
105 instance: &ash::Instance,
106 ) -> Result<Vec<Compute>, Box<dyn Error>> {
107 let pdevices = unsafe { instance.enumerate_physical_devices()? };
108 Ok(pdevices.into_iter()
109 .filter(|pdevice| {
110 let (_, properties) = Self::device_name(instance, *pdevice);
111 let sp = Self::subgroup_properties(instance, *pdevice);
112 properties.device_type.ne(&vk::PhysicalDeviceType::CPU)
113 && sp.supported_stages.contains(vk::ShaderStageFlags::COMPUTE)
114 })
115 .map(|pdevice| {
116 let device = Self::create_device(instance, pdevice)?;
117 let queue_infos = unsafe { Self::queue_infos(instance, pdevice) };
118 let fences = Self::create_fences(&device, queue_infos)?;
119 let allocator = Allocator::new(&AllocatorCreateDesc {
120 physical_device: pdevice,
121 device: device.clone(),
122 instance: instance.clone(),
123 debug_settings: Default::default(),
124 buffer_device_address: false,
125 })?;
126 let memory = unsafe { instance.get_physical_device_memory_properties(pdevice) };
127
128 Ok(Compute { device, allocator: Some(RwLock::new(allocator)), fences, memory})
129
130 })
131 .collect::<Result<Vec<Compute>, Box<dyn Error>>>()?.into_iter()
132 .filter(|c| !c.fences.is_empty())
133 .collect::<Vec<Compute>>())
134 }
135
136 unsafe fn queue_infos(
137 instance: &ash::Instance,
138 pdevice: vk::PhysicalDevice,
139 ) -> Vec<(usize, Vec<f32>)> {
140 instance.get_physical_device_queue_family_properties(pdevice).iter().enumerate()
141 .filter(|(_, prop)| prop.queue_flags.contains(vk::QueueFlags::COMPUTE))
142 .map(|(idx, prop)| (idx, vec![1.0_f32; prop.queue_count as usize]))
143 .collect()
144 }
145
146 fn device_name(
147 instance: &ash::Instance,
148 pdevice: vk::PhysicalDevice,
149 ) -> (String, vk::PhysicalDeviceProperties) {
150 let mut dp2 = vk::PhysicalDeviceProperties2::builder().build();
151 unsafe { instance.fp_v1_1().get_physical_device_properties2(pdevice, &mut dp2) };
152 let device_name = dp2.properties.device_name.iter()
153 .filter_map(|f| match *f as u8 {
154 0 => None,
155 _ => Some(*f as u8 as char),
156 })
157 .collect::<String>();
158 (device_name, dp2.properties)
159 }
160
161 fn subgroup_properties(
162 instance: &ash::Instance,
163 pdevice: vk::PhysicalDevice,
164 ) -> vk::PhysicalDeviceSubgroupProperties {
165 let mut sp = vk::PhysicalDeviceSubgroupProperties::builder();
168 let mut dp2 = vk::PhysicalDeviceProperties2::builder().push_next(&mut sp).build();
169 unsafe { instance.fp_v1_1().get_physical_device_properties2(pdevice, &mut dp2); }
170 sp.build()
171 }
172
173 fn create_fences(
174 device: &ash::Device,
175 queue_infos: Vec<(usize, Vec<f32>)>,
176 ) -> Result<Vec<Fence>, Box<dyn Error>> {
177 Ok(queue_infos.into_iter().flat_map(|(phy_index, queue_priorities)| {
178 (0..queue_priorities.len()).into_iter().map(|queue_index| {
179 let vk_fence = unsafe { device.create_fence(&vk::FenceCreateInfo::default(), None)? };
180 let present_queue = unsafe { device.get_device_queue(phy_index as u32, queue_index as u32) };
181 Ok(Fence{ fence: vk_fence, present_queue, phy_index: phy_index as u32 })
182 })
183 .collect::<Result<Vec<Fence>, Box<dyn Error>>>().into_iter()
184 .flatten()
185 })
186 .collect())
187 }
188
189 fn create_device(
190 instance: &ash::Instance,
191 pdevice: vk::PhysicalDevice,
192 ) -> Result<ash::Device, vk::Result> {
193
194 let features = vk::PhysicalDeviceFeatures {
195 ..Default::default()
196 };
197
198 let mut variable_pointers = vk::PhysicalDeviceVariablePointersFeatures::builder()
199 .variable_pointers(true)
200 .variable_pointers_storage_buffer(true)
201 .build();
202
203 let mut ext_names: Vec<_> = vec![
204 EXT_VARIABLE_POINTERS,
205 EXT_GET_MEMORY_REQUIREMENTS2,
206 EXT_DEDICATED_ALLOCATION,
207 ];
208
209 if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") {
210 ext_names.push(EXT_PORTABILITY_SUBSET);
211 }
212
213 let priorities = unsafe { Self::queue_infos(instance, pdevice) }.into_iter().map(|f| f.1).collect::<Vec<_>>();
215 let queue_create_infos = unsafe { Self::queue_infos(instance, pdevice) }.into_iter().enumerate().map(|(idx, (phy_index, _))| {
216 vk::DeviceQueueCreateInfo::builder()
217 .queue_family_index(phy_index as u32)
218 .queue_priorities(&priorities[idx])
219 .build()
220 })
221 .collect::<Vec<vk::DeviceQueueCreateInfo>>();
222
223 let device_info = vk::DeviceCreateInfo::builder()
224 .queue_create_infos(&queue_create_infos)
225 .enabled_extension_names(&ext_names)
226 .enabled_features(&features)
227 .push_next(&mut variable_pointers);
228
229 unsafe { instance.create_device(pdevice, &device_info, None) }
230 }
231
232 pub fn load_shader<R: std::io::Read + std::io::Seek>(
233 &self,
234 x: &mut R,
235 specializations: Option<Vec<Vec<u8>>>,
236 ) -> Result<Shader<'_>, Box<dyn Error>> {
237 let binary = ash::util::read_spv(x)?;
238 let bindings = Shader::module(&binary).map(|module| Shader::descriptor_set_layout_bindings(Shader::binding_count(&module)))?;
239 match &self.compute {
240 Some(c) => {
241 let shaders = c.iter()
242 .enumerate()
243 .map(|(idx, f)| {
244 match &specializations {
245 Some(specs) => {
246 let maps = (0..specs.len())
247 .into_iter()
248 .map(|id| {
249 vk::SpecializationMapEntry::builder()
250 .constant_id(id as u32)
251 .offset(0)
252 .size(1)
253 .build()
254 })
255 .collect::<Vec<_>>();
256 let spec = vk::SpecializationInfo::builder()
257 .data(specs.get(idx).unwrap())
258 .map_entries(&maps);
259 Shader::create(&f.device, &bindings, &binary, spec)
260 },
261 None => {
262 let spec = vk::SpecializationInfo::builder();
263 Shader::create(&f.device, &bindings, &binary, spec)
264 },
265 }
266 })
267 .collect::<Result<Vec<Shader<'_>>, Box<dyn Error>>>()?;
268 match shaders.into_iter().next() {
269 Some(s) => Ok(s),
270 None => Err("No compute capable devices".to_string().into()),
271 }
272 }
273 None => Err("No compute capable devices".to_string().into()),
274 }
275 }
276
277 pub fn compute<T: std::marker::Sync>(
278 &self,
279 input: &[Vec<Vec<T>>],
280 output: &mut [T],
281 shader: &Shader<'_>,
282 ) -> Result<(), Box<dyn Error + Send + Sync>> {
283 match &self.compute {
284 Some(c) => c.first().unwrap().execute(input, output, shader),
285 None => Err("No compute capable devices".to_string().into()),
286 }
287 }
288
289 pub fn threads(
290 &self
291 ) -> usize {
292 match &self.compute {
293 Some(c) => c.iter().map(|d| d.fences.len()).sum::<usize>(),
294 None => 0,
295 }
296 }
297}
298
299impl fmt::Display for Vulkan {
300 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
301 writeln!(f, "cpu_logical_cores: {}", std::thread::available_parallelism().unwrap().get());
302 let pdevices = unsafe { self.instance.enumerate_physical_devices().unwrap() };
303 writeln!(f, "f32_size: {}", std::mem::size_of::<f32>());
304 writeln!(f, "gpu_device_count: {}", pdevices.len());
305 pdevices.into_iter()
306 .filter(|pdevice| {
307 let (_, properties) = Self::device_name(&self.instance, *pdevice);
308 let sp = Self::subgroup_properties(&self.instance, *pdevice);
309 properties.device_type.ne(&vk::PhysicalDeviceType::CPU)
310 && sp.supported_stages.contains(vk::ShaderStageFlags::COMPUTE)
311 })
312 .for_each(|pdevice| {
313
314 let (device_name, properties) = Self::device_name(&self.instance, pdevice);
315 writeln!(f, "name: {}", device_name);
316 writeln!(f, "type: {:?}", properties.device_type);
317
318 let queue_infos = unsafe { Self::queue_infos(&self.instance, pdevice) };
319 writeln!(f, "queue_size: {:?}", queue_infos.len());
320 let queue_str = queue_infos.iter()
321 .map(|f| {
322 format!("{} {}", f.0, f.1.len())
323 })
324 .collect::<Vec<String>>();
325 writeln!(f, "queues: {:?}", queue_str);
326 let memory = unsafe { self.instance.get_physical_device_memory_properties(pdevice) };
327
328 let sp = Self::subgroup_properties(&self.instance, pdevice);
329 writeln!(f, "subgroup_size: {:?}", sp.subgroup_size);
330 writeln!(f, "subgroup_operations: {:?}", sp.supported_operations);
331
332 writeln!(f, "memory_heap_count: {}", memory.memory_heap_count);
333 let mem_str = memory.memory_heaps.iter()
334 .filter(|mh| mh.size.ne(&0))
335 .enumerate()
336 .map(|(idx, mh)| {
337 format!("{} {}", idx, mh.size / 1_073_741_824)
338 })
339 .collect::<Vec<String>>();
340 writeln!(f, "memory_heaps: {:?}", mem_str);
341
342 });
343 write!(f, "")
344 }
345}
346
347impl Drop for Vulkan {
348 fn drop(&mut self) {
349 self.compute = None;
350 self.debug_layer = None;
351 unsafe { self.instance.destroy_instance(None) }
352 }
353}
354
355pub enum DebugOption {
356 None,
357 Validation,
358 Debug,
359}
360
361struct DebugLayer {
362 loader: ash::extensions::ext::DebugUtils,
363 messenger: vk::DebugUtilsMessengerEXT,
364}
365
366impl DebugLayer {
367 extern "system" fn callback(
368 message_severity: vk::DebugUtilsMessageSeverityFlagsEXT,
369 message_type: vk::DebugUtilsMessageTypeFlagsEXT,
370 p_callback_data: *const vk::DebugUtilsMessengerCallbackDataEXT,
371 _p_user_data: *mut std::ffi::c_void,
372 ) -> vk::Bool32 {
373 let message = unsafe { std::ffi::CStr::from_ptr((*p_callback_data).p_message) };
374 let severity = format!("{:?}", message_severity).to_lowercase();
375 let ty = format!("{:?}", message_type).to_lowercase();
376 println!("[Debug][{}][{}] {:?}", severity, ty, message);
377 vk::FALSE
378 }
379}
380
381impl Drop for DebugLayer {
382 fn drop(&mut self) {
383 unsafe { self.loader.destroy_debug_utils_messenger(self.messenger, None) }
384 }
385}
386
387pub struct Shader<'a> {
388 module: vk::ShaderModule,
389 pipeline_layout: vk::PipelineLayout,
390 pipeline: vk::Pipeline,
391 set_layouts: Vec<vk::DescriptorSetLayout>,
392 binding_count: u32,
393
394 device: &'a ash::Device,
395}
396
397impl <'a> Shader<'_> {
398
399 fn module(
400 binary: &[u32]
401 ) -> Result<rspirv::dr::Module, Box<dyn Error>> {
402 let mut loader = rspirv::dr::Loader::new();
403 rspirv::binary::parse_words(binary, &mut loader)?;
404 Ok(loader.module())
405 }
406
407 fn binding_count(
408 module: &rspirv::dr::Module
409 ) -> usize {
410 module.annotations.iter()
411 .flat_map(|f| f.operands.clone())
412 .filter(|op| op.eq(&rspirv::dr::Operand::Decoration(rspirv::spirv::Decoration::Binding)))
413 .count()
414 }
415
416 fn specialization(
417 module: &rspirv::dr::Module
418 ) -> usize {
419 module.annotations.iter()
420 .flat_map(|f| f.operands.clone())
421 .filter(|op| op.eq(&rspirv::dr::Operand::Decoration(rspirv::spirv::Decoration::SpecId)))
422 .count()
423 }
424
425 fn descriptor_set_layout_bindings(
426 binding_count: usize
427 ) -> Vec<vk::DescriptorSetLayoutBinding> {
428 (0..binding_count).into_iter().map(|i|
429 vk::DescriptorSetLayoutBinding::builder()
430 .binding(i as u32)
431 .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
432 .descriptor_count(1)
433 .stage_flags(vk::ShaderStageFlags::COMPUTE)
434 .build()
435 )
436 .collect()
437 }
438
439 fn create(
440 device: &'a ash::Device,
441 bindings: &[vk::DescriptorSetLayoutBinding],
442 binary: &[u32],
443 spec: vk::SpecializationInfoBuilder,
444 ) -> Result<Shader<'a>, Box<dyn Error>> {
445 let set_layouts = unsafe { device.create_descriptor_set_layout(
446 &vk::DescriptorSetLayoutCreateInfo::builder().bindings(bindings),
447 None,
448 ) }.map(|set_layout| vec![set_layout])?;
449
450 let pipeline_layout = unsafe { device.create_pipeline_layout(
451 &vk::PipelineLayoutCreateInfo::builder().set_layouts(&set_layouts),
452 None,
453 )? };
454
455 let module = unsafe { device.create_shader_module(&vk::ShaderModuleCreateInfo::builder().code(binary), None)? };
456 let stage = vk::PipelineShaderStageCreateInfo::builder()
457 .flags(vk::PipelineShaderStageCreateFlags::ALLOW_VARYING_SUBGROUP_SIZE_EXT|vk::PipelineShaderStageCreateFlags::REQUIRE_FULL_SUBGROUPS_EXT)
464 .module(module)
465 .name(std::ffi::CStr::from_bytes_with_nul(b"main\0")?)
466 .specialization_info(&spec)
467 .stage(vk::ShaderStageFlags::COMPUTE);
468
469 let pipeline = unsafe { device.create_compute_pipelines(
470 vk::PipelineCache::null(),
471 &[vk::ComputePipelineCreateInfo::builder().stage(stage.build()).layout(pipeline_layout).build()],
472 None,
473 ) }.map(|pipelines| pipelines[0]).map_err(|(_, err)| err)?;
474
475 Ok(Shader{module, pipeline_layout, pipeline, set_layouts, device, binding_count: bindings.len() as u32})
476 }
477}
478
479impl <'a> Drop for Shader<'a> {
480 fn drop(&mut self) {
481 unsafe { self.device.destroy_pipeline_layout(self.pipeline_layout, None) };
482 unsafe { self.device.destroy_shader_module(self.module, None) };
483 for set_layout in self.set_layouts.iter().copied() {
484 unsafe { self.device.destroy_descriptor_set_layout(set_layout, None) };
485 }
486 unsafe { self.device.destroy_pipeline(self.pipeline, None) };
487 }
488}
489
490struct Fence {
491 fence: vk::Fence,
492 present_queue: vk::Queue,
493 phy_index: u32,
494}
495
496struct Compute {
497 device: ash::Device,
498 allocator: Option<RwLock<Allocator>>,
499 fences: Vec<Fence>,
500 memory: vk::PhysicalDeviceMemoryProperties,
501}
502
503impl fmt::Debug for Compute {
504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505
506 println!("Memory types: {}", self.memory.memory_type_count);
507 self.memory.memory_types.iter()
508 .filter(|mt| !mt.property_flags.is_empty())
509 .enumerate()
510 .for_each(|(idx, mt)| {
511 println!("Index {} {:?} (heap {})", idx, mt.property_flags, mt.heap_index);
512 });
513
514 println!("Memory heaps: {}", self.memory.memory_heap_count);
515 self.memory.memory_heaps.iter()
516 .filter(|mh| mh.size.ne(&0))
517 .enumerate()
518 .for_each(|(idx, mh)| {
519 println!("{:?} GiB {:?} (heap {})", mh.size / 1_073_741_824, mh.flags, idx);
520 });
521
522 f.write_fmt(format_args!(" Found {} compute core(s) with {} total of thread(s)", self.cores().len(), self.fences.len()))
523 }
524}
525
526impl Compute {
527
528 fn cores(
529 &self
530 ) -> Vec<u32> {
531 self.fences.iter().fold(vec![], |mut acc, f| {
532 if !acc.contains(&f.phy_index) {
533 acc.push(f.phy_index);
534 }
535 acc
536 })
537 }
538
539 fn task<T>(
543 &self,
544 descriptor_set: vk::DescriptorSet,
545 command_buffer: vk::CommandBuffer,
546 shader: &Shader<'_>,
547 output: vk::Buffer,
548 input: &[Vec<T>],
549 memory_mapping: (vk::DeviceSize, vk::DeviceSize),
550 ) -> Result<Vec<Buffer<'_, '_>>, Box<dyn Error + Send + Sync>> {
551
552 let input_buffers = input.iter().map(|data| {
553 let buffer = Buffer::new(
554 "cpu input",
555 &self.device,
556 &self.allocator,
557 (data.len() * std::mem::size_of::<T>()) as vk::DeviceSize,
558 vk::BufferUsageFlags::TRANSFER_DST | vk::BufferUsageFlags::TRANSFER_SRC | vk::BufferUsageFlags::STORAGE_BUFFER,
559 gpu_allocator::MemoryLocation::CpuToGpu,
560 &self.cores(),
561 )?.fill(data)?;
562 Ok(buffer)
563 })
564 .collect::<Result<Vec<Buffer<'_, '_>>, Box<dyn Error + Send + Sync>>>()?;
565
566 let buffer_infos = (0..=input_buffers.len()).into_iter()
567 .map(|f| match f {
568 0 => [vk::DescriptorBufferInfo::builder()
569 .buffer(output)
570 .offset(memory_mapping.0)
571 .range(memory_mapping.1)
572 .build()],
573 _ => [vk::DescriptorBufferInfo::builder()
574 .buffer(input_buffers[f-1].buffer)
575 .offset(0)
576 .range(vk::WHOLE_SIZE)
577 .build()],
578 })
579 .collect::<Vec<[vk::DescriptorBufferInfo; 1]>>();
580
581 let wds = buffer_infos.iter().enumerate()
582 .map(|(i, buf)| {
583 vk::WriteDescriptorSet::builder()
584 .dst_set(descriptor_set)
585 .dst_binding(i as u32)
586 .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
587 .buffer_info(buf)
588 .build()
589 })
590 .collect::<Vec<vk::WriteDescriptorSet>>();
591
592 unsafe {
593 self.device.update_descriptor_sets(&wds, &[]);
594 self.device.begin_command_buffer(command_buffer, &vk::CommandBufferBeginInfo::builder().flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT))?;
595 self.device.cmd_bind_pipeline(command_buffer, vk::PipelineBindPoint::COMPUTE, shader.pipeline);
596 self.device.cmd_bind_descriptor_sets(command_buffer, vk::PipelineBindPoint::COMPUTE, shader.pipeline_layout, 0, &[descriptor_set], &[]);
597 self.device.cmd_dispatch(command_buffer, 1024, 1, 1);
598 self.device.end_command_buffer(command_buffer)?;
599 }
600
601 Ok(input_buffers)
602 }
603
604 fn execute<T: std::marker::Sync>(
605 &self,
606 input: &[Vec<Vec<T>>],
607 output: &mut [T],
608 shader: &Shader<'_>,
609 ) -> Result<(), Box<dyn Error + Send + Sync>> {
610
611 let output_buffer = Buffer::new(
612 "output buffer",
613 &self.device,
614 &self.allocator,
615 (output.len() * std::mem::size_of::<T>()) as vk::DeviceSize,
616 vk::BufferUsageFlags::TRANSFER_DST | vk::BufferUsageFlags::TRANSFER_SRC | vk::BufferUsageFlags::STORAGE_BUFFER,
617 gpu_allocator::MemoryLocation::GpuToCpu,
618 &self.cores(),
619 )?;
620 let output_chunk_size = output_buffer.device_size / input.len() as vk::DeviceSize;
621
622 let command = Command::new(
623 0,
624 shader.binding_count,
625 &shader.set_layouts,
626 input.len() as u32,
627 &self.device,
628 )?;
629
630 command.descriptor_sets.iter()
631 .zip(command.command_buffers.iter())
632 .zip(input.iter())
633 .enumerate().map(|(idx, cmd)| {
634 self.task(
635 *cmd.0.0,
636 *cmd.0.1,
637 shader,
638 output_buffer.buffer,
639 cmd.1,
640 (output_chunk_size * idx as vk::DeviceSize, output_chunk_size),
641 )
642 })
643 .collect::<Result<Vec<_>, _>>()
644 .and_then(|_| unsafe {
645 let submits = [vk::SubmitInfo::builder().command_buffers(&command.command_buffers).build()];
646 self.device.queue_submit(self.fences[0].present_queue, &submits, self.fences[0].fence)?;
647 self.device.wait_for_fences(&[self.fences[0].fence], true, u64::MAX)?;
648 self.device.reset_fences(&[self.fences[0].fence])?;
649 Ok(())
650 })?;
651
652 let data_ptr = output_buffer.c_ptr.as_ptr().cast::<T>();
653 unsafe { data_ptr.copy_to_nonoverlapping(output.as_mut_ptr(), output.len()) };
654
655 Ok(())
656 }
657}
658
659impl Drop for Compute {
660 fn drop(&mut self) {
661 unsafe { self.device.device_wait_idle().unwrap() }
662 for fence in &self.fences {
663 unsafe { self.device.destroy_fence(fence.fence, None) }
664 }
665 self.allocator = None;
666 unsafe { self.device.destroy_device(None) }
667 }
668}
669
670struct Command<'a> {
671 descriptor_pool: vk::DescriptorPool,
672 command_pool: vk::CommandPool,
673 command_buffers: Vec<vk::CommandBuffer>,
674 descriptor_sets: Vec<vk::DescriptorSet>,
675
676 device: &'a ash::Device,
677}
678
679impl <'a> Command<'_> {
680
681 fn descriptor_pool(
682 device: &ash::Device,
683 descriptor_count: u32,
684 max_sets: u32,
685 ) -> Result<vk::DescriptorPool, vk::Result> {
686 let descriptor_pool_size = [vk::DescriptorPoolSize::builder()
687 .ty(vk::DescriptorType::STORAGE_BUFFER)
688 .descriptor_count(descriptor_count)
689 .build()];
690 let descriptor_pool_info = vk::DescriptorPoolCreateInfo::builder()
691 .max_sets(max_sets)
692 .pool_sizes(&descriptor_pool_size);
693 unsafe { device.create_descriptor_pool(&descriptor_pool_info, None) }
694 }
695
696 fn command_pool(
697 device: &ash::Device,
698 queue_family_index: u32,
699 ) -> Result<vk::CommandPool, vk::Result> {
700 let command_pool_info = vk::CommandPoolCreateInfo::builder()
701 .queue_family_index(queue_family_index);
702 unsafe { device.create_command_pool(&command_pool_info, None) }
703 }
704
705 fn allocate_command_buffers(
706 device: &ash::Device,
707 command_pool: vk::CommandPool,
708 command_buffer_count: u32,
709 ) -> Result<Vec<vk::CommandBuffer>, vk::Result> {
710 let command_buffers_info = vk::CommandBufferAllocateInfo::builder()
711 .command_buffer_count(command_buffer_count)
712 .command_pool(command_pool);
713 unsafe { device.allocate_command_buffers(&command_buffers_info) }
714 }
715
716 fn new(
717 queue_family_index: u32,
718 descriptor_count: u32,
719 set_layouts: &[vk::DescriptorSetLayout],
720 command_buffer_count: u32,
721 device: &'a ash::Device,
722 ) -> Result<Command<'a>, Box<dyn Error + Send + Sync>> {
723
724 let descriptor_pool = Command::descriptor_pool(device, descriptor_count, command_buffer_count)?;
725
726 let descriptor_set_info = vk::DescriptorSetAllocateInfo::builder()
727 .descriptor_pool(descriptor_pool)
728 .set_layouts(set_layouts);
729
730 let descriptor_sets = (0..command_buffer_count).into_iter().flat_map(|_| {
731 unsafe { device.allocate_descriptor_sets(&descriptor_set_info) }.map(|sets| sets[0])
732 }).collect();
733
734 let command_pool = Command::command_pool(device, queue_family_index)?;
735 let command_buffers = Command::allocate_command_buffers(device, command_pool, command_buffer_count)?;
736
737 Ok(Command { descriptor_pool, command_pool, command_buffers, descriptor_sets, device })
738 }
739}
740
741impl <'a> Drop for Command<'a> {
742 fn drop(&mut self) {
743 unsafe { self.device.destroy_command_pool(self.command_pool, None) };
744 unsafe { self.device.destroy_descriptor_pool(self.descriptor_pool, None) };
745 }
746}
747
748struct Buffer<'a, 'b> {
749 buffer: vk::Buffer,
750 allocation: Option<Allocation>,
751 device_size: vk::DeviceSize,
752 c_ptr: std::ptr::NonNull<std::ffi::c_void>,
753
754 device: &'a ash::Device,
755 allocator: &'b Option<RwLock<Allocator>>,
756}
757
758impl <'a, 'b> Buffer<'_, '_> {
759
760 fn new(
761 name: &str,
762 device: &'a ash::Device,
763 allocator: &'b Option<RwLock<Allocator>>,
764 device_size: vk::DeviceSize,
765 usage: vk::BufferUsageFlags,
766 location: gpu_allocator::MemoryLocation,
767 queue_family_indices: &[u32],
768 ) -> Result<Buffer<'a, 'b>, Box<dyn Error + Send + Sync>> {
769 let create_info = vk::BufferCreateInfo::builder()
770 .size(device_size)
771 .usage(usage)
772 .sharing_mode(match queue_family_indices.len() {
773 1 => vk::SharingMode::EXCLUSIVE,
774 _ => vk::SharingMode::CONCURRENT,
775 })
776 .queue_family_indices(queue_family_indices);
777 let buffer = unsafe { device.create_buffer(&create_info, None)? };
778 let requirements = unsafe { device.get_buffer_memory_requirements(buffer) };
779 let mut malloc = allocator.as_ref().unwrap().write().unwrap();
780 let allocation = malloc.allocate(&AllocationCreateDesc {
781 name,
782 requirements,
783 location,
784 linear: true,
785 })?;
786 let c_ptr = allocation.mapped_ptr().unwrap();
787 unsafe { device.bind_buffer_memory(buffer, allocation.memory(), allocation.offset())? };
788 Ok(Buffer { buffer, allocation: Some(allocation), c_ptr, device_size, device, allocator })
789 }
790
791 fn fill<T: Sized>(
792 self,
793 data: &[T],
794 ) -> Result<Self, Box<dyn Error + Send + Sync>> {
795 let data_ptr = self.c_ptr.as_ptr().cast::<T>();
796 unsafe { data_ptr.copy_from_nonoverlapping(data.as_ptr(), data.len()) };
797 Ok(self)
798 }
799}
800
801impl <'a, 'b> Drop for Buffer<'a, 'b> {
802 fn drop(&mut self) {
803 let lock = self.allocator.as_ref().unwrap();
804 let mut malloc = lock.write().unwrap();
805 malloc.free(self.allocation.take().unwrap()).unwrap();
806 unsafe { self.device.destroy_buffer(self.buffer, None) };
807 }
808}