mod config;
pub(crate) mod norm;
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::{Embedding, Linear, VarBuilder};
use crate::backend::MIBackend;
use crate::error::Result;
use crate::hooks::{HookCache, HookPoint, HookSpec};
use self::norm::LayerNorm;
pub use config::{RwkvConfig, RwkvLoraDims, RwkvVersion, SUPPORTED_RWKV_MODEL_TYPES};
struct RwkvState {
attn_x: Vec<Option<Tensor>>,
attn_kv: Vec<Option<Tensor>>,
ffn_x: Vec<Option<Tensor>>,
}
impl RwkvState {
fn new(n_layers: usize) -> Self {
Self {
attn_x: vec![None; n_layers],
attn_kv: vec![None; n_layers],
ffn_x: vec![None; n_layers],
}
}
}
enum TimeMix {
V6(TimeMixV6),
V7(TimeMixV7),
}
enum ChannelMix {
V6(ChannelMixV6),
V7(ChannelMixV7),
}
struct RwkvBlock {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
time_mix: TimeMix,
channel_mix: ChannelMix,
}
impl RwkvBlock {
#[allow(clippy::needless_pass_by_value)]
fn load_v6(config: &RwkvConfig, vb: VarBuilder<'_>, layer_id: usize) -> Result<Self> {
let eps = config.norm_eps;
let h = config.hidden_size;
let pre_ln = if layer_id == 0 {
Some(LayerNorm::load(h, eps, vb.pp("pre_ln"))?)
} else {
None
};
let ln1 = LayerNorm::load(h, eps, vb.pp("ln1"))?;
let ln2 = LayerNorm::load(h, eps, vb.pp("ln2"))?;
let time_mix = TimeMix::V6(TimeMixV6::load(config, vb.pp("attention"))?);
let channel_mix = ChannelMix::V6(ChannelMixV6::load(config, vb.pp("feed_forward"))?);
Ok(Self {
pre_ln,
ln1,
ln2,
time_mix,
channel_mix,
})
}
#[allow(clippy::needless_pass_by_value)]
fn load_v7(config: &RwkvConfig, vb: VarBuilder<'_>, layer_id: usize) -> Result<Self> {
let eps = config.norm_eps;
let h = config.hidden_size;
let pre_ln = if layer_id == 0 {
Some(LayerNorm::load(h, eps, vb.pp("pre_norm"))?)
} else {
None
};
let ln1 = LayerNorm::load(h, eps, vb.pp("attn_norm"))?;
let ln2 = LayerNorm::load(h, eps, vb.pp("ffn_norm"))?;
let time_mix = TimeMix::V7(TimeMixV7::load(config, vb.pp("attn"), layer_id)?);
let channel_mix = ChannelMix::V7(ChannelMixV7::load(config, vb.pp("ffn"))?);
Ok(Self {
pre_ln,
ln1,
ln2,
time_mix,
channel_mix,
})
}
#[allow(clippy::type_complexity, clippy::too_many_arguments)]
fn forward(
&self,
hidden: &Tensor,
attn_x_state: Option<&Tensor>,
attn_kv_state: Option<&Tensor>,
ffn_x_state: Option<&Tensor>,
v_first: Option<&Tensor>,
compute_eff_attn: bool,
intervention_positions: Option<&std::collections::HashSet<usize>>,
kv_scale: f32,
) -> Result<(
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Option<Tensor>,
)> {
let hidden = if let Some(ref pre_ln) = self.pre_ln {
pre_ln.forward(hidden)?
} else {
hidden.clone()
};
let (attn_out, new_attn_x, new_attn_kv, v_out, decay, eff_attn) = match &self.time_mix {
TimeMix::V6(tm) => {
let (out, ax, akv, d, eff) = tm.forward(
&self.ln1.forward(&hidden)?,
attn_x_state,
attn_kv_state,
compute_eff_attn,
intervention_positions,
kv_scale,
)?;
let dummy_v = Tensor::zeros(1, DType::F32, hidden.device())?;
(out, ax, akv, dummy_v, d, eff)
}
TimeMix::V7(tm) => {
let (out, ax, akv, v_out, d, eff) = tm.forward(
&self.ln1.forward(&hidden)?,
attn_x_state,
attn_kv_state,
v_first,
compute_eff_attn,
intervention_positions,
kv_scale,
)?;
(out, ax, akv, v_out, d, eff)
}
};
let hidden = (&hidden + attn_out)?;
let (ffn_out, new_ffn_x) = match &self.channel_mix {
ChannelMix::V6(cm) => cm.forward(&self.ln2.forward(&hidden)?, ffn_x_state)?,
ChannelMix::V7(cm) => cm.forward(&self.ln2.forward(&hidden)?, ffn_x_state)?,
};
let hidden = (&hidden + ffn_out)?;
Ok((
hidden,
new_attn_x,
new_attn_kv,
new_ffn_x,
v_out,
decay,
eff_attn,
))
}
}
struct TimeMixV6 {
time_maa_x: Tensor, time_maa_w: Tensor,
time_maa_k: Tensor,
time_maa_v: Tensor,
time_maa_r: Tensor,
time_maa_g: Tensor,
time_maa_w1: Tensor,
time_maa_w2: Tensor,
time_decay: Tensor,
time_decay_w1: Tensor,
time_decay_w2: Tensor,
time_faaaa: Tensor,
receptance: Linear,
key: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x_weight: Tensor,
ln_x_bias: Tensor,
num_heads: usize,
head_dim: usize,
group_norm_eps: f64,
time_mix_extra_dim: usize,
}
impl TimeMixV6 {
#[allow(clippy::needless_pass_by_value)] fn load(config: &RwkvConfig, vb: VarBuilder<'_>) -> Result<Self> {
let h = config.hidden_size;
let ah = config.num_heads * config.head_dim; let nh = config.num_heads;
let hs = config.head_dim;
let mix_extra = config.lora_dims.time_mix_extra_dim;
let decay_extra = config.lora_dims.time_decay_extra_dim;
let time_maa_x = vb.get((1, 1, h), "time_maa_x")?;
let time_maa_w = vb.get((1, 1, h), "time_maa_w")?;
let time_maa_k = vb.get((1, 1, h), "time_maa_k")?;
let time_maa_v = vb.get((1, 1, h), "time_maa_v")?;
let time_maa_r = vb.get((1, 1, h), "time_maa_r")?;
let time_maa_g = vb.get((1, 1, h), "time_maa_g")?;
let time_maa_w1 = vb.get((h, mix_extra * 5), "time_maa_w1")?;
let time_maa_w2 = vb.get((5, mix_extra, h), "time_maa_w2")?;
let time_decay = vb.get((1, 1, ah), "time_decay")?;
let time_decay_w1 = vb.get((h, decay_extra), "time_decay_w1")?;
let time_decay_w2 = vb.get((decay_extra, ah), "time_decay_w2")?;
let time_faaaa = vb.get((nh, hs), "time_faaaa")?;
let receptance = candle_nn::linear_no_bias(h, ah, vb.pp("receptance"))?;
let key = candle_nn::linear_no_bias(h, ah, vb.pp("key"))?;
let value = candle_nn::linear_no_bias(h, ah, vb.pp("value"))?;
let gate = candle_nn::linear_no_bias(h, ah, vb.pp("gate"))?;
let output = candle_nn::linear_no_bias(ah, h, vb.pp("output"))?;
let ln_x_weight = vb.get(ah, "ln_x.weight")?;
let ln_x_bias = vb.get(ah, "ln_x.bias")?;
Ok(Self {
time_maa_x,
time_maa_w,
time_maa_k,
time_maa_v,
time_maa_r,
time_maa_g,
time_maa_w1,
time_maa_w2,
time_decay,
time_decay_w1,
time_decay_w2,
time_faaaa,
receptance,
key,
value,
gate,
output,
ln_x_weight,
ln_x_bias,
num_heads: nh,
head_dim: hs,
group_norm_eps: config.group_norm_eps(),
time_mix_extra_dim: mix_extra,
})
}
#[allow(clippy::many_single_char_names, clippy::too_many_lines)]
fn forward(
&self,
hidden: &Tensor,
attn_x_state: Option<&Tensor>,
attn_kv_state: Option<&Tensor>,
compute_eff_attn: bool,
intervention_positions: Option<&std::collections::HashSet<usize>>,
kv_scale: f32,
) -> Result<(Tensor, Tensor, Tensor, Tensor, Option<Tensor>)> {
let (batch, seq_len, channels) = hidden.dims3()?;
let nh = self.num_heads;
let hs = self.head_dim;
let shifted = token_shift(hidden, attn_x_state, batch, seq_len, channels)?;
let new_attn_x_state = hidden.i((.., seq_len - 1, ..))?;
let xx = shifted.broadcast_sub(hidden)?;
let xxx = hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_x)?)?;
let bt = batch * seq_len;
let xxx_flat = xxx.reshape((bt, channels))?;
let projected = xxx_flat.matmul(&self.time_maa_w1)?; let projected = projected.tanh()?;
let projected = projected.reshape((bt, 5, self.time_mix_extra_dim))?;
let projected = projected.transpose(0, 1)?.contiguous()?;
let mixed = projected.matmul(&self.time_maa_w2)?; let mixed = mixed.reshape((5, batch, seq_len, channels))?;
let mw = mixed.i(0)?; let mk = mixed.i(1)?;
let mv = mixed.i(2)?;
let mr = mixed.i(3)?;
let mg = mixed.i(4)?;
let time_decay_input =
hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_w.broadcast_add(&mw)?)?)?;
let key_input =
hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_k.broadcast_add(&mk)?)?)?;
let value_input =
hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_v.broadcast_add(&mv)?)?)?;
let receptance_input =
hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_r.broadcast_add(&mr)?)?)?;
let gate_input =
hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_g.broadcast_add(&mg)?)?)?;
let rec = self.receptance.forward(&receptance_input)?; let key = self.key.forward(&key_input)?;
let val = self.value.forward(&value_input)?;
let gate_val = candle_nn::ops::silu(&self.gate.forward(&gate_input)?)?;
let td_flat = time_decay_input.reshape((bt, channels))?;
let td_proj = td_flat.matmul(&self.time_decay_w1)?.tanh()?;
let td_proj = td_proj.matmul(&self.time_decay_w2)?; let td_proj = td_proj.reshape((batch, seq_len, nh * hs))?;
let decay_raw = self.time_decay.broadcast_add(&td_proj)?;
let decay = decay_raw.to_dtype(DType::F32)?.exp()?.neg()?.exp()?;
let rec = rec
.to_dtype(DType::F32)?
.reshape((batch, seq_len, nh, hs))?;
let key = key
.to_dtype(DType::F32)?
.reshape((batch, seq_len, nh, hs))?;
let val = val
.to_dtype(DType::F32)?
.reshape((batch, seq_len, nh, hs))?;
let decay = decay.reshape((batch, seq_len, nh, hs))?;
let time_first = self.time_faaaa.to_dtype(DType::F32)?;
let time_first = time_first.reshape((1, 1, nh, hs))?;
let mut state = match attn_kv_state {
Some(prev) => prev.to_dtype(DType::F32)?,
None => Tensor::zeros((batch, nh, hs, hs), DType::F32, hidden.device())?,
};
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for ti in 0..seq_len {
let r_t = rec.i((.., ti, .., ..))?;
let k_t = key.i((.., ti, .., ..))?;
let v_t = val.i((.., ti, .., ..))?;
let decay_t = decay.i((.., ti, .., ..))?;
let time_first_t = time_first.i((.., 0, .., ..))?;
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
let v_row = v_t.unsqueeze(2)?;
let kv = k_col.matmul(&v_row)?;
let time_first_expanded = time_first_t.unsqueeze(candle_core::D::Minus1)?;
let weighted_kv = kv.broadcast_mul(&time_first_expanded)?;
let combined = (&weighted_kv + &state)?;
let r_row = r_t.unsqueeze(2)?; let out_t = r_row.matmul(&combined)?; let out_t = out_t.squeeze(2)?;
outputs.push(out_t);
let decay_expanded = decay_t.unsqueeze(candle_core::D::Minus1)?;
let should_intervene =
intervention_positions.is_some_and(|positions| positions.contains(&ti));
if should_intervene {
let decayed = state.broadcast_mul(&decay_expanded)?;
if kv_scale == 0.0 {
state = decayed;
} else {
state = ((kv * f64::from(kv_scale))? + decayed)?;
}
} else {
state = (kv + state.broadcast_mul(&decay_expanded)?)?;
}
}
let out = Tensor::stack(&outputs, 1)?;
let new_attn_kv_state = state;
let eff_attn = if compute_eff_attn {
Some(Self::compute_effective_attention_v6(
&rec,
&key,
&decay,
&self.time_faaaa.to_dtype(DType::F32)?,
batch,
seq_len,
nh,
hs,
hidden.device(),
)?)
} else {
None
};
let out = out.reshape((bt, nh * hs))?;
let out = norm::group_norm(
&out,
self.num_heads,
&self.ln_x_weight.to_dtype(DType::F32)?,
&self.ln_x_bias.to_dtype(DType::F32)?,
self.group_norm_eps,
)?;
let out = out
.reshape((batch, seq_len, nh * hs))?
.to_dtype(hidden.dtype())?;
let out = (out * gate_val)?;
let out = self.output.forward(&out)?;
Ok((out, new_attn_x_state, new_attn_kv_state, decay, eff_attn))
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::many_single_char_names,
clippy::needless_range_loop
)]
fn compute_effective_attention_v6(
r: &Tensor,
k: &Tensor,
decay: &Tensor,
time_first: &Tensor,
batch: usize,
seq_len: usize,
nh: usize,
hs: usize,
device: &Device,
) -> Result<Tensor> {
let log_decay = decay.log()?;
let mut cum_ld: Vec<Tensor> = Vec::with_capacity(seq_len + 1);
cum_ld.push(Tensor::zeros((batch, nh, hs), DType::F32, device)?);
for ti in 0..seq_len {
let ld_ti = log_decay.i((.., ti, .., ..))?; let prev = cum_ld.get(ti).ok_or_else(|| {
crate::error::MIError::Model(candle_core::Error::Msg(
"prefix sum index out of bounds".into(),
))
})?;
cum_ld.push((prev + &ld_ti)?);
}
let prefix = Tensor::stack(&cum_ld, 1)?;
let time_first_2d = time_first.reshape((nh, hs))?;
let mut eff_attn_rows: Vec<Tensor> = Vec::with_capacity(seq_len);
for ti in 0..seq_len {
let r_ti = r.i((.., ti, .., ..))?; let k_ti = k.i((.., ti, .., ..))?;
let diag = (&r_ti * &k_ti)?.broadcast_mul(&time_first_2d)?; let diag_alpha = diag
.sum(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?;
let alpha_raw = if ti > 0 {
let k_past = k.i((.., ..ti, .., ..))?; let pref_ti = prefix.i((.., ti, .., ..))?.unsqueeze(1)?; let pref_src = prefix.i((.., 1..=ti, .., ..))?; let log_cd = pref_ti.broadcast_sub(&pref_src)?; let cd = log_cd.exp()?;
let r_exp = r_ti.unsqueeze(1)?; let per_ch = r_exp.broadcast_mul(&k_past)?.broadcast_mul(&cd)?; let a_past = per_ch.sum(candle_core::D::Minus1)?.transpose(1, 2)?;
Tensor::cat(&[&a_past, &diag_alpha], candle_core::D::Minus1)? } else {
diag_alpha };
let alpha_raw = if ti + 1 < seq_len {
let pad = Tensor::zeros((batch, nh, seq_len - ti - 1), DType::F32, device)?;
Tensor::cat(&[&alpha_raw, &pad], candle_core::D::Minus1)? } else {
alpha_raw
};
let alpha_relu = alpha_raw.relu()?;
let row_sum = (alpha_relu.sum_keepdim(candle_core::D::Minus1)? + 1e-10_f64)?;
let alpha_normed = alpha_relu.broadcast_div(&row_sum)?;
eff_attn_rows.push(alpha_normed);
}
Tensor::stack(&eff_attn_rows, 2).map_err(Into::into)
}
}
struct LoraBlock {
down: Linear,
up: Linear,
}
impl LoraBlock {
#[allow(clippy::needless_pass_by_value)]
fn load(
input_dim: usize,
low_rank: usize,
output_dim: usize,
has_up_bias: bool,
vb: VarBuilder<'_>,
) -> Result<Self> {
let down = candle_nn::linear_no_bias(input_dim, low_rank, vb.pp("0"))?;
let up = if has_up_bias {
candle_nn::linear(low_rank, output_dim, vb.pp("2"))?
} else {
candle_nn::linear_no_bias(low_rank, output_dim, vb.pp("2"))?
};
Ok(Self { down, up })
}
fn forward_tanh(&self, x: &Tensor) -> Result<Tensor> {
Ok(self.up.forward(&self.down.forward(x)?.tanh()?)?)
}
fn forward_sigmoid(&self, x: &Tensor) -> Result<Tensor> {
Ok(self
.up
.forward(&candle_nn::ops::sigmoid(&self.down.forward(x)?)?)?)
}
fn forward_linear(&self, x: &Tensor) -> Result<Tensor> {
Ok(self.up.forward(&self.down.forward(x)?)?)
}
}
struct TimeMixV7 {
x_r: Tensor,
x_w: Tensor,
x_k: Tensor,
x_v: Tensor,
x_a: Tensor,
x_g: Tensor,
k_k: Tensor,
k_a: Tensor,
r_k: Tensor,
r_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
w_lora: LoraBlock,
a_lora: LoraBlock,
g_lora: LoraBlock,
v_lora: Option<LoraBlock>,
g_norm_weight: Tensor,
g_norm_bias: Tensor,
num_heads: usize,
head_dim: usize,
norm_eps: f64,
}
impl TimeMixV7 {
#[allow(clippy::needless_pass_by_value)]
fn load(config: &RwkvConfig, vb: VarBuilder<'_>, layer_idx: usize) -> Result<Self> {
let h = config.hidden_size;
let nh = config.num_heads;
let hd = config.head_dim;
let lora = &config.lora_dims;
let x_r = vb.get((1, 1, h), "x_r")?;
let x_w = vb.get((1, 1, h), "x_w")?;
let x_k = vb.get((1, 1, h), "x_k")?;
let x_v = vb.get((1, 1, h), "x_v")?;
let x_a = vb.get((1, 1, h), "x_a")?;
let x_g = vb.get((1, 1, h), "x_g")?;
let k_k = vb.get(h, "k_k")?;
let k_a = vb.get(h, "k_a")?;
let r_k = vb.get((nh, hd), "r_k")?;
let r_proj = candle_nn::linear_no_bias(h, h, vb.pp("r_proj"))?;
let k_proj = candle_nn::linear_no_bias(h, h, vb.pp("k_proj"))?;
let v_proj = candle_nn::linear_no_bias(h, h, vb.pp("v_proj"))?;
let o_proj = candle_nn::linear_no_bias(h, h, vb.pp("o_proj"))?;
let w_lora = LoraBlock::load(h, lora.decay_low_rank_dim, h, true, vb.pp("w_lora.lora"))?;
let a_lora = LoraBlock::load(h, lora.a_low_rank_dim, h, true, vb.pp("a_lora.lora"))?;
let g_lora = LoraBlock::load(h, lora.gate_low_rank_dim, h, false, vb.pp("g_lora.lora"))?;
let v_lora = if layer_idx > 0 {
Some(LoraBlock::load(
h,
lora.v_low_rank_dim,
h,
true,
vb.pp("v_lora.lora"),
)?)
} else {
None
};
let g_norm_weight = vb.get(h, "g_norm.weight")?;
let g_norm_bias = vb.get(h, "g_norm.bias")?;
Ok(Self {
x_r,
x_w,
x_k,
x_v,
x_a,
x_g,
k_k,
k_a,
r_k,
r_proj,
k_proj,
v_proj,
o_proj,
w_lora,
a_lora,
g_lora,
v_lora,
g_norm_weight,
g_norm_bias,
num_heads: nh,
head_dim: hd,
norm_eps: config.norm_eps,
})
}
#[allow(
clippy::many_single_char_names,
clippy::too_many_lines,
clippy::similar_names,
clippy::cast_precision_loss,
clippy::as_conversions,
clippy::too_many_arguments
)]
fn forward(
&self,
hidden: &Tensor,
attn_x_state: Option<&Tensor>,
attn_kv_state: Option<&Tensor>,
v_first: Option<&Tensor>,
compute_eff_attn: bool,
intervention_positions: Option<&std::collections::HashSet<usize>>,
kv_scale: f32,
) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor, Option<Tensor>)> {
let (batch, seq_len, _channels) = hidden.dims3()?;
let nh = self.num_heads;
let hd = self.head_dim;
let h = nh * hd;
let bt = batch * seq_len;
let shifted = token_shift(hidden, attn_x_state, batch, seq_len, h)?;
let new_attn_x = hidden.i((.., seq_len - 1, ..))?;
let delta = shifted.broadcast_sub(hidden)?;
let xr = hidden.broadcast_add(&delta.broadcast_mul(&self.x_r)?)?;
let xw = hidden.broadcast_add(&delta.broadcast_mul(&self.x_w)?)?;
let xk = hidden.broadcast_add(&delta.broadcast_mul(&self.x_k)?)?;
let xv = hidden.broadcast_add(&delta.broadcast_mul(&self.x_v)?)?;
let xa = hidden.broadcast_add(&delta.broadcast_mul(&self.x_a)?)?;
let xg = hidden.broadcast_add(&delta.broadcast_mul(&self.x_g)?)?;
let r = self.r_proj.forward(&xr)?; let k = self.k_proj.forward(&xk)?; let v = self.v_proj.forward(&xv)?;
let w_lora_out = self.w_lora.forward_tanh(&xw)?; let w = (candle_nn::ops::sigmoid(&w_lora_out.to_dtype(DType::F32)?)?
* (-0.606_530_659_712_633_4_f64))?;
let v_out = if let Some(v_lora) = &self.v_lora {
let v_lora_out = v_lora.forward_linear(&xv)?;
let mix = candle_nn::ops::sigmoid(&v_lora_out)?;
let v_first_t = v_first.ok_or_else(|| {
crate::error::MIError::Config("v_first required for layer > 0".into())
})?;
(&v + &(&(v_first_t - &v)? * mix)?)?
} else {
v
};
let a = candle_nn::ops::sigmoid(&self.a_lora.forward_linear(&xa)?)?;
let g = self.g_lora.forward_sigmoid(&xg)?;
let k_scaled = k.broadcast_mul(&self.k_k)?; let k_scaled_4d = k_scaled.reshape((batch, seq_len, nh, hd))?;
let kk = l2_norm(&k_scaled_4d)?;
let a_minus_1 = (a.clone() - 1.0_f64)?;
let k_mod = (&k + &(&k * &a_minus_1)?.broadcast_mul(&self.k_a)?)?;
let r_4d = r.to_dtype(DType::F32)?.reshape((batch, seq_len, nh, hd))?;
let k_4d = k_mod
.to_dtype(DType::F32)?
.reshape((batch, seq_len, nh, hd))?;
let v_4d = v_out
.to_dtype(DType::F32)?
.reshape((batch, seq_len, nh, hd))?;
let w_4d = w.reshape((batch, seq_len, nh, hd))?; let kk_f32 = kk.to_dtype(DType::F32)?;
let a_4d = a.to_dtype(DType::F32)?.reshape((batch, seq_len, nh, hd))?;
let mut state = match attn_kv_state {
Some(prev) => prev.to_dtype(DType::F32)?,
None => Tensor::zeros((batch, nh, hd, hd), DType::F32, hidden.device())?,
};
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for ti in 0..seq_len {
let r_t = r_4d.i((.., ti, .., ..))?; let k_t = k_4d.i((.., ti, .., ..))?; let v_t = v_4d.i((.., ti, .., ..))?; let w_t = w_4d.i((.., ti, .., ..))?; let kk_t = kk_f32.i((.., ti, .., ..))?; let a_t = a_4d.i((.., ti, .., ..))?;
let act_a = kk_t.neg()?;
let b_t = (&kk_t * &a_t)?;
let exp_w = w_t.exp()?; let exp_w_col = exp_w.unsqueeze(candle_core::D::Minus1)?; let term1 = state.broadcast_mul(&exp_w_col)?;
let act_a_row = act_a.unsqueeze(2)?; let a_times_s = act_a_row.matmul(&state)?; let b_col = b_t.unsqueeze(candle_core::D::Minus1)?; let term2 = b_col.matmul(&a_times_s)?;
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?; let v_row = v_t.unsqueeze(2)?; let term3 = k_col.matmul(&v_row)?;
let should_intervene =
intervention_positions.is_some_and(|positions| positions.contains(&ti));
if should_intervene {
let state_no_kv = (&term1 + &term2)?;
if kv_scale == 0.0 {
state = state_no_kv;
} else {
state = (state_no_kv + (term3 * f64::from(kv_scale))?)?;
}
} else {
state = ((&term1 + &term2)? + &term3)?;
}
let r_row = r_t.unsqueeze(2)?; let out_t = r_row.matmul(&state)?; let out_t = out_t.squeeze(2)?;
outputs.push(out_t);
}
let out = Tensor::stack(&outputs, 1)?; let new_attn_kv = state; let decay = w_4d.clone();
let gn_eps = (self.head_dim as f64) * self.norm_eps;
let out_flat = out.reshape((bt, h))?;
let out_gn = norm::group_norm(
&out_flat,
self.num_heads,
&self.g_norm_weight.to_dtype(DType::F32)?,
&self.g_norm_bias.to_dtype(DType::F32)?,
gn_eps,
)?;
let out_gn = out_gn.reshape((batch, seq_len, h))?;
let r_k_flat = self.r_k.to_dtype(DType::F32)?.reshape((1, 1, h))?;
let r_f32 = r.to_dtype(DType::F32)?;
let k_mod_f32 = k_mod.to_dtype(DType::F32)?;
let rkrk = (&r_f32 * &k_mod_f32)?.broadcast_mul(&r_k_flat)?;
let rkrk_4d = rkrk.reshape((batch, seq_len, nh, hd))?;
let v_f32 = v_out.to_dtype(DType::F32)?;
let rkrk_sum_4d = rkrk_4d
.sum_keepdim(candle_core::D::Minus1)?
.reshape((batch, seq_len, nh, 1))?;
let v_4d_corr = v_f32.reshape((batch, seq_len, nh, hd))?;
let correction = rkrk_sum_4d
.broadcast_mul(&v_4d_corr)?
.reshape((batch, seq_len, h))?;
let g_f32 = g.to_dtype(DType::F32)?;
let out_corrected = ((&out_gn + &correction)? * &g_f32)?;
let out_final = out_corrected.to_dtype(hidden.dtype())?;
let out_final = self.o_proj.forward(&out_final)?;
let eff_attn = if compute_eff_attn {
Some(Self::compute_effective_attention_v7(
&r_4d,
&k_4d,
&w_4d,
&kk_f32,
&a_4d,
batch,
seq_len,
nh,
hidden.device(),
)?)
} else {
None
};
Ok((out_final, new_attn_x, new_attn_kv, v_out, decay, eff_attn))
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::many_single_char_names,
clippy::needless_range_loop
)]
fn compute_effective_attention_v7(
r: &Tensor,
k: &Tensor,
w: &Tensor,
kk: &Tensor,
a: &Tensor,
batch: usize,
seq_len: usize,
nh: usize,
device: &Device,
) -> Result<Tensor> {
let exp_w = w.exp()?;
let b_all = (kk * a)?; let act_a_all = kk.neg()?;
let mut eff_attn_rows: Vec<Tensor> = Vec::with_capacity(seq_len);
for ti in 0..seq_len {
let mut l = r.i((.., ti, .., ..))?;
let mut alphas: Vec<Tensor> = Vec::with_capacity(ti + 1);
let k_ti = k.i((.., ti, .., ..))?;
let diag_alpha = (&l * &k_ti)?
.sum(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?; alphas.push(diag_alpha);
for j in (1..=ti).rev() {
let exp_w_j = exp_w.i((.., j, .., ..))?; let b_j = b_all.i((.., j, .., ..))?; let act_a_j = act_a_all.i((.., j, .., ..))?;
let l_dot_b = (&l * &b_j)?.sum_keepdim(candle_core::D::Minus1)?; l = (l.broadcast_mul(&exp_w_j)? + l_dot_b.broadcast_mul(&act_a_j)?)?;
let k_prev = k.i((.., j - 1, .., ..))?;
let alpha = (&l * &k_prev)?
.sum(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?; alphas.push(alpha);
}
alphas.reverse();
let alpha_raw = Tensor::cat(&alphas, candle_core::D::Minus1)?;
let alpha_raw = if ti + 1 < seq_len {
let pad = Tensor::zeros((batch, nh, seq_len - ti - 1), DType::F32, device)?;
Tensor::cat(&[&alpha_raw, &pad], candle_core::D::Minus1)?
} else {
alpha_raw
};
let alpha_relu = alpha_raw.relu()?;
let row_sum = (alpha_relu.sum_keepdim(candle_core::D::Minus1)? + 1e-10_f64)?;
let alpha_normed = alpha_relu.broadcast_div(&row_sum)?;
eff_attn_rows.push(alpha_normed);
}
Tensor::stack(&eff_attn_rows, 2).map_err(Into::into)
}
}
struct ChannelMixV7 {
x_k: Tensor,
key: Linear,
value: Linear,
}
impl ChannelMixV7 {
#[allow(clippy::needless_pass_by_value)]
fn load(config: &RwkvConfig, vb: VarBuilder<'_>) -> Result<Self> {
let h = config.hidden_size;
let intermediate = config.intermediate_size;
let x_k = vb.get(h, "x_k")?;
let key = candle_nn::linear_no_bias(h, intermediate, vb.pp("key"))?;
let value = candle_nn::linear_no_bias(intermediate, h, vb.pp("value"))?;
Ok(Self { x_k, key, value })
}
fn forward(&self, hidden: &Tensor, ffn_x_state: Option<&Tensor>) -> Result<(Tensor, Tensor)> {
let (batch, seq_len, channels) = hidden.dims3()?;
let shifted = token_shift(hidden, ffn_x_state, batch, seq_len, channels)?;
let new_ffn_x = hidden.i((.., seq_len - 1, ..))?;
let delta = shifted.broadcast_sub(hidden)?;
let key_input = hidden.broadcast_add(&delta.broadcast_mul(&self.x_k)?)?;
let key_out = self.key.forward(&key_input)?.relu()?.sqr()?;
let out = self.value.forward(&key_out)?;
Ok((out, new_ffn_x))
}
}
struct ChannelMixV6 {
time_maa_k: Tensor,
time_maa_r: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
}
impl ChannelMixV6 {
#[allow(clippy::needless_pass_by_value)] fn load(config: &RwkvConfig, vb: VarBuilder<'_>) -> Result<Self> {
let h = config.hidden_size;
let intermediate = config.intermediate_size;
let time_maa_k = vb.get((1, 1, h), "time_maa_k")?;
let time_maa_r = vb.get((1, 1, h), "time_maa_r")?;
let key = candle_nn::linear_no_bias(h, intermediate, vb.pp("key"))?;
let receptance = candle_nn::linear_no_bias(h, h, vb.pp("receptance"))?;
let value = candle_nn::linear_no_bias(intermediate, h, vb.pp("value"))?;
Ok(Self {
time_maa_k,
time_maa_r,
key,
receptance,
value,
})
}
fn forward(&self, hidden: &Tensor, ffn_x_state: Option<&Tensor>) -> Result<(Tensor, Tensor)> {
let (batch, seq_len, channels) = hidden.dims3()?;
let shifted = token_shift(hidden, ffn_x_state, batch, seq_len, channels)?;
let new_ffn_x_state = hidden.i((.., seq_len - 1, ..))?;
let xx = shifted.broadcast_sub(hidden)?;
let key_input = hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_k)?)?;
let rec_input = hidden.broadcast_add(&xx.broadcast_mul(&self.time_maa_r)?)?;
let key_out = self.key.forward(&key_input)?.relu()?.sqr()?;
let val_out = self.value.forward(&key_out)?;
let rec_gate = candle_nn::ops::sigmoid(&self.receptance.forward(&rec_input)?)?;
let out = (rec_gate * val_out)?;
Ok((out, new_ffn_x_state))
}
}
fn l2_norm(x: &Tensor) -> Result<Tensor> {
let x_f32 = x.to_dtype(DType::F32)?;
let sq_sum = x_f32.sqr()?.sum_keepdim(candle_core::D::Minus1)?;
let norm = (sq_sum + 1e-12_f64)?.sqrt()?;
Ok(x_f32.broadcast_div(&norm)?)
}
fn token_shift(
hidden: &Tensor,
state: Option<&Tensor>,
batch: usize,
seq_len: usize,
channels: usize,
) -> Result<Tensor> {
if seq_len == 1 {
match state {
Some(prev) => Ok(prev.unsqueeze(1)?),
None => Ok(Tensor::zeros(
(batch, 1, channels),
hidden.dtype(),
hidden.device(),
)?),
}
} else {
let zeros = Tensor::zeros((batch, 1, channels), hidden.dtype(), hidden.device())?;
let prev_tokens = hidden.i((.., ..seq_len - 1, ..))?;
let shifted = Tensor::cat(&[&zeros, &prev_tokens], 1)?;
if let Some(prev) = state {
let state_expanded = prev.unsqueeze(1)?;
let rest = shifted.i((.., 1.., ..))?;
Ok(Tensor::cat(&[&state_expanded, &rest], 1)?)
} else {
Ok(shifted)
}
}
}
fn resolve_layer_intervention<'a>(
hooks: &HookSpec,
layer_idx: usize,
knockout_positions: Option<&'a std::collections::HashSet<usize>>,
steering_positions: Option<&'a std::collections::HashSet<usize>>,
steering_scale: f32,
) -> (Option<&'a std::collections::HashSet<usize>>, f32) {
if let Some(ko) = hooks.state_knockout()
&& ko.applies_to_layer(layer_idx)
{
return (knockout_positions, 0.0);
}
if let Some(st) = hooks.state_steering()
&& st.applies_to_layer(layer_idx)
{
return (steering_positions, steering_scale);
}
(None, 1.0)
}
pub struct GenericRwkv {
embeddings: Embedding,
blocks: Vec<RwkvBlock>,
ln_out: LayerNorm,
lm_head: Linear,
config: RwkvConfig,
}
impl GenericRwkv {
#[allow(clippy::needless_pass_by_value)] pub fn load(
config: RwkvConfig,
_device: &Device,
_dtype: DType,
vb: VarBuilder<'_>,
) -> Result<Self> {
match config.version {
RwkvVersion::V6 => Self::load_v6(config, vb),
RwkvVersion::V7 => Self::load_v7(config, vb),
}
}
#[allow(clippy::needless_pass_by_value)] fn load_v6(config: RwkvConfig, vb: VarBuilder<'_>) -> Result<Self> {
let vb_rwkv = vb.pp("rwkv");
let embeddings = candle_nn::embedding(
config.vocab_size,
config.hidden_size,
vb_rwkv.pp("embeddings"),
)?;
let mut blocks = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
blocks.push(RwkvBlock::load_v6(
&config,
vb_rwkv.pp(format!("blocks.{i}")),
i,
)?);
}
let ln_out = LayerNorm::load(config.hidden_size, config.norm_eps, vb_rwkv.pp("ln_out"))?;
let lm_head = if config.tie_word_embeddings {
Linear::new(embeddings.embeddings().clone(), None)
} else {
candle_nn::linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("head"))?
};
Ok(Self {
embeddings,
blocks,
ln_out,
lm_head,
config,
})
}
#[allow(clippy::needless_pass_by_value)] fn load_v7(config: RwkvConfig, vb: VarBuilder<'_>) -> Result<Self> {
let vb_model = vb.pp("model");
let embeddings = candle_nn::embedding(
config.vocab_size,
config.hidden_size,
vb_model.pp("embeddings"),
)?;
let mut blocks = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
blocks.push(RwkvBlock::load_v7(
&config,
vb_model.pp(format!("layers.{i}")),
i,
)?);
}
let ln_out = LayerNorm::load(config.hidden_size, config.norm_eps, vb_model.pp("norm"))?;
let lm_head = if config.tie_word_embeddings {
Linear::new(embeddings.embeddings().clone(), None)
} else {
candle_nn::linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?
};
Ok(Self {
embeddings,
blocks,
ln_out,
lm_head,
config,
})
}
#[must_use]
pub const fn config(&self) -> &RwkvConfig {
&self.config
}
}
impl MIBackend for GenericRwkv {
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn hidden_size(&self) -> usize {
self.config.hidden_size
}
fn vocab_size(&self) -> usize {
self.config.vocab_size
}
fn num_heads(&self) -> usize {
self.config.num_heads
}
fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
let device = input_ids.device();
let mut hidden = self.embeddings.forward(input_ids)?;
let mut cache = HookCache::new(Tensor::zeros(1, DType::F32, device)?);
if hooks.is_captured(&HookPoint::Embed) {
cache.store(HookPoint::Embed, hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::Embed) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
let knockout_positions = hooks
.state_knockout()
.map(crate::interp::intervention::StateKnockoutSpec::position_set);
let steering_positions = hooks
.state_steering()
.map(crate::interp::intervention::StateSteeringSpec::position_set);
let steering_scale = hooks.state_steering().map_or(1.0, |s| s.scale);
let mut state = RwkvState::new(self.config.num_layers);
let mut v_first: Option<Tensor> = None;
for (layer_idx, block) in self.blocks.iter().enumerate() {
if hooks.is_captured(&HookPoint::ResidPre(layer_idx)) {
cache.store(HookPoint::ResidPre(layer_idx), hidden.clone());
}
let compute_eff_attn = hooks.is_captured(&HookPoint::RwkvEffectiveAttn(layer_idx));
let (layer_int_positions, layer_kv_scale) = resolve_layer_intervention(
hooks,
layer_idx,
knockout_positions.as_ref(),
steering_positions.as_ref(),
steering_scale,
);
let (new_hidden, new_attn_x, new_attn_kv, new_ffn_x, v_out, decay, eff_attn) = block
.forward(
&hidden,
state.attn_x.get(layer_idx).and_then(Option::as_ref),
state.attn_kv.get(layer_idx).and_then(Option::as_ref),
state.ffn_x.get(layer_idx).and_then(Option::as_ref),
v_first.as_ref(),
compute_eff_attn,
layer_int_positions,
layer_kv_scale,
)?;
if layer_idx == 0 && self.config.version == RwkvVersion::V7 {
v_first = Some(v_out);
}
hidden = new_hidden;
if hooks.is_captured(&HookPoint::RwkvState(layer_idx)) {
cache.store(HookPoint::RwkvState(layer_idx), new_attn_kv.clone());
}
if hooks.is_captured(&HookPoint::RwkvDecay(layer_idx)) {
cache.store(HookPoint::RwkvDecay(layer_idx), decay);
}
if let Some(ea) = eff_attn {
cache.store(HookPoint::RwkvEffectiveAttn(layer_idx), ea);
}
if let Some(slot) = state.attn_x.get_mut(layer_idx) {
*slot = Some(new_attn_x);
}
if let Some(slot) = state.attn_kv.get_mut(layer_idx) {
*slot = Some(new_attn_kv);
}
if let Some(slot) = state.ffn_x.get_mut(layer_idx) {
*slot = Some(new_ffn_x);
}
if hooks.is_captured(&HookPoint::ResidPost(layer_idx)) {
cache.store(HookPoint::ResidPost(layer_idx), hidden.clone());
}
}
hidden = self.ln_out.forward(&hidden)?;
if hooks.is_captured(&HookPoint::FinalNorm) {
cache.store(HookPoint::FinalNorm, hidden.clone());
}
let logits = self.lm_head.forward(&hidden)?;
cache.set_output(logits);
Ok(cache)
}
fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor> {
let normed = self.ln_out.forward(hidden)?;
Ok(self.lm_head.forward(&normed)?)
}
}