use anyhow::Result;
use candle_core::{DType, Device, Tensor, Module, D};
use candle_nn::VarBuilder;
pub struct BitLinear {
weight: Tensor,
alpha: Tensor,
}
impl BitLinear {
pub fn load(vb: VarBuilder, in_features: usize, out_features: usize) -> Result<Self> {
let weight = vb.get((out_features, in_features), "weight")?;
let alpha = vb.get(out_features, "alpha")?;
Ok(Self { weight, alpha })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let eps = 1e-5f64;
let w_scale = self.weight.abs()?.mean_keepdim(1)?.clamp(eps, f64::MAX)?;
let w_normalized = self.weight.broadcast_div(&w_scale)?;
let w_quant = w_normalized.sign()?;
let a_scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(eps, f64::MAX)?;
let a_scaled = (x.broadcast_div(&a_scale)? * 127.0)?;
let a_quant = a_scaled.round()?.clamp(-128.0, 127.0)?;
let y = a_quant.broadcast_matmul(&w_quant.t()?)?;
let w_scale_t = w_scale.squeeze(1)?;
let rescale = (a_scale.broadcast_mul(&w_scale_t)? / 127.0)?;
let y = y.broadcast_mul(&rescale)?;
let y = y.broadcast_mul(&self.alpha)?;
Ok(y)
}
}
pub struct RMSNorm {
weight: Tensor,
eps: f64,
}
impl RMSNorm {
pub fn load(vb: VarBuilder, dim: usize) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps: 1e-5 })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let variance = x.sqr()?.mean_keepdim(D::Minus1)?;
let rms = (variance + self.eps)?.sqrt()?;
let normed = x.broadcast_div(&rms)?;
Ok(normed.broadcast_mul(&self.weight)?)
}
}
pub struct BitMambaBlock {
in_proj: BitLinear,
out_proj: BitLinear,
conv1d: candle_nn::Conv1d,
x_proj: candle_nn::Linear,
dt_proj: candle_nn::Linear,
log_a: Tensor,
d: Tensor,
norm: RMSNorm,
d_state: usize,
dt_rank: usize,
}
impl BitMambaBlock {
pub fn load(vb: VarBuilder, d_model: usize, d_state: usize, expand: usize) -> Result<Self> {
let d_inner = d_model * expand;
let dt_rank = (d_model + 15) / 16;
let in_proj = BitLinear::load(vb.pp("in_proj"), d_model, d_inner * 2)?;
let out_proj = BitLinear::load(vb.pp("out_proj"), d_inner, d_model)?;
let conv_cfg = candle_nn::Conv1dConfig {
groups: d_inner,
padding: 3,
..Default::default()
};
let conv1d = candle_nn::conv1d(d_inner, d_inner, 4, conv_cfg, vb.pp("conv1d"))?;
let x_proj = candle_nn::linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
let dt_proj = candle_nn::linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
let log_a = vb.get((d_inner, d_state), "log_A")?;
let d = vb.get(d_inner, "D")?;
let norm = RMSNorm::load(vb.pp("norm"), d_model)?;
Ok(Self { in_proj, out_proj, conv1d, x_proj, dt_proj, log_a, d, norm, d_state, dt_rank })
}
fn ssm_scan(&self, x: &Tensor, dt: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
let (batch, seq_len, d_inner) = x.dims3()?;
let d_state = self.d_state;
let device = x.device();
let neg_exp_log_a = self.log_a.exp()?.neg()?;
let mut h = Tensor::zeros((batch, d_inner, d_state), DType::F32, device)?;
let mut ys: Vec<Tensor> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let x_t = x.narrow(1, t, 1)?.squeeze(1)?;
let dt_t = dt.narrow(1, t, 1)?.squeeze(1)?;
let b_t = b.narrow(1, t, 1)?.squeeze(1)?;
let c_t = c.narrow(1, t, 1)?.squeeze(1)?;
let dt_expanded = dt_t.unsqueeze(2)?;
let a_expanded = neg_exp_log_a.unsqueeze(0)?;
let da_t = dt_expanded.broadcast_mul(&a_expanded)?.exp()?;
let db_t = dt_expanded.broadcast_mul(&b_t.unsqueeze(1)?)?;
let u_t = db_t.broadcast_mul(&x_t.unsqueeze(2)?)?;
h = (da_t.broadcast_mul(&h)? + u_t)?;
let y_t = h.broadcast_mul(&c_t.unsqueeze(1)?)?.sum(2)?;
ys.push(y_t.unsqueeze(1)?);
}
Ok(Tensor::cat(&ys, 1)?)
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let residual = x;
let x = self.norm.forward(x)?;
let xz = self.in_proj.forward(&x)?;
let d_inner = xz.dim(2)? / 2;
let x_part = xz.narrow(2, 0, d_inner)?;
let z = xz.narrow(2, d_inner, d_inner)?;
let seq_len = x_part.dim(1)?;
let x_t = x_part.transpose(1, 2)?;
let x_conv = self.conv1d.forward(&x_t)?;
let x_conv = x_conv.narrow(2, 0, seq_len)?;
let x_conv = x_conv.transpose(1, 2)?;
let x_act = candle_nn::ops::silu(&x_conv)?;
let ssm_params = self.x_proj.forward(&x_act)?;
let dt_raw = ssm_params.narrow(2, 0, self.dt_rank)?;
let b = ssm_params.narrow(2, self.dt_rank, self.d_state)?;
let c = ssm_params.narrow(2, self.dt_rank + self.d_state, self.d_state)?;
let dt = self.dt_proj.forward(&dt_raw)?;
let dt = softplus(&dt)?;
let y_ssm = self.ssm_scan(&x_act, &dt, &b, &c)?;
let z_act = candle_nn::ops::silu(&z)?;
let y = (y_ssm * z_act)?.broadcast_mul(&self.d)?;
let out = self.out_proj.forward(&y)?;
Ok((out + residual)?)
}
}
fn softplus(x: &Tensor) -> Result<Tensor> {
let x_clamped = x.clamp(-20.0f32, 20.0f32)?;
let ones = Tensor::ones_like(&x_clamped)?;
Ok((ones + x_clamped.exp()?)?.log()?)
}
pub struct BitMambaStudent {
embed_tokens: candle_nn::Embedding,
layers: Vec<BitMambaBlock>,
norm: RMSNorm,
lm_head: candle_nn::Linear,
device: Device,
}
impl BitMambaStudent {
pub fn load(vb: VarBuilder, device: Device) -> Result<Self> {
let vocab_size = 151665;
let hidden_dim = 768;
let num_layers = 12;
let d_state = 16;
let expand = 2;
let embed_tokens = candle_nn::embedding(vocab_size, hidden_dim, vb.pp("embed_tokens"))?;
let norm = RMSNorm::load(vb.pp("norm"), hidden_dim)?;
let lm_head = candle_nn::linear_no_bias(hidden_dim, vocab_size, vb.pp("lm_head"))?;
let mut layers = Vec::new();
for i in 0..num_layers {
layers.push(BitMambaBlock::load(vb.pp(format!("layers.{}", i)), hidden_dim, d_state, expand)?);
}
Ok(Self { embed_tokens, layers, norm, lm_head, device })
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let mut x = self.embed_tokens.forward(input_ids)?;
for layer in &self.layers {
x = layer.forward(&x)?;
}
x = self.norm.forward(&x)?;
let seq_len = x.dim(1)?;
let x = x.narrow(1, seq_len - 1, 1)?;
let logits = self.lm_head.forward(&x)?;
Ok(logits.squeeze(1)?)
}
pub fn generate(&self, input_ids: &[u32], max_new_tokens: usize, temperature: f64) -> Result<Vec<u32>> {
let mut token_ids = input_ids.to_vec();
for _ in 0..max_new_tokens {
let input_tensor = Tensor::new(&token_ids[..], &self.device)?.unsqueeze(0)?;
let logits = self.forward(&input_tensor)?;
let next_token_id = sample(&logits, temperature)?;
if next_token_id == 151643 || next_token_id == 151645 {
break;
}
token_ids.push(next_token_id);
}
Ok(token_ids)
}
pub fn device(&self) -> &Device {
&self.device
}
}
pub fn sample(logits: &Tensor, temperature: f64) -> Result<u32> {
if temperature <= 0.0 {
return Ok(logits.argmax(1)?.flatten_all()?.to_vec1::<u32>()?[0]);
}
use rand::Rng;
let logits = (logits / temperature)?;
let max_logit = logits.max(1)?.unsqueeze(1)?;
let logits = logits.broadcast_sub(&max_logit)?;
let probs = candle_nn::ops::softmax(&logits, 1)?;
let probs_vec = probs.flatten_all()?.to_vec1::<f32>()?;
let mut rng = rand::thread_rng();
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for (i, p) in probs_vec.iter().enumerate() {
cumsum += p;
if cumsum >= r {
return Ok(i as u32);
}
}
Ok((probs_vec.len() - 1) as u32)
}