use std::ffi::c_void;
use std::ptr::NonNull;
use anyhow::Result;
use ash::vk;
use ash::vk::Handle;
use crate::core::device::ExtensionID;
use crate::core::traits::{AsRaw, Nameable};
use crate::util::align::align;
use crate::{Allocation, Allocator, DefaultAllocator, Device, Error, MemoryType};
#[derive(Derivative)]
#[derivative(Debug)]
pub struct Buffer<A: Allocator = DefaultAllocator> {
#[derivative(Debug = "ignore")]
device: Device,
#[derivative(Debug = "ignore")]
#[allow(dead_code)]
memory: A::Allocation,
address: vk::DeviceAddress,
pointer: Option<NonNull<c_void>>,
handle: vk::Buffer,
size: vk::DeviceSize,
}
unsafe impl<A: Allocator> Send for Buffer<A> {}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct BufferView {
handle: vk::Buffer,
pointer: Option<NonNull<c_void>>,
address: vk::DeviceAddress,
offset: vk::DeviceSize,
size: vk::DeviceSize,
}
unsafe impl Send for BufferView {}
fn get_buffer_usage_flags(device: &Device) -> vk::BufferUsageFlags {
let mut usage = vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
| vk::BufferUsageFlags::INDEX_BUFFER
| vk::BufferUsageFlags::INDIRECT_BUFFER
| vk::BufferUsageFlags::STORAGE_BUFFER
| vk::BufferUsageFlags::STORAGE_TEXEL_BUFFER
| vk::BufferUsageFlags::TRANSFER_DST
| vk::BufferUsageFlags::TRANSFER_SRC
| vk::BufferUsageFlags::UNIFORM_BUFFER
| vk::BufferUsageFlags::UNIFORM_TEXEL_BUFFER
| vk::BufferUsageFlags::VERTEX_BUFFER;
if device.is_extension_enabled(ExtensionID::AccelerationStructure) {
usage |= vk::BufferUsageFlags::ACCELERATION_STRUCTURE_STORAGE_KHR
| vk::BufferUsageFlags::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_KHR;
}
if device.is_extension_enabled(ExtensionID::RayTracingPipeline) {
usage |= vk::BufferUsageFlags::SHADER_BINDING_TABLE_KHR;
}
usage
}
impl<A: Allocator> Buffer<A> {
pub fn new(
device: Device,
allocator: &mut A,
size: impl Into<vk::DeviceSize>,
location: MemoryType,
) -> Result<Self> {
let size = size.into();
let sharing_mode = if device.is_single_queue() {
vk::SharingMode::EXCLUSIVE
} else {
vk::SharingMode::CONCURRENT
};
let usage = get_buffer_usage_flags(&device);
let handle = unsafe {
device.create_buffer(
&vk::BufferCreateInfo {
s_type: vk::StructureType::BUFFER_CREATE_INFO,
p_next: std::ptr::null(),
flags: vk::BufferCreateFlags::empty(),
size,
usage,
sharing_mode,
queue_family_index_count: if sharing_mode == vk::SharingMode::CONCURRENT {
device.queue_families().len() as u32
} else {
0
},
p_queue_family_indices: if sharing_mode == vk::SharingMode::CONCURRENT {
device.queue_families().as_ptr()
} else {
std::ptr::null()
},
},
None,
)?
};
#[cfg(feature = "log-objects")]
trace!("Created new VkBuffer {handle:p} (size = {size} bytes)");
let requirements = unsafe { device.get_buffer_memory_requirements(handle) };
let memory = allocator.allocate("buffer", &requirements, location)?;
unsafe { device.bind_buffer_memory(handle, memory.memory(), memory.offset())? };
let address = unsafe {
device.get_buffer_device_address(&vk::BufferDeviceAddressInfo {
s_type: vk::StructureType::BUFFER_DEVICE_ADDRESS_INFO,
p_next: std::ptr::null(),
buffer: handle,
})
};
Ok(Self {
device,
pointer: memory.mapped_ptr(),
memory,
handle,
size,
address,
})
}
pub fn new_aligned(
device: Device,
allocator: &mut A,
size: impl Into<vk::DeviceSize>,
alignment: impl Into<vk::DeviceSize>,
location: MemoryType,
) -> Result<Self> {
let alignment = alignment.into();
let size = align(size.into(), alignment);
let sharing_mode = if device.is_single_queue() {
vk::SharingMode::EXCLUSIVE
} else {
vk::SharingMode::CONCURRENT
};
let usage = get_buffer_usage_flags(&device);
let handle = unsafe {
device.create_buffer(
&vk::BufferCreateInfo {
s_type: vk::StructureType::BUFFER_CREATE_INFO,
p_next: std::ptr::null(),
flags: vk::BufferCreateFlags::empty(),
size,
usage,
sharing_mode,
queue_family_index_count: if sharing_mode == vk::SharingMode::CONCURRENT {
device.queue_families().len() as u32
} else {
0
},
p_queue_family_indices: if sharing_mode == vk::SharingMode::CONCURRENT {
device.queue_families().as_ptr()
} else {
std::ptr::null()
},
},
None,
)?
};
#[cfg(feature = "log-objects")]
trace!("Created new VkBuffer {handle:p} (size = {size} bytes)");
let mut requirements = unsafe { device.get_buffer_memory_requirements(handle) };
requirements.alignment = alignment;
let memory = allocator.allocate("buffer", &requirements, location)?;
unsafe { device.bind_buffer_memory(handle, memory.memory(), memory.offset())? };
let address = unsafe {
device.get_buffer_device_address(&vk::BufferDeviceAddressInfo {
s_type: vk::StructureType::BUFFER_DEVICE_ADDRESS_INFO,
p_next: std::ptr::null(),
buffer: handle,
})
};
Ok(Self {
device,
pointer: memory.mapped_ptr(),
memory,
handle,
size,
address,
})
}
pub fn new_device_local(
device: Device,
allocator: &mut A,
size: impl Into<vk::DeviceSize>,
) -> Result<Self> {
Self::new(device, allocator, size, MemoryType::GpuOnly)
}
pub fn view(
&self,
offset: impl Into<vk::DeviceSize>,
size: impl Into<vk::DeviceSize>,
) -> Result<BufferView> {
let offset = offset.into();
let size = size.into();
if offset + size > self.size {
Err(anyhow::Error::from(Error::BufferViewOutOfRange))
} else {
Ok(BufferView {
handle: self.handle,
offset,
pointer: unsafe {
self.pointer
.map(|p| NonNull::new(p.as_ptr().offset(offset as isize)).unwrap())
},
address: self.address + offset,
size,
})
}
}
pub fn view_full(&self) -> BufferView {
BufferView {
handle: self.handle,
pointer: self.pointer,
offset: 0,
address: self.address,
size: self.size,
}
}
pub fn is_mapped(&self) -> bool {
self.pointer.is_some()
}
pub unsafe fn handle(&self) -> vk::Buffer {
self.handle
}
pub fn size(&self) -> vk::DeviceSize {
self.size
}
pub fn address(&self) -> vk::DeviceAddress {
self.address
}
}
unsafe impl AsRaw for Buffer {
unsafe fn as_raw(&self) -> u64 {
self.handle().as_raw()
}
}
impl Nameable for Buffer {
const OBJECT_TYPE: vk::ObjectType = vk::ObjectType::BUFFER;
}
impl<A: Allocator> Drop for Buffer<A> {
fn drop(&mut self) {
#[cfg(feature = "log-objects")]
trace!("Destroying VkBuffer {:p}", self.handle);
unsafe {
self.device.destroy_buffer(self.handle, None);
}
}
}
impl BufferView {
pub fn mapped_slice<T>(&mut self) -> Result<&mut [T]> {
if let Some(pointer) = self.pointer {
Ok(unsafe {
std::slice::from_raw_parts_mut(
pointer.cast::<T>().as_ptr(),
self.size as usize / std::mem::size_of::<T>(),
)
})
} else {
Err(anyhow::Error::from(Error::UnmappableBuffer))
}
}
pub unsafe fn handle(&self) -> vk::Buffer {
self.handle
}
pub fn offset(&self) -> vk::DeviceSize {
self.offset
}
pub fn size(&self) -> vk::DeviceSize {
self.size
}
pub fn address(&self) -> vk::DeviceAddress {
self.address
}
}