kn_cuda_eval/autokernel/
reduce.rs

1use itertools::Itertools;
2
3use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
4use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
5use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
6use kn_cuda_sys::wrapper::status::Status;
7
8use crate::autokernel::common::{
9    c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
10};
11use crate::device_tensor::DeviceTensor;
12use crate::shape::StridedShape;
13
14#[derive(Debug)]
15pub struct ReduceKernel {
16    _code: ReduceCode,
17    _reduced_axes: Vec<usize>,
18
19    function: CuFunction,
20
21    input_shape: StridedShape,
22    output_shape: StridedShape,
23}
24
25const REDUCE_SOURCE: &str = include_str!("reduce.cu");
26
27#[derive(Debug, Clone)]
28pub struct ReduceCode {
29    pub ty: String,
30    pub identity: String,
31    pub operation: String,
32    pub post_process: String,
33}
34
35impl ReduceKernel {
36    pub fn new(
37        device: CudaDevice,
38        code: ReduceCode,
39        input_shape: &StridedShape,
40        output_shape: &StridedShape,
41        reduced_axes: &[usize],
42    ) -> Self {
43        // check that axes are unique and in-bounds
44        assert!(
45            reduced_axes.iter().all_unique(),
46            "Reduced axes must be unique, got {:?}",
47            reduced_axes
48        );
49        for &axis in reduced_axes {
50            assert!(
51                axis < input_shape.rank(),
52                "Reduced axis out of bounds for shape {:?}",
53                input_shape
54            );
55        }
56
57        // split strides and shapes
58        let mut input_kept_shape = vec![];
59        let mut input_reduced_shape = vec![];
60        let mut input_kept_strides = vec![];
61        let mut input_reduced_strides = vec![];
62
63        for axis in 0..input_shape.rank() {
64            let size = input_shape.shape()[axis];
65            let stride = input_shape.strides()[axis];
66
67            if reduced_axes.contains(&axis) {
68                input_reduced_shape.push(size);
69                input_reduced_strides.push(stride);
70            } else {
71                input_kept_shape.push(size);
72                input_kept_strides.push(stride);
73            }
74        }
75
76        // check that things make sense
77        let kept_size: usize = input_kept_shape.iter().copied().product();
78        let reduction_size: usize = input_reduced_shape.iter().copied().product();
79
80        assert_eq!(input_kept_shape, output_shape.shape(), "Output shape mismatch");
81        assert_eq!(kept_size, output_shape.size());
82        assert_eq!(input_shape.size(), kept_size * reduction_size);
83
84        // build replacements
85        let kept_shape_dense = StridedShape::new_simple(input_kept_shape.clone());
86        let reduced_shape_dense = StridedShape::new_simple(input_reduced_shape.clone());
87
88        // pad arrays to ensure they never become zero-sized
89        let mut kept_stides_dense = kept_shape_dense.strides().to_vec();
90        kept_stides_dense.push(0);
91        input_kept_strides.push(0);
92        let mut output_kept_strides = output_shape.strides().to_vec();
93        output_kept_strides.push(0);
94
95        let replacements = vec![
96            ("$KEPT_RANK$", format!("{}", input_kept_shape.len())),
97            ("$REDUCED_RANK$", format!("{}", input_reduced_shape.len())),
98            ("$KEPT_SIZE$", format!("{}", kept_size)),
99            ("$REDUCTION_SIZE$", format!("{}", reduction_size)),
100            ("$KEPT_STRIDES_DENSE$", c_array_string(&kept_stides_dense)),
101            ("$REDUCED_STRIDES_DENSE$", c_array_string(reduced_shape_dense.strides())),
102            (
103                "$KEPT_STRIDES$",
104                c_nested_array_string(&[input_kept_strides, output_kept_strides]),
105            ),
106            ("$REDUCED_STRIDES$", c_array_string(&input_reduced_strides)),
107            ("$TYPE$", code.ty.clone()),
108            ("$IDENTITY$", code.identity.clone()),
109            ("$OPERATION$", code.operation.clone()),
110            ("$POST_PROCESS$", code.post_process.clone()),
111        ];
112
113        // compile the kernel
114        let source = fill_replacements(REDUCE_SOURCE, &replacements);
115        let key = KernelKey {
116            device,
117            source,
118            func_name: "reduce_kernel".to_owned(),
119        };
120        let function = compile_cached_kernel(key);
121
122        // wrap everything up
123        ReduceKernel {
124            function,
125            _code: code,
126            _reduced_axes: reduced_axes.to_owned(),
127            input_shape: input_shape.clone(),
128            output_shape: output_shape.clone(),
129        }
130    }
131
132    pub unsafe fn run(&self, stream: &CudaStream, input: &DeviceTensor, output: &DeviceTensor) {
133        assert_eq!(input.strided_shape(), &self.input_shape);
134        assert_eq!(output.strided_shape(), &self.output_shape);
135
136        let mut args = KernelArgs::new();
137        args.push(input.ptr().ptr());
138        args.push(output.ptr().ptr());
139        let args = args.finish();
140
141        let warps = self.output_shape.size();
142        // TODO see what the effect of increasing this is
143        let warps_per_block = 16;
144        let threads_per_warp = 32;
145
146        let threads_per_block = (threads_per_warp * warps_per_block) as u32;
147        let blocks = ceil_div((warps * threads_per_warp) as u32, threads_per_block as u32);
148
149        self.function
150            .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
151            .unwrap();
152    }
153}