cortex_rust 0.6.0

High-performance LLM inference with 4-bit quantization and Test-Time Training (TTT)
Documentation
//! BitLlamaBlock - Transformer block with TTT + MLP

use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use std::collections::HashMap;

use crate::layers::{KVCache, RMSNorm, SwiGLU, TTTLayer};
use crate::model::config::{BitLlamaConfig, ModelArch};

/// Epsilon for RMSNorm
const RMS_NORM_EPS: f64 = 1e-5;

/// Enum to dispatch between TTT and Attention layers
pub enum LayerDispatch {
    TTT(Box<TTTLayer>),
    Attention(Box<crate::layers::BitAttention>),
}

/// Single transformer block: TTT/Attn + MLP with residual connections
pub struct BitLlamaBlock {
    pub norm1: RMSNorm,
    pub core: LayerDispatch,
    pub norm2: RMSNorm,
    pub mlp: SwiGLU,
}

impl BitLlamaBlock {
    /// Load directly from pre-loaded tensor HashMap (bypasses VarBuilder).
    /// This avoids U8→F32 conversion for packed tensors.
    pub fn load_direct(
        tensors: &HashMap<String, Tensor>,
        prefix: &str, // e.g., "model.layers.0"
        cfg: &BitLlamaConfig,
        device: &candle_core::Device,
    ) -> Result<Self> {
        let dim = cfg.hidden_dim;

        // Try HF keys first, then legacy
        let norm1_key = if tensors.contains_key(&format!("{}.input_layernorm.weight", prefix)) {
            format!("{}.input_layernorm.weight", prefix)
        } else {
            format!("{}.norm1.weight", prefix)
        };

        let norm2_key =
            if tensors.contains_key(&format!("{}.post_attention_layernorm.weight", prefix)) {
                format!("{}.post_attention_layernorm.weight", prefix)
            } else {
                format!("{}.norm2.weight", prefix)
            };

        let norm1 = RMSNorm::load_direct(tensors, &norm1_key, dim, RMS_NORM_EPS, device)?;
        let norm2 = RMSNorm::load_direct(tensors, &norm2_key, dim, RMS_NORM_EPS, device)?;

        let mlp_dim = cfg.intermediate_dim.unwrap_or(dim * 4);
        let mlp = SwiGLU::load_direct(
            tensors,
            &format!("{}.mlp", prefix),
            dim,
            mlp_dim,
            device,
            &cfg.quantization,
        )?;

        // Dispatch Layer Loading based on Config
        let core = match cfg.arch {
            ModelArch::TTT => {
                let ttt = TTTLayer::load_direct(
                    tensors,
                    &format!("{}.ttt", prefix),
                    dim,
                    cfg.inner_lr,
                    device,
                    &cfg.quantization,
                )?;
                LayerDispatch::TTT(Box::new(ttt))
            }
            ModelArch::Llama | ModelArch::Gemma | ModelArch::Gemma2 => {
                let attn = crate::layers::BitAttention::load_direct(
                    tensors,
                    &format!("{}.self_attn", prefix),
                    dim,
                    cfg.n_heads,
                    cfg.n_kv_heads,
                    cfg.rope_theta,
                    cfg.max_position_embeddings,
                    device,
                    &cfg.quantization,
                )?;
                LayerDispatch::Attention(Box::new(attn))
            }
        };

        Ok(Self {
            norm1,
            core,
            norm2,
            mlp,
        })
    }

    pub fn load(
        cfg: &BitLlamaConfig,
        vb: VarBuilder, // Usually "layers.N" or "model.layers.N"
        device: &candle_core::Device,
    ) -> Result<Self> {
        let dim = cfg.hidden_dim;
        let norm1 = RMSNorm::load(dim, RMS_NORM_EPS, vb.pp("norm1").pp("model.norm"), device)
            .or_else(|_| RMSNorm::load(dim, RMS_NORM_EPS, vb.pp("norm1"), device))
            .or_else(|_| RMSNorm::load(dim, RMS_NORM_EPS, vb.pp("input_layernorm"), device))?;

        let norm2 = RMSNorm::load(dim, RMS_NORM_EPS, vb.pp("norm2").pp("model.norm"), device)
            .or_else(|_| RMSNorm::load(dim, RMS_NORM_EPS, vb.pp("norm2"), device))
            .or_else(|_| {
                RMSNorm::load(dim, RMS_NORM_EPS, vb.pp("post_attention_layernorm"), device)
            })?;

        let mlp_dim = cfg.intermediate_dim.unwrap_or(dim * 4);
        let mlp = SwiGLU::load(dim, mlp_dim, vb.pp("mlp"), device)?;

        // Dispatch Layer Loading based on Config
        let core = match cfg.arch {
            ModelArch::TTT => {
                let ttt = TTTLayer::load(dim, cfg.inner_lr, vb.pp("ttt"), device)?;
                LayerDispatch::TTT(Box::new(ttt))
            }
            ModelArch::Llama | ModelArch::Gemma | ModelArch::Gemma2 => {
                let attn = crate::layers::BitAttention::load(
                    dim,
                    cfg.n_heads,
                    cfg.n_kv_heads,
                    cfg.rope_theta,
                    cfg.max_position_embeddings,
                    vb.pp("self_attn"),
                    device,
                )?;
                LayerDispatch::Attention(Box::new(attn))
            }
        };

        Ok(Self {
            norm1,
            core,
            norm2,
            mlp,
        })
    }

    pub fn device(&self) -> &candle_core::Device {
        self.norm1.weight.device()
    }

    pub fn precompute_packed(&mut self) -> Result<()> {
        match &mut self.core {
            LayerDispatch::TTT(t) => t.precompute_packed()?,
            LayerDispatch::Attention(_) => {} // No precompute needed yet
        }
        self.mlp.precompute_packed()?;
        Ok(())
    }

    pub fn forward(
        &self,
        x: &Tensor,
        w_state: &Tensor,
        kv_cache: &mut Option<KVCache>,
        pos: usize,
    ) -> Result<(Tensor, Tensor)> {
        // NOTE: Device transfer is now handled by llama.rs::forward_one
        // using stored gpu_device/cpu_device for correct layer-to-device mapping

        let residual = x;
        let x_norm = self.norm1.forward(x)?;

        let (mixed_out, w_new) = match &self.core {
            LayerDispatch::TTT(t) => {
                // TTT Path: uses w_state, ignores kv_cache/pos
                t.forward_update(w_state, &x_norm)?
            }
            LayerDispatch::Attention(a) => {
                // Attention Path: uses kv_cache/pos, ignores w_state (passthrough)
                let out = a.forward(&x_norm, kv_cache, pos)?;
                (out, w_state.clone())
            }
        };

        // [Hybrid Guard] Ensure mixed output is on same device as residual before adding
        let mixed_out = if mixed_out.device().same_device(residual.device()) {
            mixed_out
        } else {
            mixed_out.to_device(residual.device())?
        };

        let x_mid = (residual + mixed_out)?;
        let residual = &x_mid;
        let x_norm2 = self.norm2.forward(&x_mid)?;
        let mlp_out = self.mlp.forward(&x_norm2)?;

        // [Hybrid Guard] Ensure MLP output is on same device as residual before adding
        let mlp_out = if mlp_out.device().same_device(residual.device()) {
            mlp_out
        } else {
            mlp_out.to_device(residual.device())?
        };

        let x_out = (residual + mlp_out)?;

        Ok((x_out, w_new))
    }

    pub fn forward_chunkwise(
        &self,
        x: &Tensor,
        w_state: &Tensor,
        chunk_size: usize,
    ) -> Result<(Tensor, Tensor)> {
        let residual = x;
        let x_norm = self.norm1.forward(x)?;

        let (mixed_out, w_final) = match &self.core {
            LayerDispatch::TTT(t) => t.forward_chunkwise(w_state, &x_norm, chunk_size)?,
            LayerDispatch::Attention(a) => {
                // Critical: Chunkwise training for Attention not yet supported with KV Cache state
                // Use a temporary cache or stateless mode (causal mask only)
                // For now, assume training = no past cache
                let mut cache = None;
                let out = a.forward(&x_norm, &mut cache, 0)?;
                (out, w_state.clone())
            }
        };

        let x_mid = (residual + mixed_out)?;
        let residual = &x_mid;
        let x_norm2 = self.norm2.forward(&x_mid)?;
        let mlp_out = self.mlp.forward(&x_norm2)?;
        let x_out = (residual + mlp_out)?;

        Ok((x_out, w_final))
    }
}