use std::marker::PhantomData;
use std::sync::Arc;
use ash::vk;
use super::device::VulkanContext;
use super::error::VulkanError;
use crate::backend::BufferOps;
use crate::scalar::Scalar;
pub struct VulkanBuffer<T: Scalar> {
pub(crate) ctx: Arc<VulkanContext>,
pub(crate) buffer: vk::Buffer,
pub(crate) memory: vk::DeviceMemory,
pub(crate) size_bytes: u64,
pub(crate) len: usize,
_marker: PhantomData<T>,
}
impl<T: Scalar> VulkanBuffer<T> {
pub(crate) fn new(ctx: Arc<VulkanContext>, len: usize) -> Result<Self, VulkanError> {
let size_bytes = (len * T::BYTES) as u64;
let usage = vk::BufferUsageFlags::STORAGE_BUFFER
| vk::BufferUsageFlags::TRANSFER_SRC
| vk::BufferUsageFlags::TRANSFER_DST;
let (buffer, memory, _) =
ctx.allocate_buffer(size_bytes, usage, vk::MemoryPropertyFlags::DEVICE_LOCAL)?;
Ok(Self {
ctx,
buffer,
memory,
size_bytes,
len,
_marker: PhantomData,
})
}
pub fn raw(&self) -> vk::Buffer {
self.buffer
}
pub fn size_bytes(&self) -> u64 {
self.size_bytes
}
}
impl<T: Scalar> BufferOps<super::VulkanBackend, T> for VulkanBuffer<T> {
fn len(&self) -> usize {
self.len
}
fn write(&mut self, src: &[T]) -> Result<(), VulkanError> {
if src.len() != self.len {
return Err(VulkanError::LengthMismatch {
expected: self.len,
got: src.len(),
});
}
staging_copy_in(&self.ctx, self.buffer, self.size_bytes, src)
}
fn read(&self, dst: &mut [T]) -> Result<(), VulkanError> {
if dst.len() != self.len {
return Err(VulkanError::LengthMismatch {
expected: self.len,
got: dst.len(),
});
}
staging_copy_out(&self.ctx, self.buffer, self.size_bytes, dst)
}
}
impl<T: Scalar> Drop for VulkanBuffer<T> {
fn drop(&mut self) {
unsafe {
self.ctx.device.destroy_buffer(self.buffer, None);
self.ctx.device.free_memory(self.memory, None);
}
}
}
fn staging_copy_in<T: Scalar>(
ctx: &VulkanContext,
dst: vk::Buffer,
size_bytes: u64,
src: &[T],
) -> Result<(), VulkanError> {
let (staging, staging_mem, _) = ctx.allocate_buffer(
size_bytes,
vk::BufferUsageFlags::TRANSFER_SRC,
vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT,
)?;
unsafe {
let ptr = ctx
.device
.map_memory(staging_mem, 0, size_bytes, vk::MemoryMapFlags::empty())
.map_err(|e| {
ctx.device.destroy_buffer(staging, None);
ctx.device.free_memory(staging_mem, None);
VulkanError::vk("map_memory", e)
})?;
std::ptr::copy_nonoverlapping(
src.as_ptr().cast::<u8>(),
ptr.cast::<u8>(),
size_bytes as usize,
);
ctx.device.unmap_memory(staging_mem);
}
let result = copy_buffer_to_buffer(ctx, staging, dst, size_bytes);
unsafe {
ctx.device.destroy_buffer(staging, None);
ctx.device.free_memory(staging_mem, None);
}
result
}
fn staging_copy_out<T: Scalar>(
ctx: &VulkanContext,
src: vk::Buffer,
size_bytes: u64,
dst: &mut [T],
) -> Result<(), VulkanError> {
let (staging, staging_mem, _) = ctx.allocate_buffer(
size_bytes,
vk::BufferUsageFlags::TRANSFER_DST,
vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT,
)?;
let copy_result = copy_buffer_to_buffer(ctx, src, staging, size_bytes);
if let Err(e) = copy_result {
unsafe {
ctx.device.destroy_buffer(staging, None);
ctx.device.free_memory(staging_mem, None);
}
return Err(e);
}
unsafe {
let ptr = ctx
.device
.map_memory(staging_mem, 0, size_bytes, vk::MemoryMapFlags::empty())
.map_err(|e| {
ctx.device.destroy_buffer(staging, None);
ctx.device.free_memory(staging_mem, None);
VulkanError::vk("map_memory", e)
})?;
std::ptr::copy_nonoverlapping(
ptr.cast_const().cast::<u8>(),
dst.as_mut_ptr().cast::<u8>(),
size_bytes as usize,
);
ctx.device.unmap_memory(staging_mem);
}
unsafe {
ctx.device.destroy_buffer(staging, None);
ctx.device.free_memory(staging_mem, None);
}
Ok(())
}
fn copy_buffer_to_buffer(
ctx: &VulkanContext,
src: vk::Buffer,
dst: vk::Buffer,
size_bytes: u64,
) -> Result<(), VulkanError> {
unsafe {
let alloc_info = vk::CommandBufferAllocateInfo::default()
.command_pool(ctx.transfer_pool)
.level(vk::CommandBufferLevel::PRIMARY)
.command_buffer_count(1);
let cmd_bufs = ctx
.device
.allocate_command_buffers(&alloc_info)
.map_err(|e| VulkanError::vk("allocate_command_buffers", e))?;
let cmd = cmd_bufs[0];
let begin = vk::CommandBufferBeginInfo::default()
.flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT);
ctx.device
.begin_command_buffer(cmd, &begin)
.map_err(|e| VulkanError::vk("begin_command_buffer", e))?;
let region = [vk::BufferCopy::default().size(size_bytes)];
ctx.device.cmd_copy_buffer(cmd, src, dst, ®ion);
ctx.device
.end_command_buffer(cmd)
.map_err(|e| VulkanError::vk("end_command_buffer", e))?;
let submit = [vk::SubmitInfo::default().command_buffers(&cmd_bufs)];
ctx.device
.reset_fences(&[ctx.transfer_fence])
.map_err(|e| VulkanError::vk("reset_fences", e))?;
ctx.device
.queue_submit(ctx.queue, &submit, ctx.transfer_fence)
.map_err(|e| VulkanError::vk("queue_submit", e))?;
ctx.device
.wait_for_fences(&[ctx.transfer_fence], true, u64::MAX)
.map_err(|e| VulkanError::vk("wait_for_fences", e))?;
ctx.device
.free_command_buffers(ctx.transfer_pool, &cmd_bufs);
}
Ok(())
}