use crate::kernels::conversion::convert_apply;
use crate::kernels::creation::{arange_apply, eye_apply, full_apply, linspace as linspace_kernel};
use crate::tensor::{IntoPartition, Reshape, Tensor, Unpartition};
use cuda_async::device_buffer::DeviceBuffer;
use cuda_async::device_context::with_default_device_policy;
use cuda_async::device_future::DeviceFuture;
use cuda_async::device_operation::{value, DeviceOp, ExecutionContext, Unzippable1, Unzippable2};
use cuda_async::error::DeviceError;
use cuda_core::curand::{RandNormal, RandUniform, RNG};
use cuda_core::sys::CUdeviceptr;
use cuda_core::DType;
use cuda_core::{memcpy_dtod_async, memcpy_dtoh_async, memcpy_htod_async};
use half::f16;
use std::alloc::{alloc, Layout};
use std::future::IntoFuture;
use std::sync::Arc;
pub struct CopyDeviceToDevice<T: DType> {
_storage: Arc<DeviceBuffer>, src_ptr: CUdeviceptr,
shape: Vec<i32>,
strides: Vec<i32>,
num_elements: usize,
_dtype: std::marker::PhantomData<T>,
}
impl<T: DType> DeviceOp for CopyDeviceToDevice<T> {
type Output = Tensor<T>;
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let num_bytes = self.num_elements * std::mem::size_of::<T>();
let dst = ctx.alloc_async(num_bytes);
memcpy_dtod_async::<T>(dst, self.src_ptr, self.num_elements, ctx.get_cuda_stream());
Ok(Tensor::from_raw_parts(
dst,
num_bytes,
ctx.get_device_id(),
self.shape,
self.strides,
))
}
}
impl<T: DType> IntoFuture for CopyDeviceToDevice<T> {
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, CopyDeviceToDevice<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream)))
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn dup<T: DType>(tensor: &Tensor<T>) -> impl DeviceOp<Output = Tensor<T>> {
CopyDeviceToDevice {
_storage: tensor.storage.clone(),
src_ptr: tensor.cu_deviceptr(),
shape: tensor.shape.clone(),
strides: tensor.strides.clone(),
num_elements: tensor.size(),
_dtype: std::marker::PhantomData,
}
}
pub fn memcpy<T: DType>(dst: &mut Tensor<T>, src: &Tensor<T>) -> Memcpy {
assert_eq!(
src.size(),
dst.size(),
"memcpy: src length ({}) != dst length ({})",
src.size(),
dst.size(),
);
Memcpy {
src_ptr: src.cu_deviceptr(),
dst_ptr: dst.cu_deviceptr(),
len: dst.num_bytes(),
}
}
pub struct Memcpy {
src_ptr: cuda_core::sys::CUdeviceptr,
dst_ptr: cuda_core::sys::CUdeviceptr,
len: usize,
}
impl DeviceOp for Memcpy {
type Output = ();
unsafe fn execute(self, ctx: &ExecutionContext) -> Result<(), DeviceError> {
memcpy_dtod_async::<u8>(self.dst_ptr, self.src_ptr, self.len, ctx.get_cuda_stream());
Ok(())
}
}
impl cuda_async::device_operation::GraphNode for Memcpy {}
impl IntoFuture for Memcpy {
type Output = Result<(), DeviceError>;
type IntoFuture = DeviceFuture<(), Memcpy>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream)))
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
struct CopyDeviceToHostVec<T: DType> {
tensor: Arc<Tensor<T>>,
}
impl<T: DType> DeviceOp for CopyDeviceToHostVec<T> {
type Output = Vec<T>;
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let cu_deviceptr = self.tensor.cu_deviceptr();
let size = self.tensor.size();
let layout = Layout::array::<T>(size).expect("overflow cannot happen");
let async_ptr = unsafe { alloc(layout).cast::<T>() };
memcpy_dtoh_async(async_ptr, cu_deviceptr, size, ctx.get_cuda_stream());
Ok(unsafe { Vec::from_raw_parts(async_ptr, size, size) })
}
}
impl<T: DType> IntoFuture for CopyDeviceToHostVec<T> {
type Output = Result<Vec<T>, DeviceError>;
type IntoFuture = DeviceFuture<Vec<T>, CopyDeviceToHostVec<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream)))
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn copy_device_to_host_vec<T: DType>(
tensor: &Arc<Tensor<T>>,
) -> impl DeviceOp<Output = Vec<T>> {
CopyDeviceToHostVec {
tensor: tensor.clone(),
}
}
struct CopyHostVecToDevice<T: DType> {
vec: Arc<Vec<T>>,
}
impl<T: DType> DeviceOp for CopyHostVecToDevice<T> {
type Output = Tensor<T>;
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let vec = self.vec;
let element_size = std::mem::size_of::<T>();
let num_elements = vec.len();
let shape = vec![num_elements as i32];
let strides = vec![1];
let dptr = ctx.alloc_async(element_size * num_elements);
memcpy_htod_async(dptr, vec.as_ptr(), num_elements, ctx.get_cuda_stream());
Ok(Tensor::from_raw_parts(
dptr,
element_size * num_elements,
ctx.get_device_id(),
shape.clone(),
strides.clone(),
))
}
}
impl<T: DType> IntoFuture for CopyHostVecToDevice<T> {
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, CopyHostVecToDevice<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream)))
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn copy_host_vec_to_device<T: DType>(vec: &Arc<Vec<T>>) -> impl DeviceOp<Output = Tensor<T>> {
CopyHostVecToDevice { vec: vec.clone() }
}
pub fn zeros<T: DType>(shape: &[usize]) -> impl DeviceOp<Output = Tensor<T>> {
full(T::zero(), shape)
}
pub fn ones<T: DType>(shape: &[usize]) -> impl DeviceOp<Output = Tensor<T>> {
full(T::one(), shape)
}
pub fn full<T: DType>(val: T, shape: &[usize]) -> impl DeviceOp<Output = Tensor<T>> {
let shape = shape.to_vec();
let len = shape.iter().product::<usize>();
Tensor::<T>::uninitialized(len).then(move |t| {
let partition_size = 128;
let result = unsafe { t.assume_init() }.partition([partition_size]);
let (_, res) = value((val, result)).then(full_apply).unzip();
res.unpartition().reshape(&shape)
})
}
pub fn fill<T: DType>(tensor: Tensor<T>, val: T) -> impl DeviceOp<Output = Tensor<T>> {
value(tensor).then(move |t| {
let partition_size = 128;
let result = t.partition([partition_size]);
let (_, res) = value((val, result)).then(full_apply).unzip();
res.unpartition()
})
}
pub fn arange<T: DType>(len: usize) -> impl DeviceOp<Output = Tensor<T>> {
Tensor::<T>::uninitialized(len).then(move |t| {
let partition_size = 128;
let result = unsafe { t.assume_init() }.partition([partition_size]);
let res = value((result,)).then(arange_apply).unzip();
res.0.unpartition()
})
}
pub fn linspace(start: f32, stop: f32, n: usize) -> impl DeviceOp<Output = Tensor<f32>> {
let step = if n > 1 {
(stop - start) / (n - 1) as f32
} else {
0.0
};
Tensor::<f32>::uninitialized(n).then(move |t| {
let partition_size = 128;
let result = unsafe { t.assume_init() }.partition([partition_size]);
linspace_kernel(result, start, step)
.then(|(tensor, _, _)| value(tensor))
.unpartition()
})
}
pub fn eye(n: usize) -> impl DeviceOp<Output = Tensor<f32>> {
eye_rect(n, n)
}
pub fn eye_rect(rows: usize, cols: usize) -> impl DeviceOp<Output = Tensor<f32>> {
let len = rows * cols;
let br = 16;
let bc = 16;
Tensor::<f32>::uninitialized(len).then(move |t| {
let t2d = unsafe { t.assume_init() }
.reshape(&[rows, cols])
.expect("eye: reshape failed");
let result = t2d.partition([br, bc]);
let res = value((result,)).then(eye_apply).unzip();
res.0.unpartition()
})
}
pub fn convert<FromType: DType, ToType: DType>(
src: Arc<Tensor<FromType>>,
) -> impl DeviceOp<Output = Tensor<ToType>> {
let len = src.shape.clone().iter().product::<i32>() as usize;
Tensor::<ToType>::uninitialized(len).then(move |t| {
let partition_size = 128;
let dst = unsafe { t.assume_init() }.partition([partition_size]);
let res = value((src.clone(), dst)).then(convert_apply).unzip();
res.1
.unpartition()
.reshape(&src.shape.iter().map(|x| *x as usize).collect::<Vec<_>>())
})
}
pub fn randn<T: DType + RandNormal, const RANK: usize>(
mean: T,
std: T,
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOp<Output = Tensor<T>> {
let len = shape.iter().product::<usize>();
Tensor::<T>::uninitialized(len).then(move |t| unsafe {
let t = t.assume_init();
let rng = RNG::new(seed);
T::generate_normal(&rng, t.cu_deviceptr(), len, mean, std);
value(t.reshape_unchecked(&shape))
})
}
pub fn randn_f16<const RANK: usize>(
mean: f16,
std: f16,
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOp<Output = Tensor<f16>> {
let len = shape.clone().iter().product::<usize>();
randn(mean.to_f32(), std.to_f32(), [len], seed).then(move |src_tensor| {
let dst = Tensor::<f16>::uninitialized(len);
dst.then(move |dst_tensor| {
let partition_size = 128;
let dst = unsafe { dst_tensor.assume_init() }.partition([partition_size]);
let res = value((Arc::new(src_tensor), dst))
.then(convert_apply)
.unzip();
res.1.unpartition().reshape(&shape.to_vec())
})
})
}
pub fn rand<T: DType + RandUniform, const RANK: usize>(
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOp<Output = Tensor<T>> {
let len = shape.iter().product::<usize>();
Tensor::<T>::uninitialized(len).then(move |t| unsafe {
let t = t.assume_init();
let rng = RNG::new(seed);
T::generate_uniform(&rng, t.cu_deviceptr(), len);
value(t.reshape_unchecked(&shape))
})
}
pub struct ReshapeOp<O: Send, DI: DeviceOp<Output = O>> {
shape: Vec<usize>,
input: DI,
}
impl<T: DType, DI: DeviceOp<Output = Tensor<T>>> DeviceOp for ReshapeOp<Tensor<T>, DI> {
type Output = Tensor<T>;
unsafe fn execute(self, context: &ExecutionContext) -> Result<Tensor<T>, DeviceError> {
let tensor = self.input.execute(context)?;
Ok(tensor.reshape_unchecked(&self.shape))
}
}
impl<T: DType, DI: DeviceOp<Output = Tensor<T>>> IntoFuture for ReshapeOp<Tensor<T>, DI> {
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, Self>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream)))
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
impl<T: DType + Send, DI: DeviceOp<Output = Arc<Tensor<T>>>> DeviceOp
for ReshapeOp<Arc<Tensor<T>>, DI>
{
type Output = Arc<Tensor<T>>;
unsafe fn execute(self, context: &ExecutionContext) -> Result<Arc<Tensor<T>>, DeviceError> {
let arc_tensor = self.input.execute(context)?;
arc_tensor
.reshape_shared(&self.shape)
.map_err(|e| DeviceError::Internal(e.to_string()))
}
}
impl<T: DType + Send, DI: DeviceOp<Output = Arc<Tensor<T>>>> IntoFuture
for ReshapeOp<Arc<Tensor<T>>, DI>
{
type Output = Result<Arc<Tensor<T>>, DeviceError>;
type IntoFuture = DeviceFuture<Arc<Tensor<T>>, Self>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream)))
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub trait DeviceOpReshape<T: DType>: DeviceOp<Output = Tensor<T>> + Sized {
fn reshape(self, shape: &[usize]) -> ReshapeOp<Tensor<T>, Self> {
ReshapeOp {
shape: shape.to_vec(),
input: self,
}
}
}
impl<T: DType, DI: DeviceOp<Output = Tensor<T>>> DeviceOpReshape<T> for DI {}
pub trait DeviceOpReshapeShared<T: DType + Send>:
DeviceOp<Output = Arc<Tensor<T>>> + Sized
{
fn reshape(self, shape: &[usize]) -> ReshapeOp<Arc<Tensor<T>>, Self> {
ReshapeOp {
shape: shape.to_vec(),
input: self,
}
}
}
impl<T: DType + Send, DI: DeviceOp<Output = Arc<Tensor<T>>>> DeviceOpReshapeShared<T> for DI {}