kn_cuda_eval/autokernel/
scalar.rs

1use std::fmt::Write;
2
3use itertools::zip_eq;
4
5use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
6use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
7use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
8use kn_cuda_sys::wrapper::status::Status;
9
10use crate::autokernel::common::{
11    c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
12};
13use crate::device_tensor::DeviceTensor;
14use crate::shape::StridedShape;
15
16/// An instance of a scalar/elementwise kernel. Can be build for any operation, rank, strides, and sizes.
17/// The first axis is runtime-dynamic without recompiling the kernel.
18#[derive(Debug)]
19pub struct ScalarKernel {
20    #[allow(dead_code)]
21    operation: String,
22
23    inner_size: usize,
24    inner_shape: Vec<usize>,
25    operand_types: Vec<String>,
26    operand_strides: Vec<Vec<isize>>,
27
28    function: CuFunction,
29}
30
31const SCALAR_SOURCE: &str = include_str!("scalar.cu");
32
33impl ScalarKernel {
34    /// Compile an instance of a new scalar kernel.
35    ///
36    /// `operation` has the format `*x0 = *x1 + *x2;`.
37    pub fn new(
38        device: CudaDevice,
39        operation: &str,
40        inner_shape: Vec<usize>,
41        operand_types: Vec<String>,
42        operand_strides: Vec<Vec<isize>>,
43    ) -> Self {
44        // TODO try to simplify shape and operand strides if they are contiguous
45        assert!(operand_types.len() > 0);
46        assert_eq!(operand_strides.len(), operand_types.len());
47        for stride in &operand_strides {
48            assert_eq!(stride.len(), inner_shape.len() + 1);
49        }
50
51        assert!(
52            operation.trim_end().ends_with(";"),
53            "Operation should end with ';', got {:?}",
54            operation
55        );
56
57        let full_operation = build_operation(&operand_types, operation);
58
59        let mut full_shape = vec![0];
60        full_shape.extend_from_slice(&inner_shape);
61
62        let dense = StridedShape::new_simple(full_shape.to_vec());
63
64        let replacements = vec![
65            ("$RANK$", format!("{}", dense.rank())),
66            ("$OPERANDS$", format!("{}", operand_strides.len())),
67            ("$STRIDES_DENSE$", c_array_string(dense.strides())),
68            ("$STRIDES$", c_nested_array_string(&operand_strides)),
69            ("$OPERATION$", full_operation.to_owned()),
70        ];
71        let source = fill_replacements(SCALAR_SOURCE, &replacements);
72
73        let key = KernelKey {
74            device,
75            source,
76            func_name: "scalar_kernel".to_owned(),
77        };
78
79        let function = compile_cached_kernel(key);
80        let inner_size = inner_shape.iter().product();
81
82        ScalarKernel {
83            operation: operation.to_owned(),
84            function,
85            inner_size,
86            inner_shape,
87            operand_types,
88            operand_strides,
89        }
90    }
91
92    /// Wrapper around [Self::new] that's a bit easier to use if you know the full shape of the operands up front.
93    pub fn new_for_shapes(
94        device: CudaDevice,
95        operation: &str,
96        shapes: &[StridedShape],
97        operand_types: Vec<String>,
98    ) -> Self {
99        assert!(shapes.len() > 0);
100        let expected_shape = shapes[0].shape();
101        assert!(expected_shape.len() > 0);
102
103        for shape in shapes {
104            assert_eq!(shape.shape(), expected_shape);
105        }
106
107        let inner_shape = shapes[0].shape()[1..].to_vec();
108        let operand_strides = shapes.iter().map(|s| s.strides().to_vec()).collect();
109
110        Self::new(device, operation, inner_shape, operand_types, operand_strides)
111    }
112
113    pub unsafe fn run(&self, stream: &CudaStream, tensors: &[DeviceTensor]) {
114        let items_per_thread = 64;
115        let threads_per_block = 64;
116        self.run_custom(stream, tensors, items_per_thread, threads_per_block);
117    }
118
119    pub unsafe fn run_custom(
120        &self,
121        stream: &CudaStream,
122        tensors: &[DeviceTensor],
123        items_per_thread: u32,
124        threads_per_block: u32,
125    ) {
126        assert_eq!(tensors.len(), self.operand_types.len());
127        //TODO verify tensor types once that's implemented
128
129        let batch_size = tensors[0].strided_shape().shape()[0];
130
131        let mut args = KernelArgs::new();
132        args.push_int(batch_size as i32);
133
134        for (expected_strides, tensor) in zip_eq(&self.operand_strides, tensors) {
135            assert_eq!(1 + self.inner_shape.len(), tensor.strided_shape().rank());
136            assert_eq!(batch_size, tensor.strided_shape().shape()[0]);
137            assert_eq!(self.inner_shape, tensor.strided_shape().shape()[1..]);
138            assert_eq!(expected_strides, tensor.strided_shape().strides());
139
140            args.push(tensor.ptr().ptr());
141        }
142
143        let args = args.finish();
144
145        let items = batch_size * self.inner_size;
146
147        let blocks = ceil_div(items as u32, items_per_thread * threads_per_block);
148
149        // TODO cache all of this so we just have to call launch_kernel at the end?
150        self.function
151            .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
152            .unwrap();
153    }
154}
155
156fn build_operation(operand_types: &[String], operation: &str) -> String {
157    let mut full_operation = String::new();
158    let f = &mut full_operation;
159
160    for (i, ty) in operand_types.iter().enumerate() {
161        writeln!(
162            f,
163            "{ty} *x{i} = &(({ty} *) pointers[{i}])[offsets[{i}]];",
164            ty = ty,
165            i = i
166        )
167        .unwrap();
168    }
169    writeln!(f, "{}", operation).unwrap();
170
171    full_operation
172}