use std::sync::Arc;
use ash::vk;
use super::device::VulkanContext;
use super::error::VulkanError;
impl VulkanError {
fn with_context(self, _io: std::io::Error) -> Self {
self
}
}
const STRIDE_COPY_SPV: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/stride_copy.spv"));
const STRIDE_COPY_LOCAL_SIZE_X: u32 = 64;
pub(crate) struct StrideCopyKernel {
ctx: Arc<VulkanContext>,
shader_module: vk::ShaderModule,
dsl: vk::DescriptorSetLayout,
pipeline_layout: vk::PipelineLayout,
pipeline: vk::Pipeline,
descriptor_pool: vk::DescriptorPool,
descriptor_set: vk::DescriptorSet,
}
#[repr(C)]
#[derive(Clone, Copy, Debug)]
struct PushBlock {
row_uints: u32,
src_stride_uints: u32,
dst_stride_uints: u32,
n_rows: u32,
}
unsafe impl bytemuck::Zeroable for PushBlock {}
unsafe impl bytemuck::Pod for PushBlock {}
impl StrideCopyKernel {
pub(crate) fn new(ctx: Arc<VulkanContext>) -> Result<Self, VulkanError> {
let code_u32 =
ash::util::read_spv(&mut std::io::Cursor::new(STRIDE_COPY_SPV)).map_err(|e| {
VulkanError::vk("read_spv", ash::vk::Result::ERROR_INITIALIZATION_FAILED)
.with_context(e)
})?;
let shader_module = unsafe {
let ci = vk::ShaderModuleCreateInfo::default().code(&code_u32);
ctx.device
.create_shader_module(&ci, None)
.map_err(|e| VulkanError::vk("create_shader_module", e))?
};
let dsl = unsafe {
let bindings = [
vk::DescriptorSetLayoutBinding::default()
.binding(0)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.descriptor_count(1)
.stage_flags(vk::ShaderStageFlags::COMPUTE),
vk::DescriptorSetLayoutBinding::default()
.binding(1)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.descriptor_count(1)
.stage_flags(vk::ShaderStageFlags::COMPUTE),
];
let ci = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
ctx.device
.create_descriptor_set_layout(&ci, None)
.map_err(|e| {
ctx.device.destroy_shader_module(shader_module, None);
VulkanError::vk("create_descriptor_set_layout", e)
})?
};
let pipeline_layout = unsafe {
let push_range = [vk::PushConstantRange::default()
.stage_flags(vk::ShaderStageFlags::COMPUTE)
.offset(0)
.size(std::mem::size_of::<PushBlock>() as u32)];
let dsls = [dsl];
let ci = vk::PipelineLayoutCreateInfo::default()
.set_layouts(&dsls)
.push_constant_ranges(&push_range);
ctx.device.create_pipeline_layout(&ci, None).map_err(|e| {
ctx.device.destroy_descriptor_set_layout(dsl, None);
ctx.device.destroy_shader_module(shader_module, None);
VulkanError::vk("create_pipeline_layout", e)
})?
};
let pipeline = unsafe {
let entry_name = std::ffi::CString::new("main").unwrap();
let stage = vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::COMPUTE)
.module(shader_module)
.name(&entry_name);
let ci = [vk::ComputePipelineCreateInfo::default()
.stage(stage)
.layout(pipeline_layout)];
let pipelines = ctx
.device
.create_compute_pipelines(vk::PipelineCache::null(), &ci, None)
.map_err(|(_, e)| {
ctx.device.destroy_pipeline_layout(pipeline_layout, None);
ctx.device.destroy_descriptor_set_layout(dsl, None);
ctx.device.destroy_shader_module(shader_module, None);
VulkanError::vk("create_compute_pipelines", e)
})?;
pipelines[0]
};
let descriptor_pool = unsafe {
let sizes = [vk::DescriptorPoolSize::default()
.ty(vk::DescriptorType::STORAGE_BUFFER)
.descriptor_count(2)];
let ci = vk::DescriptorPoolCreateInfo::default()
.max_sets(1)
.pool_sizes(&sizes);
ctx.device.create_descriptor_pool(&ci, None).map_err(|e| {
ctx.device.destroy_pipeline(pipeline, None);
ctx.device.destroy_pipeline_layout(pipeline_layout, None);
ctx.device.destroy_descriptor_set_layout(dsl, None);
ctx.device.destroy_shader_module(shader_module, None);
VulkanError::vk("create_descriptor_pool", e)
})?
};
let descriptor_set = unsafe {
let dsls = [dsl];
let alloc = vk::DescriptorSetAllocateInfo::default()
.descriptor_pool(descriptor_pool)
.set_layouts(&dsls);
ctx.device.allocate_descriptor_sets(&alloc).map_err(|e| {
ctx.device.destroy_descriptor_pool(descriptor_pool, None);
ctx.device.destroy_pipeline(pipeline, None);
ctx.device.destroy_pipeline_layout(pipeline_layout, None);
ctx.device.destroy_descriptor_set_layout(dsl, None);
ctx.device.destroy_shader_module(shader_module, None);
VulkanError::vk("allocate_descriptor_sets", e)
})?[0]
};
Ok(Self {
ctx,
shader_module,
dsl,
pipeline_layout,
pipeline,
descriptor_pool,
descriptor_set,
})
}
pub(crate) fn update_descriptor(
&self,
src: vk::Buffer,
src_bytes: u64,
dst: vk::Buffer,
dst_bytes: u64,
) {
unsafe {
let src_info = [vk::DescriptorBufferInfo::default()
.buffer(src)
.offset(0)
.range(src_bytes)];
let dst_info = [vk::DescriptorBufferInfo::default()
.buffer(dst)
.offset(0)
.range(dst_bytes)];
let writes = [
vk::WriteDescriptorSet::default()
.dst_set(self.descriptor_set)
.dst_binding(0)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.buffer_info(&src_info),
vk::WriteDescriptorSet::default()
.dst_set(self.descriptor_set)
.dst_binding(1)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.buffer_info(&dst_info),
];
self.ctx.device.update_descriptor_sets(&writes, &[]);
}
}
pub(crate) fn record_dispatch(
&self,
cmd: vk::CommandBuffer,
row_uints: u32,
src_stride_uints: u32,
dst_stride_uints: u32,
n_rows: u32,
) {
let pc = PushBlock {
row_uints,
src_stride_uints,
dst_stride_uints,
n_rows,
};
let total = row_uints as u64 * n_rows as u64;
let groups_x = total.div_ceil(STRIDE_COPY_LOCAL_SIZE_X as u64) as u32;
unsafe {
self.ctx
.device
.cmd_bind_pipeline(cmd, vk::PipelineBindPoint::COMPUTE, self.pipeline);
self.ctx.device.cmd_bind_descriptor_sets(
cmd,
vk::PipelineBindPoint::COMPUTE,
self.pipeline_layout,
0,
&[self.descriptor_set],
&[],
);
self.ctx.device.cmd_push_constants(
cmd,
self.pipeline_layout,
vk::ShaderStageFlags::COMPUTE,
0,
bytemuck::bytes_of(&pc),
);
self.ctx.device.cmd_dispatch(cmd, groups_x, 1, 1);
}
}
}
impl Drop for StrideCopyKernel {
fn drop(&mut self) {
unsafe {
self.ctx.device.device_wait_idle().ok();
self.ctx
.device
.destroy_descriptor_pool(self.descriptor_pool, None);
self.ctx.device.destroy_pipeline(self.pipeline, None);
self.ctx
.device
.destroy_pipeline_layout(self.pipeline_layout, None);
self.ctx
.device
.destroy_descriptor_set_layout(self.dsl, None);
self.ctx
.device
.destroy_shader_module(self.shader_module, None);
}
}
}