pub mod compare_tensor;
pub mod convolution;
pub mod elemwise;
pub mod index_slicing;
pub mod linalg;
pub mod reduction;
use std::rc::Rc;
use std::cell::RefCell;
#[cfg(feature = "use-serde")]
use serde::{Serialize, Deserialize, Serializer};
use cuda11_cudart_sys::{self, cudaMalloc, cudaStreamCreate, cudaMemcpy, cudaStreamSynchronize, cudaFree, cudaStreamDestroy, cudaMemcpyKind, check_cuda_status, cudaStream_t};
use cuda11_cutensor_sys::{self, cutensorHandle_t, check_cutensor_status, cutensorInit, cudaDataType_t,cutensorOperator_t_CUTENSOR_OP_IDENTITY, cutensorTensorDescriptor_t, cutensorInitTensorDescriptor, cutensorPermutation, cutensorOperator_t_CUTENSOR_OP_ADD,cutensorElementwiseBinary};
use crate::tensor::gen_tensor::{GenTensor};
use crate::tensor::cuda_helper::*;
pub struct CudaTensor {
device_data: *mut f32,
dim: Vec<usize>,
stream: Rc<StreamCell>,
}
impl CudaTensor {
pub fn new() -> CudaTensor {
CudaTensor {
device_data: std::ptr::null_mut(),
dim: Vec::new(),
stream: Rc::new(StreamCell::new()),
}
}
pub fn new_raw(data: &[f32], shape: &[usize]) -> CudaTensor {
let mut device_data: *mut f32 = std::ptr::null_mut();
let elems: usize = shape.iter().product();
if elems != data.len() {
panic!();
}
unsafe {
check_cuda_status(cudaMalloc(&mut device_data as *mut _ as *mut _,
std::mem::size_of::<f32>()*elems));
cudaMemcpy(device_data as *mut _,
data.as_ptr() as *mut _,
std::mem::size_of::<f32>()*elems,
cudaMemcpyKind::cudaMemcpyHostToDevice);
}
CudaTensor {
device_data: device_data,
dim: shape.to_vec(),
stream: Rc::new(StreamCell::new()),
}
}
pub fn new_move(data: Vec::<f32>, shape: Vec::<usize>) -> CudaTensor {
let mut device_data: *mut f32 = std::ptr::null_mut();
let elems: usize = shape.iter().product();
if elems != data.len() {
panic!();
}
unsafe {
check_cuda_status(cudaMalloc(&mut device_data as *mut _ as *mut _,
std::mem::size_of::<f32>()*elems));
cudaMemcpy(device_data as *mut _,
data.as_ptr() as *mut _,
std::mem::size_of::<f32>()*elems,
cudaMemcpyKind::cudaMemcpyHostToDevice);
}
CudaTensor {
device_data: device_data,
dim: shape.to_vec(),
stream: Rc::new(StreamCell::new()),
}
}
pub fn _get_stream(&self) -> cudaStream_t {
self.stream.get_stream().raw_stream()
}
pub fn _flush(&self) {
unsafe {
cudaStreamSynchronize(self._get_stream() as _);
}
}
pub fn _get_cutensor(&self) -> Option<Rc<CudaCutensor>>{
todo!();
}
pub fn to_GenTensor(&self) -> GenTensor<f32> {
let mut data = vec![0.; self.numel()];
self._flush();
unsafe {
cudaMemcpy(data.as_mut_ptr() as *mut _,
self.device_data as *mut _,
std::mem::size_of::<f32>()*self.numel(),
cudaMemcpyKind::cudaMemcpyDeviceToHost);
}
GenTensor::<f32>::new_move(data, self.dim.clone())
}
pub fn from_GenTensor(data: &GenTensor<f32>) -> CudaTensor {
CudaTensor::new_raw(data.get_data(), data.size())
}
pub fn index2dimpos(&self, index: usize) -> Vec::<usize> {
if index >= self.numel() {
panic!("index out of range, {:?}, {:?}", index, self.numel());
}
let mut ret = Vec::new();
let mut reminder = index;
for i in &self.stride() {
ret.push(reminder/i);
reminder %= i;
}
ret
}
pub fn dimpos2index(&self, dimpos: &[usize]) -> usize {
if dimpos.len() != self.dim.len() {
panic!("get expects the same dim self.dim: {:?}, o: {:?}", self.dim, dimpos);
}
for (i, j) in self.dim.iter().zip(dimpos.iter()) {
if j >= i {
panic!("get expects the dim within range self.dim: {:?}, o: {:?}", self.dim, dimpos);
}
}
let mut ret = 0;
for (st, i) in self.stride().iter().zip(dimpos.iter()) {
ret += st*i;
}
ret
}
pub fn zeros(size: &[usize]) -> CudaTensor {
let cap = size.iter().product();
CudaTensor::new_raw(&vec![0.; cap], size)
}
pub fn zeros_like(&self) -> CudaTensor {
let cap = self.dim.iter().product();
CudaTensor::new_raw(&vec![0.; cap], &self.dim)
}
pub fn ones(size: &[usize]) -> CudaTensor {
let cap = size.iter().product();
CudaTensor::new_raw(&vec![1.; cap], size)
}
pub fn ones_like(&self) -> CudaTensor {
let cap = self.dim.iter().product();
CudaTensor::new_raw(&vec![1.; cap], &self.dim)
}
pub fn arange(end: usize) -> CudaTensor {
let mut d: Vec<f32> = vec![0.; end];
for i in 0..end {
d[i] = i as f32;
}
CudaTensor::new_raw(&d, &vec![1])
}
pub fn empty(shape: &[usize]) -> CudaTensor {
let mut device_data: *mut f32 = std::ptr::null_mut();
let elems: usize = shape.iter().product();
unsafe {
check_cuda_status(cudaMalloc(&mut device_data as *mut _ as *mut _,
std::mem::size_of::<f32>()*elems));
}
let mut ret = CudaTensor {
device_data: device_data,
dim: shape.to_vec(),
stream: Rc::new(StreamCell::new()),
};
ret
}
pub fn empty_like(&self) -> CudaTensor {
let mut device_data: *mut f32 = std::ptr::null_mut();
let elems: usize = self.dim.iter().product();
unsafe {
check_cuda_status(cudaMalloc(&mut device_data as *mut _ as *mut _,
std::mem::size_of::<f32>()*elems));
}
let mut ret = CudaTensor {
device_data: device_data,
dim: self.dim.to_vec(),
stream: self.stream.clone(), };
ret
}
pub fn fill(d: f32, shape: &[usize]) -> CudaTensor {
let elems: usize = shape.iter().product();
let d: Vec<f32> = vec![d; elems];
CudaTensor::new_raw(&d, shape)
}
pub fn from_record(&mut self, row: usize, record: &[f32]) -> Result<(), ()> {
if record.len() != self.dim[self.dim.len() - 1] {
Err(())
} else {
unsafe {
cudaMemcpy(((self.device_data as usize)
+ row*self.dim[self.dim.len()-1]*std::mem::size_of::<f32>()) as _,
record.as_ptr() as *mut _,
std::mem::size_of::<f32>()*record.len(),
cudaMemcpyKind::cudaMemcpyHostToDevice);
}
Ok(())
}
}
pub fn stride(&self) -> Vec<usize> {
let mut ret = vec![0; self.dim.len()];
let dsize = ret.len();
for i in 0..dsize {
if i == 0 {
ret[dsize-1] = 1;
} else {
ret[dsize-i-1] = ret[dsize-i]*self.dim[dsize-i];
}
}
ret
}
pub fn get(&self, o: &[usize]) -> f32 {
let index = self.dimpos2index(o);
let mut data: Vec<f32> = vec![0.0];
unsafe {
cudaMemcpy(data.as_mut_ptr() as *mut _,
((self.device_data as usize)
+ std::mem::size_of::<f32>()*index) as *mut _,
std::mem::size_of::<f32>(),
cudaMemcpyKind::cudaMemcpyDeviceToHost);
}
data[0]
}
pub fn set(&mut self, o: &[usize], v: f32) {
let index = self.dimpos2index(o);
let mut data: Vec<f32> = vec![v];
unsafe {
cudaMemcpy(((self.device_data as usize)
+ std::mem::size_of::<f32>()*index) as *mut _,
data.as_mut_ptr() as *mut _,
std::mem::size_of::<f32>(),
cudaMemcpyKind::cudaMemcpyHostToDevice);
}
}
pub fn set_1d(&mut self, o: usize, v: f32) {
let mut data: Vec<f32> = vec![v];
unsafe {
cudaMemcpy(((self.device_data as usize)
+ std::mem::size_of::<f32>()*o) as *mut _,
data.as_mut_ptr() as *mut _,
std::mem::size_of::<f32>(),
cudaMemcpyKind::cudaMemcpyHostToDevice);
}
}
pub fn get_mut(&mut self, o: &[usize]) -> &mut f32 {
unimplemented!("This deprecated, use set()");
}
pub fn get_raw(&self) -> Vec<f32> {
let mut data: Vec<f32> = vec![0.0; self.numel()];
unsafe {
cudaMemcpy(data.as_mut_ptr() as *mut _,
self.device_data as *mut _,
std::mem::size_of::<f32>()*self.numel(),
cudaMemcpyKind::cudaMemcpyDeviceToHost);
}
data
}
pub fn get_u8(&self) -> Option<Vec<u8>> {
self.to_GenTensor().get_u8()
}
pub fn get_scale(&self) -> f32 {
if self.dim.len() <= 1 && self.dim[0] == 1 {
return self.to_GenTensor().get_scale();
} else {
panic!("Only one element tensor can get_scale()");
}
}
pub fn get_n(&self) -> CudaTensor {
CudaTensor::new_raw(&vec![self.dim[0] as f32], &vec![1])
}
pub fn get_c(&self) -> CudaTensor {
CudaTensor::new_raw(&vec![self.dim[1] as f32], &vec![1])
}
pub fn get_d(&self) -> CudaTensor {
if self.dim.len() == 5 {
CudaTensor::new_raw(&vec![self.dim[2] as f32], &vec![1])
} else {
panic!("Bad shape for get_D");
}
}
pub fn get_h(&self) -> CudaTensor {
if self.dim.len() == 5 {
CudaTensor::new_raw(&vec![self.dim[3] as f32], &vec![1])
} else if self.dim.len() == 4 {
CudaTensor::new_raw(&vec![self.dim[2] as f32], &vec![1])
} else {
panic!("Bad shape for get_D");
}
}
pub fn get_w(&self) -> CudaTensor {
if self.dim.len() == 5 {
CudaTensor::new_raw(&vec![self.dim[4] as f32], &vec![1])
} else if self.dim.len() == 4 {
CudaTensor::new_raw(&vec![self.dim[3] as f32], &vec![1])
} else {
panic!("Bad shape for get_D");
}
}
pub fn size(&self) -> &Vec<usize> {
&self.dim
}
pub fn get_data(&self) -> &Vec<f32> {
unimplemented!("tensor on device cannot get mut reference");
}
pub fn get_data_mut(&mut self) -> &mut Vec<f32> {
unimplemented!("tensor on device cannot get mut reference");
}
pub fn _get_device_data(&self) -> *mut f32 {
self.device_data
}
pub fn numel(&self) -> usize {
self.dim.iter().product()
}
pub fn numel_tensor(&self) -> CudaTensor {
CudaTensor::new_move(vec![self.dim.iter().map(|x| *x as f32).product()], vec![1])
}
pub fn get_patch(&self, range: &[(usize, usize)], step: Option<&[usize]>) -> CudaTensor {
todo!();
}
pub fn set_patch(&mut self, val: &CudaTensor, range: &[(usize, usize)], step: Option<&[usize]>) {
todo!();
}
pub fn add(&self, o: &CudaTensor) -> CudaTensor {
let mut ret = o.clone();
unsafe {
let mut stream: cudaStream_t = self._get_stream();
let mut handle:cutensorHandle_t = std::mem::uninitialized();
check_cutensor_status(cutensorInit(&mut handle as *mut _));
let alpha: f32 = 1.0;
let gamma: f32 = 1.0;
let extent: Vec<i64> = self.size().iter().map(|x| *x as i64).collect();
let mut descA: cutensorTensorDescriptor_t = std::mem::uninitialized();
let mut descC: cutensorTensorDescriptor_t = std::mem::uninitialized();
check_cutensor_status(cutensorInitTensorDescriptor( &mut handle,
&mut descA,
self.size().len() as _,
extent.as_ptr(),
std::ptr::null(),
cudaDataType_t::CUDA_R_32F,
cutensorOperator_t_CUTENSOR_OP_IDENTITY));
check_cutensor_status(cutensorInitTensorDescriptor( &mut handle,
&mut descC,
self.size().len() as _,
extent.as_ptr(),
std::ptr::null(),
cudaDataType_t::CUDA_R_32F,
cutensorOperator_t_CUTENSOR_OP_IDENTITY));
let mut modeA: Vec<i32> = vec![32; self.size().len()];
let mut modeC: Vec<i32> = vec![32; self.size().len()];
for i in 0..self.size().len() {
modeA[i] = modeA[i] + i as i32;
modeC[i] = modeC[i] + i as i32;
}
check_cutensor_status(cutensorElementwiseBinary(&handle,
&alpha as *const _ as _,
self._get_device_data() as _,
&descA as _,
modeA.as_ptr(),
&gamma as *const _ as _,
ret._get_device_data() as _,
&descC as _,
modeC.as_ptr(),
ret._get_device_data() as _,
&descC as _,
modeC.as_ptr(),
cutensorOperator_t_CUTENSOR_OP_ADD,
cudaDataType_t::CUDA_R_32F,
stream as _
));
}
ret
}
pub fn sub(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn mul(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn div(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn mm(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn matmul(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn outer(&self, o: &CudaTensor, avg: Option<bool>) -> CudaTensor {
unimplemented!();
}
pub fn squared_error(t1: &Self, t2: &Self) -> CudaTensor {
unimplemented!();
}
pub fn all_close(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn arg_sort(&self, dim: usize, descending: bool) -> CudaTensor {
unimplemented!();
}
pub fn eq_t(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn equal(&self, o: &CudaTensor) -> bool {
unimplemented!();
}
pub fn ge(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn gt(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn le(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn lt(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
pub fn ne(&self, o: &CudaTensor) -> CudaTensor {
unimplemented!();
}
}
impl Drop for CudaTensor {
fn drop(&mut self) {
if self.device_data != std::ptr::null_mut() {
unsafe {
check_cuda_status(cudaFree(self.device_data as _));
}
}
}
}
impl std::fmt::Debug for CudaTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}\n", self.to_GenTensor())
}
}
impl Clone for CudaTensor {
fn clone(&self) -> Self {
let mut device_data: *mut f32 = std::ptr::null_mut();
unsafe {
check_cuda_status(cudaMalloc(&mut device_data as *mut _ as *mut _,
std::mem::size_of::<f32>()*self.numel()));
cudaMemcpy(device_data as _,
self.device_data as _,
std::mem::size_of::<f32>()*self.numel(),
cudaMemcpyKind::cudaMemcpyDeviceToDevice);
}
CudaTensor {
device_data: device_data,
dim: self.dim.clone(),
stream: self.stream.clone()
}
}
}
#[cfg(feature = "use-serde")]
impl Serialize for CudaTensor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let t = self.to_GenTensor();
let mut state = serializer.serialize_struct("CudaTensor", 3)?;
state.end()
}
}
#[cfg(all(test, feature = "use-cuda"))]
mod tests {
use super::*;
#[test]
fn cuda_stream() {
let mut stream = CudaStream::new();
}
#[test]
fn cuda_memcpy() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
}
#[test]
fn cuda_to_GenTensor() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
let local = input.to_GenTensor();
assert_eq!(local.numel(), 9);
assert_eq!(local.get_data().clone(), vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
}
#[test]
fn cuda_from_record() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
input.from_record(1, &vec![11., 12., 13.]);
assert_eq!(input.to_GenTensor().get_data().clone(), vec![1.0, 2.0, 3.0, 11.0, 12.0, 13.0, 7.0, 8.0, 9.0]);
}
#[test]
fn cuda_get() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
assert_eq!(input.get(&vec![0,0,1,1]), 5.);
}
#[test]
fn cuda_set() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
input.set(&vec![0,0,1,1], 15.);
assert_eq!(input.get(&vec![0,0,1,1]), 15.);
}
#[test]
fn cuda_set_1d() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
input.set_1d(4, 15.);
assert_eq!(input.get(&vec![0,0,1,1]), 15.);
}
#[test]
fn cuda_numel() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
assert_eq!(input.numel(), 9);
}
#[test]
fn cuda_add() {
let m1 = CudaTensor::new_raw(&vec![1.,2.,3.,4.,], &vec![2,2]);
let m2 = CudaTensor::new_raw(&vec![1.,2.,3.,4.,], &vec![2,2]);
let m3 = m1.add(&m2);
println!("{:?}", m3);
}
#[test]
fn cuda_clone() {
let mut input = CudaTensor::new_raw(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
&vec![1, 1, 3, 3]);
let input2 = input.clone();
assert_eq!(input2.to_GenTensor(), input.to_GenTensor());
}
}