gllm 0.10.6

Pure Rust library for local embeddings, reranking, and text generation with MoE-optimized inference and aggressive performance tuning
Documentation
use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, TensorData};

#[derive(Debug, Clone, Copy)]
pub struct FlashAttentionConfig {
    pub block_q: usize,
    pub block_kv: usize,
}

impl Default for FlashAttentionConfig {
    fn default() -> Self {
        Self {
            block_q: 64,
            block_kv: 64,
        }
    }
}

pub trait FlashAttention<B: Backend> {
    fn forward(
        &self,
        q: Tensor<B, 4>,
        k: Tensor<B, 4>,
        v: Tensor<B, 4>,
        causal: bool,
    ) -> Tensor<B, 4>;
}

impl<B: Backend> FlashAttention<B> for FlashAttentionConfig {
    fn forward(
        &self,
        q: Tensor<B, 4>,
        k: Tensor<B, 4>,
        v: Tensor<B, 4>,
        causal: bool,
    ) -> Tensor<B, 4> {
        let key_len = k.dims()[2];
        flash_attention_forward(q, k, v, causal, 0, key_len, None, *self)
    }
}

pub fn flash_attention_forward<B: Backend>(
    q: Tensor<B, 4>,
    k: Tensor<B, 4>,
    v: Tensor<B, 4>,
    causal: bool,
    position_offset: usize,
    key_len: usize,
    sliding_window: Option<usize>,
    config: FlashAttentionConfig,
) -> Tensor<B, 4> {
    let device = q.device();
    let [batch_size, num_heads, query_len, head_dim] = q.dims();
    let key_len = key_len.min(k.dims()[2]).min(v.dims()[2]);

    if query_len == 0 || key_len == 0 {
        return Tensor::zeros([batch_size, num_heads, query_len, head_dim], &device);
    }

    let block_q = config.block_q.max(1);
    let block_kv = config.block_kv.max(1);
    let window = sliding_window.unwrap_or(key_len);
    let scale = (head_dim as f32).sqrt();

    let mut outputs = Vec::new();
    let mut q_start = 0usize;
    while q_start < query_len {
        let q_end = (q_start + block_q).min(query_len);
        let q_block_len = q_end - q_start;
        let q_block = q
            .clone()
            .slice([0..batch_size, 0..num_heads, q_start..q_end, 0..head_dim]);

        let mut m_i = Tensor::<B, 4>::full(
            [batch_size, num_heads, q_block_len, 1],
            f32::NEG_INFINITY,
            &device,
        );
        let mut l_i =
            Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, 1], &device);
        let mut o_i =
            Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, head_dim], &device);

        let mut kv_start = 0usize;
        while kv_start < key_len {
            let kv_end = (kv_start + block_kv).min(key_len);
            let kv_block_len = kv_end - kv_start;

            let k_block = k.clone().slice([
                0..batch_size,
                0..num_heads,
                kv_start..kv_end,
                0..head_dim,
            ]);
            let v_block = v.clone().slice([
                0..batch_size,
                0..num_heads,
                kv_start..kv_end,
                0..head_dim,
            ]);

            let mut scores = q_block.clone().matmul(k_block.transpose()) / scale;

            if causal {
                let mask = build_block_causal_mask(
                    &device,
                    q_block_len,
                    kv_block_len,
                    q_start,
                    kv_start,
                    position_offset,
                    window,
                );
                scores = scores + mask;
            }

            let m_ij = scores.clone().max_dim(3);
            let m_new = m_i.clone().max_pair(m_ij);
            let p_ij = (scores - m_new.clone()).exp();
            let m_scale = (m_i.clone() - m_new.clone()).exp();

            let l_new = m_scale.clone() * l_i + p_ij.clone().sum_dim(3);
            let o_new = m_scale * o_i + p_ij.matmul(v_block);

            m_i = m_new;
            l_i = l_new;
            o_i = o_new;
            kv_start = kv_end;
        }

        outputs.push(o_i / l_i);
        q_start = q_end;
    }

    Tensor::cat(outputs, 2)
}

fn build_block_causal_mask<B: Backend>(
    device: &B::Device,
    query_len: usize,
    key_len: usize,
    q_start: usize,
    kv_start: usize,
    position_offset: usize,
    window: usize,
) -> Tensor<B, 4> {
    let mut data = Vec::with_capacity(query_len * key_len);
    let mask_value = -1.0e4_f32;

    for i in 0..query_len {
        let absolute_pos = position_offset + q_start + i;
        let start = if window > 0 {
            absolute_pos.saturating_sub(window.saturating_sub(1))
        } else {
            0
        };
        for j in 0..key_len {
            let absolute_key = kv_start + j;
            let allowed = absolute_key <= absolute_pos && absolute_key >= start;
            data.push(if allowed { 0.0 } else { mask_value });
        }
    }

    Tensor::<B, 2>::from_data(TensorData::new(data, [query_len, key_len]), device)
        .reshape([1, 1, query_len, key_len])
}

#[cfg(test)]
mod tests {
    use super::*;
    use burn::backend::ndarray::NdArray;
    use burn::tensor::activation::softmax;

    fn build_full_causal_mask<B: Backend>(
        device: &B::Device,
        query_len: usize,
        key_len: usize,
        position_offset: usize,
        sliding_window: Option<usize>,
    ) -> Tensor<B, 4> {
        let window = sliding_window.unwrap_or(key_len);
        build_block_causal_mask(
            device,
            query_len,
            key_len,
            0,
            0,
            position_offset,
            window,
        )
    }

    #[test]
    fn test_flash_attention_matches_standard() {
        let device = <NdArray<f32> as Backend>::Device::default();
        let batch_size = 1;
        let num_heads = 2;
        let seq_len = 5;
        let head_dim = 4;
        let total = batch_size * num_heads * seq_len * head_dim;

        let q_data: Vec<f32> = (0..total)
            .map(|i| (i as f32 % 11.0) * 0.05)
            .collect();
        let k_data: Vec<f32> = (0..total)
            .map(|i| ((i + 3) as f32 % 13.0) * 0.04)
            .collect();
        let v_data: Vec<f32> = (0..total)
            .map(|i| ((i + 7) as f32 % 17.0) * 0.03)
            .collect();

        let q = Tensor::<NdArray<f32>, 4>::from_data(
            TensorData::new(q_data, [batch_size, num_heads, seq_len, head_dim]),
            &device,
        );
        let k = Tensor::<NdArray<f32>, 4>::from_data(
            TensorData::new(k_data, [batch_size, num_heads, seq_len, head_dim]),
            &device,
        );
        let v = Tensor::<NdArray<f32>, 4>::from_data(
            TensorData::new(v_data, [batch_size, num_heads, seq_len, head_dim]),
            &device,
        );

        let config = FlashAttentionConfig {
            block_q: 2,
            block_kv: 3,
        };
        let output = flash_attention_forward(q.clone(), k.clone(), v.clone(), false, 0, seq_len, None, config);

        let scale = (head_dim as f32).sqrt();
        let scores = q.matmul(k.transpose()) / scale;
        let attn = softmax(scores, 3);
        let expected = attn.matmul(v);

        let output_data = output
            .into_data()
            .into_vec::<f32>()
            .expect("flash attention output should be float data");
        let expected_data = expected
            .into_data()
            .into_vec::<f32>()
            .expect("reference attention output should be float data");

        for (idx, (left, right)) in output_data.iter().zip(expected_data.iter()).enumerate() {
            let diff = (left - right).abs();
            assert!(
                diff < 1e-3,
                "mismatch at index {idx}: {left} vs {right} (diff {diff})"
            );
        }
    }

    #[test]
    fn test_flash_attention_causal_mask() {
        let device = <NdArray<f32> as Backend>::Device::default();
        let batch_size = 1;
        let num_heads = 1;
        let seq_len = 4;
        let head_dim = 3;
        let total = batch_size * num_heads * seq_len * head_dim;

        let q_data: Vec<f32> = (0..total)
            .map(|i| (i as f32 % 7.0) * 0.07)
            .collect();
        let k_data: Vec<f32> = (0..total)
            .map(|i| ((i + 5) as f32 % 9.0) * 0.06)
            .collect();
        let v_data: Vec<f32> = (0..total)
            .map(|i| ((i + 11) as f32 % 11.0) * 0.05)
            .collect();

        let q = Tensor::<NdArray<f32>, 4>::from_data(
            TensorData::new(q_data, [batch_size, num_heads, seq_len, head_dim]),
            &device,
        );
        let k = Tensor::<NdArray<f32>, 4>::from_data(
            TensorData::new(k_data, [batch_size, num_heads, seq_len, head_dim]),
            &device,
        );
        let v = Tensor::<NdArray<f32>, 4>::from_data(
            TensorData::new(v_data, [batch_size, num_heads, seq_len, head_dim]),
            &device,
        );

        let config = FlashAttentionConfig {
            block_q: 2,
            block_kv: 2,
        };
        let output = flash_attention_forward(q.clone(), k.clone(), v.clone(), true, 0, seq_len, None, config);

        let scale = (head_dim as f32).sqrt();
        let scores = q.matmul(k.transpose()) / scale;
        let mask = build_full_causal_mask(&device, seq_len, seq_len, 0, None);
        let attn = softmax(scores + mask, 3);
        let expected = attn.matmul(v);

        let output_data = output
            .into_data()
            .into_vec::<f32>()
            .expect("flash attention output should be float data");
        let expected_data = expected
            .into_data()
            .into_vec::<f32>()
            .expect("reference attention output should be float data");

        for (idx, (left, right)) in output_data.iter().zip(expected_data.iter()).enumerate() {
            let diff = (left - right).abs();
            assert!(
                diff < 1e-3,
                "mismatch at index {idx}: {left} vs {right} (diff {diff})"
            );
        }
    }
}