kapsl-hal 0.1.0

Hardware abstraction layer for CPU, CUDA, and ROCm devices in Kapsl
Documentation
use std::fmt::Debug;
use thiserror::Error;

use crate::tensor::{TensorView, TensorViewMut};

#[derive(Error, Debug)]
pub enum KernelError {
    #[error("Backend error: {0}")]
    BackendError(String),
    #[error("Unsupported operation: {0}")]
    Unsupported(String),
    #[error("Invalid input: {0}")]
    InvalidInput(String),
}

#[derive(Debug, Clone, Copy)]
pub struct AttentionConfig {
    pub scale: Option<f32>,
    pub causal: bool,
}

impl AttentionConfig {
    pub fn scale_for(&self, head_dim: usize) -> f32 {
        self.scale.unwrap_or_else(|| 1.0 / (head_dim as f32).sqrt())
    }
}

pub trait AttentionKernel: Send + Sync + Debug {
    fn flash_attention_v2(
        &self,
        query: &TensorView<f32>,
        key: &TensorView<f32>,
        value: &TensorView<f32>,
        output: &mut TensorViewMut<f32>,
        config: AttentionConfig,
    ) -> Result<(), KernelError>;

    fn paged_attention_v1(
        &self,
        query: &TensorView<f32>,
        key: &TensorView<f32>,
        value: &TensorView<f32>,
        output: &mut TensorViewMut<f32>,
        config: AttentionConfig,
    ) -> Result<(), KernelError>;
}

pub trait MlpKernel: Send + Sync + Debug {
    fn fused_swiglu(
        &self,
        gate: &TensorView<f32>,
        up: &TensorView<f32>,
        out: &mut TensorViewMut<f32>,
    ) -> Result<(), KernelError>;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KernelBackendType {
    Cpu,
    Cuda,
    Rocm,
}

pub trait KernelBackend: Send + Sync + Debug {
    fn backend_type(&self) -> KernelBackendType;

    fn attention(&self) -> Box<dyn AttentionKernel>;
    fn mlp(&self) -> Box<dyn MlpKernel>;
}