kn_cuda_eval/autokernel/
softmax.rs1use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
2use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
3use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
4use kn_cuda_sys::wrapper::status::Status;
5use kn_graph::dtype::DisplayCFloat;
6
7use crate::autokernel::common::{
8 c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
9};
10use crate::device_tensor::DeviceTensor;
11use crate::shape::StridedShape;
12
13#[derive(Debug)]
14pub struct SoftmaxKernel {
15 input_shape: StridedShape,
16 output_shape: StridedShape,
17
18 _softmax_axis: usize,
19 _input_scale: f32,
20 static_size: usize,
21
22 function: CuFunction,
23}
24
25const SOFTMAX_SOURCE: &str = include_str!("softmax.cu");
26
27impl SoftmaxKernel {
28 pub fn new(
29 device: CudaDevice,
30 input_shape: &StridedShape,
31 output_shape: &StridedShape,
32 softmax_axis: usize,
33 input_scale: f32,
34 ) -> Self {
35 assert_eq!(input_shape.shape(), output_shape.shape());
36
37 let softmax_size = input_shape.shape()[softmax_axis];
38 let static_size = input_shape.size() / softmax_size;
39
40 let input_static = input_shape.remove(softmax_axis);
41 let output_static = output_shape.remove(softmax_axis);
42
43 let static_dense = StridedShape::new_simple(input_static.shape().to_vec());
44
45 let mut static_strides = [input_static.strides().to_vec(), output_static.strides().to_vec()];
46 let mut static_dense_strides = static_dense.strides().to_vec();
47
48 let softmax_strides = [
49 input_shape.strides()[softmax_axis],
50 output_shape.strides()[softmax_axis],
51 ];
52
53 static_strides[0].push(0);
55 static_strides[1].push(0);
56 static_dense_strides.push(1);
57
58 let replacements = vec![
59 ("$RANK$", format!("{}", input_shape.rank())),
60 ("$STATIC_SIZE$", format!("{}", static_size)),
61 ("$SOFTMAX_SIZE$", format!("{}", softmax_size)),
62 ("$INPUT_SCALE$", format!("{}", DisplayCFloat(input_scale as f64))),
63 ("$STATIC_DENSE_STRIDES$", c_array_string(&static_dense_strides)),
64 ("$STATIC_STRIDES$", c_nested_array_string(&static_strides)),
65 ("$SOFTMAX_STRIDES$", c_array_string(&softmax_strides)),
66 ];
67
68 let source = fill_replacements(SOFTMAX_SOURCE, &replacements);
70 let key = KernelKey {
71 device,
72 source,
73 func_name: "softmax_kernel".to_owned(),
74 };
75 let function = compile_cached_kernel(key);
76
77 SoftmaxKernel {
79 function,
80 input_shape: input_shape.clone(),
81 output_shape: output_shape.clone(),
82 _softmax_axis: softmax_axis,
83 _input_scale: input_scale,
84 static_size,
85 }
86 }
87
88 pub unsafe fn run(&self, stream: &CudaStream, input: &DeviceTensor, output: &DeviceTensor) {
89 assert_eq!(input.strided_shape(), &self.input_shape);
90 assert_eq!(output.strided_shape(), &self.output_shape);
91
92 let mut args = KernelArgs::new();
93 args.push(input.ptr().ptr());
94 args.push(output.ptr().ptr());
95 let args = args.finish();
96
97 let warps = self.static_size;
98 let warps_per_block = 4;
99 let threads_per_warp = 32;
100
101 let threads_per_block = (threads_per_warp * warps_per_block) as u32;
102 let blocks = ceil_div((warps * threads_per_warp) as u32, threads_per_block as u32);
103
104 self.function
105 .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
106 .unwrap();
107 }
108}