kn_cuda_eval/
device_tensor.rs

1use bytemuck::{cast_slice, cast_slice_mut};
2
3use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
4use kn_cuda_sys::wrapper::mem::device::DevicePtr;
5use kn_graph::dtype::DType;
6
7use crate::autokernel::scalar::ScalarKernel;
8use crate::offset_tensor::{OffsetPtr, PtrTensor};
9use crate::shape::StridedShape;
10use crate::step::{OperandKind, ScalarOpArgs};
11
12pub type DeviceTensor = PtrTensor<DevicePtr>;
13
14impl OffsetPtr for DevicePtr {
15    fn offset_bytes(self, offset: isize) -> Self {
16        DevicePtr::offset_bytes(self, offset)
17    }
18}
19
20impl DeviceTensor {
21    pub fn alloc_simple(device: CudaDevice, shape: Vec<usize>, dtype: DType) -> Self {
22        let shape = StridedShape::new_simple(shape);
23        let ptr = DevicePtr::alloc(device, shape.size() * dtype.size().bytes());
24        DeviceTensor::from_parts(ptr, shape, dtype)
25    }
26
27    pub fn device(&self) -> CudaDevice {
28        self.ptr().device()
29    }
30
31    pub fn deep_clone(&self) -> DeviceTensor {
32        let new = DeviceTensor::alloc_simple(self.device(), self.strided_shape().shape().to_vec(), self.dtype());
33        unsafe {
34            new.copy_from(self);
35        }
36        new
37    }
38
39    pub unsafe fn copy_simple_from_host(&self, buffer: &[u8]) {
40        assert!(
41            self.strided_shape().has_simple_strides(),
42            "Tensor must have simple strides, got {:?}",
43            self,
44        );
45        assert_eq!(
46            buffer.len(),
47            self.dense_size_bytes(),
48            "Wrong buffer size {} for {:?}",
49            buffer.len(),
50            self,
51        );
52        self.ptr().copy_linear_from_host(cast_slice(buffer));
53    }
54
55    pub unsafe fn copy_simple_to_host(&self, buffer: &mut [u8]) {
56        assert!(
57            self.strided_shape().has_simple_strides(),
58            "Tensor must have simple strides, got {:?}",
59            self,
60        );
61        assert_eq!(
62            self.dense_size_bytes(),
63            buffer.len(),
64            "Wrong buffer size {} for {:?}",
65            buffer.len(),
66            self,
67        );
68        self.ptr().copy_linear_to_host(cast_slice_mut(buffer));
69    }
70
71    // TODO ideally we would decay to memcpy if possible
72    //   but callers can already do that, this is this fallback!
73    pub fn copy_from_as_scalar_op(&self, other: &DeviceTensor) -> ScalarOpArgs<DevicePtr> {
74        assert_eq!(self.device(), other.device(), "Tensors must be on the same device");
75        assert_eq!(self.dtype(), other.dtype(), "Tensors must have the same dtype");
76        let device = self.device();
77        let dtype = self.dtype();
78
79        assert_eq!(
80            self.strided_shape().shape(),
81            other.strided_shape().shape(),
82            "Tensors must have the same shape: {:?} vs {:?}",
83            self,
84            other
85        );
86
87        let dtype_str = dtype.as_c_str();
88        let kernel = ScalarKernel::new_for_shapes(
89            device,
90            "*x0 = *x1",
91            &[self.strided_shape().clone(), other.strided_shape().clone()],
92            vec![dtype_str.to_owned(), dtype_str.to_owned()],
93        );
94
95        ScalarOpArgs {
96            kernel,
97            operands: vec![self.clone(), other.clone()],
98            operand_kinds: vec![OperandKind::Out, OperandKind::In]
99        }
100    }
101
102    pub unsafe fn copy_from(&self, other: &DeviceTensor) {
103        assert_eq!(self.dtype(), other.dtype(), "Tensors must have the same dtype");
104        let dtype = self.dtype();
105        assert_eq!(self.device(), other.device(), "Tensors must be on the same device");
106        let device = self.device();
107
108        assert_eq!(
109            self.strided_shape().shape(),
110            other.strided_shape().shape(),
111            "Tensors must have the same shape: {:?} vs {:?}",
112            self,
113            other
114        );
115
116        if self.strided_shape() == other.strided_shape() && self.strided_shape().has_dense_strides() {
117            // if strides are dense and match we can just do a simple memcpy
118            self.ptr()
119                .copy_linear_from_device(&other.ptr(), self.strided_shape().size() * dtype.size().bytes())
120        } else {
121            // otherwise use the TensorOp restride trick
122            let stream = CudaStream::new(device);
123            self.copy_from_as_scalar_op(&other).run(&stream);
124            stream.synchronize();
125        }
126    }
127
128    /// A (potentially) slower version of [Self::copy_simple_from_host] that works for any strides,
129    /// by potentially copying to an intermediate stage on the device.
130    pub unsafe fn copy_from_host_staged(&self, buffer: &[u8]) {
131        assert_eq!(self.dtype(), DType::F32, "Only f32 is supported for now");
132
133        assert_eq!(
134            self.strided_shape().size(),
135            buffer.len(),
136            "Wrong buffer size for {:?}",
137            self.strided_shape()
138        );
139
140        if self.strided_shape().has_simple_strides() {
141            self.copy_simple_from_host(buffer);
142        } else {
143            let stage = DeviceTensor::alloc_simple(self.device(), self.strided_shape().shape().to_vec(), self.dtype());
144            stage.copy_simple_from_host(buffer);
145            self.copy_from(&stage);
146        }
147    }
148
149    /// A (potentially) slower version of [Self::copy_simple_to_host] that works for any strides,
150    /// by potentially copying to an intermediate stage on the device.
151    pub unsafe fn copy_to_host_staged(&self, buffer: &mut [u8]) {
152        assert_eq!(self.dtype(), DType::F32, "Only f32 is supported for now");
153
154        assert_eq!(
155            self.strided_shape().size(),
156            buffer.len(),
157            "Wrong buffer size for {:?}",
158            self.strided_shape()
159        );
160
161        if self.strided_shape().has_simple_strides() {
162            self.copy_simple_to_host(buffer);
163        } else {
164            let stage = DeviceTensor::alloc_simple(self.device(), self.strided_shape().shape().to_vec(), self.dtype());
165            stage.copy_from(self);
166            stage.copy_simple_to_host(buffer);
167        }
168    }
169}