kn_cuda_eval/
offset_tensor.rs1use std::fmt::Debug;
2
3use kn_cuda_sys::bindings::cublasOperation_t;
4use kn_cuda_sys::wrapper::group::MatMulOperand;
5use kn_graph::dtype::DType;
6use kn_graph::graph::SliceRange;
7
8use crate::shape::{StridedShape, ViewError};
9
10pub trait OffsetPtr: Debug + Clone {
11 fn offset_bytes(self, offset: isize) -> Self;
12}
13
14#[derive(Debug, Clone, Eq, PartialEq)]
16pub struct PtrTensor<P> {
17 ptr: P,
18 dtype: DType,
19 shape: StridedShape,
20}
21
22impl<P> PtrTensor<P> {
23 pub fn from_parts(ptr: P, shape: StridedShape, dtype: DType) -> Self {
24 PtrTensor { ptr, shape, dtype }
25 }
26
27 pub fn into_ptr(self) -> P {
28 self.ptr
29 }
30
31 pub fn ptr(&self) -> &P {
32 &self.ptr
33 }
34
35 pub fn strided_shape(&self) -> &StridedShape {
36 &self.shape
37 }
38
39 pub fn dtype(&self) -> DType {
40 self.dtype
41 }
42
43 pub fn dense_size_bytes(&self) -> usize {
44 self.strided_shape().size() * self.dtype().size().bytes()
45 }
46
47 pub fn map_ptr<K>(self, f: impl FnOnce(P) -> K) -> PtrTensor<K> {
48 PtrTensor::from_parts(f(self.ptr), self.shape, self.dtype)
49 }
50}
51
52impl<P: OffsetPtr> PtrTensor<P> {
53 fn offset(&self, offset_elem: isize, shape: StridedShape) -> Self {
54 let offset_bytes = self.dtype.size().bytes() as isize * offset_elem;
55 Self::from_parts(self.ptr.clone().offset_bytes(offset_bytes), shape, self.dtype)
56 }
57
58 pub fn permute(&self, permutation: &[usize]) -> Self {
59 self.offset(0, self.shape.permute(permutation))
60 }
61
62 pub fn view(&self, new_shape: Vec<usize>) -> Result<Self, ViewError> {
63 self.shape.view(new_shape).map(|shape| self.offset(0, shape))
64 }
65
66 pub fn broadcast(&self, new_shape: Vec<usize>) -> Self {
67 self.offset(0, self.shape.broadcast(new_shape))
68 }
69
70 pub fn slice(&self, axis: usize, range: impl Into<SliceRange>) -> Self {
71 let range = range.into();
72
73 let result_shape = self.shape.slice(axis, range);
75
76 let offset = if result_shape.size() != 0 {
77 self.strided_shape().strides()[axis] * range.start as isize
79 } else {
80 0
81 };
82
83 self.offset(offset, result_shape)
84 }
85
86 pub fn index(&self, axis: usize, index: usize) -> Self {
87 let mut new_shape = self.shape.shape().to_vec();
88 new_shape.remove(axis);
89
90 self.slice(axis, SliceRange::simple(index, index + 1))
91 .view(new_shape)
92 .unwrap()
93 }
94
95 pub fn flip(&self, axis: usize) -> Self {
96 let result_shape = self.shape.flip(axis);
98
99 let axis_len = self.shape.shape()[axis];
100 let offset = if self.shape.size() != 0 && axis_len != 0 {
101 (axis_len - 1) as isize * self.shape.strides()[axis]
103 } else {
104 0
105 };
106
107 self.offset(offset, result_shape)
108 }
109
110 pub fn repeat_unary(&self, axis: usize, count: usize) -> Self {
111 let result_shape = self.shape.repeat_unary(axis, count);
112 self.offset(0, result_shape)
113 }
114}
115
116impl<P: Clone> PtrTensor<P> {
117 pub fn to_mat_mul_arg(&self) -> MatMulOperand<P> {
119 assert_eq!(self.strided_shape().rank(), 3);
120
121 let (trans, lead_axis) = if self.shape.strides()[1] == 1 {
123 (cublasOperation_t::CUBLAS_OP_N, 2)
124 } else if self.shape.strides()[2] == 1 {
125 (cublasOperation_t::CUBLAS_OP_T, 1)
126 } else {
127 panic!(
128 "GPU matmul operand must be either col- or row-dense, got {:?}",
129 self.shape
130 )
131 };
132
133 MatMulOperand {
134 ptr: self.ptr().clone(),
135 trans,
136 ld: self.shape.strides()[lead_axis] as i32,
137 stride: self.strided_shape().strides()[0] as i64,
138 }
139 }
140}