use crate::{Error, OpenCL};
use min_cl::api::{
build_program, create_kernels_in_program, create_program_with_source, Kernel, OCLErrorKind,
};
use std::collections::HashMap;
#[derive(Debug, Default)]
pub struct KernelCacheCL {
pub kernel_cache: HashMap<String, Kernel>,
}
impl KernelCacheCL {
pub fn kernel(&mut self, device: &OpenCL, src: &str) -> Result<&Kernel, Error> {
if self.kernel_cache.contains_key(src) {
return Ok(self.kernel_cache.get(src).unwrap());
}
let program = create_program_with_source(device.ctx(), src)?;
build_program(&program, &[device.device()], Some("-cl-std=CL1.2"))?;
let kernel = create_kernels_in_program(&program)?
.into_iter()
.next()
.ok_or(OCLErrorKind::InvalidKernel)?;
self.kernel_cache.insert(src.to_string(), kernel);
Ok(self.kernel_cache.get(src).unwrap())
}
}
#[cfg(test)]
mod tests {
use super::KernelCacheCL;
use crate::OpenCL;
use std::collections::HashMap;
#[test]
fn test_kernel_cache() -> crate::Result<()> {
let device = OpenCL::new(0)?;
let mut kernel_cache = KernelCacheCL {
kernel_cache: HashMap::new(),
};
let kernel = kernel_cache
.kernel(
&device,
"
__kernel void foo(__global float* test) {}
",
)?
.0;
let same_kernel = kernel_cache
.kernel(
&device,
"
__kernel void foo(__global float* test) {}
",
)?
.0;
assert_eq!(kernel, same_kernel);
let another_kernel = kernel_cache
.kernel(
&device,
"
__kernel void bar(__global float* test, __global float* out) {}
",
)?
.0;
assert_ne!(kernel, another_kernel);
Ok(())
}
}