kn_cuda_eval/autokernel/
common.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use itertools::Itertools;
5use lazy_static::lazy_static;
6
7use kn_cuda_sys::wrapper::handle::CudaDevice;
8use kn_cuda_sys::wrapper::rtc::core::{CuFunction, CuModule};
9
10// TODO cache kernel compilation on disk
11//   * make sure to invalidate old files?
12//   * user-configurable cache dir, either env var or actual code?
13//   * disabled by default to ensure it always works, even on read-only fs
14lazy_static! {
15    static ref KERNEL_CACHE: Mutex<HashMap<KernelKey, CuFunction>> = Mutex::new(HashMap::new());
16    static ref HEADERS: HashMap<&'static str, &'static str> = {
17        let mut map = HashMap::new();
18        map.insert("util.cu", include_str!("util.cu"));
19        map
20    };
21}
22
23#[derive(Debug, Eq, PartialEq, Hash)]
24pub struct KernelKey {
25    pub device: CudaDevice,
26    pub source: String,
27    pub func_name: String,
28}
29
30pub fn compile_cached_kernel(key: KernelKey) -> CuFunction {
31    // keep locked for the duration of compilation
32    let mut cache = KERNEL_CACHE.lock().unwrap();
33
34    let func = cache.entry(key).or_insert_with_key(|key| {
35        let module = CuModule::from_source(key.device, &key.source, None, &[&key.func_name], &HEADERS);
36
37        if !module.log.is_empty() {
38            let source_numbered = module.source_with_line_numbers();
39            eprintln!("Kernel source:\n{}\nLog:\n{}\n", source_numbered, module.log);
40        }
41
42        module.get_function_by_name(&key.func_name).unwrap().unwrap()
43    });
44
45    func.clone()
46}
47
48pub fn fill_replacements(src: &str, replacements: &[(&str, String)]) -> String {
49    let result = replacements.iter().fold(src.to_owned(), |src, (key, value)| {
50        assert!(
51            key.starts_with('$') && key.ends_with('$'),
52            "Key '{}' should start and end with '$'",
53            key
54        );
55        assert!(src.contains(key), "Source does not contain key '{}'", key);
56        src.replace(key, value)
57    });
58
59    if result.contains('$') {
60        eprintln!("Source after replacements:\n{}", result);
61        panic!("Source still contains $");
62    }
63
64    result
65}
66
67pub fn c_nested_array_string(values: &[Vec<isize>]) -> String {
68    assert!(values.len() > 0, "C array cannot be empty");
69    format!("{{{}}}", values.iter().map(|a| c_array_string(a)).join(", "))
70}
71
72pub fn c_array_string(values: &[isize]) -> String {
73    assert!(values.len() > 0, "C array cannot be empty");
74    format!("{{{}}}", values.iter().map(|v| v.to_string()).join(", "))
75}
76
77pub fn ceil_div(x: u32, y: u32) -> u32 {
78    (x + y - 1) / y
79}