burn_dragon_core 0.21.0

burn dragon core model and utilities
Documentation
use super::*;

impl<B: Backend> DragonModel<B> {
    pub(super) fn rollout_executor_mode(&self) -> RolloutExecutorMode {
        if self.sequence_kernel.memory_system == SequenceMemorySystem::LinearAttention
            && self.sequence_kernel.executor == SequenceTrainingExecutor::Reference
            && self.kernel.enabled
            && self.kernel.wgpu_recurrent_kernel
            && self.kernel.wgpu_rollout_fused
            && supports_recurrent_backend::<B>()
        {
            return RolloutExecutorMode::WgpuFused;
        }
        RolloutExecutorMode::HostLoop
    }

    pub(super) fn recurrent_attention_reference(
        &self,
        query: Tensor<B, 4>,
        value: Tensor<B, 4>,
        rho_state: Option<Tensor<B, 4>>,
        decay: Option<Tensor<B, 1>>,
    ) -> (Tensor<B, 4>, Tensor<B, 4>) {
        recurrent_attention_reference(query, value, rho_state, decay)
    }

    pub(super) fn recurrent_attention_dense_score_reference(
        &self,
        query: Tensor<B, 4>,
        value: Tensor<B, 4>,
        rho_state: Option<Tensor<B, 4>>,
        decay: Option<Tensor<B, 1>>,
    ) -> (Tensor<B, 4>, Tensor<B, 4>) {
        recurrent_attention_dense_score_reference(query, value, rho_state, decay)
    }

    pub(super) fn recurrent_attention_dense_score_final_rho_reference(
        &self,
        query: Tensor<B, 4>,
        value: Tensor<B, 4>,
        rho_state: Option<Tensor<B, 4>>,
        decay: Option<Tensor<B, 1>>,
    ) -> Tensor<B, 4> {
        recurrent_attention_dense_score_final_rho_reference(query, value, rho_state, decay)
    }

    pub(super) fn recurrent_attention_dense_score_initial_context_reference(
        &self,
        query: Tensor<B, 4>,
        rho_state: Option<Tensor<B, 4>>,
        decay: Option<Tensor<B, 1>>,
        n_embd: usize,
    ) -> Tensor<B, 4> {
        recurrent_attention_dense_score_initial_context_reference(query, rho_state, decay, n_embd)
    }

    pub(super) fn recurrent_attention_with_plan(
        &self,
        query: Tensor<B, 4>,
        value: Tensor<B, 4>,
        layer_state: &mut LayerState<B>,
        position: usize,
        position_mode: RecurrentPositionMode,
        fused_plan: Option<&CompiledRecurrentAttentionPlan<B>>,
    ) -> Tensor<B, 4> {
        match (
            self.sequence_kernel.memory_system,
            self.sequence_kernel.executor,
        ) {
            (SequenceMemorySystem::LinearAttention, SequenceTrainingExecutor::Reference) => {
                let query = match position_mode {
                    RecurrentPositionMode::Sequential => {
                        self.attention.rotate_positions(query, position)
                    }
                    RecurrentPositionMode::Fixed => {
                        self.attention.rotate_positions_fixed(query, position)
                    }
                };
                let decay = self.attention.alibi_decay();
                let device = query.device();
                let initial_rho = self.resolve_linear_attention_rho_state(layer_state, &device);

                if self.kernel.enabled && self.kernel.wgpu_recurrent_kernel {
                    let fused = if let Some(plan) = fused_plan {
                        try_fused_recurrent_attention_wgpu_with_plan(
                            &query,
                            &value,
                            initial_rho.as_ref(),
                            decay.as_ref(),
                            plan,
                        )
                    } else {
                        try_fused_recurrent_attention_wgpu(
                            &query,
                            &value,
                            initial_rho.as_ref(),
                            decay.as_ref(),
                        )
                    };
                    if let Some(output) = fused {
                        self.write_linear_attention_rho_state(layer_state, output.rho);
                        return output.context;
                    }
                }

                let (context, rho) =
                    self.recurrent_attention_reference(query, value, initial_rho, decay);
                self.write_linear_attention_rho_state(layer_state, rho);
                context
            }
            (
                SequenceMemorySystem::LinearAttention,
                SequenceTrainingExecutor::DenseScoreShortContext,
            ) => {
                let query = match position_mode {
                    RecurrentPositionMode::Sequential => {
                        self.attention.rotate_positions(query, position)
                    }
                    RecurrentPositionMode::Fixed => {
                        self.attention.rotate_positions_fixed(query, position)
                    }
                };
                let decay = self.attention.alibi_decay();
                let device = query.device();
                let initial_rho = self.resolve_linear_attention_rho_state(layer_state, &device);
                if self.kernel.enabled
                    && self.kernel.wgpu_rollout_fused
                    && supports_dense_causal_attention_backend::<B>()
                {
                    let decay_tensor = decay
                        .clone()
                        .unwrap_or_else(|| Tensor::<B, 1>::ones([self.n_head], &device));
                    if let Some(fused_context) =
                        try_fused_dense_causal_attention_wgpu(&query, &value, &decay_tensor)
                    {
                        let initial_context = self
                            .recurrent_attention_dense_score_initial_context_reference(
                                query.clone(),
                                initial_rho.clone(),
                                decay.clone(),
                                value.shape().dims::<4>()[3],
                            );
                        let rho = self.recurrent_attention_dense_score_final_rho_reference(
                            query.clone(),
                            value.clone(),
                            initial_rho.clone(),
                            decay.clone(),
                        );
                        self.write_linear_attention_rho_state(layer_state, rho);
                        return initial_context + fused_context;
                    }
                }
                let (context, rho) = self.recurrent_attention_dense_score_reference(
                    query,
                    value,
                    initial_rho,
                    decay,
                );
                self.write_linear_attention_rho_state(layer_state, rho);
                context
            }
            (
                SequenceMemorySystem::Mamba3StateSpaceDuality,
                SequenceTrainingExecutor::Reference,
            ) => {
                let params = self
                    .mamba
                    .as_ref()
                    .expect("mamba3 sequence family requires initialized mamba params");
                let [batch, views, _time, dim] = value.shape().dims::<4>();
                assert_eq!(views, 1, "Mamba3 expects a single dense stream view");
                assert_eq!(
                    dim, self.n_embd,
                    "Mamba3 dense stream dim {} must match model dim {}",
                    dim, self.n_embd
                );
                let config = self.mamba_config;
                let device = value.device();
                let initial_state = mamba3_state(
                    layer_state,
                    batch,
                    config.nheads,
                    config.headdim,
                    config.d_state,
                    config.num_rope_angles,
                    &device,
                );
                if self.kernel.enabled
                    && config.use_fast_path
                    && use_tensorized_mamba3_forward_experimental()
                {
                    let params = params.mamba3();
                    let output = tensorized_mamba3_forward(
                        value,
                        config.d_inner,
                        config.d_state,
                        config.headdim,
                        config.ngroups,
                        config.num_rope_angles,
                        config.norm_eps,
                        config.a_floor,
                        config.chunk_size,
                        params.in_proj_tensor(),
                        params.dt_bias_tensor(),
                        params.b_bias_tensor(),
                        params.c_bias_tensor(),
                        params.b_norm_weight_tensor(),
                        params.c_norm_weight_tensor(),
                        params.d_skip_tensor(),
                        params.out_proj_tensor(),
                        Some(Mamba3TensorizedState {
                            ssm: initial_state.ssm,
                            angle: initial_state.angle,
                            k: initial_state.k,
                            v: initial_state.v,
                        }),
                    );
                    write_mamba3_state(
                        layer_state,
                        output.state.ssm,
                        output.state.angle,
                        output.state.k,
                        output.state.v,
                    );
                    return output.context;
                }
                let (context, next_state) = mamba_reference(
                    value,
                    params,
                    Some(MambaReferenceState {
                        ssm: initial_state.ssm,
                        angle: Some(initial_state.angle),
                        k: Some(initial_state.k),
                        v: Some(initial_state.v),
                    }),
                );
                write_mamba3_state(
                    layer_state,
                    next_state.ssm,
                    next_state.angle.expect("mamba3 next angle state"),
                    next_state.k.expect("mamba3 next k state"),
                    next_state.v.expect("mamba3 next v state"),
                );
                context
            }
            (family, executor) => panic!(
                "sequence kernel family {:?} with executor {:?} is not implemented in DragonModel yet",
                family, executor
            ),
        }
    }
}