kn_cuda_eval/autokernel/
scalar.rs1use std::fmt::Write;
2
3use itertools::zip_eq;
4
5use kn_cuda_sys::wrapper::handle::{CudaDevice, CudaStream};
6use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
7use kn_cuda_sys::wrapper::rtc::core::{CuFunction, Dim3};
8use kn_cuda_sys::wrapper::status::Status;
9
10use crate::autokernel::common::{
11 c_array_string, c_nested_array_string, ceil_div, compile_cached_kernel, fill_replacements, KernelKey,
12};
13use crate::device_tensor::DeviceTensor;
14use crate::shape::StridedShape;
15
16#[derive(Debug)]
19pub struct ScalarKernel {
20 #[allow(dead_code)]
21 operation: String,
22
23 inner_size: usize,
24 inner_shape: Vec<usize>,
25 operand_types: Vec<String>,
26 operand_strides: Vec<Vec<isize>>,
27
28 function: CuFunction,
29}
30
31const SCALAR_SOURCE: &str = include_str!("scalar.cu");
32
33impl ScalarKernel {
34 pub fn new(
38 device: CudaDevice,
39 operation: &str,
40 inner_shape: Vec<usize>,
41 operand_types: Vec<String>,
42 operand_strides: Vec<Vec<isize>>,
43 ) -> Self {
44 assert!(operand_types.len() > 0);
46 assert_eq!(operand_strides.len(), operand_types.len());
47 for stride in &operand_strides {
48 assert_eq!(stride.len(), inner_shape.len() + 1);
49 }
50
51 assert!(
52 operation.trim_end().ends_with(";"),
53 "Operation should end with ';', got {:?}",
54 operation
55 );
56
57 let full_operation = build_operation(&operand_types, operation);
58
59 let mut full_shape = vec![0];
60 full_shape.extend_from_slice(&inner_shape);
61
62 let dense = StridedShape::new_simple(full_shape.to_vec());
63
64 let replacements = vec![
65 ("$RANK$", format!("{}", dense.rank())),
66 ("$OPERANDS$", format!("{}", operand_strides.len())),
67 ("$STRIDES_DENSE$", c_array_string(dense.strides())),
68 ("$STRIDES$", c_nested_array_string(&operand_strides)),
69 ("$OPERATION$", full_operation.to_owned()),
70 ];
71 let source = fill_replacements(SCALAR_SOURCE, &replacements);
72
73 let key = KernelKey {
74 device,
75 source,
76 func_name: "scalar_kernel".to_owned(),
77 };
78
79 let function = compile_cached_kernel(key);
80 let inner_size = inner_shape.iter().product();
81
82 ScalarKernel {
83 operation: operation.to_owned(),
84 function,
85 inner_size,
86 inner_shape,
87 operand_types,
88 operand_strides,
89 }
90 }
91
92 pub fn new_for_shapes(
94 device: CudaDevice,
95 operation: &str,
96 shapes: &[StridedShape],
97 operand_types: Vec<String>,
98 ) -> Self {
99 assert!(shapes.len() > 0);
100 let expected_shape = shapes[0].shape();
101 assert!(expected_shape.len() > 0);
102
103 for shape in shapes {
104 assert_eq!(shape.shape(), expected_shape);
105 }
106
107 let inner_shape = shapes[0].shape()[1..].to_vec();
108 let operand_strides = shapes.iter().map(|s| s.strides().to_vec()).collect();
109
110 Self::new(device, operation, inner_shape, operand_types, operand_strides)
111 }
112
113 pub unsafe fn run(&self, stream: &CudaStream, tensors: &[DeviceTensor]) {
114 let items_per_thread = 64;
115 let threads_per_block = 64;
116 self.run_custom(stream, tensors, items_per_thread, threads_per_block);
117 }
118
119 pub unsafe fn run_custom(
120 &self,
121 stream: &CudaStream,
122 tensors: &[DeviceTensor],
123 items_per_thread: u32,
124 threads_per_block: u32,
125 ) {
126 assert_eq!(tensors.len(), self.operand_types.len());
127 let batch_size = tensors[0].strided_shape().shape()[0];
130
131 let mut args = KernelArgs::new();
132 args.push_int(batch_size as i32);
133
134 for (expected_strides, tensor) in zip_eq(&self.operand_strides, tensors) {
135 assert_eq!(1 + self.inner_shape.len(), tensor.strided_shape().rank());
136 assert_eq!(batch_size, tensor.strided_shape().shape()[0]);
137 assert_eq!(self.inner_shape, tensor.strided_shape().shape()[1..]);
138 assert_eq!(expected_strides, tensor.strided_shape().strides());
139
140 args.push(tensor.ptr().ptr());
141 }
142
143 let args = args.finish();
144
145 let items = batch_size * self.inner_size;
146
147 let blocks = ceil_div(items as u32, items_per_thread * threads_per_block);
148
149 self.function
151 .launch_kernel(Dim3::single(blocks), Dim3::single(threads_per_block), 0, &stream, &args)
152 .unwrap();
153 }
154}
155
156fn build_operation(operand_types: &[String], operation: &str) -> String {
157 let mut full_operation = String::new();
158 let f = &mut full_operation;
159
160 for (i, ty) in operand_types.iter().enumerate() {
161 writeln!(
162 f,
163 "{ty} *x{i} = &(({ty} *) pointers[{i}])[offsets[{i}]];",
164 ty = ty,
165 i = i
166 )
167 .unwrap();
168 }
169 writeln!(f, "{}", operation).unwrap();
170
171 full_operation
172}