hodu_cuda_kernels 0.2.4

hodu cuda kernels
use crate::{compat::*, cuda::*, error::CudaKernelError, source::Source};

#[derive(Debug, Clone)]
pub enum KernelName {
    Ref(&'static str),
    Value(String),
}

impl AsRef<str> for KernelName {
    fn as_ref(&self) -> &str {
        match self {
            Self::Ref(r) => r,
            Self::Value(v) => v.as_str(),
        }
    }
}

impl Hash for KernelName {
    fn hash<H: Hasher>(&self, state: &mut H) {
        match self {
            Self::Ref(r) => r.hash(state),
            Self::Value(v) => v.hash(state),
        }
    }
}

impl PartialEq for KernelName {
    fn eq(&self, other: &Self) -> bool {
        let v1: &str = self.as_ref();
        let v2: &str = other.as_ref();
        v1 == v2
    }
}

impl Eq for KernelName {}

impl PartialOrd for KernelName {
    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for KernelName {
    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
        let v1: &str = self.as_ref();
        let v2: &str = other.as_ref();
        v1.cmp(v2)
    }
}

impl From<&'static str> for KernelName {
    fn from(value: &'static str) -> Self {
        Self::Ref(value)
    }
}

impl From<String> for KernelName {
    fn from(value: String) -> Self {
        Self::Value(value)
    }
}

type Ptxs = HashMap<Source, Ptx>;
type Modules = HashMap<Source, Arc<CudaModule>>;
type Functions = HashMap<(Source, KernelName), CudaFunction>;

#[derive(Debug)]
pub struct Kernels {
    ptxs: RwLock<Ptxs>,
    modules: RwLock<Modules>,
    functions: RwLock<Functions>,
}

impl Default for Kernels {
    fn default() -> Self {
        Self::new()
    }
}

impl Kernels {
    /// Create a new Kernels instance for managing CUDA kernel compilation and caching
    ///
    /// Each instance maintains its own cache of compiled PTX, modules, and functions.
    /// For multi-device scenarios, create separate Kernels instances per device.
    pub fn new() -> Self {
        let ptxs = RwLock::new(Ptxs::new());
        let modules = RwLock::new(Modules::new());
        let functions = RwLock::new(Functions::new());
        Self {
            ptxs,
            modules,
            functions,
        }
    }

    fn get_source_code(&self, source: Source) -> &'static str {
        match source {
            Source::OpsBinary => crate::source::get_ops_binary(),
            Source::OpsCast => crate::source::get_ops_cast(),
            Source::OpsConcatSplit => crate::source::get_ops_concat_split(),
            Source::OpsConv => crate::source::get_ops_conv(),
            Source::OpsMatrix => crate::source::get_ops_matrix(),
            Source::OpsIndexing => crate::source::get_ops_indexing(),
            Source::OpsReduce => crate::source::get_ops_reduce(),
            Source::OpsUnary => crate::source::get_ops_unary(),
            Source::OpsMemory => crate::source::get_ops_memory(),
            Source::OpsWindowing => crate::source::get_ops_windowing(),
            Source::Storage => crate::source::get_storage(),
        }
    }

    pub fn load_ptx(&self, source: Source) -> Result<Ptx, CudaKernelError> {
        let mut ptxs = self.ptxs.write_compat().map_err(CudaKernelError::Message)?;

        if let Some(ptx) = ptxs.get(&source) {
            return Ok(ptx.clone());
        }

        // Get pre-compiled PTX from source
        let ptx_str = self.get_source_code(source);
        let ptx = Ptx::from_src(ptx_str);

        ptxs.insert(source, ptx.clone());
        Ok(ptx)
    }

    pub fn load_function(
        &self,
        context: &Arc<CudaContext>,
        source: Source,
        name: impl Into<KernelName>,
    ) -> Result<CudaFunction, CudaKernelError> {
        let name = name.into();
        let key = (source, name.clone());

        {
            let functions = self.functions.read_compat().map_err(CudaKernelError::Message)?;
            if let Some(func) = functions.get(&key) {
                return Ok(func.clone());
            }
        }

        // Load or get module
        let module = {
            let modules = self.modules.read_compat().map_err(CudaKernelError::Message)?;
            if let Some(module) = modules.get(&source) {
                module.clone()
            } else {
                drop(modules);
                let ptx = self.load_ptx(source)?;
                let module = context
                    .load_module(ptx)
                    .map_err(|e| CudaKernelError::LaunchError(format!("Failed to load module: {:?}", e)))?;

                let mut modules = self.modules.write_compat().map_err(CudaKernelError::Message)?;
                modules.insert(source, module.clone());
                module
            }
        };

        // Load function from module
        let func = module
            .load_function(name.as_ref())
            .map_err(|e| CudaKernelError::InvalidKernel(format!("Failed to load function: {:?}", e)))?;

        let mut functions = self.functions.write_compat().map_err(CudaKernelError::Message)?;
        functions.insert(key, func.clone());

        Ok(func)
    }
}