use crate::error::{AccelError, AccelResult};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BufferHandle(pub(crate) u64);
impl BufferHandle {
#[must_use]
pub fn new(id: u64) -> Self {
Self(id)
}
#[must_use]
pub fn id(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct DispatchParams {
pub groups_x: u32,
pub groups_y: u32,
pub groups_z: u32,
}
impl DispatchParams {
#[must_use]
pub const fn new_1d(groups_x: u32) -> Self {
Self {
groups_x,
groups_y: 1,
groups_z: 1,
}
}
#[must_use]
pub const fn new_2d(groups_x: u32, groups_y: u32) -> Self {
Self {
groups_x,
groups_y,
groups_z: 1,
}
}
#[must_use]
pub const fn new_3d(groups_x: u32, groups_y: u32, groups_z: u32) -> Self {
Self {
groups_x,
groups_y,
groups_z,
}
}
#[must_use]
pub fn for_image(width: u32, height: u32, local_x: u32, local_y: u32) -> Self {
Self {
groups_x: width.div_ceil(local_x),
groups_y: height.div_ceil(local_y),
groups_z: 1,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct GpuMemoryStats {
pub allocated_bytes: u64,
pub peak_bytes: u64,
pub allocation_count: u64,
}
impl GpuMemoryStats {
#[must_use]
pub fn allocated_mib(&self) -> f64 {
self.allocated_bytes as f64 / (1024.0 * 1024.0)
}
#[must_use]
pub fn peak_mib(&self) -> f64 {
self.peak_bytes as f64 / (1024.0 * 1024.0)
}
}
pub trait ComputeBackend: Send + Sync {
fn allocate_buffer(&self, size: u64) -> AccelResult<BufferHandle>;
fn upload_buffer(&self, handle: &BufferHandle, data: &[u8]) -> AccelResult<()>;
fn download_buffer(&self, handle: &BufferHandle) -> AccelResult<Vec<u8>>;
fn free_buffer(&self, handle: BufferHandle) -> AccelResult<()>;
fn dispatch_kernel(
&self,
kernel_name: &str,
buffers: &[&BufferHandle],
dispatch: DispatchParams,
) -> AccelResult<()>;
fn synchronize(&self) -> AccelResult<()>;
fn backend_name(&self) -> &str;
fn is_gpu(&self) -> bool;
fn memory_stats(&self) -> GpuMemoryStats;
}
#[derive(Clone)]
struct KernelEntry {
#[allow(dead_code)]
spirv: Vec<u8>,
label: String,
}
pub struct KernelRegistry {
kernels: RwLock<HashMap<String, KernelEntry>>,
}
impl KernelRegistry {
#[must_use]
pub fn new() -> Self {
Self {
kernels: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, name: &str, spirv: &[u8], label: &str) {
let mut map = self.kernels.write().unwrap_or_else(|e| e.into_inner());
map.insert(
name.to_string(),
KernelEntry {
spirv: spirv.to_vec(),
label: label.to_string(),
},
);
}
#[must_use]
pub fn get(&self, name: &str) -> Option<String> {
let map = self.kernels.read().unwrap_or_else(|e| e.into_inner());
map.get(name).map(|e| e.label.clone())
}
#[must_use]
pub fn len(&self) -> usize {
self.kernels.read().unwrap_or_else(|e| e.into_inner()).len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.kernels.read().unwrap_or_else(|e| e.into_inner()).is_empty()
}
#[must_use]
pub fn kernel_names(&self) -> Vec<String> {
let map = self.kernels.read().unwrap_or_else(|e| e.into_inner());
map.keys().cloned().collect()
}
}
impl Default for KernelRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct YuvFrameInfo {
pub width: u32,
pub height: u32,
pub chroma_subsample_x: u32,
pub chroma_subsample_y: u32,
}
impl YuvFrameInfo {
#[must_use]
pub const fn yuv420(width: u32, height: u32) -> Self {
Self {
width,
height,
chroma_subsample_x: 2,
chroma_subsample_y: 2,
}
}
#[must_use]
pub const fn yuv444(width: u32, height: u32) -> Self {
Self {
width,
height,
chroma_subsample_x: 1,
chroma_subsample_y: 1,
}
}
#[must_use]
pub const fn luma_size(&self) -> u64 {
(self.width as u64) * (self.height as u64)
}
#[must_use]
pub fn chroma_size(&self) -> u64 {
let cw = self.width.div_ceil(self.chroma_subsample_x) as u64;
let ch = self.height.div_ceil(self.chroma_subsample_y) as u64;
cw * ch
}
#[must_use]
pub fn total_size(&self) -> u64 {
self.luma_size() + 2 * self.chroma_size()
}
}
pub fn upload_yuv_frame(
backend: &dyn ComputeBackend,
data: &[u8],
info: &YuvFrameInfo,
) -> AccelResult<BufferHandle> {
let expected = info.total_size() as usize;
if data.len() != expected {
return Err(AccelError::BufferSizeMismatch {
expected,
actual: data.len(),
});
}
let handle = backend.allocate_buffer(info.total_size())?;
backend.upload_buffer(&handle, data)?;
Ok(handle)
}
pub fn download_yuv_frame(
backend: &dyn ComputeBackend,
handle: &BufferHandle,
) -> AccelResult<Vec<u8>> {
backend.download_buffer(handle)
}
pub struct VulkanComputeBackend {
name: String,
available: bool,
next_id: std::sync::atomic::AtomicU64,
allocations: RwLock<HashMap<u64, u64>>,
total_allocated: std::sync::atomic::AtomicU64,
peak_allocated: std::sync::atomic::AtomicU64,
buffers: RwLock<HashMap<u64, Vec<u8>>>,
}
impl VulkanComputeBackend {
#[must_use]
pub fn new() -> Self {
#[cfg(feature = "vulkan-detect")]
let (name, available) = {
use crate::device::DeviceSelector;
match DeviceSelector::default().select() {
Ok(dev) => (dev.name().to_string(), true),
Err(_) => ("Vulkan (unavailable)".to_string(), false),
}
};
#[cfg(not(feature = "vulkan-detect"))]
let (name, available) = ("Vulkan compute backend".to_string(), false);
Self {
name,
available,
next_id: std::sync::atomic::AtomicU64::new(1),
allocations: RwLock::new(HashMap::new()),
total_allocated: std::sync::atomic::AtomicU64::new(0),
peak_allocated: std::sync::atomic::AtomicU64::new(0),
buffers: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn is_available(&self) -> bool {
self.available
}
fn track_alloc(&self, id: u64, size: u64) {
let mut allocs = self.allocations.write().unwrap_or_else(|e| e.into_inner());
allocs.insert(id, size);
let current = self.total_allocated.fetch_add(size, std::sync::atomic::Ordering::Relaxed) + size;
let mut peak = self.peak_allocated.load(std::sync::atomic::Ordering::Relaxed);
while current > peak {
match self.peak_allocated.compare_exchange_weak(
peak,
current,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => peak = x,
}
}
}
fn track_free(&self, id: u64) {
let mut allocs = self.allocations.write().unwrap_or_else(|e| e.into_inner());
if let Some(size) = allocs.remove(&id) {
self.total_allocated.fetch_sub(size, std::sync::atomic::Ordering::Relaxed);
}
}
}
impl Default for VulkanComputeBackend {
fn default() -> Self {
Self::new()
}
}
impl ComputeBackend for VulkanComputeBackend {
fn allocate_buffer(&self, size: u64) -> AccelResult<BufferHandle> {
let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.track_alloc(id, size);
let mut bufs = self.buffers.write().unwrap_or_else(|e| e.into_inner());
bufs.insert(id, vec![0u8; size as usize]);
Ok(BufferHandle::new(id))
}
fn upload_buffer(&self, handle: &BufferHandle, data: &[u8]) -> AccelResult<()> {
let mut bufs = self.buffers.write().unwrap_or_else(|e| e.into_inner());
match bufs.get_mut(&handle.0) {
Some(buf) if buf.len() == data.len() => {
buf.copy_from_slice(data);
Ok(())
}
Some(buf) => Err(AccelError::BufferSizeMismatch {
expected: buf.len(),
actual: data.len(),
}),
None => Err(AccelError::BufferAllocation(
format!("Invalid buffer handle: {}", handle.0),
)),
}
}
fn download_buffer(&self, handle: &BufferHandle) -> AccelResult<Vec<u8>> {
let bufs = self.buffers.read().unwrap_or_else(|e| e.into_inner());
bufs.get(&handle.0)
.map(Clone::clone)
.ok_or_else(|| AccelError::BufferAllocation(format!("Invalid buffer handle: {}", handle.0)))
}
fn free_buffer(&self, handle: BufferHandle) -> AccelResult<()> {
self.track_free(handle.0);
let mut bufs = self.buffers.write().unwrap_or_else(|e| e.into_inner());
if bufs.remove(&handle.0).is_none() {
return Err(AccelError::BufferAllocation(
format!("Double-free of buffer handle: {}", handle.0),
));
}
Ok(())
}
fn dispatch_kernel(
&self,
kernel_name: &str,
_buffers: &[&BufferHandle],
_dispatch: DispatchParams,
) -> AccelResult<()> {
tracing::debug!(
"VulkanComputeBackend::dispatch_kernel: '{}' (no-op in abstraction layer)",
kernel_name
);
Ok(())
}
fn synchronize(&self) -> AccelResult<()> {
Ok(())
}
fn backend_name(&self) -> &str {
&self.name
}
fn is_gpu(&self) -> bool {
self.available
}
fn memory_stats(&self) -> GpuMemoryStats {
let allocs = self.allocations.read().unwrap_or_else(|e| e.into_inner());
GpuMemoryStats {
allocated_bytes: self.total_allocated.load(std::sync::atomic::Ordering::Relaxed),
peak_bytes: self.peak_allocated.load(std::sync::atomic::Ordering::Relaxed),
allocation_count: allocs.len() as u64,
}
}
}
pub struct CpuFallbackBackend {
next_id: std::sync::atomic::AtomicU64,
allocations: RwLock<HashMap<u64, u64>>,
buffers: RwLock<HashMap<u64, Vec<u8>>>,
current_bytes: std::sync::atomic::AtomicU64,
peak_bytes: std::sync::atomic::AtomicU64,
}
impl CpuFallbackBackend {
#[must_use]
pub fn new() -> Self {
Self {
next_id: std::sync::atomic::AtomicU64::new(1),
allocations: RwLock::new(HashMap::new()),
buffers: RwLock::new(HashMap::new()),
current_bytes: std::sync::atomic::AtomicU64::new(0),
peak_bytes: std::sync::atomic::AtomicU64::new(0),
}
}
fn update_peak(&self, current: u64) {
let mut peak = self.peak_bytes.load(std::sync::atomic::Ordering::Relaxed);
while current > peak {
match self.peak_bytes.compare_exchange_weak(
peak,
current,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => peak = x,
}
}
}
}
impl Default for CpuFallbackBackend {
fn default() -> Self {
Self::new()
}
}
impl ComputeBackend for CpuFallbackBackend {
fn allocate_buffer(&self, size: u64) -> AccelResult<BufferHandle> {
let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
{
let mut allocs = self.allocations.write().unwrap_or_else(|e| e.into_inner());
allocs.insert(id, size);
}
let current = self.current_bytes.fetch_add(size, std::sync::atomic::Ordering::Relaxed) + size;
self.update_peak(current);
let mut bufs = self.buffers.write().unwrap_or_else(|e| e.into_inner());
bufs.insert(id, vec![0u8; size as usize]);
Ok(BufferHandle::new(id))
}
fn upload_buffer(&self, handle: &BufferHandle, data: &[u8]) -> AccelResult<()> {
let mut bufs = self.buffers.write().unwrap_or_else(|e| e.into_inner());
match bufs.get_mut(&handle.0) {
Some(buf) if buf.len() == data.len() => {
buf.copy_from_slice(data);
Ok(())
}
Some(buf) => Err(AccelError::BufferSizeMismatch {
expected: buf.len(),
actual: data.len(),
}),
None => Err(AccelError::BufferAllocation(
format!("Invalid buffer handle: {}", handle.0),
)),
}
}
fn download_buffer(&self, handle: &BufferHandle) -> AccelResult<Vec<u8>> {
let bufs = self.buffers.read().unwrap_or_else(|e| e.into_inner());
bufs.get(&handle.0)
.map(Clone::clone)
.ok_or_else(|| AccelError::BufferAllocation(format!("Invalid buffer handle: {}", handle.0)))
}
fn free_buffer(&self, handle: BufferHandle) -> AccelResult<()> {
let size = {
let mut allocs = self.allocations.write().unwrap_or_else(|e| e.into_inner());
allocs.remove(&handle.0)
};
match size {
Some(s) => {
self.current_bytes.fetch_sub(s, std::sync::atomic::Ordering::Relaxed);
let mut bufs = self.buffers.write().unwrap_or_else(|e| e.into_inner());
bufs.remove(&handle.0);
Ok(())
}
None => Err(AccelError::BufferAllocation(
format!("Double-free of buffer handle: {}", handle.0),
)),
}
}
fn dispatch_kernel(
&self,
kernel_name: &str,
_buffers: &[&BufferHandle],
_dispatch: DispatchParams,
) -> AccelResult<()> {
tracing::debug!(
"CpuFallbackBackend::dispatch_kernel: '{}' (CPU path; use CpuAccel for real work)",
kernel_name
);
Ok(())
}
fn synchronize(&self) -> AccelResult<()> {
Ok(())
}
fn backend_name(&self) -> &str {
"CPU Fallback"
}
fn is_gpu(&self) -> bool {
false
}
fn memory_stats(&self) -> GpuMemoryStats {
let allocs = self.allocations.read().unwrap_or_else(|e| e.into_inner());
GpuMemoryStats {
allocated_bytes: self.current_bytes.load(std::sync::atomic::Ordering::Relaxed),
peak_bytes: self.peak_bytes.load(std::sync::atomic::Ordering::Relaxed),
allocation_count: allocs.len() as u64,
}
}
}
#[must_use]
pub fn create_backend() -> Arc<dyn ComputeBackend> {
let vulkan = VulkanComputeBackend::new();
if vulkan.is_available() {
tracing::info!("ComputeBackend: using Vulkan ({})", vulkan.name);
Arc::new(vulkan)
} else {
tracing::info!("ComputeBackend: Vulkan unavailable, using CPU fallback");
Arc::new(CpuFallbackBackend::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cpu_backend() -> CpuFallbackBackend {
CpuFallbackBackend::new()
}
#[test]
fn test_cpu_backend_alloc_upload_download_free() {
let b = make_cpu_backend();
let data = vec![1u8, 2, 3, 4, 5, 6];
let h = b.allocate_buffer(data.len() as u64).expect("h should be valid");
b.upload_buffer(&h, &data).expect("upload_buffer should succeed");
let out = b.download_buffer(&h).expect("out should be valid");
assert_eq!(out, data);
b.free_buffer(h).expect("free_buffer should succeed");
}
#[test]
fn test_cpu_backend_size_mismatch() {
let b = make_cpu_backend();
let h = b.allocate_buffer(4).expect("h should be valid");
let result = b.upload_buffer(&h, &[1u8, 2, 3]); assert!(result.is_err());
b.free_buffer(h).expect("free_buffer should succeed");
}
#[test]
fn test_cpu_backend_double_free() {
let b = make_cpu_backend();
let h = b.allocate_buffer(8).expect("h should be valid");
let h2 = BufferHandle::new(h.0);
b.free_buffer(h).expect("free_buffer should succeed");
assert!(b.free_buffer(h2).is_err());
}
#[test]
fn test_cpu_backend_dispatch_noop() {
let b = make_cpu_backend();
let dispatch = DispatchParams::new_2d(8, 8);
b.dispatch_kernel("my_kernel", &[], dispatch).expect("dispatch_kernel should succeed");
}
#[test]
fn test_cpu_backend_synchronize() {
let b = make_cpu_backend();
b.synchronize().expect("synchronize should succeed");
}
#[test]
fn test_cpu_backend_memory_stats() {
let b = make_cpu_backend();
let h1 = b.allocate_buffer(1024).expect("h1 should be valid");
let h2 = b.allocate_buffer(2048).expect("h2 should be valid");
let stats = b.memory_stats();
assert_eq!(stats.allocated_bytes, 3072);
assert_eq!(stats.allocation_count, 2);
b.free_buffer(h1).expect("free_buffer should succeed");
let stats2 = b.memory_stats();
assert_eq!(stats2.allocated_bytes, 2048);
assert_eq!(stats2.peak_bytes, 3072);
b.free_buffer(h2).expect("free_buffer should succeed");
}
#[test]
fn test_kernel_registry() {
let reg = KernelRegistry::new();
assert!(reg.is_empty());
reg.register("bilinear", b"\x03\x02\x23\x07", "Bilinear scale");
assert_eq!(reg.len(), 1);
assert!(reg.get("bilinear").is_some());
assert!(reg.get("unknown").is_none());
let names = reg.kernel_names();
assert!(names.contains(&"bilinear".to_string()));
}
#[test]
fn test_yuv_frame_info() {
let info = YuvFrameInfo::yuv420(1920, 1080);
assert_eq!(info.luma_size(), 1920 * 1080);
assert_eq!(info.chroma_size(), 960 * 540);
assert_eq!(info.total_size(), 1920 * 1080 + 2 * 960 * 540);
}
#[test]
fn test_upload_download_yuv_frame() {
let b = CpuFallbackBackend::new();
let info = YuvFrameInfo::yuv420(4, 4);
let data = vec![0u8; info.total_size() as usize];
let h = upload_yuv_frame(&b, &data, &info).expect("h should be valid");
let out = download_yuv_frame(&b, &h).expect("out should be valid");
assert_eq!(out.len(), data.len());
b.free_buffer(h).expect("free_buffer should succeed");
}
#[test]
fn test_dispatch_params() {
let p = DispatchParams::for_image(1920, 1080, 16, 16);
assert_eq!(p.groups_x, 120); assert_eq!(p.groups_y, 68); }
#[test]
fn test_vulkan_backend_creation() {
let b = VulkanComputeBackend::new();
assert!(!b.backend_name().is_empty());
}
#[test]
fn test_vulkan_backend_alloc_upload_download() {
let b = VulkanComputeBackend::new();
let data = vec![42u8; 16];
let h = b.allocate_buffer(16).expect("h should be valid");
b.upload_buffer(&h, &data).expect("upload_buffer should succeed");
let out = b.download_buffer(&h).expect("out should be valid");
assert_eq!(out, data);
b.free_buffer(h).expect("free_buffer should succeed");
}
}