use crate::Device;
use torsh_core::{
dtype::DType,
error::{Result, TorshError},
shape::Shape,
};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
#[cfg(not(feature = "std"))]
use core::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "std")]
use std::sync::atomic::{AtomicUsize, Ordering};
static BUFFER_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
pub fn generate_buffer_id() -> usize {
BUFFER_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
}
#[derive(Debug, Clone)]
pub struct Buffer {
pub id: usize,
pub device: Device,
pub size: usize,
pub usage: BufferUsage,
pub descriptor: BufferDescriptor,
pub handle: BufferHandle,
}
impl Buffer {
pub fn new(
id: usize,
device: Device,
size: usize,
usage: BufferUsage,
descriptor: BufferDescriptor,
handle: BufferHandle,
) -> Self {
Self {
id,
device,
size,
usage,
descriptor,
handle,
}
}
pub fn id(&self) -> usize {
self.id
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn size(&self) -> usize {
self.size
}
pub fn usage(&self) -> BufferUsage {
self.usage
}
pub fn handle(&self) -> &BufferHandle {
&self.handle
}
pub fn supports_usage(&self, usage: BufferUsage) -> bool {
self.usage.contains(usage)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BufferDescriptor {
pub size: usize,
pub usage: BufferUsage,
pub location: MemoryLocation,
pub dtype: Option<DType>,
pub shape: Option<Shape>,
pub initial_data: Option<Vec<u8>>,
pub alignment: Option<usize>,
pub zero_init: bool,
}
impl BufferDescriptor {
pub fn new(size: usize, usage: BufferUsage) -> Self {
Self {
size,
usage,
location: MemoryLocation::Device,
dtype: None,
shape: None,
initial_data: None,
alignment: None,
zero_init: false,
}
}
pub fn with_location(mut self, location: MemoryLocation) -> Self {
self.location = location;
self
}
pub fn with_dtype(mut self, dtype: DType) -> Self {
self.dtype = Some(dtype);
self
}
pub fn with_shape(mut self, shape: Shape) -> Self {
self.shape = Some(shape);
self
}
pub fn with_initial_data(mut self, data: Vec<u8>) -> Self {
self.initial_data = Some(data);
self
}
pub fn with_alignment(mut self, alignment: usize) -> Self {
self.alignment = Some(alignment);
self
}
pub fn with_zero_init(mut self) -> Self {
self.zero_init = true;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BufferUsage {
bits: u32,
}
impl BufferUsage {
pub const NONE: Self = Self { bits: 0 };
pub const READ: Self = Self { bits: 1 << 0 };
pub const WRITE: Self = Self { bits: 1 << 1 };
pub const STORAGE: Self = Self { bits: 1 << 2 };
pub const UNIFORM: Self = Self { bits: 1 << 3 };
pub const VERTEX: Self = Self { bits: 1 << 4 };
pub const INDEX: Self = Self { bits: 1 << 5 };
pub const COPY_SRC: Self = Self { bits: 1 << 6 };
pub const COPY_DST: Self = Self { bits: 1 << 7 };
pub const MAP_READ: Self = Self { bits: 1 << 8 };
pub const MAP_WRITE: Self = Self { bits: 1 << 9 };
pub const READ_WRITE: Self = Self {
bits: Self::READ.bits | Self::WRITE.bits,
};
pub const STORAGE_READ_WRITE: Self = Self {
bits: Self::STORAGE.bits | Self::READ.bits | Self::WRITE.bits,
};
pub const fn new(bits: u32) -> Self {
Self { bits }
}
pub const fn contains(self, other: Self) -> bool {
(self.bits & other.bits) == other.bits
}
pub const fn union(self, other: Self) -> Self {
Self {
bits: self.bits | other.bits,
}
}
pub const fn difference(self, other: Self) -> Self {
Self {
bits: self.bits & !other.bits,
}
}
pub const fn bits(self) -> u32 {
self.bits
}
}
impl std::ops::BitOr for BufferUsage {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
self.union(rhs)
}
}
impl std::ops::BitOrAssign for BufferUsage {
fn bitor_assign(&mut self, rhs: Self) {
*self = *self | rhs;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryLocation {
#[default]
Device,
Host,
Unified,
HostCached,
DeviceHost,
}
#[derive(Debug)]
pub enum BufferHandle {
Cpu { ptr: *mut u8, size: usize },
#[cfg(feature = "cuda")]
Cuda { device_ptr: u64, size: usize },
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
Metal { buffer_id: u64, size: usize },
#[cfg(feature = "webgpu")]
WebGpu { buffer_ptr: u64, size: usize },
Generic {
handle: Box<dyn std::any::Any + Send + Sync>,
size: usize,
},
}
impl Clone for BufferHandle {
fn clone(&self) -> Self {
match self {
BufferHandle::Cpu { ptr, size } => BufferHandle::Cpu {
ptr: *ptr,
size: *size,
},
#[cfg(feature = "cuda")]
BufferHandle::Cuda { device_ptr, size } => BufferHandle::Cuda {
device_ptr: *device_ptr,
size: *size,
},
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
BufferHandle::Metal { buffer_id, size } => BufferHandle::Metal {
buffer_id: *buffer_id,
size: *size,
},
#[cfg(feature = "webgpu")]
BufferHandle::WebGpu { buffer_ptr, size } => BufferHandle::WebGpu {
buffer_ptr: *buffer_ptr,
size: *size,
},
BufferHandle::Generic { .. } => {
panic!("Cannot clone Generic buffer handles")
}
}
}
}
impl BufferHandle {
pub fn size(&self) -> usize {
match self {
BufferHandle::Cpu { size, .. } => *size,
#[cfg(feature = "cuda")]
BufferHandle::Cuda { size, .. } => *size,
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
BufferHandle::Metal { size, .. } => *size,
#[cfg(feature = "webgpu")]
BufferHandle::WebGpu { size, .. } => *size,
BufferHandle::Generic { size, .. } => *size,
}
}
pub fn id(&self) -> usize {
match self {
BufferHandle::Cpu { ptr, .. } => *ptr as usize,
#[cfg(feature = "cuda")]
BufferHandle::Cuda { device_ptr, .. } => *device_ptr as usize,
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
BufferHandle::Metal { buffer_id, .. } => *buffer_id as usize,
#[cfg(feature = "webgpu")]
BufferHandle::WebGpu { buffer_ptr, .. } => *buffer_ptr as usize,
BufferHandle::Generic { .. } => 0, }
}
pub fn is_valid(&self) -> bool {
match self {
BufferHandle::Cpu { ptr, size } => !ptr.is_null() && *size > 0,
#[cfg(feature = "cuda")]
BufferHandle::Cuda { device_ptr, size } => *device_ptr != 0 && *size > 0,
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
BufferHandle::Metal { buffer_id, size } => *buffer_id != 0 && *size > 0,
#[cfg(feature = "webgpu")]
BufferHandle::WebGpu { buffer_ptr, size } => *buffer_ptr != 0 && *size > 0,
BufferHandle::Generic { size, .. } => *size > 0,
}
}
}
unsafe impl Send for BufferHandle {}
unsafe impl Sync for BufferHandle {}
impl PartialEq for BufferHandle {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
BufferHandle::Cpu {
ptr: ptr1,
size: size1,
},
BufferHandle::Cpu {
ptr: ptr2,
size: size2,
},
) => ptr1 == ptr2 && size1 == size2,
#[cfg(feature = "cuda")]
(
BufferHandle::Cuda {
device_ptr: ptr1,
size: size1,
},
BufferHandle::Cuda {
device_ptr: ptr2,
size: size2,
},
) => ptr1 == ptr2 && size1 == size2,
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
(
BufferHandle::Metal {
buffer_id: id1,
size: size1,
},
BufferHandle::Metal {
buffer_id: id2,
size: size2,
},
) => id1 == id2 && size1 == size2,
#[cfg(feature = "webgpu")]
(
BufferHandle::WebGpu {
buffer_ptr: ptr1,
size: size1,
},
BufferHandle::WebGpu {
buffer_ptr: ptr2,
size: size2,
},
) => ptr1 == ptr2 && size1 == size2,
(
BufferHandle::Generic { size: size1, .. },
BufferHandle::Generic { size: size2, .. },
) => {
size1 == size2
}
_ => false,
}
}
}
impl Eq for BufferHandle {}
impl std::hash::Hash for BufferHandle {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
BufferHandle::Cpu { ptr, size } => {
0u8.hash(state); (*ptr as usize).hash(state);
size.hash(state);
}
#[cfg(feature = "cuda")]
BufferHandle::Cuda { device_ptr, size } => {
1u8.hash(state); device_ptr.hash(state);
size.hash(state);
}
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
BufferHandle::Metal { buffer_id, size } => {
2u8.hash(state); buffer_id.hash(state);
size.hash(state);
}
#[cfg(feature = "webgpu")]
BufferHandle::WebGpu { buffer_ptr, size } => {
3u8.hash(state); buffer_ptr.hash(state);
size.hash(state);
}
BufferHandle::Generic { size, .. } => {
4u8.hash(state); size.hash(state);
}
}
}
}
#[derive(Debug)]
pub struct BufferView {
pub buffer: Buffer,
pub offset: usize,
pub size: usize,
pub dtype: Option<DType>,
pub shape: Option<Shape>,
}
impl BufferView {
pub fn new(buffer: Buffer, offset: usize, size: usize) -> Result<Self> {
if offset + size > buffer.size {
return Err(TorshError::InvalidArgument(
"Buffer view exceeds buffer bounds".to_string(),
));
}
Ok(Self {
buffer,
offset,
size,
dtype: None,
shape: None,
})
}
pub fn typed(mut self, dtype: DType) -> Self {
self.dtype = Some(dtype);
self
}
pub fn shaped(mut self, shape: Shape) -> Self {
self.shape = Some(shape);
self
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn size(&self) -> usize {
self.size
}
pub fn end_offset(&self) -> usize {
self.offset + self.size
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::{Device, DeviceInfo};
use torsh_core::{device::DeviceType, dtype::DType, shape::Shape};
fn create_test_device() -> Device {
let info = DeviceInfo::default();
Device::new(0, DeviceType::Cpu, "Test CPU".to_string(), info)
}
#[test]
fn test_buffer_descriptor_creation() {
let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
assert_eq!(desc.size, 1024);
assert_eq!(desc.usage, BufferUsage::READ_WRITE);
assert_eq!(desc.location, MemoryLocation::Device);
assert_eq!(desc.dtype, None);
assert_eq!(desc.shape, None);
assert_eq!(desc.initial_data, None);
assert_eq!(desc.alignment, None);
assert!(!desc.zero_init);
}
#[test]
fn test_buffer_descriptor_builder() {
let desc = BufferDescriptor::new(2048, BufferUsage::STORAGE)
.with_location(MemoryLocation::Host)
.with_dtype(DType::F32)
.with_shape(Shape::new(vec![64, 32]))
.with_alignment(64)
.with_zero_init();
assert_eq!(desc.size, 2048);
assert_eq!(desc.usage, BufferUsage::STORAGE);
assert_eq!(desc.location, MemoryLocation::Host);
assert_eq!(desc.dtype, Some(DType::F32));
assert!(desc.shape.is_some());
assert_eq!(desc.alignment, Some(64));
assert!(desc.zero_init);
}
#[test]
fn test_buffer_usage_flags() {
let usage = BufferUsage::READ | BufferUsage::WRITE;
assert!(usage.contains(BufferUsage::READ));
assert!(usage.contains(BufferUsage::WRITE));
assert!(!usage.contains(BufferUsage::STORAGE));
let combined = BufferUsage::STORAGE_READ_WRITE;
assert!(combined.contains(BufferUsage::STORAGE));
assert!(combined.contains(BufferUsage::READ));
assert!(combined.contains(BufferUsage::WRITE));
}
#[test]
fn test_buffer_handle_validation() {
let handle_valid = BufferHandle::Cpu {
ptr: 0x1000 as *mut u8,
size: 1024,
};
assert!(handle_valid.is_valid());
assert_eq!(handle_valid.size(), 1024);
let handle_invalid = BufferHandle::Cpu {
ptr: std::ptr::null_mut(),
size: 1024,
};
assert!(!handle_invalid.is_valid());
}
#[test]
fn test_buffer_creation() {
let device = create_test_device();
let desc = BufferDescriptor::new(512, BufferUsage::READ_WRITE);
let handle = BufferHandle::Cpu {
ptr: 0x2000 as *mut u8,
size: 512,
};
let buffer = Buffer::new(
1,
device.clone(),
512,
BufferUsage::READ_WRITE,
desc.clone(),
handle,
);
assert_eq!(buffer.id(), 1);
assert_eq!(buffer.size(), 512);
assert_eq!(buffer.usage(), BufferUsage::READ_WRITE);
assert_eq!(buffer.device().id(), device.id());
assert!(buffer.supports_usage(BufferUsage::READ));
assert!(buffer.supports_usage(BufferUsage::WRITE));
assert!(!buffer.supports_usage(BufferUsage::STORAGE));
}
#[test]
fn test_buffer_view_creation() {
let device = create_test_device();
let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
let handle = BufferHandle::Cpu {
ptr: 0x3000 as *mut u8,
size: 1024,
};
let buffer = Buffer::new(1, device, 1024, BufferUsage::READ_WRITE, desc, handle);
let view = BufferView::new(buffer, 256, 512).unwrap();
assert_eq!(view.offset(), 256);
assert_eq!(view.size(), 512);
assert_eq!(view.end_offset(), 768);
let device2 = create_test_device();
let desc2 = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
let handle2 = BufferHandle::Cpu {
ptr: 0x3001 as *mut u8,
size: 1024,
};
let buffer2 = Buffer::new(2, device2, 1024, BufferUsage::READ_WRITE, desc2, handle2);
let invalid_view = BufferView::new(buffer2, 800, 512);
assert!(invalid_view.is_err());
}
#[test]
fn test_buffer_view_typed() {
let device = create_test_device();
let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
let handle = BufferHandle::Cpu {
ptr: 0x4000 as *mut u8,
size: 1024,
};
let buffer = Buffer::new(1, device, 1024, BufferUsage::READ_WRITE, desc, handle);
let view = BufferView::new(buffer, 0, 1024)
.unwrap()
.typed(DType::F32)
.shaped(Shape::new(vec![256]));
assert_eq!(view.dtype, Some(DType::F32));
assert!(view.shape.is_some());
}
#[test]
fn test_memory_location_variants() {
assert_eq!(MemoryLocation::default(), MemoryLocation::Device);
let locations = [
MemoryLocation::Device,
MemoryLocation::Host,
MemoryLocation::Unified,
MemoryLocation::HostCached,
MemoryLocation::DeviceHost,
];
for location in locations {
let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE).with_location(location);
assert_eq!(desc.location, location);
}
}
}