use super::nv_coop_matrix2::CoopMatrixNV2Capabilities;
use super::nv_coop_matrix2::query_cooperative_matrix_nv2_limits;
use crate::utils::dtype::{bool_to_vk_bool32, vk_to_onnx_dtype};
use crate::utils::error::VKMLError;
use onnx_extractor::DataType;
use std::any::Any;
use std::ffi::{CStr, CString, c_void};
use std::os::raw::c_char;
use std::ptr;
use vulkanalia::vk::InstanceV1_0;
use vulkanalia::{Instance, vk};
pub struct DeviceCreateExtras {
pub names: Vec<CString>,
pub name_ptrs: Vec<*const c_char>,
pub pnext: DevicePNext,
}
impl DeviceCreateExtras {
pub fn device_create_next(&self) -> *const std::ffi::c_void {
self.pnext.ptr as *const std::ffi::c_void
}
}
pub struct DevicePNext {
pub ptr: *mut std::ffi::c_void,
_holders: Vec<Box<dyn Any>>,
}
#[derive(Clone, Debug)]
pub struct CoopMatrixShape {
pub m: u32,
pub n: u32,
pub k: u32,
pub a_type: DataType,
pub b_type: DataType,
pub c_type: DataType,
pub result_type: DataType,
}
#[derive(Debug, Default)]
pub struct VkExtensions {
memory_budget: bool,
shader_float_16_int8: bool,
storage_16bit: bool,
shader_bfloat16: bool,
cooperative_matrix: Option<Vec<CoopMatrixShape>>,
cooperative_matrix_nv2: Option<CoopMatrixNV2Capabilities>,
}
impl VkExtensions {
pub const VK_KHR_COOPERATIVE_MATRIX: &'static str = "VK_KHR_cooperative_matrix";
pub const VK_NV_COOPERATIVE_MATRIX2: &'static str = "VK_NV_cooperative_matrix2";
pub const VK_EXT_MEMORY_BUDGET: &'static str = "VK_EXT_memory_budget";
pub const VK_KHR_SHADER_FLOAT16_INT8: &'static str = "VK_KHR_shader_float16_int8";
pub const VK_KHR_16BIT_STORAGE: &'static str = "VK_KHR_16bit_storage";
pub const VK_KHR_SHADER_BFLOAT16: &'static str = "VK_KHR_shader_bfloat16";
pub fn from_extension_properties(
instance: &Instance,
physical_device: vk::PhysicalDevice,
props: &[vk::ExtensionProperties],
) -> Result<Self, VKMLError> {
let mut res = Self::default();
for p in props {
let name_cstr = unsafe { CStr::from_ptr(p.extension_name.as_ptr()) };
let name = name_cstr.to_string_lossy();
match name.as_ref() {
Self::VK_KHR_COOPERATIVE_MATRIX => {
res.cooperative_matrix =
Self::query_cooperative_matrix_limits(instance, physical_device);
}
Self::VK_NV_COOPERATIVE_MATRIX2 => {
res.cooperative_matrix_nv2 =
query_cooperative_matrix_nv2_limits(instance, physical_device);
}
Self::VK_EXT_MEMORY_BUDGET => res.memory_budget = true,
Self::VK_KHR_SHADER_FLOAT16_INT8 => {
res.shader_float_16_int8 = true;
}
Self::VK_KHR_16BIT_STORAGE => {
res.storage_16bit = true;
}
Self::VK_KHR_SHADER_BFLOAT16 => {
res.shader_bfloat16 = true;
}
_ => {}
}
}
Ok(res)
}
fn query_cooperative_matrix_limits(
instance: &Instance,
physical_device: vk::PhysicalDevice,
) -> Option<Vec<CoopMatrixShape>> {
unsafe {
let fp = (*instance)
.commands()
.get_physical_device_cooperative_matrix_properties_khr;
let mut count: u32 = 0;
let result = fp(physical_device, &mut count, std::ptr::null_mut());
if result != vk::Result::SUCCESS || count == 0 {
return None;
}
let mut properties =
Vec::<vk::CooperativeMatrixPropertiesKHR>::with_capacity(count as usize);
let result = fp(physical_device, &mut count, properties.as_mut_ptr());
if result != vk::Result::SUCCESS {
return None;
}
properties.set_len(count as usize);
let shapes: Vec<CoopMatrixShape> = properties
.iter()
.map(|p| CoopMatrixShape {
m: p.m_size,
n: p.n_size,
k: p.k_size,
a_type: vk_to_onnx_dtype(p.a_type).unwrap_or(DataType::Undefined),
b_type: vk_to_onnx_dtype(p.b_type).unwrap_or(DataType::Undefined),
c_type: vk_to_onnx_dtype(p.c_type).unwrap_or(DataType::Undefined),
result_type: vk_to_onnx_dtype(p.result_type).unwrap_or(DataType::Undefined),
})
.collect();
Some(shapes)
}
}
pub fn enabled_extension_names(&self) -> Vec<CString> {
let mut v = Vec::new();
if self.cooperative_matrix.is_some() {
v.push(CString::new(Self::VK_KHR_COOPERATIVE_MATRIX).unwrap());
}
if self.cooperative_matrix_nv2.is_some() {
v.push(CString::new(Self::VK_NV_COOPERATIVE_MATRIX2).unwrap());
}
if self.memory_budget {
v.push(CString::new(Self::VK_EXT_MEMORY_BUDGET).unwrap());
}
if self.shader_bfloat16 {
v.push(CString::new(Self::VK_KHR_SHADER_BFLOAT16).unwrap());
}
v
}
pub fn prepare_device_create(&self) -> DeviceCreateExtras {
let names = self.enabled_extension_names();
let name_ptrs: Vec<*const c_char> =
names.iter().map(|s| s.as_ptr() as *const c_char).collect();
let mut holders: Vec<Box<dyn Any>> = Vec::new();
let mut head: *mut c_void = ptr::null_mut();
if self.cooperative_matrix.is_some() {
let mut feat = Box::new(vk::PhysicalDeviceCooperativeMatrixFeaturesKHR {
s_type: vk::StructureType::PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR,
next: ptr::null_mut(),
cooperative_matrix: vk::TRUE,
..Default::default()
});
feat.next = head as *mut _;
head = (&*feat) as *const _ as *mut c_void;
holders.push(feat);
}
if let Some(nv2) = &self.cooperative_matrix_nv2 {
let mut feat = Box::new(vk::PhysicalDeviceCooperativeMatrix2FeaturesNV {
s_type: vk::StructureType::PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV,
next: ptr::null_mut(),
cooperative_matrix_workgroup_scope: bool_to_vk_bool32(nv2.features.workgroup_scope),
cooperative_matrix_flexible_dimensions: bool_to_vk_bool32(
nv2.features.flexible_dimensions,
),
cooperative_matrix_reductions: bool_to_vk_bool32(nv2.features.reductions),
cooperative_matrix_conversions: bool_to_vk_bool32(nv2.features.conversions),
cooperative_matrix_per_element_operations: bool_to_vk_bool32(
nv2.features.per_element_operations,
),
cooperative_matrix_tensor_addressing: bool_to_vk_bool32(
nv2.features.tensor_addressing,
),
cooperative_matrix_block_loads: bool_to_vk_bool32(nv2.features.block_loads),
});
feat.next = head as *mut _;
head = (&*feat) as *const _ as *mut c_void;
holders.push(feat);
}
if self.shader_float_16_int8 {
let mut feat = Box::new(vk::PhysicalDeviceShaderFloat16Int8Features {
s_type: vk::StructureType::PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES,
next: ptr::null_mut(),
shader_float16: vk::TRUE,
shader_int8: vk::FALSE,
});
feat.next = head as *mut _;
head = (&*feat) as *const _ as *mut c_void;
holders.push(feat);
}
if self.storage_16bit {
let mut feat = Box::new(vk::PhysicalDevice16BitStorageFeatures {
s_type: vk::StructureType::PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES,
next: ptr::null_mut(),
storage_buffer_16bit_access: vk::TRUE,
uniform_and_storage_buffer_16bit_access: vk::TRUE,
storage_push_constant16: vk::FALSE,
storage_input_output16: vk::FALSE,
});
feat.next = head as *mut _;
head = (&*feat) as *const _ as *mut c_void;
holders.push(feat);
}
if self.shader_bfloat16 {
let mut feat = Box::new(vk::PhysicalDeviceShaderBfloat16FeaturesKHR {
s_type: vk::StructureType::PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR,
next: ptr::null_mut(),
shader_b_float16_type: vk::TRUE,
shader_b_float_16dot_product: vk::FALSE,
shader_b_float16_cooperative_matrix: vk::TRUE,
});
feat.next = head as *mut _;
head = (&*feat) as *const _ as *mut c_void;
holders.push(feat);
}
DeviceCreateExtras {
names,
name_ptrs,
pnext: DevicePNext {
ptr: head,
_holders: holders,
},
}
}
pub fn coop_matrix_shapes(&self) -> Option<Vec<CoopMatrixShape>> {
self.cooperative_matrix.clone()
}
pub fn coop_matrix_nv2_capabilities(&self) -> Option<CoopMatrixNV2Capabilities> {
self.cooperative_matrix_nv2.clone()
}
pub fn supports_fp16(&self) -> bool {
self.shader_float_16_int8 && self.storage_16bit
}
pub fn supports_bf16(&self) -> bool {
self.shader_bfloat16 && self.shader_float_16_int8
}
pub fn has_memory_budget(&self) -> bool {
self.memory_budget
}
pub fn get_coop_matrix_sizes(
&self,
a_type: DataType,
b_type: DataType,
c_type: DataType,
result_type: DataType,
) -> Option<Vec<CoopMatrixShape>> {
self.cooperative_matrix.as_ref().map(|shapes| {
shapes
.iter()
.filter(|shape| {
shape.a_type == a_type
&& shape.b_type == b_type
&& shape.c_type == c_type
&& shape.result_type == result_type
})
.cloned()
.collect()
})
}
}