use std::fmt::Debug;
use kn_cuda_sys::bindings::cublasOperation_t;
use kn_cuda_sys::wrapper::group::MatMulOperand;
use kn_graph::dtype::DType;
use kn_graph::graph::SliceRange;
use crate::shape::{StridedShape, ViewError};
pub trait OffsetPtr: Debug + Clone {
fn offset_bytes(self, offset: isize) -> Self;
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct PtrTensor<P> {
ptr: P,
dtype: DType,
shape: StridedShape,
}
impl<P> PtrTensor<P> {
pub fn from_parts(ptr: P, shape: StridedShape, dtype: DType) -> Self {
PtrTensor { ptr, shape, dtype }
}
pub fn into_ptr(self) -> P {
self.ptr
}
pub fn ptr(&self) -> &P {
&self.ptr
}
pub fn strided_shape(&self) -> &StridedShape {
&self.shape
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn dense_size_bytes(&self) -> usize {
self.strided_shape().size() * self.dtype().size().bytes()
}
pub fn map_ptr<K>(self, f: impl FnOnce(P) -> K) -> PtrTensor<K> {
PtrTensor::from_parts(f(self.ptr), self.shape, self.dtype)
}
}
impl<P: OffsetPtr> PtrTensor<P> {
fn offset(&self, offset_elem: isize, shape: StridedShape) -> Self {
let offset_bytes = self.dtype.size().bytes() as isize * offset_elem;
Self::from_parts(self.ptr.clone().offset_bytes(offset_bytes), shape, self.dtype)
}
pub fn permute(&self, permutation: &[usize]) -> Self {
self.offset(0, self.shape.permute(permutation))
}
pub fn view(&self, new_shape: Vec<usize>) -> Result<Self, ViewError> {
self.shape.view(new_shape).map(|shape| self.offset(0, shape))
}
pub fn broadcast(&self, new_shape: Vec<usize>) -> Self {
self.offset(0, self.shape.broadcast(new_shape))
}
pub fn slice(&self, axis: usize, range: impl Into<SliceRange>) -> Self {
let range = range.into();
let result_shape = self.shape.slice(axis, range);
let offset = if result_shape.size() != 0 {
self.strided_shape().strides()[axis] * range.start as isize
} else {
0
};
self.offset(offset, result_shape)
}
pub fn index(&self, axis: usize, index: usize) -> Self {
let mut new_shape = self.shape.shape().to_vec();
new_shape.remove(axis);
self.slice(axis, SliceRange::simple(index, index + 1))
.view(new_shape)
.unwrap()
}
pub fn flip(&self, axis: usize) -> Self {
let result_shape = self.shape.flip(axis);
let axis_len = self.shape.shape()[axis];
let offset = if self.shape.size() != 0 && axis_len != 0 {
(axis_len - 1) as isize * self.shape.strides()[axis]
} else {
0
};
self.offset(offset, result_shape)
}
pub fn repeat_unary(&self, axis: usize, count: usize) -> Self {
let result_shape = self.shape.repeat_unary(axis, count);
self.offset(0, result_shape)
}
}
impl<P: Clone> PtrTensor<P> {
pub fn to_mat_mul_arg(&self) -> MatMulOperand<P> {
assert_eq!(self.strided_shape().rank(), 3);
let (trans, lead_axis) = if self.shape.strides()[1] == 1 {
(cublasOperation_t::CUBLAS_OP_N, 2)
} else if self.shape.strides()[2] == 1 {
(cublasOperation_t::CUBLAS_OP_T, 1)
} else {
panic!(
"GPU matmul operand must be either col- or row-dense, got {:?}",
self.shape
)
};
MatMulOperand {
ptr: self.ptr().clone(),
trans,
ld: self.shape.strides()[lead_axis] as i32,
stride: self.strided_shape().strides()[0] as i64,
}
}
}