Skip to main content

kapsl_hal/
kernel.rs

1use std::fmt::Debug;
2use thiserror::Error;
3
4use crate::tensor::{TensorView, TensorViewMut};
5
6#[derive(Error, Debug)]
7pub enum KernelError {
8    #[error("Backend error: {0}")]
9    BackendError(String),
10    #[error("Unsupported operation: {0}")]
11    Unsupported(String),
12    #[error("Invalid input: {0}")]
13    InvalidInput(String),
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct AttentionConfig {
18    pub scale: Option<f32>,
19    pub causal: bool,
20}
21
22impl AttentionConfig {
23    pub fn scale_for(&self, head_dim: usize) -> f32 {
24        self.scale.unwrap_or_else(|| 1.0 / (head_dim as f32).sqrt())
25    }
26}
27
28pub trait AttentionKernel: Send + Sync + Debug {
29    fn flash_attention_v2(
30        &self,
31        query: &TensorView<f32>,
32        key: &TensorView<f32>,
33        value: &TensorView<f32>,
34        output: &mut TensorViewMut<f32>,
35        config: AttentionConfig,
36    ) -> Result<(), KernelError>;
37
38    fn paged_attention_v1(
39        &self,
40        query: &TensorView<f32>,
41        key: &TensorView<f32>,
42        value: &TensorView<f32>,
43        output: &mut TensorViewMut<f32>,
44        config: AttentionConfig,
45    ) -> Result<(), KernelError>;
46}
47
48pub trait MlpKernel: Send + Sync + Debug {
49    fn fused_swiglu(
50        &self,
51        gate: &TensorView<f32>,
52        up: &TensorView<f32>,
53        out: &mut TensorViewMut<f32>,
54    ) -> Result<(), KernelError>;
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum KernelBackendType {
59    Cpu,
60    Cuda,
61    Rocm,
62}
63
64pub trait KernelBackend: Send + Sync + Debug {
65    fn backend_type(&self) -> KernelBackendType;
66
67    fn attention(&self) -> Box<dyn AttentionKernel>;
68    fn mlp(&self) -> Box<dyn MlpKernel>;
69}