1use std::fmt::Debug;
2use thiserror::Error;
3
4use crate::tensor::{TensorView, TensorViewMut};
5
6#[derive(Error, Debug)]
7pub enum KernelError {
8 #[error("Backend error: {0}")]
9 BackendError(String),
10 #[error("Unsupported operation: {0}")]
11 Unsupported(String),
12 #[error("Invalid input: {0}")]
13 InvalidInput(String),
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct AttentionConfig {
18 pub scale: Option<f32>,
19 pub causal: bool,
20}
21
22impl AttentionConfig {
23 pub fn scale_for(&self, head_dim: usize) -> f32 {
24 self.scale.unwrap_or_else(|| 1.0 / (head_dim as f32).sqrt())
25 }
26}
27
28pub trait AttentionKernel: Send + Sync + Debug {
29 fn flash_attention_v2(
30 &self,
31 query: &TensorView<f32>,
32 key: &TensorView<f32>,
33 value: &TensorView<f32>,
34 output: &mut TensorViewMut<f32>,
35 config: AttentionConfig,
36 ) -> Result<(), KernelError>;
37
38 fn paged_attention_v1(
39 &self,
40 query: &TensorView<f32>,
41 key: &TensorView<f32>,
42 value: &TensorView<f32>,
43 output: &mut TensorViewMut<f32>,
44 config: AttentionConfig,
45 ) -> Result<(), KernelError>;
46}
47
48pub trait MlpKernel: Send + Sync + Debug {
49 fn fused_swiglu(
50 &self,
51 gate: &TensorView<f32>,
52 up: &TensorView<f32>,
53 out: &mut TensorViewMut<f32>,
54 ) -> Result<(), KernelError>;
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum KernelBackendType {
59 Cpu,
60 Cuda,
61 Rocm,
62}
63
64pub trait KernelBackend: Send + Sync + Debug {
65 fn backend_type(&self) -> KernelBackendType;
66
67 fn attention(&self) -> Box<dyn AttentionKernel>;
68 fn mlp(&self) -> Box<dyn MlpKernel>;
69}