burn_dragon_core 0.5.0

burn dragon core model and utilities
Documentation
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;

pub(crate) fn rwkv8_query_weights<B: Backend>(query_t: Tensor<B, 3>) -> Tensor<B, 3> {
    query_t
        .clone()
        .div(query_t.clone().sum_dim(2).add_scalar(1.0e-6).reshape([
            query_t.shape().dims::<3>()[0],
            query_t.shape().dims::<3>()[1],
            1,
        ]))
}

pub(crate) fn rwkv8_read_context_step<B: Backend>(
    rho: Tensor<B, 4>,
    rho_norm: Tensor<B, 3>,
    query_t: Tensor<B, 3>,
) -> Tensor<B, 4> {
    let [batch, heads, latent, embd] = rho.shape().dims::<4>();
    let q_weights = rwkv8_query_weights(query_t);
    (rho.div(
        rho_norm
            .add_scalar(1.0e-6)
            .reshape([batch, heads, latent, 1]),
    ) * q_weights.reshape([batch, heads, latent, 1]))
    .sum_dim(2)
    .reshape([batch, heads, 1, embd])
}

pub(crate) fn rwkv8_update_state_step<B: Backend>(
    rho: Tensor<B, 4>,
    rho_norm: Tensor<B, 3>,
    query_t: Tensor<B, 3>,
    value_t: Tensor<B, 3>,
    decay: Tensor<B, 3>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
    let [batch, heads, latent] = query_t.shape().dims::<3>();
    let n_embd = value_t.shape().dims::<3>()[2];
    let next_rho = rho.mul(decay.clone().reshape([1, heads, latent, 1])).add(
        query_t.clone().reshape([batch, heads, latent, 1])
            * value_t.reshape([batch, heads, 1, n_embd]),
    );
    let next_rho_norm = rho_norm.mul(decay).add(query_t);
    (next_rho, next_rho_norm)
}

pub fn recurrent_rwkv8_state_space_reference<B: Backend>(
    query: Tensor<B, 4>,
    value: Tensor<B, 4>,
    rho_state: Option<Tensor<B, 4>>,
    rho_norm_state: Option<Tensor<B, 3>>,
    decay: Tensor<B, 3>,
) -> (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 3>) {
    let [batch, heads, time, latent] = query.shape().dims();
    let n_embd = value.shape().dims::<4>()[3];
    let device = value.device();
    let decay = decay.reshape([1, heads, latent]);

    let mut rho = match rho_state {
        Some(existing) => {
            let dims = existing.shape().dims::<4>();
            if dims == [batch, heads, latent, n_embd] {
                existing
            } else {
                Tensor::<B, 4>::zeros([batch, heads, latent, n_embd], &device)
            }
        }
        None => Tensor::<B, 4>::zeros([batch, heads, latent, n_embd], &device),
    };

    let mut rho_norm = match rho_norm_state {
        Some(existing) => {
            let dims = existing.shape().dims::<3>();
            if dims == [batch, heads, latent] {
                existing
            } else {
                Tensor::<B, 3>::zeros([batch, heads, latent], &device)
            }
        }
        None => Tensor::<B, 3>::zeros([batch, heads, latent], &device),
    };

    let mut outputs: Vec<Tensor<B, 4>> = Vec::with_capacity(time);

    for t in 0..time {
        let q_t = query
            .clone()
            .slice_dim(2, t..t + 1)
            .reshape([batch, heads, latent]);
        let value_t = value
            .clone()
            .slice_dim(2, t..t + 1)
            .repeat_dim(1, heads)
            .reshape([batch, heads, n_embd]);

        let context_t = rwkv8_read_context_step(rho.clone(), rho_norm.clone(), q_t.clone());
        outputs.push(context_t);

        (rho, rho_norm) = rwkv8_update_state_step(rho, rho_norm, q_t, value_t, decay.clone());
    }

    (Tensor::cat(outputs, 2), rho, rho_norm)
}

#[cfg(test)]
mod tests {
    use super::*;
    use burn::tensor::TensorData;
    use burn::tensor::backend::Backend as BackendTrait;
    use burn_ndarray::NdArray;

    type Backend = NdArray<f32>;

    fn test_inputs(
        device: &<Backend as BackendTrait>::Device,
    ) -> (Tensor<Backend, 4>, Tensor<Backend, 4>, Tensor<Backend, 3>) {
        let batch = 1;
        let heads = 2;
        let time = 5;
        let latent = 4;
        let n_embd = 3;
        let query = Tensor::<Backend, 4>::from_data(
            TensorData::new(
                vec![
                    0.2, 0.3, 0.4, 0.5, 0.3, 0.4, 0.5, 0.6, 0.4, 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.8,
                    0.6, 0.7, 0.8, 0.9, 0.4, 0.3, 0.2, 0.1, 0.5, 0.4, 0.3, 0.2, 0.6, 0.5, 0.4, 0.3,
                    0.7, 0.6, 0.5, 0.4, 0.8, 0.7, 0.6, 0.5,
                ],
                [batch, heads, time, latent],
            ),
            device,
        );
        let value = Tensor::<Backend, 4>::from_data(
            TensorData::new(
                vec![
                    0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.3, 0.4, 0.5, 0.4, 0.5, 0.6, 0.5, 0.6, 0.7,
                ],
                [batch, 1, time, n_embd],
            ),
            device,
        );
        let decay = Tensor::<Backend, 3>::from_data(
            TensorData::new(
                vec![0.97, 0.95, 0.93, 0.91, 0.96, 0.94, 0.92, 0.90],
                [1, heads, latent],
            ),
            device,
        );
        (query, value, decay)
    }

    fn max_abs_4(lhs: Tensor<Backend, 4>, rhs: Tensor<Backend, 4>) -> f32 {
        lhs.sub(rhs).abs().max().into_scalar()
    }

    fn max_abs_3(lhs: Tensor<Backend, 3>, rhs: Tensor<Backend, 3>) -> f32 {
        lhs.sub(rhs).abs().max().into_scalar()
    }

    #[test]
    fn rwkv8_query_weights_sum_to_one() {
        let device = <Backend as BackendTrait>::Device::default();
        let query_t = Tensor::<Backend, 3>::from_data(
            TensorData::new(vec![0.2, 0.3, 0.5, 0.4, 0.1, 0.5], [1, 2, 3]),
            &device,
        );
        let weights = rwkv8_query_weights(query_t);
        let sums = weights
            .sum_dim(2)
            .to_data()
            .convert::<f32>()
            .into_vec::<f32>()
            .unwrap();
        for actual in sums {
            assert!((actual - 1.0).abs() <= 1.0e-6, "weight sum {actual}");
        }
    }

    #[test]
    fn rwkv8_step_helpers_match_single_step_reference() {
        let device = <Backend as BackendTrait>::Device::default();
        let query_t = Tensor::<Backend, 3>::from_data(
            TensorData::new(vec![0.4, 0.6, 0.3, 0.7], [1, 1, 4]),
            &device,
        );
        let value_t = Tensor::<Backend, 3>::from_data(
            TensorData::new(vec![0.25, 0.5, 0.75], [1, 1, 3]),
            &device,
        );
        let rho = Tensor::<Backend, 4>::from_data(
            TensorData::new(
                vec![
                    0.5, 1.0, 1.5, 0.75, 1.25, 1.75, 1.0, 1.5, 2.0, 1.25, 1.75, 2.25,
                ],
                [1, 1, 4, 3],
            ),
            &device,
        );
        let rho_norm = Tensor::<Backend, 3>::from_data(
            TensorData::new(vec![1.2, 0.8, 1.5, 0.9], [1, 1, 4]),
            &device,
        );
        let decay = Tensor::<Backend, 3>::from_data(
            TensorData::new(vec![0.95, 0.9, 0.85, 0.8], [1, 1, 4]),
            &device,
        );

        let helper_context =
            rwkv8_read_context_step(rho.clone(), rho_norm.clone(), query_t.clone());
        let (helper_rho, helper_rho_norm) = rwkv8_update_state_step(
            rho.clone(),
            rho_norm.clone(),
            query_t.clone(),
            value_t.clone(),
            decay.clone(),
        );

        let (reference_context, reference_rho, reference_rho_norm) =
            recurrent_rwkv8_state_space_reference(
                query_t.reshape([1, 1, 1, 4]),
                value_t.reshape([1, 1, 1, 3]),
                Some(rho),
                Some(rho_norm),
                decay,
            );

        assert!(max_abs_4(helper_context, reference_context) <= 1.0e-6);
        assert!(max_abs_4(helper_rho, reference_rho) <= 1.0e-6);
        assert!(max_abs_3(helper_rho_norm, reference_rho_norm) <= 1.0e-6);
    }

    #[test]
    fn rwkv8_reference_step_state_matches_full_sequence() {
        let device = <Backend as BackendTrait>::Device::default();
        let (query, value, decay) = test_inputs(&device);
        let [_, _, time, _] = query.shape().dims::<4>();

        let (full_out, full_rho, full_rho_norm) = recurrent_rwkv8_state_space_reference(
            query.clone(),
            value.clone(),
            None,
            None,
            decay.clone(),
        );

        let mut outputs = Vec::with_capacity(time);
        let mut rho_state = None;
        let mut rho_norm_state = None;
        for step in 0..time {
            let step_query = query.clone().slice_dim(2, step..step + 1);
            let step_value = value.clone().slice_dim(2, step..step + 1);
            let (step_out, next_rho, next_rho_norm) = recurrent_rwkv8_state_space_reference(
                step_query,
                step_value,
                rho_state,
                rho_norm_state,
                decay.clone(),
            );
            outputs.push(step_out);
            rho_state = Some(next_rho);
            rho_norm_state = Some(next_rho_norm);
        }

        let step_out = Tensor::cat(outputs, 2);
        let step_rho = rho_state.expect("rwkv rho state");
        let step_rho_norm = rho_norm_state.expect("rwkv rho_norm state");

        assert!(max_abs_4(step_out, full_out) <= 1.0e-6);
        assert!(max_abs_4(step_rho, full_rho) <= 1.0e-6);
        assert!(max_abs_3(step_rho_norm, full_rho_norm) <= 1.0e-6);
    }

    #[test]
    fn rwkv8_reference_chunked_state_matches_full_sequence() {
        let device = <Backend as BackendTrait>::Device::default();
        let (query, value, decay) = test_inputs(&device);

        let (full_out, full_rho, full_rho_norm) = recurrent_rwkv8_state_space_reference(
            query.clone(),
            value.clone(),
            None,
            None,
            decay.clone(),
        );

        let (prefix_out, prefix_rho, prefix_rho_norm) = recurrent_rwkv8_state_space_reference(
            query.clone().slice_dim(2, 0..2),
            value.clone().slice_dim(2, 0..2),
            None,
            None,
            decay.clone(),
        );
        let (suffix_out, suffix_rho, suffix_rho_norm) = recurrent_rwkv8_state_space_reference(
            query.clone().slice_dim(2, 2..5),
            value.clone().slice_dim(2, 2..5),
            Some(prefix_rho),
            Some(prefix_rho_norm),
            decay.clone(),
        );

        let chunked_out = Tensor::cat(vec![prefix_out, suffix_out], 2);
        assert!(max_abs_4(chunked_out, full_out) <= 1.0e-6);
        assert!(max_abs_4(suffix_rho, full_rho) <= 1.0e-6);
        assert!(max_abs_3(suffix_rho_norm, full_rho_norm) <= 1.0e-6);
    }
}