use crate::buffer::BufferManager;
use crate::error::{AccelError, AccelResult};
use crate::shaders::motion::block_sad;
use std::sync::Arc;
use vulkano::buffer::BufferUsage;
use vulkano::descriptor_set::allocator::{
StandardDescriptorSetAllocator, StandardDescriptorSetAllocatorCreateInfo,
};
use vulkano::descriptor_set::{DescriptorSet, WriteDescriptorSet};
use vulkano::device::{Device, Queue};
use vulkano::pipeline::{
compute::ComputePipelineCreateInfo, layout::PipelineDescriptorSetLayoutCreateInfo,
ComputePipeline, Pipeline, PipelineBindPoint, PipelineLayout, PipelineShaderStageCreateInfo,
};
use vulkano::sync::GpuFuture;
pub struct MotionKernel {
device: Arc<Device>,
queue: Arc<Queue>,
buffer_manager: BufferManager,
descriptor_allocator: Arc<StandardDescriptorSetAllocator>,
pipeline: Arc<ComputePipeline>,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, bytemuck::Pod, bytemuck::Zeroable)]
struct MotionPushConstants {
width: u32,
height: u32,
block_size: u32,
search_range: u32,
}
impl MotionKernel {
pub fn new(
device: Arc<Device>,
queue: Arc<Queue>,
buffer_manager: BufferManager,
) -> AccelResult<Self> {
let descriptor_allocator = Arc::new(StandardDescriptorSetAllocator::new(
device.clone(),
StandardDescriptorSetAllocatorCreateInfo::default(),
));
let shader = block_sad::load(device.clone()).map_err(|e| {
AccelError::ShaderCompilation(format!("Motion estimation shader: {e:?}"))
})?;
let stage =
PipelineShaderStageCreateInfo::new(shader.entry_point("main").ok_or_else(|| {
AccelError::ShaderCompilation(
"Motion estimation shader: entry point 'main' not found".to_string(),
)
})?);
let layout = PipelineLayout::new(
device.clone(),
PipelineDescriptorSetLayoutCreateInfo::from_stages(std::slice::from_ref(&stage))
.into_pipeline_layout_create_info(device.clone())
.map_err(|e| AccelError::PipelineCreation(format!("Motion layout: {e:?}")))?,
)
.map_err(|e| AccelError::PipelineCreation(format!("Motion layout creation: {e:?}")))?;
let pipeline = ComputePipeline::new(
device.clone(),
None,
ComputePipelineCreateInfo::stage_layout(stage, layout),
)
.map_err(|e| AccelError::PipelineCreation(format!("Motion pipeline: {e:?}")))?;
Ok(Self {
device,
queue,
buffer_manager,
descriptor_allocator,
pipeline,
})
}
#[allow(clippy::cast_possible_truncation)]
pub fn estimate(
&self,
reference: &[u8],
current: &[u8],
width: u32,
height: u32,
block_size: u32,
) -> AccelResult<Vec<(i16, i16)>> {
let expected_size = (width * height) as usize;
if reference.len() != expected_size || current.len() != expected_size {
return Err(AccelError::BufferSizeMismatch {
expected: expected_size,
actual: reference.len().min(current.len()),
});
}
let ref_buffer = self
.buffer_manager
.create_device_buffer(reference.len() as u64, BufferUsage::STORAGE_BUFFER)?;
let cur_buffer = self
.buffer_manager
.create_device_buffer(current.len() as u64, BufferUsage::STORAGE_BUFFER)?;
self.buffer_manager.upload_data(reference, &ref_buffer)?;
self.buffer_manager.upload_data(current, &cur_buffer)?;
let blocks_wide = width.div_ceil(block_size);
let blocks_high = height.div_ceil(block_size);
let mv_count = blocks_wide * blocks_high;
let mv_buffer = self
.buffer_manager
.create_device_buffer(u64::from(mv_count * 4), BufferUsage::STORAGE_BUFFER)?;
let layout = self
.pipeline
.layout()
.set_layouts()
.first()
.ok_or_else(|| AccelError::PipelineCreation("No descriptor set layout".to_string()))?;
let descriptor_set = DescriptorSet::new(
self.descriptor_allocator.clone(),
layout.clone(),
[
WriteDescriptorSet::buffer(0, ref_buffer),
WriteDescriptorSet::buffer(1, cur_buffer),
WriteDescriptorSet::buffer(2, mv_buffer.clone()),
],
[],
)
.map_err(|e| AccelError::PipelineCreation(format!("Descriptor set: {e:?}")))?;
let mut builder = self.buffer_manager.create_command_buffer()?;
let push_constants = MotionPushConstants {
width,
height,
block_size,
search_range: 8,
};
builder
.bind_pipeline_compute(self.pipeline.clone())
.map_err(|e| AccelError::CommandBuffer(format!("Bind pipeline: {e:?}")))?
.bind_descriptor_sets(
PipelineBindPoint::Compute,
self.pipeline.layout().clone(),
0,
descriptor_set,
)
.map_err(|e| AccelError::CommandBuffer(format!("Bind descriptor sets: {e:?}")))?
.push_constants(self.pipeline.layout().clone(), 0, push_constants)
.map_err(|e| AccelError::CommandBuffer(format!("Push constants: {e:?}")))?;
unsafe {
builder
.dispatch([blocks_wide.div_ceil(8), blocks_high.div_ceil(8), 1])
.map_err(|e| AccelError::Dispatch(format!("Dispatch: {e:?}")))?;
}
let command_buffer = builder
.build()
.map_err(|e| AccelError::CommandBuffer(format!("Build: {e:?}")))?;
vulkano::sync::now(self.device.clone())
.then_execute(self.queue.clone(), command_buffer)
.map_err(|e| AccelError::Dispatch(format!("Execute: {e:?}")))?
.then_signal_fence_and_flush()
.map_err(|e| AccelError::Dispatch(format!("Flush: {e:?}")))?
.wait(None)
.map_err(|e| AccelError::Synchronization(format!("Wait: {e:?}")))?;
let mv_data = self.buffer_manager.download_data(&mv_buffer)?;
let mut motion_vectors = Vec::with_capacity(mv_count as usize);
for chunk in mv_data.chunks_exact(4) {
let dx = i16::from_le_bytes([chunk[0], chunk[1]]);
let dy = i16::from_le_bytes([chunk[2], chunk[3]]);
motion_vectors.push((dx, dy));
}
Ok(motion_vectors)
}
}