kn_cuda_eval/autokernel/
softmax.rs

1use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
2use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
3use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
4use kn_cuda_sys::wrapper::status::Status;
5use kn_graph::dtype::DisplayCFloat;
6
7use crate::autokernel::common::{
8    c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
9};
10use crate::device_tensor::DeviceTensor;
11use crate::shape::StridedShape;
12
13#[derive(Debug)]
14pub struct SoftmaxKernel {
15    input_shape: StridedShape,
16    output_shape: StridedShape,
17
18    _softmax_axis: usize,
19    _input_scale: f32,
20    static_size: usize,
21
22    function: CuFunction,
23}
24
25const SOFTMAX_SOURCE: &str = include_str!("softmax.cu");
26
27impl SoftmaxKernel {
28    pub fn new(
29        device: CudaDevice,
30        input_shape: &StridedShape,
31        output_shape: &StridedShape,
32        softmax_axis: usize,
33        input_scale: f32,
34    ) -> Self {
35        assert_eq!(input_shape.shape(), output_shape.shape());
36
37        let softmax_size = input_shape.shape()[softmax_axis];
38        let static_size = input_shape.size() / softmax_size;
39
40        let input_static = input_shape.remove(softmax_axis);
41        let output_static = output_shape.remove(softmax_axis);
42
43        let static_dense = StridedShape::new_simple(input_static.shape().to_vec());
44
45        let mut static_strides = [input_static.strides().to_vec(), output_static.strides().to_vec()];
46        let mut static_dense_strides = static_dense.strides().to_vec();
47
48        let softmax_strides = [
49            input_shape.strides()[softmax_axis],
50            output_shape.strides()[softmax_axis],
51        ];
52
53        // pad arrays to ensure they never become zero-sized
54        static_strides[0].push(0);
55        static_strides[1].push(0);
56        static_dense_strides.push(1);
57
58        let replacements = vec![
59            ("$RANK$", format!("{}", input_shape.rank())),
60            ("$STATIC_SIZE$", format!("{}", static_size)),
61            ("$SOFTMAX_SIZE$", format!("{}", softmax_size)),
62            ("$INPUT_SCALE$", format!("{}", DisplayCFloat(input_scale as f64))),
63            ("$STATIC_DENSE_STRIDES$", c_array_string(&static_dense_strides)),
64            ("$STATIC_STRIDES$", c_nested_array_string(&static_strides)),
65            ("$SOFTMAX_STRIDES$", c_array_string(&softmax_strides)),
66        ];
67
68        // compile the kernel
69        let source = fill_replacements(SOFTMAX_SOURCE, &replacements);
70        let key = KernelKey {
71            device,
72            source,
73            func_name: "softmax_kernel".to_owned(),
74        };
75        let function = compile_cached_kernel(key);
76
77        // wrap everything up
78        SoftmaxKernel {
79            function,
80            input_shape: input_shape.clone(),
81            output_shape: output_shape.clone(),
82            _softmax_axis: softmax_axis,
83            _input_scale: input_scale,
84            static_size,
85        }
86    }
87
88    pub unsafe fn run(&self, stream: &CudaStream, input: &DeviceTensor, output: &DeviceTensor) {
89        assert_eq!(input.strided_shape(), &self.input_shape);
90        assert_eq!(output.strided_shape(), &self.output_shape);
91
92        let mut args = KernelArgs::new();
93        args.push(input.ptr().ptr());
94        args.push(output.ptr().ptr());
95        let args = args.finish();
96
97        let warps = self.static_size;
98        let warps_per_block = 4;
99        let threads_per_warp = 32;
100
101        let threads_per_block = (threads_per_warp * warps_per_block) as u32;
102        let blocks = ceil_div((warps * threads_per_warp) as u32, threads_per_block as u32);
103
104        self.function
105            .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
106            .unwrap();
107    }
108}