use crate::error::{KernelError, Result};
#[cfg(feature = "std")]
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
#[cfg(feature = "std")]
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum GpuBackend {
#[default]
None = 0,
Vulkan = 1,
Metal = 2,
Dx12 = 3,
OpenGL = 4,
WebGpu = 5,
}
impl GpuBackend {
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Self::None => "none",
Self::Vulkan => "vulkan",
Self::Metal => "metal",
Self::Dx12 => "dx12",
Self::OpenGL => "opengl",
Self::WebGpu => "webgpu",
}
}
#[must_use]
pub const fn is_native(&self) -> bool {
matches!(self, Self::Vulkan | Self::Metal | Self::Dx12)
}
#[must_use]
pub const fn is_available(&self) -> bool {
!matches!(self, Self::None)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct BufferUsage(u32);
impl BufferUsage {
pub const MAP_READ: Self = Self(1 << 0);
pub const MAP_WRITE: Self = Self(1 << 1);
pub const COPY_SRC: Self = Self(1 << 2);
pub const COPY_DST: Self = Self(1 << 3);
pub const INDEX: Self = Self(1 << 4);
pub const VERTEX: Self = Self(1 << 5);
pub const UNIFORM: Self = Self(1 << 6);
pub const STORAGE: Self = Self(1 << 7);
pub const INDIRECT: Self = Self(1 << 8);
#[must_use]
pub const fn empty() -> Self {
Self(0)
}
#[must_use]
pub const fn union(self, other: Self) -> Self {
Self(self.0 | other.0)
}
#[must_use]
pub const fn contains(self, other: Self) -> bool {
(self.0 & other.0) == other.0
}
#[must_use]
pub const fn bits(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct GpuLimits {
pub max_buffer_size: u64,
pub max_storage_buffer_binding_size: u32,
pub max_uniform_buffer_binding_size: u32,
pub max_compute_workgroup_size_x: u32,
pub max_compute_workgroup_size_y: u32,
pub max_compute_workgroup_size_z: u32,
pub max_compute_workgroups_per_dimension: u32,
pub max_bind_groups: u32,
}
impl Default for GpuLimits {
fn default() -> Self {
Self {
max_buffer_size: 256 * 1024 * 1024, max_storage_buffer_binding_size: 128 * 1024 * 1024,
max_uniform_buffer_binding_size: 64 * 1024,
max_compute_workgroup_size_x: 256,
max_compute_workgroup_size_y: 256,
max_compute_workgroup_size_z: 64,
max_compute_workgroups_per_dimension: 65535,
max_bind_groups: 4,
}
}
}
#[derive(Debug, Clone)]
pub struct GpuDeviceInfo {
pub name: String,
pub vendor: String,
pub backend: GpuBackend,
pub device_type: GpuDeviceType,
pub limits: GpuLimits,
}
impl Default for GpuDeviceInfo {
fn default() -> Self {
Self {
name: "Mock GPU".to_string(),
vendor: "Pepita".to_string(),
backend: GpuBackend::None,
device_type: GpuDeviceType::Other,
limits: GpuLimits::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum GpuDeviceType {
DiscreteGpu = 0,
IntegratedGpu = 1,
VirtualGpu = 2,
Cpu = 3,
#[default]
Other = 4,
}
#[cfg(feature = "std")]
#[derive(Debug)]
pub struct Buffer {
id: u64,
size: u64,
usage: BufferUsage,
data: Vec<u8>,
}
#[cfg(feature = "std")]
impl Buffer {
pub fn new(id: u64, size: u64, usage: BufferUsage) -> Self {
Self {
id,
size,
usage,
data: vec![0u8; size as usize],
}
}
pub fn with_data(id: u64, data: &[u8], usage: BufferUsage) -> Self {
Self {
id,
size: data.len() as u64,
usage,
data: data.to_vec(),
}
}
#[must_use]
pub const fn id(&self) -> u64 {
self.id
}
#[must_use]
pub const fn size(&self) -> u64 {
self.size
}
#[must_use]
pub const fn usage(&self) -> BufferUsage {
self.usage
}
#[must_use]
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [u8] {
&mut self.data
}
pub fn map_read(&self) -> Result<&[u8]> {
if !self.usage.contains(BufferUsage::MAP_READ) {
return Err(KernelError::InvalidRequest);
}
Ok(&self.data)
}
pub fn map_write(&mut self) -> Result<&mut [u8]> {
if !self.usage.contains(BufferUsage::MAP_WRITE) {
return Err(KernelError::InvalidRequest);
}
Ok(&mut self.data)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ShaderStage {
Vertex = 0,
Fragment = 1,
Compute = 2,
}
#[cfg(feature = "std")]
#[derive(Debug)]
pub struct ComputeShader {
id: u64,
source: String,
entry_point: String,
workgroup_size: (u32, u32, u32),
}
#[cfg(feature = "std")]
impl ComputeShader {
pub fn from_wgsl(source: impl Into<String>) -> Result<Self> {
static SHADER_ID: AtomicU64 = AtomicU64::new(1);
Ok(Self {
id: SHADER_ID.fetch_add(1, Ordering::Relaxed),
source: source.into(),
entry_point: "main".to_string(),
workgroup_size: (64, 1, 1),
})
}
#[must_use]
pub fn with_entry_point(mut self, name: impl Into<String>) -> Self {
self.entry_point = name.into();
self
}
#[must_use]
pub const fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.workgroup_size = (x, y, z);
self
}
#[must_use]
pub const fn id(&self) -> u64 {
self.id
}
#[must_use]
pub fn source(&self) -> &str {
&self.source
}
#[must_use]
pub fn entry_point(&self) -> &str {
&self.entry_point
}
#[must_use]
pub const fn workgroup_size(&self) -> (u32, u32, u32) {
self.workgroup_size
}
}
#[cfg(feature = "std")]
pub struct GpuDevice {
info: GpuDeviceInfo,
available: AtomicBool,
buffer_counter: AtomicU64,
dispatch_count: AtomicU64,
}
#[cfg(feature = "std")]
impl GpuDevice {
pub fn mock() -> Self {
Self {
info: GpuDeviceInfo::default(),
available: AtomicBool::new(true),
buffer_counter: AtomicU64::new(1),
dispatch_count: AtomicU64::new(0),
}
}
pub fn default_device() -> Result<Self> {
Ok(Self::mock())
}
#[must_use]
pub const fn info(&self) -> &GpuDeviceInfo {
&self.info
}
#[must_use]
pub fn is_available(&self) -> bool {
self.available.load(Ordering::Acquire)
}
pub fn create_buffer(&self, data: &[u8], usage: BufferUsage) -> Result<Buffer> {
if !self.is_available() {
return Err(KernelError::DeviceNotReady);
}
if data.len() as u64 > self.info.limits.max_buffer_size {
return Err(KernelError::OutOfMemory);
}
let id = self.buffer_counter.fetch_add(1, Ordering::Relaxed);
Ok(Buffer::with_data(id, data, usage))
}
pub fn create_buffer_uninit(&self, size: u64, usage: BufferUsage) -> Result<Buffer> {
if !self.is_available() {
return Err(KernelError::DeviceNotReady);
}
if size > self.info.limits.max_buffer_size {
return Err(KernelError::OutOfMemory);
}
let id = self.buffer_counter.fetch_add(1, Ordering::Relaxed);
Ok(Buffer::new(id, size, usage))
}
pub fn dispatch(
&self,
_shader: &ComputeShader,
_buffers: &[&Buffer],
workgroups: (u32, u32, u32),
) -> Result<()> {
if !self.is_available() {
return Err(KernelError::DeviceNotReady);
}
let limits = &self.info.limits;
if workgroups.0 > limits.max_compute_workgroups_per_dimension
|| workgroups.1 > limits.max_compute_workgroups_per_dimension
|| workgroups.2 > limits.max_compute_workgroups_per_dimension
{
return Err(KernelError::InvalidArgument);
}
self.dispatch_count.fetch_add(1, Ordering::Relaxed);
Ok(())
}
#[must_use]
pub fn dispatch_count(&self) -> u64 {
self.dispatch_count.load(Ordering::Relaxed)
}
pub fn copy_buffer(&self, src: &Buffer, dst: &mut Buffer) -> Result<()> {
if !self.is_available() {
return Err(KernelError::DeviceNotReady);
}
if !src.usage.contains(BufferUsage::COPY_SRC) {
return Err(KernelError::InvalidRequest);
}
if !dst.usage.contains(BufferUsage::COPY_DST) {
return Err(KernelError::InvalidRequest);
}
let len = src.size.min(dst.size) as usize;
dst.data[..len].copy_from_slice(&src.data[..len]);
Ok(())
}
pub fn submit_and_wait(&self) -> Result<()> {
if !self.is_available() {
return Err(KernelError::DeviceNotReady);
}
Ok(())
}
}
#[cfg(feature = "std")]
impl Default for GpuDevice {
fn default() -> Self {
Self::mock()
}
}
#[cfg(feature = "std")]
impl std::fmt::Debug for GpuDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuDevice")
.field("name", &self.info.name)
.field("backend", &self.info.backend)
.field("available", &self.is_available())
.finish()
}
}
#[cfg(feature = "std")]
pub struct ComputePipeline {
shader: Arc<ComputeShader>,
bind_groups: u32,
}
#[cfg(feature = "std")]
impl ComputePipeline {
pub fn new(shader: ComputeShader) -> Self {
Self {
shader: Arc::new(shader),
bind_groups: 1,
}
}
#[must_use]
pub const fn with_bind_groups(mut self, count: u32) -> Self {
self.bind_groups = count;
self
}
#[must_use]
pub fn shader(&self) -> &ComputeShader {
&self.shader
}
}
#[cfg(feature = "std")]
impl std::fmt::Debug for ComputePipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComputePipeline")
.field("shader_id", &self.shader.id())
.field("bind_groups", &self.bind_groups)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_backend_default() {
assert_eq!(GpuBackend::default(), GpuBackend::None);
}
#[test]
fn test_gpu_backend_name() {
assert_eq!(GpuBackend::Vulkan.name(), "vulkan");
assert_eq!(GpuBackend::Metal.name(), "metal");
assert_eq!(GpuBackend::None.name(), "none");
}
#[test]
fn test_gpu_backend_is_native() {
assert!(GpuBackend::Vulkan.is_native());
assert!(GpuBackend::Metal.is_native());
assert!(GpuBackend::Dx12.is_native());
assert!(!GpuBackend::OpenGL.is_native());
assert!(!GpuBackend::WebGpu.is_native());
}
#[test]
fn test_gpu_backend_is_available() {
assert!(GpuBackend::Vulkan.is_available());
assert!(!GpuBackend::None.is_available());
}
#[test]
fn test_buffer_usage_empty() {
let usage = BufferUsage::empty();
assert_eq!(usage.bits(), 0);
}
#[test]
fn test_buffer_usage_union() {
let usage = BufferUsage::STORAGE.union(BufferUsage::COPY_SRC);
assert!(usage.contains(BufferUsage::STORAGE));
assert!(usage.contains(BufferUsage::COPY_SRC));
assert!(!usage.contains(BufferUsage::UNIFORM));
}
#[test]
fn test_buffer_usage_contains() {
let usage = BufferUsage::STORAGE;
assert!(usage.contains(BufferUsage::STORAGE));
assert!(!usage.contains(BufferUsage::UNIFORM));
}
#[test]
fn test_gpu_limits_default() {
let limits = GpuLimits::default();
assert!(limits.max_buffer_size > 0);
assert!(limits.max_compute_workgroup_size_x > 0);
}
#[test]
fn test_gpu_device_info_default() {
let info = GpuDeviceInfo::default();
assert!(!info.name.is_empty());
assert_eq!(info.backend, GpuBackend::None);
}
#[test]
fn test_buffer_new() {
let buffer = Buffer::new(1, 1024, BufferUsage::STORAGE);
assert_eq!(buffer.id(), 1);
assert_eq!(buffer.size(), 1024);
assert_eq!(buffer.data().len(), 1024);
}
#[test]
fn test_buffer_with_data() {
let data = vec![1u8, 2, 3, 4];
let buffer = Buffer::with_data(1, &data, BufferUsage::STORAGE);
assert_eq!(buffer.size(), 4);
assert_eq!(buffer.data(), &data);
}
#[test]
fn test_buffer_map_read() {
let buffer = Buffer::new(1, 64, BufferUsage::MAP_READ);
let data = buffer.map_read().unwrap();
assert_eq!(data.len(), 64);
}
#[test]
fn test_buffer_map_read_invalid() {
let buffer = Buffer::new(1, 64, BufferUsage::STORAGE);
assert!(buffer.map_read().is_err());
}
#[test]
fn test_buffer_map_write() {
let mut buffer = Buffer::new(1, 64, BufferUsage::MAP_WRITE);
let data = buffer.map_write().unwrap();
data[0] = 42;
assert_eq!(buffer.data()[0], 42);
}
#[test]
fn test_compute_shader_from_wgsl() {
let shader = ComputeShader::from_wgsl("@compute fn main() {}").unwrap();
assert!(!shader.source().is_empty());
assert_eq!(shader.entry_point(), "main");
}
#[test]
fn test_compute_shader_with_entry_point() {
let shader = ComputeShader::from_wgsl("")
.unwrap()
.with_entry_point("compute_main");
assert_eq!(shader.entry_point(), "compute_main");
}
#[test]
fn test_compute_shader_with_workgroup_size() {
let shader = ComputeShader::from_wgsl("")
.unwrap()
.with_workgroup_size(128, 1, 1);
assert_eq!(shader.workgroup_size(), (128, 1, 1));
}
#[test]
fn test_gpu_device_mock() {
let device = GpuDevice::mock();
assert!(device.is_available());
}
#[test]
fn test_gpu_device_default() {
let device = GpuDevice::default_device().unwrap();
assert!(device.is_available());
}
#[test]
fn test_gpu_device_create_buffer() {
let device = GpuDevice::mock();
let data = vec![1u8, 2, 3, 4];
let buffer = device.create_buffer(&data, BufferUsage::STORAGE).unwrap();
assert_eq!(buffer.size(), 4);
}
#[test]
fn test_gpu_device_create_buffer_uninit() {
let device = GpuDevice::mock();
let buffer = device.create_buffer_uninit(1024, BufferUsage::STORAGE).unwrap();
assert_eq!(buffer.size(), 1024);
}
#[test]
fn test_gpu_device_dispatch() {
let device = GpuDevice::mock();
let shader = ComputeShader::from_wgsl("").unwrap();
device.dispatch(&shader, &[], (64, 1, 1)).unwrap();
assert_eq!(device.dispatch_count(), 1);
}
#[test]
fn test_gpu_device_dispatch_invalid_workgroups() {
let device = GpuDevice::mock();
let shader = ComputeShader::from_wgsl("").unwrap();
let result = device.dispatch(&shader, &[], (100000, 1, 1));
assert!(result.is_err());
}
#[test]
fn test_gpu_device_copy_buffer() {
let device = GpuDevice::mock();
let src_data = vec![1u8, 2, 3, 4];
let src = device
.create_buffer(&src_data, BufferUsage::COPY_SRC)
.unwrap();
let mut dst = device
.create_buffer_uninit(4, BufferUsage::COPY_DST)
.unwrap();
device.copy_buffer(&src, &mut dst).unwrap();
assert_eq!(dst.data(), &src_data);
}
#[test]
fn test_gpu_device_debug() {
let device = GpuDevice::mock();
let debug = format!("{:?}", device);
assert!(debug.contains("GpuDevice"));
}
#[test]
fn test_compute_pipeline_new() {
let shader = ComputeShader::from_wgsl("").unwrap();
let pipeline = ComputePipeline::new(shader);
assert!(pipeline.shader().id() > 0);
}
#[test]
fn test_compute_pipeline_with_bind_groups() {
let shader = ComputeShader::from_wgsl("").unwrap();
let pipeline = ComputePipeline::new(shader).with_bind_groups(2);
let debug = format!("{:?}", pipeline);
assert!(debug.contains("bind_groups: 2"));
}
}