use bs62::num_traits;
use cudarc::driver::{CudaContext, CudaSlice, CudaStream, DevicePtr};
use dynamo_runtime::{error, raise, Result};
use ndarray::{ArrayViewMut, IxDyn};
use std::any::Any;
use std::ffi::c_void;
use std::ptr::NonNull;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StorageType {
Device(Arc<CudaContext>),
Pinned,
System, }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DType {
F32,
F16,
BF16,
FP8,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
}
impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
DType::F32 => 4,
DType::F16 => 2,
DType::BF16 => 2,
DType::FP8 => 1,
DType::U8 => 1,
DType::U16 => 2,
DType::U32 => 4,
DType::U64 => 8,
DType::I8 => 1,
DType::I16 => 2,
DType::I32 => 4,
DType::I64 => 8,
}
}
}
extern "C" {
fn cuda_malloc_host(ptr: *mut *mut c_void, size: usize) -> i32;
fn cuda_free_host(ptr: *mut c_void) -> i32;
fn cuda_memcpy_async(
dst: *mut c_void,
src: *const c_void,
count: usize,
stream: *mut c_void,
) -> i32;
fn cuda_memcpy_sync(dst: *mut c_void, src: *const c_void, count: usize) -> i32;
}
pub trait Storage: std::fmt::Debug {
fn get_pointer(&self) -> u64;
fn storage_size(&self) -> usize;
fn storage_type(&self) -> StorageType;
fn view<const D: usize>(
&self,
shape: [usize; D],
dtype: DType,
) -> Result<TensorView<'_, Self, D>>
where
Self: Sized,
{
TensorView::new(self, shape, dtype.size_in_bytes())
}
}
#[derive(Clone)]
pub struct OwnedStorage {
storage: Arc<dyn Storage>,
}
impl OwnedStorage {
pub fn new(storage: Arc<dyn Storage>) -> Self {
Self { storage }
}
pub fn create(bytes: usize, storage_type: StorageType) -> Result<Self> {
match storage_type {
StorageType::Device(device) => Self::create_device_array(bytes, device),
StorageType::Pinned => Self::create_pinned_array(bytes),
StorageType::System => {
raise!("System memory not yet supported");
}
}
}
pub fn create_device_array(bytes: usize, device: Arc<CudaContext>) -> Result<Self> {
let device_storage = DeviceStorageOwned::new(bytes, device)?;
Ok(Self::new(Arc::new(device_storage)))
}
pub fn create_pinned_array(bytes: usize) -> Result<Self> {
let pinned_memory = CudaPinnedMemory::new(bytes)?;
Ok(Self::new(Arc::new(pinned_memory)))
}
pub fn byo_device_array(
device_ptr: u64,
bytes: usize,
device: Arc<CudaContext>,
owner: Arc<dyn Any + Send + Sync>,
) -> Result<Self> {
let device_storage = DeviceStorageFromAny::new(owner, device_ptr, bytes, device);
Ok(Self::new(Arc::new(device_storage)))
}
}
impl Storage for OwnedStorage {
fn get_pointer(&self) -> u64 {
self.storage.get_pointer()
}
fn storage_size(&self) -> usize {
self.storage.storage_size()
}
fn storage_type(&self) -> StorageType {
self.storage.storage_type()
}
}
impl std::fmt::Debug for OwnedStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OwnedStorage")
.field("storage_type", &self.storage.storage_type())
.finish()
}
}
pub struct DeviceStorageOwned {
bytes: usize,
cuda_device: Arc<CudaContext>,
cuda_slice: Arc<CudaSlice<u8>>,
}
impl DeviceStorageOwned {
pub fn new(bytes: usize, device: Arc<CudaContext>) -> Result<Self> {
let cuda_slice = device.default_stream().alloc_zeros::<u8>(bytes)?;
device.default_stream().synchronize()?;
Ok(Self {
bytes,
cuda_device: device,
cuda_slice: Arc::new(cuda_slice),
})
}
pub fn device_ptr(&self) -> *const c_void {
let stream = self.cuda_device.default_stream();
let (ptr, _) = self.cuda_slice.device_ptr(&stream);
ptr as *const c_void
}
pub fn context(&self) -> Arc<CudaContext> {
self.cuda_device.clone()
}
}
impl Storage for DeviceStorageOwned {
fn get_pointer(&self) -> u64 {
self.device_ptr() as u64
}
fn storage_size(&self) -> usize {
self.bytes
}
fn storage_type(&self) -> StorageType {
StorageType::Device(self.cuda_device.clone())
}
}
impl std::fmt::Debug for DeviceStorageOwned {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Storage")
.field("storage_type", &self.storage_type())
.field("storage_size", &self.storage_size())
.finish()
}
}
pub struct CudaPinnedMemory {
ptr: NonNull<c_void>,
bytes: usize,
}
unsafe impl Send for CudaPinnedMemory {}
unsafe impl Sync for CudaPinnedMemory {}
impl CudaPinnedMemory {
pub fn new(bytes: usize) -> Result<Self> {
if bytes == 0 {
raise!("Bytes must be greater than 0");
}
let mut ptr: *mut c_void = std::ptr::null_mut();
let result = unsafe { cuda_malloc_host(&mut ptr, bytes) };
if result != 0 {
raise!("Failed to allocate pinned memory");
}
let ptr =
NonNull::new(ptr).ok_or_else(|| anyhow::anyhow!("Null pointer after allocation"))?;
unsafe {
std::ptr::write_bytes(ptr.as_ptr() as *mut u8, 0, bytes);
}
Ok(Self { ptr, bytes })
}
pub fn as_ptr(&self) -> *const c_void {
self.ptr.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut c_void {
self.ptr.as_ptr()
}
pub fn size(&self) -> usize {
self.bytes
}
}
impl Drop for CudaPinnedMemory {
fn drop(&mut self) {
let result = unsafe { cuda_free_host(self.ptr.as_ptr()) };
if result != 0 {
eprintln!("Failed to free pinned memory");
}
}
}
impl Storage for CudaPinnedMemory {
fn get_pointer(&self) -> u64 {
self.ptr.as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.bytes
}
fn storage_type(&self) -> StorageType {
StorageType::Pinned
}
}
impl std::fmt::Debug for CudaPinnedMemory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaPinnedMemory")
.field("ptr", &(self.ptr.as_ptr() as usize))
.field("bytes", &self.bytes)
.field("storage_type", &self.storage_type())
.finish()
}
}
#[derive(Clone)]
pub struct TensorView<'a, T: Storage, const D: usize> {
storage: &'a T,
shape: [usize; D],
strides: [usize; D],
byte_strides: [usize; D],
offset: usize,
element_size: usize,
total_elements: usize,
}
impl<'a, T: Storage, const D: usize> TensorView<'a, T, D> {
pub fn new(storage: &'a T, shape: [usize; D], element_size: usize) -> Result<Self> {
let mut strides = [0; D];
let mut byte_strides = [0; D];
if D > 0 {
strides[D - 1] = 1; byte_strides[D - 1] = element_size;
for i in (0..D - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
byte_strides[i] = strides[i] * element_size;
}
}
let total_elements = shape.iter().product();
if total_elements * element_size > storage.storage_size() {
return Err(error!(
"Shape {:?} requires {} bytes, but storage only has {} bytes",
shape,
total_elements * element_size,
storage.storage_size()
));
}
Ok(Self {
storage,
shape,
strides,
byte_strides,
offset: 0,
element_size,
total_elements,
})
}
pub fn with_strides(
storage: &'a T,
shape: [usize; D],
strides: [usize; D],
offset: usize,
element_size: usize,
) -> Result<Self, String> {
let byte_strides = strides.map(|stride| stride * element_size);
let total_elements = shape.iter().product();
let max_offset = if D > 0 {
offset + Self::calculate_max_offset(&shape, &byte_strides)
} else {
offset
};
if max_offset > storage.storage_size() {
return Err(format!(
"View would access up to byte offset {}, but storage size is only {} bytes",
max_offset,
storage.storage_size()
));
}
Ok(Self {
storage,
shape,
strides,
byte_strides,
offset,
element_size,
total_elements,
})
}
fn calculate_max_offset(shape: &[usize; D], byte_strides: &[usize; D]) -> usize {
shape
.iter()
.zip(byte_strides.iter())
.map(|(&dim_size, &stride)| {
if dim_size > 0 {
(dim_size - 1) * stride
} else {
0
}
})
.sum()
}
pub fn shape(&self) -> &[usize; D] {
&self.shape
}
pub fn strides(&self) -> &[usize; D] {
&self.strides
}
pub fn byte_strides(&self) -> &[usize; D] {
&self.byte_strides
}
pub fn element_size(&self) -> usize {
self.element_size
}
fn validate_indices(&self, indices: &[usize; D]) -> Result<(), String> {
for (dim, (&idx, &dim_size)) in indices.iter().zip(self.shape.iter()).enumerate() {
if idx >= dim_size {
return Err(format!(
"Index {} out of bounds for dimension {} with size {}",
idx, dim, dim_size
));
}
}
Ok(())
}
pub fn flat_index(&self, indices: &[usize; D]) -> Result<usize, String> {
self.validate_indices(indices)?;
let flat_idx = indices
.iter()
.zip(self.strides.iter())
.fold(0, |acc, (&idx, &stride)| acc + idx * stride);
Ok(flat_idx)
}
pub fn byte_offset(&self, indices: &[usize; D]) -> Result<usize> {
self.validate_indices(indices)
.map_err(|e| error!("{}", e))?;
let offset = indices
.iter()
.zip(self.byte_strides.iter())
.fold(self.offset, |acc, (&idx, &stride)| acc + idx * stride);
Ok(offset)
}
pub fn address(&self, indices: &[usize; D]) -> Result<u64> {
let byte_offset = self.byte_offset(indices)?;
Ok(self.storage.get_pointer() + byte_offset as u64)
}
pub fn in_bounds(&self, indices: &[usize; D]) -> bool {
indices
.iter()
.zip(self.shape.iter())
.all(|(&idx, &dim_size)| idx < dim_size)
}
pub fn get_element<E: bytemuck::Pod + Copy>(&self, indices: &[usize; D]) -> Result<E> {
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Cannot directly access elements from device tensor"))
}
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
let offset = self.byte_offset(indices)?;
let ptr = (self.storage.get_pointer() as *const u8).wrapping_add(offset) as *const E;
let value = unsafe { *ptr };
Ok(value)
}
pub fn set_element<E: bytemuck::Pod + Copy>(
&mut self,
indices: &[usize; D],
value: E,
) -> Result<()> {
match self.storage.storage_type() {
StorageType::Device(_) => return Err(error!("Cannot directly modify device tensor")),
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
let offset = self.byte_offset(indices)?;
let ptr = (self.storage.get_pointer() as *mut u8).wrapping_add(offset) as *mut E;
unsafe { *ptr = value };
Ok(())
}
pub fn fill<E: bytemuck::Pod + Copy>(&mut self, value: E) -> Result<()> {
match self.storage.storage_type() {
StorageType::Device(_) => return Err(error!("Cannot directly modify device tensor")),
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Cannot fill non-contiguous tensor"));
}
let ptr = (self.storage.get_pointer() as *mut u8).wrapping_add(self.offset) as *mut E;
let len = self.total_elements;
unsafe {
let slice = std::slice::from_raw_parts_mut(ptr, len);
slice.fill(value);
}
Ok(())
}
pub fn is_contiguous(&self) -> bool {
if D == 0 {
return true;
}
let mut expected_stride = 1;
let mut expected_byte_stride = self.element_size;
for i in (0..D).rev() {
if self.strides[i] != expected_stride || self.byte_strides[i] != expected_byte_stride {
return false;
}
expected_stride *= self.shape[i];
expected_byte_stride *= self.shape[i];
}
true
}
pub fn num_elements(&self) -> usize {
self.total_elements
}
pub fn data(&self) -> u64 {
self.storage.get_pointer()
}
pub fn size_in_bytes(&self) -> usize {
self.total_elements * self.element_size
}
pub fn copy_to_view_blocking<S: Storage>(
&self,
dst_view: &mut TensorView<'_, S, D>,
) -> Result<()> {
if self.shape != dst_view.shape || self.strides != dst_view.strides {
raise!(
"Shape or strides mismatch: {:?} vs {:?}",
self.shape,
dst_view.shape
);
}
if !self.is_contiguous() {
raise!("Source is not contiguous");
}
if !dst_view.is_contiguous() {
raise!("Destination is not contiguous");
}
assert_eq!(self.size_in_bytes(), dst_view.size_in_bytes());
tracing::debug!("Copying from {:?} to {:?}", self, dst_view);
let rc = unsafe {
cuda_memcpy_sync(
dst_view.data() as *mut c_void,
self.data() as *const c_void,
self.size_in_bytes(),
)
};
if rc != 0 {
raise!("cudaMemcpyAsync failed");
}
Ok(())
}
pub fn slice(&self, dim: usize, start: usize, end: Option<usize>) -> Result<Self, String> {
if dim >= D {
return Err(format!(
"Dimension {} out of bounds for tensor with {} dimensions",
dim, D
));
}
let end_idx = end.unwrap_or(self.shape[dim]);
if end_idx > self.shape[dim] {
return Err(format!(
"End index {} out of bounds for dimension {} with size {}",
end_idx, dim, self.shape[dim]
));
}
if start >= end_idx {
return Err(format!(
"Invalid slice range: start={}, end={}",
start, end_idx
));
}
let mut new_shape = self.shape;
new_shape[dim] = end_idx - start;
let new_offset = self.offset + start * self.byte_strides[dim];
Ok(Self {
storage: self.storage,
shape: new_shape,
strides: self.strides,
byte_strides: self.byte_strides,
offset: new_offset,
element_size: self.element_size,
total_elements: new_shape.iter().product(),
})
}
pub fn as_ndarray_view<DT>(&self) -> Result<ndarray::ArrayView<'_, DT, IxDyn>>
{
match self.storage.storage_type() {
StorageType::Device(_) => raise!("Cannot convert device tensor to ndarray"),
StorageType::System | StorageType::Pinned => {}
};
self.as_unsafe_ndarray_view::<DT>()
}
pub(crate) fn as_unsafe_ndarray_view<DT>(&self) -> Result<ndarray::ArrayView<'_, DT, IxDyn>>
{
if std::mem::size_of::<DT>() != self.element_size {
return Err(anyhow::anyhow!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<DT>(),
self.element_size
));
}
if !self.is_contiguous() {
raise!("Cannot convert non-contiguous tensor to ndarray");
}
let ptr = self.storage.get_pointer() as *const DT;
let size = self.shape.iter().product::<usize>();
let slice = unsafe { std::slice::from_raw_parts::<DT>(ptr, size) };
let dim = ndarray::IxDyn(&self.shape);
let array = ndarray::ArrayView::from_shape(dim, slice)?;
Ok(array)
}
pub fn as_ndarray_view_mut<DT>(&mut self) -> Result<ArrayViewMut<'_, DT, IxDyn>>
where
DT: bytemuck::Pod,
{
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(anyhow::anyhow!("Cannot convert device tensor to ndarray"))
}
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<DT>() != self.element_size {
return Err(anyhow::anyhow!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<DT>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(anyhow::anyhow!(
"Cannot convert non-contiguous tensor to ndarray"
));
}
let ptr =
(self.storage.get_pointer() as *mut DT).wrapping_add(self.offset / self.element_size);
let size = self.shape.iter().product::<usize>();
let slice = unsafe { std::slice::from_raw_parts_mut(ptr, size) };
let dim = ndarray::IxDyn(&self.shape);
let array = ndarray::ArrayViewMut::from_shape(dim, slice)?;
Ok(array)
}
pub fn storage_type(&self) -> StorageType {
self.storage.storage_type()
}
pub fn indices_iter(&self) -> impl Iterator<Item = [usize; D]> + '_ {
let shape = self.shape;
let total = self.total_elements;
(0..total).map(move |idx| tensor_indexing::unflatten_index(idx, &shape))
}
pub fn map_elements<E, R, F>(&self, f: F) -> Result<Vec<R>>
where
E: bytemuck::Pod + Copy,
F: Fn(E) -> R,
{
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Cannot directly access elements from device tensor"))
}
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Cannot map over elements of non-contiguous tensor"));
}
let ptr = (self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const E;
let len = self.total_elements;
let result = unsafe {
let slice = std::slice::from_raw_parts(ptr, len);
slice.iter().map(|&e| f(e)).collect()
};
Ok(result)
}
pub fn as_slice<E: bytemuck::Pod>(&self) -> Result<&[E]> {
match self.storage.storage_type() {
StorageType::Device(_) => return Err(error!("Cannot get slice from device tensor")),
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Cannot get slice from non-contiguous tensor"));
}
let ptr = (self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const E;
let len = self.total_elements;
let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
Ok(slice)
}
pub fn as_slice_mut<E: bytemuck::Pod>(&mut self) -> Result<&mut [E]> {
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Cannot get mutable slice from device tensor"))
}
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!(
"Cannot get mutable slice from non-contiguous tensor"
));
}
let ptr = (self.storage.get_pointer() as *mut u8).wrapping_add(self.offset) as *mut E;
let len = self.total_elements;
let slice = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
Ok(slice)
}
pub fn h2d<S: Storage>(
&self,
device_view: &mut TensorView<'_, S, D>,
stream: &CudaStream,
) -> Result<()> {
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Source must be a host tensor (System or Pinned)"))
}
StorageType::System | StorageType::Pinned => {}
};
match device_view.storage_type() {
StorageType::Device(_) => {}
_ => return Err(error!("Destination must be a device tensor")),
};
if self.shape != device_view.shape {
return Err(error!(
"Shape mismatch: {:?} vs {:?}",
self.shape, device_view.shape
));
}
if self.element_size != device_view.element_size {
return Err(error!(
"Element size mismatch: {} vs {}",
self.element_size, device_view.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Source tensor must be contiguous"));
}
if !device_view.is_contiguous() {
return Err(error!("Destination tensor must be contiguous"));
}
let src_ptr =
(self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const c_void;
let dst_ptr = (device_view.storage.get_pointer() as *mut u8)
.wrapping_add(device_view.offset) as *mut c_void;
let size_in_bytes = self.size_in_bytes();
let stream_id = stream.cu_stream();
let rc =
unsafe { cuda_memcpy_async(dst_ptr, src_ptr, size_in_bytes, stream_id as *mut c_void) };
if rc != 0 {
return Err(error!(
"cudaMemcpyAsync failed during host-to-device transfer"
));
}
Ok(())
}
pub fn d2h<S: Storage>(
&self,
host_view: &mut TensorView<'_, S, D>,
stream: &CudaStream,
) -> Result<()> {
match self.storage.storage_type() {
StorageType::Device(_) => {}
_ => return Err(error!("Source must be a device tensor")),
};
match host_view.storage_type() {
StorageType::Device(_) => {
return Err(error!(
"Destination must be a host tensor (System or Pinned)"
))
}
StorageType::System | StorageType::Pinned => {}
};
if self.shape != host_view.shape {
return Err(error!(
"Shape mismatch: {:?} vs {:?}",
self.shape, host_view.shape
));
}
if self.element_size != host_view.element_size {
return Err(error!(
"Element size mismatch: {} vs {}",
self.element_size, host_view.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Source tensor must be contiguous"));
}
if !host_view.is_contiguous() {
return Err(error!("Destination tensor must be contiguous"));
}
let src_ptr =
(self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const c_void;
let dst_ptr = (host_view.storage.get_pointer() as *mut u8).wrapping_add(host_view.offset)
as *mut c_void;
let size_in_bytes = self.size_in_bytes();
let stream_id = stream.cu_stream();
let rc =
unsafe { cuda_memcpy_async(dst_ptr, src_ptr, size_in_bytes, stream_id as *mut c_void) };
if rc != 0 {
return Err(error!(
"cudaMemcpyAsync failed during device-to-host transfer"
));
}
Ok(())
}
pub fn to_owned<DT: std::fmt::Debug + Clone + num_traits::Zero>(
&self,
) -> Result<ndarray::Array<DT, IxDyn>> {
match self.storage.storage_type() {
StorageType::System | StorageType::Pinned => {
let nd = self.as_ndarray_view::<DT>()?;
Ok(nd.to_owned())
}
StorageType::Device(_device) => {
let shape = self.shape.to_vec();
let dim = ndarray::IxDyn(&shape);
let mut nd = ndarray::Array::<DT, _>::zeros(dim);
println!("Copying from device to host");
println!("Before copy Values: {:?}", nd);
let rc = unsafe {
cuda_memcpy_sync(
nd.as_mut_ptr() as *mut c_void,
self.storage.get_pointer() as *const c_void,
self.size_in_bytes(),
)
};
if rc != 0 {
return Err(error!(
"cudaMemcpyAsync failed during device-to-host transfer"
));
}
println!("After copy Values: {:?}", nd);
Ok(nd)
}
}
}
}
impl<T: Storage, const D: usize> std::fmt::Debug for TensorView<'_, T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TensorView")
.field("shape", &self.shape)
.field("strides", &self.strides)
.field("byte_strides", &self.byte_strides)
.field("offset", &self.offset)
.field("element_size", &self.element_size)
.field("total_elements", &self.total_elements)
.field("storage_type", &self.storage.storage_type())
.finish()
}
}
pub mod tensor_indexing {
pub fn unflatten_index<const D: usize>(flat_idx: usize, shape: &[usize; D]) -> [usize; D] {
let mut indices = [0; D];
let mut remaining = flat_idx;
let mut strides = [0; D];
if D > 0 {
strides[D - 1] = 1;
for i in (0..D - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
for (i, &stride) in strides.iter().enumerate() {
indices[i] = remaining / stride;
remaining %= stride;
}
indices
}
pub fn calculate_strides<const D: usize>(shape: &[usize; D]) -> [usize; D] {
let mut strides = [0; D];
if D > 0 {
strides[D - 1] = 1; for i in (0..D - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
strides
}
pub fn calculate_byte_strides<const D: usize>(
shape: &[usize; D],
element_size: usize,
) -> [usize; D] {
let mut byte_strides = [0; D];
if D > 0 {
byte_strides[D - 1] = element_size; for i in (0..D - 1).rev() {
byte_strides[i] = byte_strides[i + 1] * shape[i + 1];
}
}
byte_strides
}
}
#[derive(Debug)]
pub struct DeviceStorageFromAny {
source: Arc<dyn Any + Send + Sync>,
device_ptr: u64,
bytes: usize,
device: Arc<CudaContext>,
}
impl DeviceStorageFromAny {
pub fn new(
source: Arc<dyn Any + Send + Sync>,
device_ptr: u64,
bytes: usize,
device: Arc<CudaContext>,
) -> Self {
Self {
source,
device_ptr,
bytes,
device,
}
}
pub fn source(&self) -> &Arc<dyn Any + Send + Sync> {
&self.source
}
pub fn downcast_source<T: 'static + Send + Sync>(&self) -> Option<&T> {
self.source.downcast_ref::<T>()
}
}
impl Storage for DeviceStorageFromAny {
fn get_pointer(&self) -> u64 {
self.device_ptr
}
fn storage_size(&self) -> usize {
self.bytes
}
fn storage_type(&self) -> StorageType {
StorageType::Device(self.device.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct MockTensor {
data_ptr: u64,
storage_size_bytes: usize,
}
impl Storage for MockTensor {
fn get_pointer(&self) -> u64 {
self.data_ptr
}
fn storage_size(&self) -> usize {
self.storage_size_bytes
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
#[test]
fn test_tensor_view_creation() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, };
let shape = [2, 3, 4];
let element_size = 4; let view = TensorView::<_, 3>::new(&mock_tensor, shape, element_size).unwrap();
assert_eq!(view.shape(), &[2, 3, 4]);
assert_eq!(view.strides(), &[12, 4, 1]);
assert_eq!(view.byte_strides(), &[48, 16, 4]);
assert_eq!(view.num_elements(), 24);
assert_eq!(view.size_in_bytes(), 96);
assert!(view.is_contiguous());
}
#[test]
fn test_tensor_view_indexing() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, };
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
assert_eq!(view.flat_index(&[0, 0, 0]).unwrap(), 0);
assert_eq!(view.flat_index(&[0, 0, 1]).unwrap(), 1);
assert_eq!(view.flat_index(&[0, 1, 0]).unwrap(), 4);
assert_eq!(view.flat_index(&[1, 0, 0]).unwrap(), 12);
assert_eq!(view.flat_index(&[1, 2, 3]).unwrap(), 23);
assert_eq!(view.byte_offset(&[0, 0, 0]).unwrap(), 0);
assert_eq!(view.byte_offset(&[0, 0, 1]).unwrap(), 4);
assert_eq!(view.byte_offset(&[0, 1, 0]).unwrap(), 16);
assert_eq!(view.byte_offset(&[1, 0, 0]).unwrap(), 48);
assert_eq!(view.address(&[0, 0, 0]).unwrap(), 0x1000);
assert_eq!(view.address(&[0, 0, 1]).unwrap(), 0x1004);
assert_eq!(view.address(&[1, 2, 3]).unwrap(), 0x1000 + 23 * 4);
}
#[test]
fn test_tensor_view_slicing() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, };
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
let sliced = view.slice(1, 1, Some(3)).unwrap();
assert_eq!(sliced.shape(), &[2, 2, 4]); assert_eq!(sliced.strides(), &[12, 4, 1]); assert_eq!(sliced.byte_strides(), &[48, 16, 4]); assert_eq!(sliced.offset, 16);
assert_eq!(sliced.address(&[0, 0, 0]).unwrap(), 0x1000 + 16);
assert_eq!(
sliced.address(&[1, 1, 3]).unwrap(),
0x1000 + 16 + 48 + 16 + 12
);
}
#[test]
fn test_tensor_views_with_custom_strides() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, };
let shape = [2, 3, 4];
let contiguous_view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
assert!(contiguous_view.is_contiguous());
assert_eq!(contiguous_view.strides(), &[12, 4, 1]);
assert_eq!(contiguous_view.byte_strides(), &[48, 16, 4]);
let smaller_shape = [2, 2, 4];
let non_contiguous_strides = [12, 4, 1];
let non_contiguous = TensorView::<_, 3>::with_strides(
&mock_tensor,
smaller_shape,
non_contiguous_strides,
0,
4,
)
.unwrap();
assert!(!non_contiguous.is_contiguous());
assert_eq!(non_contiguous.strides(), &[12, 4, 1]);
assert_eq!(non_contiguous.byte_strides(), &[48, 16, 4]);
let last_index = [1, 1, 3];
let byte_offset = non_contiguous.byte_offset(&last_index).unwrap();
assert_eq!(byte_offset, (12 + 4 + 3) * 4);
assert!(
byte_offset < mock_tensor.storage_size(),
"Byte offset {} should be less than storage size {}",
byte_offset,
mock_tensor.storage_size()
);
let invalid_custom_strides = [16, 4, 1];
let result =
TensorView::<_, 3>::with_strides(&mock_tensor, shape, invalid_custom_strides, 0, 4);
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(
error_msg.contains("would access up to byte offset 108"),
"Expected error about exceeding storage, got: {}",
error_msg
);
}
#[test]
fn test_tensor_view_with_offset() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 120, };
let shape = [2, 3, 4];
let offset_view =
TensorView::<_, 3>::with_strides(&mock_tensor, shape, [12, 4, 1], 16, 4).unwrap();
assert!(offset_view.is_contiguous());
assert_eq!(offset_view.offset, 16);
let first_byte_offset = offset_view.byte_offset(&[0, 0, 0]).unwrap();
assert_eq!(first_byte_offset, 16);
let last_byte_offset = offset_view.byte_offset(&[1, 2, 3]).unwrap();
assert_eq!(last_byte_offset, 16 + (12 + 2 * 4 + 3) * 4);
let result = TensorView::<_, 3>::with_strides(
&mock_tensor,
shape,
[12, 4, 1],
80, 4,
);
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(
error_msg.contains("would access up to byte offset"),
"Expected error about exceeding storage, got: {}",
error_msg
);
}
#[test]
fn test_in_bounds_method() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, };
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
assert!(view.in_bounds(&[0, 0, 0]));
assert!(view.in_bounds(&[1, 2, 3]));
assert!(!view.in_bounds(&[2, 0, 0])); assert!(!view.in_bounds(&[0, 3, 0])); assert!(!view.in_bounds(&[0, 0, 4])); assert!(!view.in_bounds(&[2, 3, 4])); }
#[test]
fn test_validate_indices() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, };
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
assert!(view.validate_indices(&[0, 0, 0]).is_ok());
assert!(view.validate_indices(&[1, 2, 3]).is_ok());
assert!(view.validate_indices(&[2, 0, 0]).is_err());
assert!(view.validate_indices(&[0, 3, 0]).is_err());
assert!(view.validate_indices(&[0, 0, 4]).is_err());
}
#[test]
fn test_indices_iter() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 24, };
let shape = [2, 3];
let view = TensorView::<_, 2>::new(&mock_tensor, shape, 4).unwrap();
let indices: Vec<[usize; 2]> = view.indices_iter().collect();
let expected_indices = vec![[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]];
assert_eq!(indices, expected_indices);
}
#[test]
fn test_get_set_element() {
use std::sync::{Arc, Mutex};
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
let real_tensor = RealDataMock::new(24);
let shape = [2, 3];
let mut view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
view.set_element::<f32>(&[0, 0], 1.0).unwrap();
view.set_element::<f32>(&[0, 1], 2.0).unwrap();
view.set_element::<f32>(&[1, 2], 6.0).unwrap();
assert_eq!(view.get_element::<f32>(&[0, 0]).unwrap(), 1.0);
assert_eq!(view.get_element::<f32>(&[0, 1]).unwrap(), 2.0);
assert_eq!(view.get_element::<f32>(&[1, 2]).unwrap(), 6.0);
assert_eq!(view.get_element::<f32>(&[0, 2]).unwrap(), 0.0);
assert_eq!(view.get_element::<f32>(&[1, 0]).unwrap(), 0.0);
assert_eq!(view.get_element::<f32>(&[1, 1]).unwrap(), 0.0);
}
#[test]
fn test_fill_method() {
use std::sync::{Arc, Mutex};
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
let real_tensor = RealDataMock::new(24);
let shape = [2, 3];
let mut view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
view.fill::<f32>(42.5).unwrap();
for i in 0..2 {
for j in 0..3 {
assert_eq!(view.get_element::<f32>(&[i, j]).unwrap(), 42.5);
}
}
}
#[test]
fn test_map_elements() {
use std::sync::{Arc, Mutex};
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
fn set_f32_values(&self, values: &[f32]) {
let mut data = self.data.lock().unwrap();
for (i, val) in values.iter().enumerate() {
let bytes = val.to_ne_bytes();
data[i * 4..(i + 1) * 4].copy_from_slice(&bytes);
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
let real_tensor = RealDataMock::new(24);
let values = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
real_tensor.set_f32_values(&values);
let shape = [2, 3];
let view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
let doubled: Vec<f32> = view.map_elements::<f32, f32, _>(|x| x * 2.0).unwrap();
let expected = [2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
assert_eq!(doubled, expected);
let as_ints: Vec<i32> = view.map_elements::<f32, i32, _>(|x| x as i32).unwrap();
let expected_ints = [1, 2, 3, 4, 5, 6];
assert_eq!(as_ints, expected_ints);
}
#[test]
fn test_as_slice() {
use std::sync::{Arc, Mutex};
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
fn set_f32_values(&self, values: &[f32]) {
let mut data = self.data.lock().unwrap();
for (i, val) in values.iter().enumerate() {
let bytes = val.to_ne_bytes();
data[i * 4..(i + 1) * 4].copy_from_slice(&bytes);
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
let real_tensor = RealDataMock::new(24);
let values = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
real_tensor.set_f32_values(&values);
let shape = [2, 3];
let view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
let slice = view.as_slice::<f32>().unwrap();
assert_eq!(slice, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut mut_view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
{
let mut_slice = mut_view.as_slice_mut::<f32>().unwrap();
mut_slice[0] = 10.0;
mut_slice[5] = 60.0;
}
assert_eq!(mut_view.get_element::<f32>(&[0, 0]).unwrap(), 10.0);
assert_eq!(mut_view.get_element::<f32>(&[1, 2]).unwrap(), 60.0);
}
#[test]
fn test_ndarray_view_with_real_data() {
use std::sync::{Arc, Mutex};
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
element_size_bytes: usize,
}
impl RealDataMock {
fn new(num_elements: usize, element_size: usize) -> Self {
let buffer = vec![0u8; num_elements * element_size];
Self {
data: Arc::new(Mutex::new(buffer)),
element_size_bytes: element_size,
}
}
fn set_element_value(&self, index: usize, value: u32) {
let mut data = self.data.lock().unwrap();
let bytes = value.to_ne_bytes();
let offset = index * self.element_size_bytes;
for i in 0..std::mem::size_of::<u32>() {
data[offset + i] = bytes[i];
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
let shape = [2, 3, 4]; let num_elements = shape.iter().product();
let element_size = std::mem::size_of::<u32>();
let mock_tensor = RealDataMock::new(num_elements, element_size);
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
let ndarray_view = view.as_ndarray_view::<u32>().unwrap();
for &value in ndarray_view.iter() {
assert_eq!(value, 0, "Expected all initial values to be 0");
}
let data_arc = mock_tensor.data.clone();
{
let mut data = data_arc.lock().unwrap();
for i in 0..num_elements {
let offset = i * element_size;
let bytes = 42u32.to_ne_bytes();
for j in 0..element_size {
data[offset + j] = bytes[j];
}
}
}
let updated_view = view.as_ndarray_view::<u32>().unwrap();
for &value in updated_view.iter() {
assert_eq!(value, 42, "Expected all values to be 42 after update");
}
mock_tensor.set_element_value(0, 0);
let final_view = view.as_ndarray_view::<u32>().unwrap();
assert_eq!(final_view[[0, 0, 0]], 0, "First element should be 0");
assert_eq!(updated_view[[0, 0, 0]], 0, "First element should be 0");
assert_eq!(
final_view[[0, 0, 1]],
42,
"Element [0,0,1] should still be 42"
);
assert_eq!(final_view[[1, 2, 3]], 42, "Last element should still be 42");
let zero_count = final_view.iter().filter(|&&x| x == 0).count();
assert_eq!(zero_count, 1, "There should be exactly one zero element");
let forty_two_count = final_view.iter().filter(|&&x| x == 42).count();
assert_eq!(
forty_two_count,
num_elements - 1,
"All other elements should be 42"
);
}
#[test]
fn test_host_device_transfers() {
use cudarc::driver::CudaContext;
let context = CudaContext::new(0).unwrap();
let stream = context.default_stream();
let pinned_storage = OwnedStorage::create_pinned_array(6 * 4).unwrap();
let shape = [2, 3];
let mut host_view = TensorView::<_, 2>::new(&pinned_storage, shape, 4).unwrap();
let values = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
for i in 0..2 {
for j in 0..3 {
host_view
.set_element::<f32>(&[i, j], values[i * 3 + j])
.unwrap();
}
}
let device_storage = OwnedStorage::create_device_array(6 * 4, context.clone()).unwrap();
let mut device_view = TensorView::<_, 2>::new(&device_storage, shape, 4).unwrap();
host_view.h2d(&mut device_view, &stream).unwrap();
let pinned_storage2 = OwnedStorage::create_pinned_array(6 * 4).unwrap();
let mut host_view2 = TensorView::<_, 2>::new(&pinned_storage2, shape, 4).unwrap();
device_view.d2h(&mut host_view2, &stream).unwrap();
stream.synchronize().unwrap();
for i in 0..2 {
for j in 0..3 {
assert_eq!(
host_view2.get_element::<f32>(&[i, j]).unwrap(),
values[i * 3 + j]
);
}
}
let new_values = [10.0f32, 20.0, 30.0, 40.0, 50.0, 60.0];
for i in 0..2 {
for j in 0..3 {
host_view
.set_element::<f32>(&[i, j], new_values[i * 3 + j])
.unwrap();
}
}
host_view.h2d(&mut device_view, &stream).unwrap();
device_view.d2h(&mut host_view2, &stream).unwrap();
stream.synchronize().unwrap();
for i in 0..2 {
for j in 0..3 {
assert_eq!(
host_view2.get_element::<f32>(&[i, j]).unwrap(),
new_values[i * 3 + j]
);
}
}
}
}