kn_cuda_eval/autokernel/
common.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use itertools::Itertools;
5use lazy_static::lazy_static;
6
7use kn_cuda_sys::wrapper::handle::CudaDevice;
8use kn_cuda_sys::wrapper::rtc::core::{CuFunction, CuModule};
9
10lazy_static! {
15 static ref KERNEL_CACHE: Mutex<HashMap<KernelKey, CuFunction>> = Mutex::new(HashMap::new());
16 static ref HEADERS: HashMap<&'static str, &'static str> = {
17 let mut map = HashMap::new();
18 map.insert("util.cu", include_str!("util.cu"));
19 map
20 };
21}
22
23#[derive(Debug, Eq, PartialEq, Hash)]
24pub struct KernelKey {
25 pub device: CudaDevice,
26 pub source: String,
27 pub func_name: String,
28}
29
30pub fn compile_cached_kernel(key: KernelKey) -> CuFunction {
31 let mut cache = KERNEL_CACHE.lock().unwrap();
33
34 let func = cache.entry(key).or_insert_with_key(|key| {
35 let module = CuModule::from_source(key.device, &key.source, None, &[&key.func_name], &HEADERS);
36
37 if !module.log.is_empty() {
38 let source_numbered = module.source_with_line_numbers();
39 eprintln!("Kernel source:\n{}\nLog:\n{}\n", source_numbered, module.log);
40 }
41
42 module.get_function_by_name(&key.func_name).unwrap().unwrap()
43 });
44
45 func.clone()
46}
47
48pub fn fill_replacements(src: &str, replacements: &[(&str, String)]) -> String {
49 let result = replacements.iter().fold(src.to_owned(), |src, (key, value)| {
50 assert!(
51 key.starts_with('$') && key.ends_with('$'),
52 "Key '{}' should start and end with '$'",
53 key
54 );
55 assert!(src.contains(key), "Source does not contain key '{}'", key);
56 src.replace(key, value)
57 });
58
59 if result.contains('$') {
60 eprintln!("Source after replacements:\n{}", result);
61 panic!("Source still contains $");
62 }
63
64 result
65}
66
67pub fn c_nested_array_string(values: &[Vec<isize>]) -> String {
68 assert!(values.len() > 0, "C array cannot be empty");
69 format!("{{{}}}", values.iter().map(|a| c_array_string(a)).join(", "))
70}
71
72pub fn c_array_string(values: &[isize]) -> String {
73 assert!(values.len() > 0, "C array cannot be empty");
74 format!("{{{}}}", values.iter().map(|v| v.to_string()).join(", "))
75}
76
77pub fn ceil_div(x: u32, y: u32) -> u32 {
78 (x + y - 1) / y
79}