bitmamba 0.1.0

BitMamba: 1.58-bit Mamba language model with infinite context window - includes OpenAI-compatible API server
Documentation
//! BitMamba Model - Shared model code

use anyhow::Result;
use candle_core::{DType, Device, Tensor, Module, D};
use candle_nn::VarBuilder;

// =========================================================================
// BitLinear
// =========================================================================
pub struct BitLinear {
    weight: Tensor,
    alpha: Tensor,
}

impl BitLinear {
    pub fn load(vb: VarBuilder, in_features: usize, out_features: usize) -> Result<Self> {
        let weight = vb.get((out_features, in_features), "weight")?;
        let alpha = vb.get(out_features, "alpha")?;
        Ok(Self { weight, alpha })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let eps = 1e-5f64;
        
        let w_scale = self.weight.abs()?.mean_keepdim(1)?.clamp(eps, f64::MAX)?;
        let w_normalized = self.weight.broadcast_div(&w_scale)?;
        let w_quant = w_normalized.sign()?;
        
        let a_scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(eps, f64::MAX)?;
        let a_scaled = (x.broadcast_div(&a_scale)? * 127.0)?;
        let a_quant = a_scaled.round()?.clamp(-128.0, 127.0)?;
        
        let y = a_quant.broadcast_matmul(&w_quant.t()?)?;
        
        let w_scale_t = w_scale.squeeze(1)?;
        let rescale = (a_scale.broadcast_mul(&w_scale_t)? / 127.0)?;
        let y = y.broadcast_mul(&rescale)?;
        let y = y.broadcast_mul(&self.alpha)?;
        
        Ok(y)
    }
}

// =========================================================================
// RMSNorm
// =========================================================================
pub struct RMSNorm {
    weight: Tensor,
    eps: f64,
}

impl RMSNorm {
    pub fn load(vb: VarBuilder, dim: usize) -> Result<Self> {
        let weight = vb.get(dim, "weight")?;
        Ok(Self { weight, eps: 1e-5 })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let variance = x.sqr()?.mean_keepdim(D::Minus1)?;
        let rms = (variance + self.eps)?.sqrt()?;
        let normed = x.broadcast_div(&rms)?;
        Ok(normed.broadcast_mul(&self.weight)?)
    }
}

// =========================================================================
// BitMambaBlock
// =========================================================================
pub struct BitMambaBlock {
    in_proj: BitLinear,
    out_proj: BitLinear,
    conv1d: candle_nn::Conv1d,
    x_proj: candle_nn::Linear,
    dt_proj: candle_nn::Linear,
    log_a: Tensor,
    d: Tensor,
    norm: RMSNorm,
    d_state: usize,
    dt_rank: usize,
}

impl BitMambaBlock {
    pub fn load(vb: VarBuilder, d_model: usize, d_state: usize, expand: usize) -> Result<Self> {
        let d_inner = d_model * expand;
        let dt_rank = (d_model + 15) / 16;

        let in_proj = BitLinear::load(vb.pp("in_proj"), d_model, d_inner * 2)?;
        let out_proj = BitLinear::load(vb.pp("out_proj"), d_inner, d_model)?;

        let conv_cfg = candle_nn::Conv1dConfig { 
            groups: d_inner, 
            padding: 3,
            ..Default::default() 
        };
        let conv1d = candle_nn::conv1d(d_inner, d_inner, 4, conv_cfg, vb.pp("conv1d"))?;

        let x_proj = candle_nn::linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
        let dt_proj = candle_nn::linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
        
        let log_a = vb.get((d_inner, d_state), "log_A")?;
        let d = vb.get(d_inner, "D")?;
        let norm = RMSNorm::load(vb.pp("norm"), d_model)?;

        Ok(Self { in_proj, out_proj, conv1d, x_proj, dt_proj, log_a, d, norm, d_state, dt_rank })
    }

    fn ssm_scan(&self, x: &Tensor, dt: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
        let (batch, seq_len, d_inner) = x.dims3()?;
        let d_state = self.d_state;
        let device = x.device();

        let neg_exp_log_a = self.log_a.exp()?.neg()?;
        
        let mut h = Tensor::zeros((batch, d_inner, d_state), DType::F32, device)?;
        let mut ys: Vec<Tensor> = Vec::with_capacity(seq_len);

        for t in 0..seq_len {
            let x_t = x.narrow(1, t, 1)?.squeeze(1)?;
            let dt_t = dt.narrow(1, t, 1)?.squeeze(1)?;
            let b_t = b.narrow(1, t, 1)?.squeeze(1)?;
            let c_t = c.narrow(1, t, 1)?.squeeze(1)?;

            let dt_expanded = dt_t.unsqueeze(2)?;
            let a_expanded = neg_exp_log_a.unsqueeze(0)?;
            let da_t = dt_expanded.broadcast_mul(&a_expanded)?.exp()?;

            let db_t = dt_expanded.broadcast_mul(&b_t.unsqueeze(1)?)?;
            let u_t = db_t.broadcast_mul(&x_t.unsqueeze(2)?)?;

            h = (da_t.broadcast_mul(&h)? + u_t)?;

            let y_t = h.broadcast_mul(&c_t.unsqueeze(1)?)?.sum(2)?;
            ys.push(y_t.unsqueeze(1)?);
        }

        Ok(Tensor::cat(&ys, 1)?)
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let residual = x;
        let x = self.norm.forward(x)?;

        let xz = self.in_proj.forward(&x)?;
        let d_inner = xz.dim(2)? / 2;
        let x_part = xz.narrow(2, 0, d_inner)?;
        let z = xz.narrow(2, d_inner, d_inner)?;

        let seq_len = x_part.dim(1)?;
        let x_t = x_part.transpose(1, 2)?;
        let x_conv = self.conv1d.forward(&x_t)?;
        let x_conv = x_conv.narrow(2, 0, seq_len)?;
        let x_conv = x_conv.transpose(1, 2)?;
        let x_act = candle_nn::ops::silu(&x_conv)?;

        let ssm_params = self.x_proj.forward(&x_act)?;
        let dt_raw = ssm_params.narrow(2, 0, self.dt_rank)?;
        let b = ssm_params.narrow(2, self.dt_rank, self.d_state)?;
        let c = ssm_params.narrow(2, self.dt_rank + self.d_state, self.d_state)?;

        let dt = self.dt_proj.forward(&dt_raw)?;
        let dt = softplus(&dt)?;

        let y_ssm = self.ssm_scan(&x_act, &dt, &b, &c)?;

        let z_act = candle_nn::ops::silu(&z)?;
        let y = (y_ssm * z_act)?.broadcast_mul(&self.d)?;

        let out = self.out_proj.forward(&y)?;
        Ok((out + residual)?)
    }
}

fn softplus(x: &Tensor) -> Result<Tensor> {
    let x_clamped = x.clamp(-20.0f32, 20.0f32)?;
    let ones = Tensor::ones_like(&x_clamped)?;
    Ok((ones + x_clamped.exp()?)?.log()?)
}

// =========================================================================
// BitMambaStudent
// =========================================================================
pub struct BitMambaStudent {
    embed_tokens: candle_nn::Embedding,
    layers: Vec<BitMambaBlock>,
    norm: RMSNorm,
    lm_head: candle_nn::Linear,
    device: Device,
}

impl BitMambaStudent {
    pub fn load(vb: VarBuilder, device: Device) -> Result<Self> {
        let vocab_size = 151665;
        let hidden_dim = 768;
        let num_layers = 12;
        let d_state = 16;
        let expand = 2;

        let embed_tokens = candle_nn::embedding(vocab_size, hidden_dim, vb.pp("embed_tokens"))?;
        let norm = RMSNorm::load(vb.pp("norm"), hidden_dim)?;
        let lm_head = candle_nn::linear_no_bias(hidden_dim, vocab_size, vb.pp("lm_head"))?;

        let mut layers = Vec::new();
        for i in 0..num_layers {
            layers.push(BitMambaBlock::load(vb.pp(format!("layers.{}", i)), hidden_dim, d_state, expand)?);
        }

        Ok(Self { embed_tokens, layers, norm, lm_head, device })
    }

    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
        let mut x = self.embed_tokens.forward(input_ids)?;
        
        for layer in &self.layers {
            x = layer.forward(&x)?;
        }
        
        x = self.norm.forward(&x)?;
        let seq_len = x.dim(1)?;
        let x = x.narrow(1, seq_len - 1, 1)?;
        let logits = self.lm_head.forward(&x)?;
        Ok(logits.squeeze(1)?)
    }

    pub fn generate(&self, input_ids: &[u32], max_new_tokens: usize, temperature: f64) -> Result<Vec<u32>> {
        let mut token_ids = input_ids.to_vec();
        
        for _ in 0..max_new_tokens {
            let input_tensor = Tensor::new(&token_ids[..], &self.device)?.unsqueeze(0)?;
            let logits = self.forward(&input_tensor)?;
            let next_token_id = sample(&logits, temperature)?;
            
            // EOS tokens
            if next_token_id == 151643 || next_token_id == 151645 {
                break;
            }
            
            token_ids.push(next_token_id);
        }
        
        Ok(token_ids)
    }

    pub fn device(&self) -> &Device {
        &self.device
    }
}

pub fn sample(logits: &Tensor, temperature: f64) -> Result<u32> {
    if temperature <= 0.0 {
        return Ok(logits.argmax(1)?.flatten_all()?.to_vec1::<u32>()?[0]);
    }
    
    use rand::Rng;
    let logits = (logits / temperature)?;
    let max_logit = logits.max(1)?.unsqueeze(1)?;
    let logits = logits.broadcast_sub(&max_logit)?;
    let probs = candle_nn::ops::softmax(&logits, 1)?;
    let probs_vec = probs.flatten_all()?.to_vec1::<f32>()?;
    
    let mut rng = rand::thread_rng();
    let r: f32 = rng.gen();
    let mut cumsum = 0.0;
    for (i, p) in probs_vec.iter().enumerate() {
        cumsum += p;
        if cumsum >= r {
            return Ok(i as u32);
        }
    }
    Ok((probs_vec.len() - 1) as u32)
}