use crate::kernels::conversion::convert_apply;
use crate::kernels::creation::{arange_apply, full_apply};
use crate::tensor::{IntoPartition, Tensor, Unpartition};
use candle_core::{FloatDType, WithDType};
use cuda_async::device_box::DeviceBox;
use cuda_async::device_context::with_default_device_policy;
use cuda_async::device_future::DeviceFuture;
use cuda_async::device_operation::{
value, DeviceOperation, ExecutionContext, Unzippable1, Unzippable2,
};
use cuda_async::error::{device_error, DeviceError};
use cuda_async::scheduling_policies::SchedulingPolicy;
use half::f16;
use cuda_core::curand::RNG;
use cuda_core::{malloc_async, memcpy_dtod_async, memcpy_dtoh_async, memcpy_htod_async};
use std::alloc::{alloc, Layout};
use std::cmp::min;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::sync::Arc;
pub struct CopyDeviceToDevice<T: WithDType + Send> {
tensor: Arc<Tensor<T>>,
}
impl<T: WithDType + Send> DeviceOperation for CopyDeviceToDevice<T> {
type Output = Tensor<T>;
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let tensor = self.tensor;
let shape = tensor.shape.clone();
let strides = tensor.strides.clone();
let element_size = std::mem::size_of::<T>();
let num_elements = tensor.size();
let num_bytes = element_size * num_elements;
let src = tensor.cu_deviceptr();
let dst = malloc_async(num_bytes, ctx.get_cuda_stream());
memcpy_dtod_async::<T>(dst, src, num_elements, ctx.get_cuda_stream());
let device_box = DeviceBox::<[T]>::from_raw_parts(dst, num_elements, ctx.get_device_id());
Ok(Tensor {
device_box,
shape: shape.clone(),
strides: strides.clone(),
})
}
}
impl<T: WithDType + Send> IntoFuture for CopyDeviceToDevice<T> {
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, CopyDeviceToDevice<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn copy<T: WithDType + Send>(
tensor: &Arc<Tensor<T>>,
) -> impl DeviceOperation<Output = Tensor<T>> {
CopyDeviceToDevice {
tensor: tensor.clone(),
}
}
pub struct CopyHostToDevice<T: WithDType + Send> {
dtype: PhantomData<T>,
tensor: Arc<candle_core::Tensor>,
}
impl<T: WithDType + Send> DeviceOperation for CopyHostToDevice<T> {
type Output = Tensor<T>;
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let tensor = self.tensor;
let params = candle_tensor_to_vec::<T>(&tensor);
let (vec, shape, strides) = params;
let element_size = std::mem::size_of::<T>();
let num_elements = vec.len();
let dptr = malloc_async(element_size * num_elements, ctx.get_cuda_stream());
memcpy_htod_async(dptr, vec.as_ptr(), num_elements, ctx.get_cuda_stream());
let device_box = DeviceBox::<[T]>::from_raw_parts(dptr, num_elements, ctx.get_device_id());
Ok(Tensor {
device_box,
shape: shape.clone(),
strides: strides.clone(),
})
}
}
impl<T: WithDType + Send> IntoFuture for CopyHostToDevice<T> {
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, CopyHostToDevice<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn copy_to_device<T: WithDType + Send>(
tensor: &Arc<candle_core::Tensor>,
) -> CopyHostToDevice<T> {
CopyHostToDevice {
tensor: tensor.clone(),
dtype: PhantomData,
}
}
pub struct CopyDeviceToHost<T: WithDType + Send> {
tensor: Arc<Tensor<T>>,
}
impl<T: WithDType + Send> DeviceOperation for CopyDeviceToHost<T> {
type Output = candle_core::Tensor;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let src = self.tensor.device_box.cu_deviceptr();
let num_elements = self.tensor.size();
let shape: Vec<usize> = self.tensor.shape.iter().map(|x| *x as usize).collect();
let layout = Layout::array::<T>(num_elements).expect("overflow cannot happen");
let dst = alloc(layout).cast::<T>();
memcpy_dtoh_async(dst, src, num_elements, context.get_cuda_stream());
let data = Vec::from_raw_parts(dst, num_elements, num_elements);
let shape = candle_core::Shape::from(shape);
match candle_core::Tensor::from_vec(data, shape, &candle_core::Device::Cpu) {
Ok(tensor) => Ok(tensor),
Err(err) => Err(device_error(
context.get_device_id(),
err.to_string().as_str(),
)),
}
}
}
impl<T: WithDType + Send> IntoFuture for CopyDeviceToHost<T> {
type Output = Result<candle_core::Tensor, DeviceError>;
type IntoFuture = DeviceFuture<candle_core::Tensor, CopyDeviceToHost<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn copy_to_host<T: WithDType>(tensor: &Arc<Tensor<T>>) -> CopyDeviceToHost<T> {
CopyDeviceToHost {
tensor: tensor.clone(),
}
}
struct CopyDeviceToHostVec<T: WithDType + Send> {
tensor: Arc<Tensor<T>>,
}
impl<T: WithDType + Send> DeviceOperation for CopyDeviceToHostVec<T> {
type Output = Vec<T>;
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let cu_deviceptr = self.tensor.device_box.cu_deviceptr();
let size = self.tensor.size();
let layout = Layout::array::<T>(size).expect("overflow cannot happen");
let async_ptr = unsafe { alloc(layout).cast::<T>() };
memcpy_dtoh_async(async_ptr, cu_deviceptr, size, ctx.get_cuda_stream());
Ok(unsafe { Vec::from_raw_parts(async_ptr, size, size) })
}
}
impl<T: WithDType + Send> IntoFuture for CopyDeviceToHostVec<T> {
type Output = Result<Vec<T>, DeviceError>;
type IntoFuture = DeviceFuture<Vec<T>, CopyDeviceToHostVec<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn copy_device_to_host_vec<T: WithDType>(
tensor: &Arc<Tensor<T>>,
) -> impl DeviceOperation<Output = Vec<T>> {
CopyDeviceToHostVec {
tensor: tensor.clone(),
}
}
struct CopyHostVecToDevice<T: WithDType + Send> {
vec: Arc<Vec<T>>,
}
impl<T: WithDType + Send> DeviceOperation for CopyHostVecToDevice<T> {
type Output = Tensor<T>;
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let vec = self.vec;
let element_size = std::mem::size_of::<T>();
let num_elements = vec.len();
let shape = vec![num_elements as i32];
let strides = vec![1];
let dptr = malloc_async(element_size * num_elements, ctx.get_cuda_stream());
memcpy_htod_async(dptr, vec.as_ptr(), num_elements, ctx.get_cuda_stream());
let device_box = DeviceBox::<[T]>::from_raw_parts(dptr, num_elements, ctx.get_device_id());
Ok(Tensor {
device_box,
shape: shape.clone(),
strides: strides.clone(),
})
}
}
impl<T: WithDType + Send> IntoFuture for CopyHostVecToDevice<T> {
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, CopyHostVecToDevice<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn copy_host_vec_to_device<T: WithDType>(
vec: &Arc<Vec<T>>,
) -> impl DeviceOperation<Output = Tensor<T>> {
CopyHostVecToDevice { vec: vec.clone() }
}
pub(crate) fn candle_tensor_to_vec<T: WithDType>(
tensor: &Arc<candle_core::Tensor>,
) -> (Vec<T>, Vec<i32>, Vec<i32>) {
let shape: Vec<i32> = tensor.shape().dims().iter().map(|x| *x as i32).collect();
let strides: Vec<i32> = tensor.stride().iter().map(|x| *x as i32).collect();
let size: usize = tensor.shape().dims().iter().fold(1, |acc, x| acc * x);
let vec = tensor.reshape((size,)).unwrap().to_vec1().unwrap();
(vec, shape, strides)
}
pub fn zeros<const RANK: usize, T: WithDType>(
shape: [usize; RANK],
) -> impl DeviceOperation<Output = Tensor<T>> {
full(T::zero(), shape)
}
pub fn ones<const RANK: usize, T: WithDType>(
shape: [usize; RANK],
) -> impl DeviceOperation<Output = Tensor<T>> {
full(T::one(), shape)
}
pub fn full<const RANK: usize, T: WithDType>(
val: T,
shape: [usize; RANK],
) -> impl DeviceOperation<Output = Tensor<T>> {
let len = shape.iter().product::<usize>();
Tensor::<T>::uninitialized(len).and_then(move |t| {
let partition_size = min(len, 128);
let result = unsafe { t.assume_init() }.partition([partition_size as i32]);
let (_, res) = value((val, result)).apply(full_apply).unzip();
res.unpartition().reshape::<RANK>(shape)
})
}
pub fn fill<T: WithDType>(tensor: Tensor<T>, val: T) -> impl DeviceOperation<Output = Tensor<T>> {
value(tensor).and_then(move |t| {
let len = t.shape.iter().product::<i32>() as usize;
let partition_size = min(len, 128);
let result = t.partition([partition_size as i32]);
let (_, res) = value((val, result)).apply(full_apply).unzip();
res.unpartition()
})
}
pub fn arange<T: WithDType>(len: usize) -> impl DeviceOperation<Output = Tensor<T>> {
Tensor::<T>::uninitialized(len).and_then(move |t| {
let partition_size = min(len, 128);
let result = unsafe { t.assume_init() }.partition([partition_size as i32]);
let res = value((result,)).apply(arange_apply).unzip();
res.0.unpartition()
})
}
pub fn convert<FromType: WithDType, ToType: WithDType>(
src: Arc<Tensor<FromType>>,
) -> impl DeviceOperation<Output = Tensor<ToType>> {
let len = src.shape.clone().iter().product::<i32>() as usize;
Tensor::<ToType>::uninitialized(len).and_then(move |t| {
let partition_size = min(len, 128);
let dst = unsafe { t.assume_init() }.partition([partition_size as i32]);
let res = value((src.clone(), dst)).apply(convert_apply).unzip();
res.1
.unpartition()
.reshape_dyn(src.shape.iter().map(|x| *x as usize).collect::<Vec<_>>())
})
}
pub fn randn_f16<const RANK: usize>(
mean: f16,
std: f16,
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOperation<Output = Tensor<f16>> {
let len = shape.clone().iter().product::<usize>();
randn_f32(mean.to_f32(), std.to_f32(), [len], seed).and_then(move |src_tensor| {
let dst = Tensor::<f16>::uninitialized(len);
dst.and_then(move |dst_tensor| {
let partition_size = min(len, 128);
let dst = unsafe { dst_tensor.assume_init() }.partition([partition_size as i32]);
let res = value((Arc::new(src_tensor), dst))
.apply(convert_apply)
.unzip();
res.1.unpartition().reshape_dyn(shape.to_vec())
})
})
}
pub fn randn_f32<const RANK: usize>(
mean: f32,
std: f32,
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOperation<Output = Tensor<f32>> {
let len = shape.iter().product::<usize>();
Tensor::<f32>::uninitialized(len).and_then(move |t| unsafe {
let t = t.assume_init();
let rng = RNG::new(seed);
rng.generate_normal_f32(t.cu_deviceptr(), len, mean, std);
value(t.reshape::<RANK>(shape))
})
}
pub fn randn_f64<const RANK: usize>(
mean: f64,
std: f64,
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOperation<Output = Tensor<f64>> {
let len = shape.iter().product::<usize>();
Tensor::<f64>::uninitialized(len).and_then(move |t| unsafe {
let t = t.assume_init();
let rng = RNG::new(seed);
rng.generate_normal_f64(t.cu_deviceptr(), len, mean, std);
value(t.reshape::<RANK>(shape))
})
}
pub fn rand_f32<const RANK: usize>(
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOperation<Output = Tensor<f32>> {
let len = shape.iter().product::<usize>();
Tensor::<f32>::uninitialized(len).and_then(move |t| unsafe {
let t = t.assume_init();
let rng = RNG::new(seed);
rng.generate_uniform_f32(t.cu_deviceptr(), len);
value(t.reshape::<RANK>(shape))
})
}
pub fn rand_f64<const RANK: usize>(
shape: [usize; RANK],
seed: Option<u64>,
) -> impl DeviceOperation<Output = Tensor<f64>> {
let len = shape.iter().product::<usize>();
Tensor::<f64>::uninitialized(len).and_then(move |t| unsafe {
let t = t.assume_init();
let rng = RNG::new(seed);
rng.generate_uniform_f64(t.cu_deviceptr(), len);
value(t.reshape::<RANK>(shape))
})
}
pub fn randn<const RANK: usize, T: FloatDType>(
mean: T,
std: T,
shape: [usize; RANK],
) -> impl DeviceOperation<Output = Tensor<T>> {
let t = candle_core::Tensor::randn(mean, std, &shape, &candle_core::Device::Cpu)
.expect("randn failed.");
copy_to_device(&Arc::new(t))
}
pub struct Reshape<const RANK: usize, T: WithDType + Send, DI: DeviceOperation<Output = Tensor<T>>>
{
shape: [usize; RANK],
input: DI,
}
impl<const RANK: usize, T: WithDType + Send, DI: DeviceOperation<Output = Tensor<T>>>
DeviceOperation for Reshape<RANK, T, DI>
{
type Output = Tensor<T>;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let tensor = self.input.execute(context)?;
Ok(tensor.reshape(self.shape))
}
}
impl<const RANK: usize, T: WithDType + Send, DI: DeviceOperation<Output = Tensor<T>>> IntoFuture
for Reshape<RANK, T, DI>
{
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, Reshape<RANK, T, DI>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub trait DeviceOperationReshape<T, DI>
where
T: Send + WithDType,
DI: DeviceOperation<Output = Tensor<T>>,
{
fn reshape<const RANK: usize>(self, shape: [usize; RANK]) -> Reshape<RANK, T, DI>;
}
impl<T, DI> DeviceOperationReshape<T, DI> for DI
where
T: Send + WithDType,
DI: DeviceOperation<Output = Tensor<T>>,
{
fn reshape<const RANK: usize>(self, shape: [usize; RANK]) -> Reshape<RANK, T, DI>
where
Self: Sized,
{
Reshape::<RANK, T, DI> { shape, input: self }
}
}
pub struct DynamicReshape<T: WithDType + Send, DI: DeviceOperation<Output = Tensor<T>>> {
shape: Vec<usize>,
input: DI,
}
impl<T: WithDType + Send, DI: DeviceOperation<Output = Tensor<T>>> DeviceOperation
for DynamicReshape<T, DI>
{
type Output = Tensor<T>;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let tensor = self.input.execute(context)?;
Ok(tensor.reshape_dyn(&self.shape))
}
}
impl<T: WithDType + Send, DI: DeviceOperation<Output = Tensor<T>>> IntoFuture
for DynamicReshape<T, DI>
{
type Output = Result<Tensor<T>, DeviceError>;
type IntoFuture = DeviceFuture<Tensor<T>, DynamicReshape<T, DI>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub trait DeviceOperationDynamicReshape<T, DI>
where
T: Send + WithDType,
DI: DeviceOperation<Output = Tensor<T>>,
{
fn reshape_dyn(self, shape: Vec<usize>) -> DynamicReshape<T, DI>;
}
impl<T, DI> DeviceOperationDynamicReshape<T, DI> for DI
where
T: Send + WithDType,
DI: DeviceOperation<Output = Tensor<T>>,
{
fn reshape_dyn(self, shape: Vec<usize>) -> DynamicReshape<T, DI>
where
Self: Sized,
{
DynamicReshape::<T, DI> { shape, input: self }
}
}