use crate::api::{copy_device_to_host_vec, copy_host_vec_to_device};
use crate::error::{tensor_error_result, Error};
use crate::tile_kernel::UnwrapPartition;
use anyhow::Result;
use cuda_async::device_buffer::{DeviceBuffer, DevicePointer};
use cuda_async::device_operation;
use cuda_async::device_operation::{value, DeviceOp, IntoDeviceOp, Value};
use cuda_core::malloc_async;
use cuda_core::sys::CUdeviceptr;
use cuda_core::{DType, DTypeId};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::mem::{align_of, size_of, MaybeUninit};
use std::ops::Index;
use std::sync::Arc;
pub struct Partition<T> {
pub(crate) object: T,
pub partition_shape: Vec<usize>,
pub partition_strides: Vec<usize>,
}
impl<T> Partition<T> {
pub fn unpartition(self) -> T {
self.object
}
}
impl<T: DType> Partition<Tensor<T>> {
pub fn num_bytes(&self) -> usize {
self.object.size() * size_of::<T>()
}
pub fn num_mb(&self) -> usize {
self.num_bytes() / 10usize.pow(6)
}
pub fn num_gb(&self) -> usize {
self.num_bytes() / 10usize.pow(9)
}
pub fn dtype(&self) -> DTypeId {
T::DTYPE
}
pub fn dtype_str(&self) -> &'static str {
T::DTYPE.as_str()
}
pub fn grid(&self) -> Result<(u32, u32, u32), Error> {
if !self.object.shape.iter().all(|&x| x > 0) {
return tensor_error_result("Shape dimensions must be positive.");
}
let shape: Vec<u32> = self.object.shape.iter().map(|&x| x as u32).collect();
let partition_shape: Vec<u32> = self.partition_shape.iter().map(|&x| x as u32).collect();
let rank = shape.len();
match rank {
1 => Ok((u32::div_ceil(shape[0], partition_shape[0]), 1, 1)),
2 => Ok((
u32::div_ceil(shape[0], partition_shape[0]),
u32::div_ceil(shape[1], partition_shape[1]),
1,
)),
3 => Ok((
u32::div_ceil(shape[0], partition_shape[0]),
u32::div_ceil(shape[1], partition_shape[1]),
u32::div_ceil(shape[2], partition_shape[2]),
)),
_ => tensor_error_result("Mutable tensor must be at most rank 3."),
}
}
}
impl<T> From<Partition<T>> for Arc<T> {
fn from(val: Partition<T>) -> Self {
Arc::new(val.unpartition())
}
}
pub trait IntoPartition {
fn partition<const RANK: usize>(self, partition_shape: [usize; RANK]) -> Partition<Self>
where
Self: Sized;
}
pub trait IntoPartitionArc {
fn partition<const RANK: usize>(
self: Arc<Self>,
partition_shape: [usize; RANK],
) -> Partition<Self>
where
Self: Sized;
}
pub use cutile_compiler::specialization::{compute_spec, SpecializationBits};
#[derive(Debug)]
pub struct Tensor<T: DType> {
pub(crate) storage: Arc<DeviceBuffer>,
pub(crate) shape: Vec<i32>,
pub(crate) strides: Vec<i32>,
pub(crate) spec: SpecializationBits,
_dtype: PhantomData<T>,
}
fn contiguous_strides(shape: &[i32]) -> Vec<i32> {
let mut stride = 1;
let mut strides = Vec::with_capacity(shape.len());
for dim in shape.iter().rev() {
strides.push(stride);
stride *= *dim;
}
strides.reverse();
strides
}
fn checked_num_elements(shape: &[usize]) -> Result<usize, Error> {
shape.iter().try_fold(1usize, |acc, dim| {
acc.checked_mul(*dim)
.ok_or_else(|| crate::error::tensor_error("Tensor shape overflowed usize."))
})
}
fn checked_num_bytes<T>(shape: &[usize]) -> Result<usize, Error> {
checked_num_elements(shape)?
.checked_mul(size_of::<T>())
.ok_or_else(|| crate::error::tensor_error("Tensor byte size overflowed usize."))
}
fn checked_num_elements_i32(shape: &[i32]) -> Result<usize, Error> {
shape.iter().try_fold(1usize, |acc, dim| {
let dim = usize::try_from(*dim)
.map_err(|_| crate::error::tensor_error("Tensor shape contains negative dimension."))?;
acc.checked_mul(dim)
.ok_or_else(|| crate::error::tensor_error("Tensor shape overflowed usize."))
})
}
fn checked_num_bytes_i32<T>(shape: &[i32]) -> Result<usize, Error> {
checked_num_elements_i32(shape)?
.checked_mul(size_of::<T>())
.ok_or_else(|| crate::error::tensor_error("Tensor byte size overflowed usize."))
}
impl<T: DType> Tensor<T> {
fn assert_valid_metadata(shape: &[i32], strides: &[i32], storage_num_bytes: usize) {
assert_eq!(
shape.len(),
strides.len(),
"Tensor shape/stride rank mismatch."
);
let num_bytes = checked_num_bytes_i32::<T>(shape)
.expect("Tensor shape contains invalid dimensions or overflows.");
assert_eq!(
num_bytes, storage_num_bytes,
"Tensor logical byte size must match storage byte size."
);
}
pub(crate) fn from_device_buffer(
device_buffer: DeviceBuffer,
shape: Vec<i32>,
strides: Vec<i32>,
) -> Self {
Self::assert_valid_metadata(&shape, &strides, device_buffer.len_bytes());
let storage = Arc::new(device_buffer);
let spec = compute_spec(
storage.cu_deviceptr(),
&shape,
&strides,
size_of::<T>() as i32,
);
Self {
storage,
shape,
strides,
spec,
_dtype: PhantomData,
}
}
pub unsafe fn from_raw_parts(
dptr: CUdeviceptr,
len_bytes: usize,
device_id: usize,
shape: Vec<i32>,
strides: Vec<i32>,
) -> Self {
Self::assert_valid_metadata(&shape, &strides, len_bytes);
Self::from_device_buffer(
DeviceBuffer::from_raw_parts(dptr, len_bytes, device_id),
shape,
strides,
)
}
fn storage_num_bytes(&self) -> usize {
self.storage.len_bytes()
}
fn num_elements(&self) -> usize {
checked_num_elements_i32(&self.shape)
.expect("Tensor shape contains invalid dimensions or overflows.")
}
fn typed_num_bytes(&self) -> usize {
checked_num_bytes_i32::<T>(&self.shape)
.expect("Tensor shape contains invalid dimensions or overflows.")
}
fn validate_view_shape(&self, shape: &[usize]) -> Result<(), Error> {
if !self.is_contiguous() {
return tensor_error_result("Zero-copy tensor views require contiguous storage.");
}
let target_num_bytes = checked_num_bytes::<T>(shape)?;
if target_num_bytes != self.typed_num_bytes() {
return tensor_error_result("View shape must preserve tensor size.");
}
Ok(())
}
fn validate_reinterpret_shape<U: DType>(&self, shape: &[usize]) -> Result<(), Error> {
if !self.is_contiguous() {
return tensor_error_result("Zero-copy reinterpret requires contiguous storage.");
}
let target_num_bytes = checked_num_bytes::<U>(shape)?;
if target_num_bytes != self.typed_num_bytes() {
return tensor_error_result("Reinterpret shape must preserve total byte size.");
}
let alignment = align_of::<U>() as u64;
if alignment > 1 && self.cu_deviceptr() % alignment != 0 {
return tensor_error_result(
"Tensor storage alignment is incompatible with reinterpret target type.",
);
}
Ok(())
}
fn assert_unique_storage(&self) {
assert!(
Arc::strong_count(&self.storage) == 1,
"Cannot create mutable partition from shared tensor storage."
);
}
pub fn uninitialized(len: usize) -> impl DeviceOp<Output = MaybeUninit<Self>> {
assert!(len > 0, "Non-zero length required.");
device_operation::with_context(move |ctx| {
let num_bytes = len * size_of::<T>();
value(MaybeUninit::new(unsafe {
Self::from_raw_parts(
malloc_async(num_bytes, ctx.get_cuda_stream()),
num_bytes,
ctx.get_device_id(),
vec![len as i32],
vec![1],
)
}))
})
}
pub fn dtype(&self) -> DTypeId {
T::DTYPE
}
pub(crate) fn cu_deviceptr(&self) -> CUdeviceptr {
self.storage.cu_deviceptr()
}
pub fn device_id(&self) -> usize {
self.storage.device_id()
}
pub fn device_pointer(&self) -> DevicePointer<T> {
unsafe { DevicePointer::from_cu_deviceptr(self.cu_deviceptr()) }
}
pub fn shape(&self) -> &[i32] {
&self.shape
}
pub fn strides(&self) -> &[i32] {
&self.strides
}
pub fn spec(&self) -> &SpecializationBits {
&self.spec
}
pub fn size(&self) -> usize {
debug_assert_eq!(self.typed_num_bytes(), self.storage_num_bytes());
self.num_elements()
}
pub fn dup(&self) -> impl DeviceOp<Output = Self> {
crate::api::dup(self)
}
pub fn num_bytes(&self) -> usize {
self.typed_num_bytes()
}
pub fn is_contiguous(&self) -> bool {
self.strides == contiguous_strides(&self.shape)
}
pub unsafe fn into_shared_alias(&self) -> Arc<Self> {
Arc::new(Self {
storage: self.storage.clone(),
shape: self.shape.clone(),
strides: self.strides.clone(),
spec: self.spec.clone(),
_dtype: PhantomData,
})
}
pub(crate) fn reshape_unchecked(mut self, shape: &[usize]) -> Self {
let shape: Vec<i32> = shape.iter().map(|&x| x as i32).collect();
self.strides = contiguous_strides(&shape);
self.spec = compute_spec(
self.storage.cu_deviceptr(),
&shape,
&self.strides,
size_of::<T>() as i32,
);
self.shape = shape;
self
}
pub(crate) fn reshape_shared(self: &Arc<Self>, shape: &[usize]) -> Result<Arc<Self>, Error> {
self.validate_view_shape(shape)?;
let new_shape: Vec<i32> = shape.iter().map(|x| *x as i32).collect();
let new_strides = contiguous_strides(&new_shape);
let spec = compute_spec(
self.storage.cu_deviceptr(),
&new_shape,
&new_strides,
size_of::<T>() as i32,
);
Ok(Arc::new(Self {
storage: self.storage.clone(),
strides: new_strides,
shape: new_shape,
spec,
_dtype: PhantomData,
}))
}
pub fn reinterpret<U: DType>(
self: &Arc<Self>,
shape: &[usize],
) -> Result<Arc<Tensor<U>>, Error> {
self.validate_reinterpret_shape::<U>(shape)?;
let new_shape: Vec<i32> = shape.iter().map(|x| *x as i32).collect();
let new_strides = contiguous_strides(&new_shape);
let spec = compute_spec(
self.storage.cu_deviceptr(),
&new_shape,
&new_strides,
size_of::<U>() as i32,
);
Ok(Arc::new(Tensor::<U> {
storage: self.storage.clone(),
strides: new_strides,
shape: new_shape,
spec,
_dtype: PhantomData,
}))
}
}
pub trait ToHostVec<T: Send> {
fn to_host_vec(self) -> impl DeviceOp<Output = Vec<T>>;
}
impl<T: DType> ToHostVec<T> for Tensor<T> {
fn to_host_vec(self) -> impl DeviceOp<Output = Vec<T>> {
let arc_self = Arc::new(self);
copy_device_to_host_vec(&arc_self)
}
}
impl<T: DType> ToHostVec<T> for Arc<Tensor<T>> {
fn to_host_vec(self) -> impl DeviceOp<Output = Vec<T>> {
copy_device_to_host_vec(&self)
}
}
impl<T: DType> ToHostVec<T> for &Arc<Tensor<T>> {
fn to_host_vec(self) -> impl DeviceOp<Output = Vec<T>> {
copy_device_to_host_vec(self)
}
}
pub trait Reshape {
type Output;
fn reshape(self, shape: &[usize]) -> Result<Self::Output, Error>;
}
impl<T: DType> Reshape for Tensor<T> {
type Output = Tensor<T>;
fn reshape(self, shape: &[usize]) -> Result<Tensor<T>, Error> {
let current_elems: i32 = self.shape.iter().product();
let new_elems: i32 = shape.iter().map(|&x| x as i32).product();
if new_elems != current_elems {
return tensor_error_result("reshape: new shape must preserve element count.");
}
Ok(self.reshape_unchecked(shape))
}
}
impl<'a, T: DType> Reshape for &'a Arc<Tensor<T>> {
type Output = Arc<Tensor<T>>;
fn reshape(self, shape: &[usize]) -> Result<Arc<Tensor<T>>, Error> {
self.reshape_shared(shape)
}
}
pub struct TensorView<'a, T: DType> {
base: &'a Tensor<T>,
offset_bytes: usize,
shape: Vec<i32>,
strides: Vec<i32>,
spec: SpecializationBits,
}
impl<'a, T: DType> TensorView<'a, T> {
pub fn shape(&self) -> &[i32] {
&self.shape
}
pub fn strides(&self) -> &[i32] {
&self.strides
}
pub fn spec(&self) -> &SpecializationBits {
&self.spec
}
pub fn size(&self) -> usize {
self.shape.iter().map(|&x| x as usize).product()
}
pub fn view(&self, shape: &[usize]) -> Result<TensorView<'_, T>, Error> {
if self.strides != contiguous_strides(&self.shape) {
return tensor_error_result("view: cannot reshape a non-contiguous view.");
}
let current_elems: i32 = self.shape.iter().product();
let new_elems: i32 = shape.iter().map(|&x| x as i32).product();
if new_elems != current_elems {
return tensor_error_result("view: new shape must preserve element count.");
}
let new_shape: Vec<i32> = shape.iter().map(|&x| x as i32).collect();
let new_strides = contiguous_strides(&new_shape);
let spec = compute_spec(
self.base.storage.cu_deviceptr(),
&new_shape,
&new_strides,
size_of::<T>() as i32,
);
Ok(TensorView {
base: self.base,
offset_bytes: self.offset_bytes,
shape: new_shape,
strides: new_strides,
spec,
})
}
pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Result<TensorView<'_, T>, Error> {
if ranges.len() > self.shape.len() {
return tensor_error_result("slice: more ranges than axes.");
}
let mut offset_elems: usize = 0;
let mut new_shape = self.shape.clone();
for (axis, range) in ranges.iter().enumerate() {
let dim = self.shape[axis] as usize;
if range.start > range.end || range.end > dim {
return tensor_error_result("slice: range out of bounds.");
}
offset_elems += range.start * self.strides[axis] as usize;
new_shape[axis] = (range.end - range.start) as i32;
}
let new_strides = self.strides.clone();
let spec = compute_spec(
self.base.storage.cu_deviceptr()
+ (self.offset_bytes + offset_elems * size_of::<T>()) as u64,
&new_shape,
&new_strides,
size_of::<T>() as i32,
);
Ok(TensorView {
base: self.base,
offset_bytes: self.offset_bytes + offset_elems * size_of::<T>(),
shape: new_shape,
strides: new_strides,
spec,
})
}
}
impl<T: DType> Tensor<T> {
pub fn view(&self, shape: &[usize]) -> Result<TensorView<'_, T>, Error> {
let current_elems: i32 = self.shape.iter().product();
let new_elems: i32 = shape.iter().map(|&x| x as i32).product();
if new_elems != current_elems {
return tensor_error_result("view: new shape must preserve element count.");
}
let new_shape: Vec<i32> = shape.iter().map(|&x| x as i32).collect();
let new_strides = contiguous_strides(&new_shape);
let spec = compute_spec(
self.storage.cu_deviceptr(),
&new_shape,
&new_strides,
size_of::<T>() as i32,
);
Ok(TensorView {
base: self,
offset_bytes: 0,
shape: new_shape,
strides: new_strides,
spec,
})
}
pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Result<TensorView<'_, T>, Error> {
if ranges.len() > self.shape.len() {
return tensor_error_result("slice: more ranges than axes.");
}
let mut offset_elems: usize = 0;
let mut new_shape = self.shape.clone();
for (axis, range) in ranges.iter().enumerate() {
let dim = self.shape[axis] as usize;
if range.start > range.end || range.end > dim {
return tensor_error_result("slice: range out of bounds.");
}
offset_elems += range.start * self.strides[axis] as usize;
new_shape[axis] = (range.end - range.start) as i32;
}
let new_strides = self.strides.clone();
let spec = compute_spec(
self.storage.cu_deviceptr() + (offset_elems * size_of::<T>()) as u64,
&new_shape,
&new_strides,
size_of::<T>() as i32,
);
Ok(TensorView {
base: self,
offset_bytes: offset_elems * size_of::<T>(),
shape: new_shape,
strides: new_strides,
spec,
})
}
}
impl<T: DType> IntoPartitionArc for Tensor<T> {
fn partition<const RANK: usize>(
self: Arc<Tensor<T>>,
partition_shape: [usize; RANK],
) -> Partition<Tensor<T>> {
let partition_shape = partition_shape.to_vec();
let partition_strides: Vec<usize> = self.strides.iter().map(|&s| s as usize).collect();
let tensor = Arc::try_unwrap(self).expect("Failed to convert Arc to Partition.");
tensor.assert_unique_storage();
Partition::<Tensor<T>> {
object: tensor,
partition_shape,
partition_strides,
}
}
}
impl<T: DType> IntoPartition for Tensor<T> {
fn partition<const RANK: usize>(self, partition_shape: [usize; RANK]) -> Partition<Tensor<T>> {
let partition_shape = partition_shape.to_vec();
let partition_strides: Vec<usize> = self.strides.iter().map(|&s| s as usize).collect();
self.assert_unique_storage();
Partition::<Tensor<T>> {
object: self,
partition_shape,
partition_strides,
}
}
}
pub trait PartitionMut<'a, T: DType> {
fn partition<const RANK: usize>(
self,
partition_shape: [usize; RANK],
) -> Partition<&'a mut Tensor<T>>;
}
impl<'a, T: DType> PartitionMut<'a, T> for &'a mut Tensor<T> {
fn partition<const RANK: usize>(
self,
partition_shape: [usize; RANK],
) -> Partition<&'a mut Tensor<T>> {
let partition_shape = partition_shape.to_vec();
let partition_strides: Vec<usize> = self.strides.iter().map(|&s| s as usize).collect();
Partition {
object: self,
partition_shape,
partition_strides,
}
}
}
impl<'a, T: DType> Partition<&'a mut Tensor<T>> {
pub fn dtype_str(&self) -> &'static str {
T::DTYPE.as_str()
}
pub fn grid(&self) -> Result<(u32, u32, u32), Error> {
if !self.object.shape.iter().all(|&x| x > 0) {
return tensor_error_result("Shape dimensions must be positive.");
}
let shape: Vec<u32> = self.object.shape.iter().map(|&x| x as u32).collect();
let partition_shape: Vec<u32> = self.partition_shape.iter().map(|&x| x as u32).collect();
let rank = shape.len();
match rank {
1 => Ok((u32::div_ceil(shape[0], partition_shape[0]), 1, 1)),
2 => Ok((
u32::div_ceil(shape[0], partition_shape[0]),
u32::div_ceil(shape[1], partition_shape[1]),
1,
)),
3 => Ok((
u32::div_ceil(shape[0], partition_shape[0]),
u32::div_ceil(shape[1], partition_shape[1]),
u32::div_ceil(shape[2], partition_shape[2]),
)),
_ => tensor_error_result("Mutable tensor must be at most rank 3."),
}
}
}
impl<'a, T: DType + Sync> IntoDeviceOp<Partition<&'a mut Tensor<T>>>
for Partition<&'a mut Tensor<T>>
{
type Op = Value<Partition<&'a mut Tensor<T>>>;
fn into_op(self) -> Value<Partition<&'a mut Tensor<T>>> {
value(self)
}
}
pub trait TryPartition<T: DType> {
fn try_partition<const RANK: usize>(
self,
partition_shape: [usize; RANK],
) -> Result<Partition<Tensor<T>>, Error>;
}
impl<T: DType> TryPartition<T> for Arc<Tensor<T>> {
fn try_partition<const RANK: usize>(
self,
partition_shape: [usize; RANK],
) -> Result<Partition<Tensor<T>>, Error> {
let tensor = Arc::try_unwrap(self).map_err(|_| {
crate::error::tensor_error("try_partition: Arc<Tensor> has multiple owners")
})?;
Ok(tensor.partition(partition_shape))
}
}
pub trait Unpartition<T: DType> {
fn unpartition(self) -> impl DeviceOp<Output = Tensor<T>>;
}
impl<T: DType, DI: DeviceOp<Output = Partition<Tensor<T>>>> Unpartition<T> for DI {
fn unpartition(self) -> impl DeviceOp<Output = Tensor<T>> {
UnwrapPartition { op: self }
}
}
#[derive(Clone, Debug)]
pub struct DeviceVec<T> {
_ty: PhantomData<T>,
host_vec: Vec<Arc<T>>,
device_vec: Arc<Tensor<i64>>,
}
impl<T: DType> DeviceVec<Tensor<T>> {
pub fn from(v: Vec<Tensor<T>>) -> DeviceVec<Tensor<T>> {
let i64vec: Arc<Vec<i64>> = v
.iter()
.map(|x| x.cu_deviceptr() as i64)
.collect::<Vec<_>>()
.into();
let device_vec: Arc<Tensor<i64>> = copy_host_vec_to_device(&i64vec)
.sync()
.expect("Failed to execute device operation.")
.reshape_unchecked(&[v.len()])
.into();
let host_vec: Vec<Arc<Tensor<T>>> = v.into_iter().map(Arc::new).collect::<Vec<_>>();
DeviceVec {
_ty: PhantomData,
host_vec,
device_vec,
}
}
pub fn is_empty(&self) -> bool {
self.host_vec.len() == 0
}
pub fn len(&self) -> usize {
self.host_vec.len()
}
pub unsafe fn inner(&self) -> &Arc<Tensor<i64>> {
&self.device_vec
}
}
impl<T: DType> From<Vec<Tensor<T>>> for DeviceVec<Tensor<T>> {
fn from(v: Vec<Tensor<T>>) -> Self {
DeviceVec::from(v)
}
}
impl<T: DType> Index<usize> for DeviceVec<Tensor<T>> {
type Output = Arc<Tensor<T>>;
fn index(&self, index: usize) -> &Self::Output {
&self.host_vec[index]
}
}
pub struct DeviceVecIntoIter<Item> {
items: DeviceVec<Item>,
}
impl<T: DType> Iterator for DeviceVecIntoIter<Tensor<T>> {
type Item = Tensor<T>;
fn next(&mut self) -> Option<Self::Item> {
if !self.items.is_empty() {
let x = self.items.host_vec.remove(0);
let x = Arc::try_unwrap(x).expect("Unable to perform into_iter from non-unique Arc.");
Some(x)
} else {
None
}
}
}
impl<T: DType> IntoIterator for DeviceVec<Tensor<T>> {
type Item = Tensor<T>;
type IntoIter = DeviceVecIntoIter<Tensor<T>>;
fn into_iter(self) -> Self::IntoIter {
DeviceVecIntoIter { items: self }
}
}
impl<T: DType> IntoDeviceOp<Partition<Tensor<T>>> for Partition<Tensor<T>> {
type Op = Value<Partition<Tensor<T>>>;
fn into_op(self) -> Value<Partition<Tensor<T>>> {
value(self)
}
}
impl<T: DType> IntoDeviceOp<Tensor<T>> for Tensor<T> {
type Op = Value<Tensor<T>>;
fn into_op(self) -> Value<Tensor<T>> {
value(self)
}
}
impl<'a, T: DType + Sync> IntoDeviceOp<&'a Tensor<T>> for &'a Tensor<T> {
type Op = Value<&'a Tensor<T>>;
fn into_op(self) -> Value<&'a Tensor<T>> {
value(self)
}
}
use cuda_async::launch::AsyncKernelLaunch;
pub trait KernelOutputStored<T: DType>: Send {
fn push_kernel_args(&self, launcher: &mut AsyncKernelLaunch);
fn grid(&self) -> Result<(u32, u32, u32), Error>;
fn dtype_str(&self) -> &'static str;
fn partition_shape_as_i32(&self) -> Vec<i32>;
fn strides_hint(&self) -> Vec<i32>;
fn spec(&self) -> &SpecializationBits;
fn shape_as_i32(&self) -> Vec<i32>;
}
pub trait KernelOutput<T: DType>: Send + Sized {
type Stored: KernelOutputStored<T>;
type Returned: Send;
fn prepare(self) -> Self::Stored;
fn recover(stored: Self::Stored) -> Self::Returned;
}
impl<T: DType> KernelOutputStored<T> for Partition<Tensor<T>> {
fn push_kernel_args(&self, launcher: &mut AsyncKernelLaunch) {
unsafe {
launcher.push_device_ptr(self.object.cu_deviceptr());
}
for dim in self.object.shape.iter() {
launcher.push_arg(*dim);
}
for stride in self.object.strides.iter() {
launcher.push_arg(*stride);
}
for dim in self.partition_shape.iter() {
launcher.push_arg(*dim as i32);
}
for stride in self.partition_strides.iter() {
launcher.push_arg(*stride as i32);
}
}
fn grid(&self) -> Result<(u32, u32, u32), Error> {
let shape: Vec<u32> = self.shape_as_i32().iter().map(|&x| x as u32).collect();
let pshape: Vec<u32> = self
.partition_shape_as_i32()
.iter()
.map(|&x| x as u32)
.collect();
match shape.len() {
1 => Ok((u32::div_ceil(shape[0], pshape[0]), 1, 1)),
2 => Ok((
u32::div_ceil(shape[0], pshape[0]),
u32::div_ceil(shape[1], pshape[1]),
1,
)),
3 => Ok((
u32::div_ceil(shape[0], pshape[0]),
u32::div_ceil(shape[1], pshape[1]),
u32::div_ceil(shape[2], pshape[2]),
)),
_ => tensor_error_result("Mutable tensor must be at most rank 3."),
}
}
fn dtype_str(&self) -> &'static str {
T::DTYPE.as_str()
}
fn partition_shape_as_i32(&self) -> Vec<i32> {
self.partition_shape.iter().map(|&x| x as i32).collect()
}
fn strides_hint(&self) -> Vec<i32> {
self.object
.spec
.stride_one
.iter()
.map(|&is_one| if is_one { 1 } else { -1 })
.collect()
}
fn spec(&self) -> &SpecializationBits {
&self.object.spec
}
fn shape_as_i32(&self) -> Vec<i32> {
self.object.shape.clone()
}
}
impl<'a, T: DType> KernelOutputStored<T> for Partition<&'a mut Tensor<T>> {
fn push_kernel_args(&self, launcher: &mut AsyncKernelLaunch) {
unsafe {
launcher.push_device_ptr(self.object.cu_deviceptr());
}
for dim in self.object.shape.iter() {
launcher.push_arg(*dim);
}
for stride in self.object.strides.iter() {
launcher.push_arg(*stride);
}
for dim in self.partition_shape.iter() {
launcher.push_arg(*dim as i32);
}
for stride in self.partition_strides.iter() {
launcher.push_arg(*stride as i32);
}
}
fn grid(&self) -> Result<(u32, u32, u32), Error> {
let shape: Vec<u32> = self.shape_as_i32().iter().map(|&x| x as u32).collect();
let pshape: Vec<u32> = self
.partition_shape_as_i32()
.iter()
.map(|&x| x as u32)
.collect();
match shape.len() {
1 => Ok((u32::div_ceil(shape[0], pshape[0]), 1, 1)),
2 => Ok((
u32::div_ceil(shape[0], pshape[0]),
u32::div_ceil(shape[1], pshape[1]),
1,
)),
3 => Ok((
u32::div_ceil(shape[0], pshape[0]),
u32::div_ceil(shape[1], pshape[1]),
u32::div_ceil(shape[2], pshape[2]),
)),
_ => tensor_error_result("Mutable tensor must be at most rank 3."),
}
}
fn dtype_str(&self) -> &'static str {
T::DTYPE.as_str()
}
fn partition_shape_as_i32(&self) -> Vec<i32> {
self.partition_shape.iter().map(|&x| x as i32).collect()
}
fn strides_hint(&self) -> Vec<i32> {
self.object
.spec
.stride_one
.iter()
.map(|&is_one| if is_one { 1 } else { -1 })
.collect()
}
fn spec(&self) -> &SpecializationBits {
&self.object.spec
}
fn shape_as_i32(&self) -> Vec<i32> {
self.object.shape.clone()
}
}
impl<T: DType> KernelOutput<T> for Partition<Tensor<T>> {
type Stored = Partition<Tensor<T>>;
type Returned = Partition<Tensor<T>>;
fn prepare(self) -> Self::Stored {
self
}
fn recover(stored: Self::Stored) -> Self::Returned {
stored
}
}
impl<'a, T: DType> KernelOutput<T> for Partition<&'a mut Tensor<T>> {
type Stored = Partition<&'a mut Tensor<T>>;
type Returned = Partition<&'a mut Tensor<T>>;
fn prepare(self) -> Self::Stored {
self
}
fn recover(stored: Self::Stored) -> Self::Returned {
stored
}
}
pub trait KernelInputStored: Send {
fn push_kernel_args(&self, launcher: &mut AsyncKernelLaunch);
fn shape(&self) -> &[i32];
fn strides(&self) -> &[i32];
fn spec(&self) -> &SpecializationBits;
fn dtype_str(&self) -> &'static str;
}
pub trait KernelInput<T: DType>: Send + Sized {
type Stored: KernelInputStored;
type Returned: Send;
fn prepare(self) -> Self::Stored;
fn recover(stored: Self::Stored) -> Self::Returned;
}
impl<T: DType> KernelInputStored for Arc<Tensor<T>> {
fn push_kernel_args(&self, launcher: &mut AsyncKernelLaunch) {
unsafe {
launcher.push_device_ptr(self.cu_deviceptr());
}
for dim in self.shape.iter() {
launcher.push_arg(*dim);
}
for stride in self.strides.iter() {
launcher.push_arg(*stride);
}
}
fn shape(&self) -> &[i32] {
Tensor::shape(self)
}
fn strides(&self) -> &[i32] {
Tensor::strides(self)
}
fn spec(&self) -> &SpecializationBits {
Tensor::spec(self)
}
fn dtype_str(&self) -> &'static str {
T::DTYPE.as_str()
}
}
impl<'a, T: DType + Sync> KernelInputStored for &'a Tensor<T> {
fn push_kernel_args(&self, launcher: &mut AsyncKernelLaunch) {
unsafe {
launcher.push_device_ptr(self.cu_deviceptr());
}
for dim in self.shape.iter() {
launcher.push_arg(*dim);
}
for stride in self.strides.iter() {
launcher.push_arg(*stride);
}
}
fn shape(&self) -> &[i32] {
Tensor::shape(self)
}
fn strides(&self) -> &[i32] {
Tensor::strides(self)
}
fn spec(&self) -> &SpecializationBits {
Tensor::spec(self)
}
fn dtype_str(&self) -> &'static str {
T::DTYPE.as_str()
}
}
impl<T: DType> KernelInput<T> for Tensor<T> {
type Stored = Arc<Tensor<T>>;
type Returned = Tensor<T>;
fn prepare(self) -> Arc<Tensor<T>> {
Arc::new(self)
}
fn recover(stored: Arc<Tensor<T>>) -> Tensor<T> {
Arc::try_unwrap(stored).expect("KernelInput::recover: Arc has multiple owners")
}
}
impl<T: DType> KernelInput<T> for Arc<Tensor<T>> {
type Stored = Arc<Tensor<T>>;
type Returned = Arc<Tensor<T>>;
fn prepare(self) -> Arc<Tensor<T>> {
self
}
fn recover(stored: Arc<Tensor<T>>) -> Arc<Tensor<T>> {
stored
}
}
impl<'a, T: DType + Sync> KernelInput<T> for &'a Tensor<T> {
type Stored = &'a Tensor<T>;
type Returned = &'a Tensor<T>;
fn prepare(self) -> &'a Tensor<T> {
self
}
fn recover(stored: &'a Tensor<T>) -> &'a Tensor<T> {
stored
}
}
impl<'a, T: DType + Sync> KernelInputStored for &'a TensorView<'a, T> {
fn push_kernel_args(&self, launcher: &mut AsyncKernelLaunch) {
unsafe {
launcher.push_device_ptr(self.base.cu_deviceptr() + self.offset_bytes as u64);
}
for dim in self.shape.iter() {
launcher.push_arg(*dim);
}
for stride in self.strides.iter() {
launcher.push_arg(*stride);
}
}
fn shape(&self) -> &[i32] {
&self.shape
}
fn strides(&self) -> &[i32] {
&self.strides
}
fn spec(&self) -> &SpecializationBits {
TensorView::spec(self)
}
fn dtype_str(&self) -> &'static str {
T::DTYPE.as_str()
}
}
impl<'a, T: DType + Sync> KernelInput<T> for &'a TensorView<'a, T> {
type Stored = &'a TensorView<'a, T>;
type Returned = &'a TensorView<'a, T>;
fn prepare(self) -> Self::Stored {
self
}
fn recover(stored: Self::Stored) -> Self::Returned {
stored
}
}
impl<'a, T: DType + Sync> IntoDeviceOp<&'a TensorView<'a, T>> for &'a TensorView<'a, T> {
type Op = Value<&'a TensorView<'a, T>>;
fn into_op(self) -> Value<&'a TensorView<'a, T>> {
value(self)
}
}