use crate::buffer::MkBufferUsage;
use crate::memory::MkMemoryType;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
pub trait MkGpuBackend: Sized + Send + Sync {
type BufferHandle: Clone + Send + Sync;
type Error: std::error::Error + Send + Sync + 'static;
fn name(&self) -> &'static str;
fn capabilities(&self) -> MkGpuCapabilities;
fn create_buffer(
&self,
size: usize,
usage: MkBufferUsage,
memory_type: MkMemoryType,
) -> Result<Self::BufferHandle, Self::Error>;
fn destroy_buffer(&self, handle: &Self::BufferHandle);
fn map(&self, handle: &Self::BufferHandle) -> Option<*mut u8>;
fn unmap(&self, handle: &Self::BufferHandle);
fn flush(&self, handle: &Self::BufferHandle, offset: usize, size: usize);
fn invalidate(&self, handle: &Self::BufferHandle, offset: usize, size: usize);
fn copy_buffer(
&self,
src: &Self::BufferHandle,
dst: &Self::BufferHandle,
size: usize,
) -> Result<(), Self::Error>;
fn copy_buffer_regions(
&self,
src: &Self::BufferHandle,
src_offset: usize,
dst: &Self::BufferHandle,
dst_offset: usize,
size: usize,
) -> Result<(), Self::Error>;
fn wait_idle(&self) -> Result<(), Self::Error>;
}
#[derive(Debug, Clone, Default)]
pub struct MkGpuCapabilities {
pub max_buffer_size: usize,
pub max_allocations: usize,
pub unified_memory: bool,
pub coherent_memory: bool,
pub device_name: String,
pub vendor_name: String,
}
pub struct DummyBackend {
next_id: AtomicU64,
buffers: Arc<RwLock<HashMap<u64, DummyBuffer>>>,
config: DummyBackendConfig,
}
#[derive(Debug, Clone)]
pub struct DummyBackendConfig {
pub max_buffer_size: usize,
pub simulate_device_local: bool,
pub transfer_delay_us: u64,
}
impl Default for DummyBackendConfig {
fn default() -> Self {
Self {
max_buffer_size: 1024 * 1024 * 1024, simulate_device_local: true,
transfer_delay_us: 0,
}
}
}
struct DummyBuffer {
data: *mut u8,
size: usize,
usage: MkBufferUsage,
memory_type: MkMemoryType,
mapped: bool,
}
unsafe impl Send for DummyBuffer {}
unsafe impl Sync for DummyBuffer {}
impl DummyBackend {
pub fn new() -> Self {
Self::with_config(DummyBackendConfig::default())
}
pub fn with_config(config: DummyBackendConfig) -> Self {
Self {
next_id: AtomicU64::new(1),
buffers: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn buffer_count(&self) -> usize {
self.buffers.read().unwrap().len()
}
pub fn total_allocated(&self) -> usize {
self.buffers.read().unwrap().values().map(|b| b.size).sum()
}
}
impl Default for DummyBackend {
fn default() -> Self {
Self::new()
}
}
impl Drop for DummyBackend {
fn drop(&mut self) {
let buffers = std::mem::take(&mut *self.buffers.write().unwrap());
for (_, buffer) in buffers {
if !buffer.data.is_null() {
let layout = std::alloc::Layout::from_size_align(buffer.size, 8).unwrap();
unsafe { std::alloc::dealloc(buffer.data, layout) };
}
}
}
}
#[derive(Clone, Debug)]
pub struct DummyBufferHandle {
id: u64,
size: usize,
memory_type: MkMemoryType,
}
unsafe impl Send for DummyBufferHandle {}
unsafe impl Sync for DummyBufferHandle {}
impl DummyBufferHandle {
pub fn size(&self) -> usize {
self.size
}
pub fn memory_type(&self) -> MkMemoryType {
self.memory_type
}
}
#[derive(Debug, Clone)]
pub enum DummyError {
BufferTooLarge { requested: usize, max: usize },
AllocationFailed,
BufferNotFound(u64),
NotMappable,
AlreadyMapped,
Other(String),
}
impl std::fmt::Display for DummyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DummyError::BufferTooLarge { requested, max } => {
write!(f, "Buffer size {} exceeds maximum {}", requested, max)
}
DummyError::AllocationFailed => write!(f, "Memory allocation failed"),
DummyError::BufferNotFound(id) => write!(f, "Buffer {} not found", id),
DummyError::NotMappable => write!(f, "Buffer is not mappable (device-local)"),
DummyError::AlreadyMapped => write!(f, "Buffer is already mapped"),
DummyError::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for DummyError {}
impl MkGpuBackend for DummyBackend {
type BufferHandle = DummyBufferHandle;
type Error = DummyError;
fn name(&self) -> &'static str {
"Dummy (CPU Simulation)"
}
fn capabilities(&self) -> MkGpuCapabilities {
MkGpuCapabilities {
max_buffer_size: self.config.max_buffer_size,
max_allocations: usize::MAX,
unified_memory: true,
coherent_memory: true,
device_name: "Dummy GPU".to_string(),
vendor_name: "memkit".to_string(),
}
}
fn create_buffer(
&self,
size: usize,
usage: MkBufferUsage,
memory_type: MkMemoryType,
) -> Result<Self::BufferHandle, Self::Error> {
if size > self.config.max_buffer_size {
return Err(DummyError::BufferTooLarge {
requested: size,
max: self.config.max_buffer_size,
});
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let data = if size > 0 {
let layout = std::alloc::Layout::from_size_align(size, 8)
.map_err(|_| DummyError::AllocationFailed)?;
let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
if ptr.is_null() {
return Err(DummyError::AllocationFailed);
}
ptr
} else {
std::ptr::null_mut()
};
let buffer = DummyBuffer {
data,
size,
usage,
memory_type,
mapped: false,
};
self.buffers.write().unwrap().insert(id, buffer);
Ok(DummyBufferHandle { id, size, memory_type })
}
fn destroy_buffer(&self, handle: &Self::BufferHandle) {
if let Some(buffer) = self.buffers.write().unwrap().remove(&handle.id) {
if !buffer.data.is_null() {
let layout = std::alloc::Layout::from_size_align(buffer.size, 8).unwrap();
unsafe { std::alloc::dealloc(buffer.data, layout) };
}
}
}
fn map(&self, handle: &Self::BufferHandle) -> Option<*mut u8> {
let mut buffers = self.buffers.write().unwrap();
let buffer = buffers.get_mut(&handle.id)?;
if self.config.simulate_device_local && handle.memory_type == MkMemoryType::DeviceLocal {
return None;
}
buffer.mapped = true;
Some(buffer.data)
}
fn unmap(&self, handle: &Self::BufferHandle) {
if let Some(buffer) = self.buffers.write().unwrap().get_mut(&handle.id) {
buffer.mapped = false;
}
}
fn flush(&self, _handle: &Self::BufferHandle, _offset: usize, _size: usize) {
}
fn invalidate(&self, _handle: &Self::BufferHandle, _offset: usize, _size: usize) {
}
fn copy_buffer(
&self,
src: &Self::BufferHandle,
dst: &Self::BufferHandle,
size: usize,
) -> Result<(), Self::Error> {
self.copy_buffer_regions(src, 0, dst, 0, size)
}
fn copy_buffer_regions(
&self,
src: &Self::BufferHandle,
src_offset: usize,
dst: &Self::BufferHandle,
dst_offset: usize,
size: usize,
) -> Result<(), Self::Error> {
if self.config.transfer_delay_us > 0 {
std::thread::sleep(std::time::Duration::from_micros(self.config.transfer_delay_us));
}
let buffers = self.buffers.read().unwrap();
let src_buf = buffers.get(&src.id)
.ok_or(DummyError::BufferNotFound(src.id))?;
let dst_buf = buffers.get(&dst.id)
.ok_or(DummyError::BufferNotFound(dst.id))?;
let copy_size = size
.min(src_buf.size.saturating_sub(src_offset))
.min(dst_buf.size.saturating_sub(dst_offset));
if copy_size > 0 && !src_buf.data.is_null() && !dst_buf.data.is_null() {
unsafe {
std::ptr::copy_nonoverlapping(
src_buf.data.add(src_offset),
dst_buf.data.add(dst_offset),
copy_size,
);
}
}
Ok(())
}
fn wait_idle(&self) -> Result<(), Self::Error> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dummy_backend_create_destroy() {
let backend = DummyBackend::new();
assert_eq!(backend.buffer_count(), 0);
let handle = backend.create_buffer(
1024,
MkBufferUsage::VERTEX,
MkMemoryType::HostVisible,
).unwrap();
assert_eq!(backend.buffer_count(), 1);
assert_eq!(backend.total_allocated(), 1024);
backend.destroy_buffer(&handle);
assert_eq!(backend.buffer_count(), 0);
}
#[test]
fn test_dummy_backend_map_write_read() {
let backend = DummyBackend::new();
let handle = backend.create_buffer(
1024,
MkBufferUsage::VERTEX,
MkMemoryType::HostVisible,
).unwrap();
let ptr = backend.map(&handle).unwrap();
unsafe {
for i in 0..256 {
*ptr.add(i) = i as u8;
}
}
backend.unmap(&handle);
let ptr2 = backend.map(&handle).unwrap();
for i in 0..256 {
assert_eq!(unsafe { *ptr2.add(i) }, i as u8);
}
backend.destroy_buffer(&handle);
}
#[test]
fn test_dummy_backend_copy() {
let backend = DummyBackend::new();
let src = backend.create_buffer(256, MkBufferUsage::TRANSFER_SRC, MkMemoryType::HostVisible).unwrap();
let dst = backend.create_buffer(256, MkBufferUsage::TRANSFER_DST, MkMemoryType::HostVisible).unwrap();
let ptr = backend.map(&src).unwrap();
unsafe {
for i in 0..256 {
*ptr.add(i) = i as u8;
}
}
backend.unmap(&src);
backend.copy_buffer(&src, &dst, 256).unwrap();
let ptr2 = backend.map(&dst).unwrap();
for i in 0..256 {
assert_eq!(unsafe { *ptr2.add(i) }, i as u8);
}
backend.destroy_buffer(&src);
backend.destroy_buffer(&dst);
}
#[test]
fn test_dummy_backend_device_local_not_mappable() {
let backend = DummyBackend::new();
let handle = backend.create_buffer(
1024,
MkBufferUsage::VERTEX,
MkMemoryType::DeviceLocal,
).unwrap();
assert!(backend.map(&handle).is_none());
backend.destroy_buffer(&handle);
}
#[test]
fn test_dummy_backend_capabilities() {
let backend = DummyBackend::new();
let caps = backend.capabilities();
assert_eq!(caps.device_name, "Dummy GPU");
assert!(caps.unified_memory);
assert!(caps.coherent_memory);
}
#[test]
fn test_dummy_backend_buffer_too_large() {
let config = DummyBackendConfig {
max_buffer_size: 1024,
..Default::default()
};
let backend = DummyBackend::with_config(config);
let result = backend.create_buffer(
2048,
MkBufferUsage::VERTEX,
MkMemoryType::HostVisible,
);
assert!(matches!(result, Err(DummyError::BufferTooLarge { .. })));
}
}