use crate::api::{
copy, copy_device_to_host_vec, copy_host_vec_to_device, copy_to_device, copy_to_host,
};
use crate::error::{tensor_error_result, Error};
use crate::tile_kernel::UnwrapPartition;
use anyhow::Result;
use candle_core::{DType, WithDType};
use cuda_async::device_box::{DeviceBox, DevicePointer};
use cuda_async::device_operation;
use cuda_async::device_operation::{value, DeviceOperation};
use cuda_async::error::DeviceError;
use cuda_core::sys::CUdeviceptr;
use cuda_core::{malloc_async, CudaStream};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ops::Index;
use std::sync::Arc;
pub struct Partition<T> {
pub(crate) object: T,
pub partition_shape: Vec<i32>,
pub partition_strides: Vec<i32>,
}
impl<T> Partition<T> {
pub fn unpartition(self) -> T {
self.object
}
}
impl<T: WithDType> Partition<Tensor<T>> {
pub fn num_bytes(&self) -> usize {
self.object.size() * size_of::<T>()
}
pub fn num_mb(&self) -> usize {
self.num_bytes() / 10usize.pow(6)
}
pub fn num_gb(&self) -> usize {
self.num_bytes() / 10usize.pow(9)
}
pub fn dtype(&self) -> DType {
T::DTYPE
}
pub fn grid(&self) -> Result<(u32, u32, u32), Error> {
let check_i32 = |x: &i32| *x > 0;
if !self.object.shape.iter().all(check_i32) {
return tensor_error_result("Shape dimensions must be positive.");
}
let to_u32 = |x: &i32| *x as u32;
let shape = self.object.shape.iter().map(to_u32).collect::<Vec<u32>>();
let partition_shape = self
.partition_shape
.iter()
.map(to_u32)
.collect::<Vec<u32>>();
let rank = shape.len();
match rank {
1 => Ok((u32::div_ceil(shape[0], partition_shape[0]), 1, 1)),
2 => Ok((
u32::div_ceil(shape[0], partition_shape[0]),
u32::div_ceil(shape[1], partition_shape[1]),
1,
)),
3 => Ok((
u32::div_ceil(shape[0], partition_shape[0]),
u32::div_ceil(shape[1], partition_shape[1]),
u32::div_ceil(shape[2], partition_shape[2]),
)),
_ => tensor_error_result("Mutable tensor must be at most rank 3."),
}
}
}
impl<T> Into<Arc<T>> for Partition<T> {
fn into(self) -> Arc<T> {
Arc::new(self.unpartition())
}
}
pub trait IntoPartition {
fn partition<const RANK: usize>(self, partition_shape: [i32; RANK]) -> Partition<Self>
where
Self: Sized;
}
pub trait IntoPartitionArc {
fn partition<const RANK: usize>(
self: Arc<Self>,
partition_shape: [i32; RANK],
) -> Partition<Self>
where
Self: Sized;
}
#[derive(Debug)]
pub struct Tensor<T: WithDType> {
pub device_box: DeviceBox<[T]>,
pub shape: Vec<i32>,
pub strides: Vec<i32>,
}
impl<T: WithDType> Tensor<T> {
pub fn uninitialized(len: usize) -> impl DeviceOperation<Output = MaybeUninit<Self>> {
assert!(len > 0, "Non-zero length required.");
device_operation::with_context(move |ctx| {
let num_bytes = len * size_of::<T>();
value(MaybeUninit::new(unsafe {
Self {
device_box: DeviceBox::from_raw_parts(
malloc_async(num_bytes, ctx.get_cuda_stream()),
len,
ctx.get_device_id(),
),
shape: vec![len as i32],
strides: vec![1],
}
}))
})
}
pub fn dtype(&self) -> DType {
T::DTYPE
}
pub fn cu_deviceptr(&self) -> CUdeviceptr {
self.device_box.cu_deviceptr()
}
pub fn device_pointer(&self) -> DevicePointer<T> {
self.device_box.device_pointer()
}
pub fn size(&self) -> usize {
self.device_box.len()
}
pub fn copy(self: &Arc<Self>) -> impl DeviceOperation<Output = Self> {
copy(self)
}
pub fn copy_sync(self: &Arc<Self>, stream: &Arc<CudaStream>) -> Result<Self, DeviceError> {
copy(self).sync_on(stream)
}
pub fn num_bytes(self: &Arc<Self>) -> usize {
self.size() * size_of::<T>()
}
pub fn num_mb(self: &Arc<Self>) -> usize {
self.num_bytes() / 10usize.pow(6)
}
pub fn num_gb(self: &Arc<Self>) -> usize {
self.num_bytes() / 10usize.pow(9)
}
pub fn reshape<const RANK: usize>(mut self, shape: [usize; RANK]) -> Self {
let shape = shape.iter().map(|x| *x as i32).collect::<Vec<_>>();
assert_eq!(
shape.iter().product::<i32>(),
self.shape.iter().product::<i32>()
);
self.shape = shape.to_vec();
match RANK {
1 => self.strides = vec![1],
2 => self.strides = vec![shape[1], 1],
3 => self.strides = vec![shape[1] * shape[2], shape[2], 1],
4 => {
self.strides = vec![
shape[1] * shape[2] * shape[3],
shape[2] * shape[3],
shape[3],
1,
]
}
_ => unimplemented!("Static reshape of rank {}", RANK),
}
self
}
pub fn reshape_dyn(mut self, shape: &[usize]) -> Self {
let shape = shape.iter().map(|x| *x as i32).collect::<Vec<_>>();
assert_eq!(
shape.iter().product::<i32>(),
self.shape.iter().product::<i32>()
);
self.shape = shape.to_vec();
let mut stride = 1;
let mut strides = Vec::with_capacity(shape.len());
for i in (0..shape.len()).rev() {
strides.insert(0, stride);
stride *= shape[i]
}
self.strides = strides;
self
}
}
pub trait ToHostVec<T: Send> {
fn to_host_vec(self) -> impl DeviceOperation<Output = Vec<T>>;
}
impl<T: WithDType> ToHostVec<T> for Tensor<T> {
fn to_host_vec(self) -> impl DeviceOperation<Output = Vec<T>> {
let arc_self = Arc::new(self);
copy_device_to_host_vec(&arc_self)
}
}
impl<T: WithDType> ToHostVec<T> for Arc<Tensor<T>> {
fn to_host_vec(self) -> impl DeviceOperation<Output = Vec<T>> {
copy_device_to_host_vec(&self)
}
}
impl<T: WithDType> ToHostVec<T> for &Arc<Tensor<T>> {
fn to_host_vec(self) -> impl DeviceOperation<Output = Vec<T>> {
copy_device_to_host_vec(self)
}
}
impl<T: WithDType + Debug> IntoPartitionArc for Tensor<T> {
fn partition<const RANK: usize>(
self: Arc<Tensor<T>>,
partition_shape: [i32; RANK],
) -> Partition<Tensor<T>> {
let partition_shape = partition_shape.to_vec();
let partition_strides = self.strides.clone();
let tensor = Arc::try_unwrap(self).expect("Failed to convert Arc to Partition.");
Partition::<Tensor<T>> {
object: tensor,
partition_shape,
partition_strides,
}
}
}
impl<T: WithDType> IntoPartition for Tensor<T> {
fn partition<const RANK: usize>(self, partition_shape: [i32; RANK]) -> Partition<Tensor<T>> {
let partition_shape = partition_shape.to_vec();
let partition_strides = self.strides.clone();
Partition::<Tensor<T>> {
object: self,
partition_shape,
partition_strides,
}
}
}
pub trait CopyToDevice {
fn copy_to_device<T: WithDType>(
self: &Arc<Self>,
) -> impl DeviceOperation<Output = Arc<Tensor<T>>>;
}
pub trait CopyToHost {
fn copy_to_host(self) -> impl DeviceOperation<Output = candle_core::Tensor>;
}
impl CopyToDevice for candle_core::Tensor {
fn copy_to_device<T: WithDType>(
self: &Arc<Self>,
) -> impl DeviceOperation<Output = Arc<Tensor<T>>> {
copy_to_device(self).arc()
}
}
pub trait CopyToDeviceTensor<T: WithDType> {
fn copy_to_device_tensor(self: &Arc<Self>) -> impl DeviceOperation<Output = Tensor<T>>;
}
impl<T: WithDType> CopyToDeviceTensor<T> for Vec<T> {
fn copy_to_device_tensor(self: &Arc<Self>) -> impl DeviceOperation<Output = Tensor<T>> {
copy_host_vec_to_device(self)
}
}
impl<T: WithDType> CopyToHost for &Arc<Tensor<T>> {
fn copy_to_host(self) -> impl DeviceOperation<Output = candle_core::Tensor> {
copy_to_host(self)
}
}
impl<T: WithDType> CopyToHost for Tensor<T> {
fn copy_to_host(self) -> impl DeviceOperation<Output = candle_core::Tensor> {
copy_to_host(&Arc::new(self))
}
}
pub trait Unpartition<T: WithDType> {
fn unpartition(self) -> impl DeviceOperation<Output = Tensor<T>>;
}
impl<T: WithDType, DI: DeviceOperation<Output = Partition<Tensor<T>>>> Unpartition<T> for DI {
fn unpartition(self) -> impl DeviceOperation<Output = Tensor<T>> {
UnwrapPartition { op: self }
}
}
#[derive(Clone, Debug)]
pub struct DeviceVec<T> {
_ty: PhantomData<T>,
host_vec: Vec<Arc<T>>,
device_vec: Arc<Tensor<i64>>,
}
impl<T: WithDType> DeviceVec<Tensor<T>> {
pub fn from(v: Vec<Tensor<T>>) -> DeviceVec<Tensor<T>> {
let i64vec: Arc<Vec<i64>> = v
.iter()
.map(|x| x.cu_deviceptr() as i64)
.collect::<Vec<_>>()
.into();
let device_vec: Arc<Tensor<i64>> = i64vec
.copy_to_device_tensor()
.sync()
.expect("Failed to execute device operation.")
.reshape([v.len()])
.into();
let host_vec: Vec<Arc<Tensor<T>>> = v.into_iter().map(Arc::new).collect::<Vec<_>>();
DeviceVec {
_ty: PhantomData,
host_vec,
device_vec,
}
}
pub fn len(&self) -> usize {
self.host_vec.len()
}
pub unsafe fn inner(&self) -> &Arc<Tensor<i64>> {
&self.device_vec
}
}
impl<T: WithDType> From<Vec<Tensor<T>>> for DeviceVec<Tensor<T>> {
fn from(v: Vec<Tensor<T>>) -> Self {
DeviceVec::from(v)
}
}
impl<T: WithDType> Index<usize> for DeviceVec<Tensor<T>> {
type Output = Arc<Tensor<T>>;
fn index(&self, index: usize) -> &Self::Output {
&self.host_vec[index]
}
}
pub struct DeviceVecIntoIter<Item> {
items: DeviceVec<Item>,
}
impl<T: WithDType + Debug> Iterator for DeviceVecIntoIter<Tensor<T>> {
type Item = Tensor<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.items.len() > 0 {
let x = self.items.host_vec.remove(0);
let x = Arc::try_unwrap(x).expect("Unable to perform into_iter from non-unique Arc.");
Some(x)
} else {
None
}
}
}
impl<T: WithDType + Debug> IntoIterator for DeviceVec<Tensor<T>> {
type Item = Tensor<T>;
type IntoIter = DeviceVecIntoIter<Tensor<T>>;
fn into_iter(self) -> Self::IntoIter {
DeviceVecIntoIter { items: self }
}
}