kn_cuda_eval/autokernel/
reduce.rs1use itertools::Itertools;
2
3use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
4use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
5use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
6use kn_cuda_sys::wrapper::status::Status;
7
8use crate::autokernel::common::{
9 c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
10};
11use crate::device_tensor::DeviceTensor;
12use crate::shape::StridedShape;
13
14#[derive(Debug)]
15pub struct ReduceKernel {
16 _code: ReduceCode,
17 _reduced_axes: Vec<usize>,
18
19 function: CuFunction,
20
21 input_shape: StridedShape,
22 output_shape: StridedShape,
23}
24
25const REDUCE_SOURCE: &str = include_str!("reduce.cu");
26
27#[derive(Debug, Clone)]
28pub struct ReduceCode {
29 pub ty: String,
30 pub identity: String,
31 pub operation: String,
32 pub post_process: String,
33}
34
35impl ReduceKernel {
36 pub fn new(
37 device: CudaDevice,
38 code: ReduceCode,
39 input_shape: &StridedShape,
40 output_shape: &StridedShape,
41 reduced_axes: &[usize],
42 ) -> Self {
43 assert!(
45 reduced_axes.iter().all_unique(),
46 "Reduced axes must be unique, got {:?}",
47 reduced_axes
48 );
49 for &axis in reduced_axes {
50 assert!(
51 axis < input_shape.rank(),
52 "Reduced axis out of bounds for shape {:?}",
53 input_shape
54 );
55 }
56
57 let mut input_kept_shape = vec![];
59 let mut input_reduced_shape = vec![];
60 let mut input_kept_strides = vec![];
61 let mut input_reduced_strides = vec![];
62
63 for axis in 0..input_shape.rank() {
64 let size = input_shape.shape()[axis];
65 let stride = input_shape.strides()[axis];
66
67 if reduced_axes.contains(&axis) {
68 input_reduced_shape.push(size);
69 input_reduced_strides.push(stride);
70 } else {
71 input_kept_shape.push(size);
72 input_kept_strides.push(stride);
73 }
74 }
75
76 let kept_size: usize = input_kept_shape.iter().copied().product();
78 let reduction_size: usize = input_reduced_shape.iter().copied().product();
79
80 assert_eq!(input_kept_shape, output_shape.shape(), "Output shape mismatch");
81 assert_eq!(kept_size, output_shape.size());
82 assert_eq!(input_shape.size(), kept_size * reduction_size);
83
84 let kept_shape_dense = StridedShape::new_simple(input_kept_shape.clone());
86 let reduced_shape_dense = StridedShape::new_simple(input_reduced_shape.clone());
87
88 let mut kept_stides_dense = kept_shape_dense.strides().to_vec();
90 kept_stides_dense.push(0);
91 input_kept_strides.push(0);
92 let mut output_kept_strides = output_shape.strides().to_vec();
93 output_kept_strides.push(0);
94
95 let replacements = vec![
96 ("$KEPT_RANK$", format!("{}", input_kept_shape.len())),
97 ("$REDUCED_RANK$", format!("{}", input_reduced_shape.len())),
98 ("$KEPT_SIZE$", format!("{}", kept_size)),
99 ("$REDUCTION_SIZE$", format!("{}", reduction_size)),
100 ("$KEPT_STRIDES_DENSE$", c_array_string(&kept_stides_dense)),
101 ("$REDUCED_STRIDES_DENSE$", c_array_string(reduced_shape_dense.strides())),
102 (
103 "$KEPT_STRIDES$",
104 c_nested_array_string(&[input_kept_strides, output_kept_strides]),
105 ),
106 ("$REDUCED_STRIDES$", c_array_string(&input_reduced_strides)),
107 ("$TYPE$", code.ty.clone()),
108 ("$IDENTITY$", code.identity.clone()),
109 ("$OPERATION$", code.operation.clone()),
110 ("$POST_PROCESS$", code.post_process.clone()),
111 ];
112
113 let source = fill_replacements(REDUCE_SOURCE, &replacements);
115 let key = KernelKey {
116 device,
117 source,
118 func_name: "reduce_kernel".to_owned(),
119 };
120 let function = compile_cached_kernel(key);
121
122 ReduceKernel {
124 function,
125 _code: code,
126 _reduced_axes: reduced_axes.to_owned(),
127 input_shape: input_shape.clone(),
128 output_shape: output_shape.clone(),
129 }
130 }
131
132 pub unsafe fn run(&self, stream: &CudaStream, input: &DeviceTensor, output: &DeviceTensor) {
133 assert_eq!(input.strided_shape(), &self.input_shape);
134 assert_eq!(output.strided_shape(), &self.output_shape);
135
136 let mut args = KernelArgs::new();
137 args.push(input.ptr().ptr());
138 args.push(output.ptr().ptr());
139 let args = args.finish();
140
141 let warps = self.output_shape.size();
142 let warps_per_block = 16;
144 let threads_per_warp = 32;
145
146 let threads_per_block = (threads_per_warp * warps_per_block) as u32;
147 let blocks = ceil_div((warps * threads_per_warp) as u32, threads_per_block as u32);
148
149 self.function
150 .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
151 .unwrap();
152 }
153}