kn_cuda_eval/
offset_tensor.rs

1use std::fmt::Debug;
2
3use kn_cuda_sys::bindings::cublasOperation_t;
4use kn_cuda_sys::wrapper::group::MatMulOperand;
5use kn_graph::dtype::DType;
6use kn_graph::graph::SliceRange;
7
8use crate::shape::{StridedShape, ViewError};
9
10pub trait OffsetPtr: Debug + Clone {
11    fn offset_bytes(self, offset: isize) -> Self;
12}
13
14/// A generic Tensor representation.
15#[derive(Debug, Clone, Eq, PartialEq)]
16pub struct PtrTensor<P> {
17    ptr: P,
18    dtype: DType,
19    shape: StridedShape,
20}
21
22impl<P> PtrTensor<P> {
23    pub fn from_parts(ptr: P, shape: StridedShape, dtype: DType) -> Self {
24        PtrTensor { ptr, shape, dtype }
25    }
26
27    pub fn into_ptr(self) -> P {
28        self.ptr
29    }
30
31    pub fn ptr(&self) -> &P {
32        &self.ptr
33    }
34
35    pub fn strided_shape(&self) -> &StridedShape {
36        &self.shape
37    }
38
39    pub fn dtype(&self) -> DType {
40        self.dtype
41    }
42
43    pub fn dense_size_bytes(&self) -> usize {
44        self.strided_shape().size() * self.dtype().size().bytes()
45    }
46
47    pub fn map_ptr<K>(self, f: impl FnOnce(P) -> K) -> PtrTensor<K> {
48        PtrTensor::from_parts(f(self.ptr), self.shape, self.dtype)
49    }
50}
51
52impl<P: OffsetPtr> PtrTensor<P> {
53    fn offset(&self, offset_elem: isize, shape: StridedShape) -> Self {
54        let offset_bytes = self.dtype.size().bytes() as isize * offset_elem;
55        Self::from_parts(self.ptr.clone().offset_bytes(offset_bytes), shape, self.dtype)
56    }
57
58    pub fn permute(&self, permutation: &[usize]) -> Self {
59        self.offset(0, self.shape.permute(permutation))
60    }
61
62    pub fn view(&self, new_shape: Vec<usize>) -> Result<Self, ViewError> {
63        self.shape.view(new_shape).map(|shape| self.offset(0, shape))
64    }
65
66    pub fn broadcast(&self, new_shape: Vec<usize>) -> Self {
67        self.offset(0, self.shape.broadcast(new_shape))
68    }
69
70    pub fn slice(&self, axis: usize, range: impl Into<SliceRange>) -> Self {
71        let range = range.into();
72
73        // use the new shape & strides (which only change along `axis`)
74        let result_shape = self.shape.slice(axis, range);
75
76        let offset = if result_shape.size() != 0 {
77            // offset initial pointer to account for `start`
78            self.strided_shape().strides()[axis] * range.start as isize
79        } else {
80            0
81        };
82
83        self.offset(offset, result_shape)
84    }
85
86    pub fn index(&self, axis: usize, index: usize) -> Self {
87        let mut new_shape = self.shape.shape().to_vec();
88        new_shape.remove(axis);
89
90        self.slice(axis, SliceRange::simple(index, index + 1))
91            .view(new_shape)
92            .unwrap()
93    }
94
95    pub fn flip(&self, axis: usize) -> Self {
96        // invert the axis stride
97        let result_shape = self.shape.flip(axis);
98
99        let axis_len = self.shape.shape()[axis];
100        let offset = if self.shape.size() != 0 && axis_len != 0 {
101            // offset so index 0 gets the last element along the axis
102            (axis_len - 1) as isize * self.shape.strides()[axis]
103        } else {
104            0
105        };
106
107        self.offset(offset, result_shape)
108    }
109
110    pub fn repeat_unary(&self, axis: usize, count: usize) -> Self {
111        let result_shape = self.shape.repeat_unary(axis, count);
112        self.offset(0, result_shape)
113    }
114}
115
116impl<P: Clone> PtrTensor<P> {
117    //TODO move this somewhere else, this is pretty random
118    pub fn to_mat_mul_arg(&self) -> MatMulOperand<P> {
119        assert_eq!(self.strided_shape().rank(), 3);
120
121        // prefer col-major in case of a tie, since cublas likes that more
122        let (trans, lead_axis) = if self.shape.strides()[1] == 1 {
123            (cublasOperation_t::CUBLAS_OP_N, 2)
124        } else if self.shape.strides()[2] == 1 {
125            (cublasOperation_t::CUBLAS_OP_T, 1)
126        } else {
127            panic!(
128                "GPU matmul operand must be either col- or row-dense, got {:?}",
129                self.shape
130            )
131        };
132
133        MatMulOperand {
134            ptr: self.ptr().clone(),
135            trans,
136            ld: self.shape.strides()[lead_axis] as i32,
137            stride: self.strided_shape().strides()[0] as i64,
138        }
139    }
140}