kn_cuda_eval/autokernel/
layernorm.rs1use std::ptr::null_mut;
2
3use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
4use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
5use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
6use kn_cuda_sys::wrapper::status::Status;
7use kn_graph::dtype::DisplayCFloat;
8
9use crate::autokernel::common::{
10 c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
11};
12use crate::device_tensor::DeviceTensor;
13use crate::shape::StridedShape;
14
15#[derive(Debug)]
16pub struct LayernormKernel {
17 input_shape: StridedShape,
18 output_shape: StridedShape,
19
20 _norm_axis: usize,
21 static_size: usize,
22
23 _eps: f32,
24 _alpha0: f32,
25 _alpha1: f32,
26 _beta: f32,
27
28 function: CuFunction,
29}
30
31const LAYERNORM_SOURCE: &str = include_str!("layernorm.cu");
32
33impl LayernormKernel {
34 pub fn new(
35 device: CudaDevice,
36 input_shape: &StridedShape,
37 output_shape: &StridedShape,
38 norm_axis: usize,
39 eps: f32,
40 alpha_0: f32,
41 alpha_1: f32,
42 beta: f32,
43 ) -> Self {
44 assert_eq!(input_shape.shape(), output_shape.shape());
45
46 let norm_size = input_shape.shape()[norm_axis];
47 let static_size = input_shape.size() / norm_size;
48
49 let input_static = input_shape.remove(norm_axis);
50 let output_static = output_shape.remove(norm_axis);
51
52 let static_dense = StridedShape::new_simple(input_static.shape().to_vec());
53
54 let mut static_strides = [input_static.strides().to_vec(), output_static.strides().to_vec()];
55 let mut static_dense_strides = static_dense.strides().to_vec();
56
57 let norm_strides = [input_shape.strides()[norm_axis], output_shape.strides()[norm_axis]];
58
59 static_strides[0].push(0);
61 static_strides[1].push(0);
62 static_dense_strides.push(1);
63
64 let replacements = vec![
65 ("$RANK$", format!("{}", input_shape.rank())),
66 ("$STATIC_SIZE$", format!("{}", static_size)),
67 ("$NORM_SIZE$", format!("{}", norm_size)),
68 ("$EPS$", format!("{}", DisplayCFloat(eps as f64))),
69 ("$ALPHA_0$", format!("{}", DisplayCFloat(alpha_0 as f64))),
70 ("$ALPHA_1$", format!("{}", DisplayCFloat(alpha_1 as f64))),
71 ("$BETA$", format!("{}", DisplayCFloat(beta as f64))),
72 ("$STATIC_DENSE_STRIDES$", c_array_string(&static_dense_strides)),
73 ("$STATIC_STRIDES$", c_nested_array_string(&static_strides)),
74 ("$NORM_STRIDES$", c_array_string(&norm_strides)),
75 ];
76
77 let source = fill_replacements(LAYERNORM_SOURCE, &replacements);
79 let key = KernelKey {
80 device,
81 source,
82 func_name: "layernorm_kernel".to_owned(),
83 };
84 let function = compile_cached_kernel(key);
85
86 LayernormKernel {
88 function,
89 input_shape: input_shape.clone(),
90 output_shape: output_shape.clone(),
91 _norm_axis: norm_axis,
92 static_size,
93 _eps: eps,
94 _alpha0: alpha_0,
95 _alpha1: alpha_1,
96 _beta: beta,
97 }
98 }
99
100 pub unsafe fn run(
101 &self,
102 stream: &CudaStream,
103 input0: &DeviceTensor,
104 input1: Option<&DeviceTensor>,
105 output: &DeviceTensor,
106 ) {
107 assert_eq!(input0.strided_shape(), &self.input_shape);
108 if let Some(input1) = input1 {
109 assert_eq!(input1.strided_shape(), &self.input_shape);
110 }
111 assert_eq!(output.strided_shape(), &self.output_shape);
112
113 if self._alpha1 != 0.0 {
114 assert_eq!(input1.is_some(), true);
115 }
116
117 let mut args = KernelArgs::new();
118 args.push(input0.ptr().ptr());
119 args.push(input1.map_or(null_mut(), |x| x.ptr().ptr()));
120 args.push(output.ptr().ptr());
121 let args = args.finish();
122
123 let warps = self.static_size;
126 let warps_per_block = 4;
127 let threads_per_warp = 32;
128
129 let threads_per_block = (threads_per_warp * warps_per_block) as u32;
130 let blocks = ceil_div((warps * threads_per_warp) as u32, threads_per_block as u32);
131
132 self.function
133 .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
134 .unwrap();
135 }
136}