kn_cuda_eval/autokernel/
layernorm.rs

1use std::ptr::null_mut;
2
3use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
4use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
5use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
6use kn_cuda_sys::wrapper::status::Status;
7use kn_graph::dtype::DisplayCFloat;
8
9use crate::autokernel::common::{
10    c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
11};
12use crate::device_tensor::DeviceTensor;
13use crate::shape::StridedShape;
14
15#[derive(Debug)]
16pub struct LayernormKernel {
17    input_shape: StridedShape,
18    output_shape: StridedShape,
19
20    _norm_axis: usize,
21    static_size: usize,
22
23    _eps: f32,
24    _alpha0: f32,
25    _alpha1: f32,
26    _beta: f32,
27
28    function: CuFunction,
29}
30
31const LAYERNORM_SOURCE: &str = include_str!("layernorm.cu");
32
33impl LayernormKernel {
34    pub fn new(
35        device: CudaDevice,
36        input_shape: &StridedShape,
37        output_shape: &StridedShape,
38        norm_axis: usize,
39        eps: f32,
40        alpha_0: f32,
41        alpha_1: f32,
42        beta: f32,
43    ) -> Self {
44        assert_eq!(input_shape.shape(), output_shape.shape());
45
46        let norm_size = input_shape.shape()[norm_axis];
47        let static_size = input_shape.size() / norm_size;
48
49        let input_static = input_shape.remove(norm_axis);
50        let output_static = output_shape.remove(norm_axis);
51
52        let static_dense = StridedShape::new_simple(input_static.shape().to_vec());
53
54        let mut static_strides = [input_static.strides().to_vec(), output_static.strides().to_vec()];
55        let mut static_dense_strides = static_dense.strides().to_vec();
56
57        let norm_strides = [input_shape.strides()[norm_axis], output_shape.strides()[norm_axis]];
58
59        // pad arrays to ensure they never become zero-sized
60        static_strides[0].push(0);
61        static_strides[1].push(0);
62        static_dense_strides.push(1);
63
64        let replacements = vec![
65            ("$RANK$", format!("{}", input_shape.rank())),
66            ("$STATIC_SIZE$", format!("{}", static_size)),
67            ("$NORM_SIZE$", format!("{}", norm_size)),
68            ("$EPS$", format!("{}", DisplayCFloat(eps as f64))),
69            ("$ALPHA_0$", format!("{}", DisplayCFloat(alpha_0 as f64))),
70            ("$ALPHA_1$", format!("{}", DisplayCFloat(alpha_1 as f64))),
71            ("$BETA$", format!("{}", DisplayCFloat(beta as f64))),
72            ("$STATIC_DENSE_STRIDES$", c_array_string(&static_dense_strides)),
73            ("$STATIC_STRIDES$", c_nested_array_string(&static_strides)),
74            ("$NORM_STRIDES$", c_array_string(&norm_strides)),
75        ];
76
77        // compile the kernel
78        let source = fill_replacements(LAYERNORM_SOURCE, &replacements);
79        let key = KernelKey {
80            device,
81            source,
82            func_name: "layernorm_kernel".to_owned(),
83        };
84        let function = compile_cached_kernel(key);
85
86        // wrap everything up
87        LayernormKernel {
88            function,
89            input_shape: input_shape.clone(),
90            output_shape: output_shape.clone(),
91            _norm_axis: norm_axis,
92            static_size,
93            _eps: eps,
94            _alpha0: alpha_0,
95            _alpha1: alpha_1,
96            _beta: beta,
97        }
98    }
99
100    pub unsafe fn run(
101        &self,
102        stream: &CudaStream,
103        input0: &DeviceTensor,
104        input1: Option<&DeviceTensor>,
105        output: &DeviceTensor,
106    ) {
107        assert_eq!(input0.strided_shape(), &self.input_shape);
108        if let Some(input1) = input1 {
109            assert_eq!(input1.strided_shape(), &self.input_shape);
110        }
111        assert_eq!(output.strided_shape(), &self.output_shape);
112
113        if self._alpha1 != 0.0 {
114            assert_eq!(input1.is_some(), true);
115        }
116
117        let mut args = KernelArgs::new();
118        args.push(input0.ptr().ptr());
119        args.push(input1.map_or(null_mut(), |x| x.ptr().ptr()));
120        args.push(output.ptr().ptr());
121        let args = args.finish();
122
123        //TODO see if these settings make sense for the typically larger layernorm sizes
124
125        let warps = self.static_size;
126        let warps_per_block = 4;
127        let threads_per_warp = 32;
128
129        let threads_per_block = (threads_per_warp * warps_per_block) as u32;
130        let blocks = ceil_div((warps * threads_per_warp) as u32, threads_per_block as u32);
131
132        self.function
133            .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
134            .unwrap();
135    }
136}