Skip to main content

AttentionKernel

Trait AttentionKernel 

Source
pub trait AttentionKernel:
    Send
    + Sync
    + Debug {
    // Required methods
    fn flash_attention_v2(
        &self,
        query: &TensorView<'_, f32>,
        key: &TensorView<'_, f32>,
        value: &TensorView<'_, f32>,
        output: &mut TensorViewMut<'_, f32>,
        config: AttentionConfig,
    ) -> Result<(), KernelError>;
    fn paged_attention_v1(
        &self,
        query: &TensorView<'_, f32>,
        key: &TensorView<'_, f32>,
        value: &TensorView<'_, f32>,
        output: &mut TensorViewMut<'_, f32>,
        config: AttentionConfig,
    ) -> Result<(), KernelError>;
}

Required Methods§

Source

fn flash_attention_v2( &self, query: &TensorView<'_, f32>, key: &TensorView<'_, f32>, value: &TensorView<'_, f32>, output: &mut TensorViewMut<'_, f32>, config: AttentionConfig, ) -> Result<(), KernelError>

Source

fn paged_attention_v1( &self, query: &TensorView<'_, f32>, key: &TensorView<'_, f32>, value: &TensorView<'_, f32>, output: &mut TensorViewMut<'_, f32>, config: AttentionConfig, ) -> Result<(), KernelError>

Implementors§