kn_cuda_eval/
device_tensor.rs1use bytemuck::{cast_slice, cast_slice_mut};
2
3use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
4use kn_cuda_sys::wrapper::mem::device::DevicePtr;
5use kn_graph::dtype::DType;
6
7use crate::autokernel::scalar::ScalarKernel;
8use crate::offset_tensor::{OffsetPtr, PtrTensor};
9use crate::shape::StridedShape;
10use crate::step::{OperandKind, ScalarOpArgs};
11
12pub type DeviceTensor = PtrTensor<DevicePtr>;
13
14impl OffsetPtr for DevicePtr {
15 fn offset_bytes(self, offset: isize) -> Self {
16 DevicePtr::offset_bytes(self, offset)
17 }
18}
19
20impl DeviceTensor {
21 pub fn alloc_simple(device: CudaDevice, shape: Vec<usize>, dtype: DType) -> Self {
22 let shape = StridedShape::new_simple(shape);
23 let ptr = DevicePtr::alloc(device, shape.size() * dtype.size().bytes());
24 DeviceTensor::from_parts(ptr, shape, dtype)
25 }
26
27 pub fn device(&self) -> CudaDevice {
28 self.ptr().device()
29 }
30
31 pub fn deep_clone(&self) -> DeviceTensor {
32 let new = DeviceTensor::alloc_simple(self.device(), self.strided_shape().shape().to_vec(), self.dtype());
33 unsafe {
34 new.copy_from(self);
35 }
36 new
37 }
38
39 pub unsafe fn copy_simple_from_host(&self, buffer: &[u8]) {
40 assert!(
41 self.strided_shape().has_simple_strides(),
42 "Tensor must have simple strides, got {:?}",
43 self,
44 );
45 assert_eq!(
46 buffer.len(),
47 self.dense_size_bytes(),
48 "Wrong buffer size {} for {:?}",
49 buffer.len(),
50 self,
51 );
52 self.ptr().copy_linear_from_host(cast_slice(buffer));
53 }
54
55 pub unsafe fn copy_simple_to_host(&self, buffer: &mut [u8]) {
56 assert!(
57 self.strided_shape().has_simple_strides(),
58 "Tensor must have simple strides, got {:?}",
59 self,
60 );
61 assert_eq!(
62 self.dense_size_bytes(),
63 buffer.len(),
64 "Wrong buffer size {} for {:?}",
65 buffer.len(),
66 self,
67 );
68 self.ptr().copy_linear_to_host(cast_slice_mut(buffer));
69 }
70
71 pub fn copy_from_as_scalar_op(&self, other: &DeviceTensor) -> ScalarOpArgs<DevicePtr> {
74 assert_eq!(self.device(), other.device(), "Tensors must be on the same device");
75 assert_eq!(self.dtype(), other.dtype(), "Tensors must have the same dtype");
76 let device = self.device();
77 let dtype = self.dtype();
78
79 assert_eq!(
80 self.strided_shape().shape(),
81 other.strided_shape().shape(),
82 "Tensors must have the same shape: {:?} vs {:?}",
83 self,
84 other
85 );
86
87 let dtype_str = dtype.as_c_str();
88 let kernel = ScalarKernel::new_for_shapes(
89 device,
90 "*x0 = *x1",
91 &[self.strided_shape().clone(), other.strided_shape().clone()],
92 vec![dtype_str.to_owned(), dtype_str.to_owned()],
93 );
94
95 ScalarOpArgs {
96 kernel,
97 operands: vec![self.clone(), other.clone()],
98 operand_kinds: vec![OperandKind::Out, OperandKind::In]
99 }
100 }
101
102 pub unsafe fn copy_from(&self, other: &DeviceTensor) {
103 assert_eq!(self.dtype(), other.dtype(), "Tensors must have the same dtype");
104 let dtype = self.dtype();
105 assert_eq!(self.device(), other.device(), "Tensors must be on the same device");
106 let device = self.device();
107
108 assert_eq!(
109 self.strided_shape().shape(),
110 other.strided_shape().shape(),
111 "Tensors must have the same shape: {:?} vs {:?}",
112 self,
113 other
114 );
115
116 if self.strided_shape() == other.strided_shape() && self.strided_shape().has_dense_strides() {
117 self.ptr()
119 .copy_linear_from_device(&other.ptr(), self.strided_shape().size() * dtype.size().bytes())
120 } else {
121 let stream = CudaStream::new(device);
123 self.copy_from_as_scalar_op(&other).run(&stream);
124 stream.synchronize();
125 }
126 }
127
128 pub unsafe fn copy_from_host_staged(&self, buffer: &[u8]) {
131 assert_eq!(self.dtype(), DType::F32, "Only f32 is supported for now");
132
133 assert_eq!(
134 self.strided_shape().size(),
135 buffer.len(),
136 "Wrong buffer size for {:?}",
137 self.strided_shape()
138 );
139
140 if self.strided_shape().has_simple_strides() {
141 self.copy_simple_from_host(buffer);
142 } else {
143 let stage = DeviceTensor::alloc_simple(self.device(), self.strided_shape().shape().to_vec(), self.dtype());
144 stage.copy_simple_from_host(buffer);
145 self.copy_from(&stage);
146 }
147 }
148
149 pub unsafe fn copy_to_host_staged(&self, buffer: &mut [u8]) {
152 assert_eq!(self.dtype(), DType::F32, "Only f32 is supported for now");
153
154 assert_eq!(
155 self.strided_shape().size(),
156 buffer.len(),
157 "Wrong buffer size for {:?}",
158 self.strided_shape()
159 );
160
161 if self.strided_shape().has_simple_strides() {
162 self.copy_simple_to_host(buffer);
163 } else {
164 let stage = DeviceTensor::alloc_simple(self.device(), self.strided_shape().shape().to_vec(), self.dtype());
165 stage.copy_from(self);
166 stage.copy_simple_to_host(buffer);
167 }
168 }
169}