pub(crate) mod attention;
pub(crate) mod mlp;
pub(crate) mod norm;
pub(crate) mod recurrent;
pub(crate) mod rope;
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::{Embedding, Linear, VarBuilder};
use crate::backend::{self, MIBackend};
use crate::config::TransformerConfig;
use crate::error::Result;
use crate::hooks::{HookCache, HookPoint, HookSpec};
use crate::util::masks;
use self::attention::Attention;
use self::mlp::Mlp;
use self::norm::{Norm, create_norm};
use self::recurrent::RecurrentPassSpec;
use self::rope::RopeCache;
struct TransformerLayer {
input_norm: Norm,
attention: Attention,
post_attention_norm: Option<Norm>,
mid_norm: Norm,
post_feedforward_norm: Option<Norm>,
mlp: Mlp,
}
impl TransformerLayer {
#[allow(clippy::needless_pass_by_value)] fn load(config: &TransformerConfig, vb: VarBuilder<'_>) -> Result<Self> {
let input_norm = create_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
vb.pp("input_layernorm"),
)?;
let attention = Attention::load(config, vb.pp("self_attn"))?;
let post_attention_norm = if config.use_post_norms {
Some(create_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
vb.pp("post_attention_layernorm"),
)?)
} else {
None
};
let mid_norm_name = if config.use_post_norms {
"pre_feedforward_layernorm"
} else {
"post_attention_layernorm"
};
let mid_norm = create_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
vb.pp(mid_norm_name),
)?;
let post_feedforward_norm = if config.use_post_norms {
Some(create_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
vb.pp("post_feedforward_layernorm"),
)?)
} else {
None
};
let mlp = Mlp::load(config, vb.pp("mlp"))?;
Ok(Self {
input_norm,
attention,
post_attention_norm,
mid_norm,
post_feedforward_norm,
mlp,
})
}
}
pub struct GenericTransformer {
embed_tokens: Embedding,
layers: Vec<TransformerLayer>,
final_norm: Norm,
lm_head: Option<Linear>,
rope_cache: RopeCache,
config: TransformerConfig,
}
impl GenericTransformer {
#[allow(clippy::needless_pass_by_value)] pub fn load(
config: TransformerConfig,
device: &Device,
dtype: DType,
vb: VarBuilder<'_>,
) -> Result<Self> {
let vb_model = vb.pp("model");
let embed_tokens = candle_nn::embedding(
config.vocab_size,
config.hidden_size,
vb_model.pp("embed_tokens"),
)?;
let mut layers = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
let vb_layer = vb_model.pp(format!("layers.{i}"));
let layer = TransformerLayer::load(&config, vb_layer)?;
layers.push(layer);
}
let final_norm = create_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
vb_model.pp("norm"),
)?;
let lm_head = if config.tie_word_embeddings {
None
} else {
Some(candle_nn::linear_no_bias(
config.hidden_size,
config.vocab_size,
vb.pp("lm_head"),
)?)
};
let rope_cache = RopeCache::new(
config.head_dim,
config.max_position_embeddings,
config.rope_theta,
device,
dtype,
)?;
Ok(Self {
embed_tokens,
layers,
final_norm,
lm_head,
rope_cache,
config,
})
}
#[must_use]
pub const fn config(&self) -> &TransformerConfig {
&self.config
}
#[allow(clippy::too_many_arguments)]
fn forward_layer_range(
&self,
mut hidden: Tensor,
start: usize,
end: usize,
seq_len: usize,
device: &Device,
dtype: DType,
hooks: &HookSpec,
cache: &mut HookCache,
) -> Result<Tensor> {
let layer_slice = self.layers.get(start..end).ok_or_else(|| {
crate::error::MIError::Intervention(format!(
"layer range {start}..{end} out of bounds (n_layers={})",
self.layers.len()
))
})?;
for (offset, layer) in layer_slice.iter().enumerate() {
let layer_idx = start + offset;
if hooks.is_captured(&HookPoint::ResidPre(layer_idx)) {
cache.store(HookPoint::ResidPre(layer_idx), hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::ResidPre(layer_idx)) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
let residual = hidden.clone();
hidden = layer.input_norm.forward(&hidden)?;
let mask = self.mask_for_layer(layer_idx, seq_len, device, dtype)?;
hidden = layer.attention.forward(
&hidden,
&mask,
&self.rope_cache,
layer_idx,
hooks,
cache,
)?;
if let Some(ref norm) = layer.post_attention_norm {
hidden = norm.forward(&hidden)?;
}
if hooks.is_captured(&HookPoint::AttnOut(layer_idx)) {
cache.store(HookPoint::AttnOut(layer_idx), hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::AttnOut(layer_idx)) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
hidden = (residual + &hidden)?;
if hooks.is_captured(&HookPoint::ResidMid(layer_idx)) {
cache.store(HookPoint::ResidMid(layer_idx), hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::ResidMid(layer_idx)) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
let residual = hidden.clone();
hidden = layer.mid_norm.forward(&hidden)?;
if hooks.is_captured(&HookPoint::MlpPre(layer_idx)) {
cache.store(HookPoint::MlpPre(layer_idx), hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::MlpPre(layer_idx)) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
hidden = layer.mlp.forward(&hidden)?;
if hooks.is_captured(&HookPoint::MlpPost(layer_idx)) {
cache.store(HookPoint::MlpPost(layer_idx), hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::MlpPost(layer_idx)) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
if let Some(ref norm) = layer.post_feedforward_norm {
hidden = norm.forward(&hidden)?;
}
if hooks.is_captured(&HookPoint::MlpOut(layer_idx)) {
cache.store(HookPoint::MlpOut(layer_idx), hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::MlpOut(layer_idx)) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
hidden = (residual + &hidden)?;
if hooks.is_captured(&HookPoint::ResidPost(layer_idx)) {
cache.store(HookPoint::ResidPost(layer_idx), hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::ResidPost(layer_idx)) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
}
Ok(hidden)
}
fn embed_with_hooks(
&self,
input_ids: &Tensor,
hooks: &HookSpec,
) -> Result<(Tensor, DType, HookCache)> {
let device = input_ids.device();
let mut hidden = self.embed_tokens.forward(input_ids)?;
let dtype = hidden.dtype();
if let Some(scale) = self.config.embedding_scale {
hidden = (hidden * scale)?;
}
let mut cache = HookCache::new(Tensor::zeros(1, DType::F32, device)?);
if hooks.is_captured(&HookPoint::Embed) {
cache.store(HookPoint::Embed, hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::Embed) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
Ok((hidden, dtype, cache))
}
fn finalize_logits(
&self,
mut hidden: Tensor,
hooks: &HookSpec,
cache: &mut HookCache,
) -> Result<Tensor> {
hidden = self.final_norm.forward(&hidden)?;
if hooks.is_captured(&HookPoint::FinalNorm) {
cache.store(HookPoint::FinalNorm, hidden.clone());
}
for intervention in hooks.interventions_at(&HookPoint::FinalNorm) {
hidden = crate::hooks::apply_intervention(&hidden, intervention)?;
}
let mut logits = self.project_logits(&hidden)?;
if let Some(cap) = self.config.final_logit_softcapping {
logits = ((logits / cap)?.tanh()? * cap)?;
}
Ok(logits)
}
pub fn forward_recurrent(
&self,
input_ids: &Tensor,
hooks: &HookSpec,
spec: &RecurrentPassSpec,
) -> Result<HookCache> {
let (mut hidden, dtype, mut cache) = self.embed_with_hooks(input_ids, hooks)?;
let device = input_ids.device();
let (_, seq_len, _) = hidden.dims3()?;
spec.validate(self.config.num_layers, seq_len, self.config.hidden_size)?;
hidden = self.forward_layer_range(
hidden,
0,
spec.loop_start,
seq_len,
device,
dtype,
hooks,
&mut cache,
)?;
let saved_input = if spec.feedback.is_empty() {
None
} else {
Some(hidden.clone())
};
hidden = self.forward_layer_range(
hidden,
spec.loop_start,
spec.loop_end + 1,
seq_len,
device,
dtype,
hooks,
&mut cache,
)?;
for _pass in 1..spec.depth {
if let Some(ref saved) = saved_input {
hidden = saved.clone();
for entry in &spec.feedback {
let scaled = (&entry.vector * f64::from(entry.strength))?;
hidden = inject_feedback_at_position(&hidden, &scaled, entry.position)?;
}
}
hidden = self.forward_layer_range(
hidden,
spec.loop_start,
spec.loop_end + 1,
seq_len,
device,
dtype,
hooks,
&mut cache,
)?;
}
hidden = self.forward_layer_range(
hidden,
spec.loop_end + 1,
self.config.num_layers,
seq_len,
device,
dtype,
hooks,
&mut cache,
)?;
let logits = self.finalize_logits(hidden, hooks, &mut cache)?;
cache.set_output(logits);
Ok(cache)
}
pub fn generate_recurrent(
&self,
prompt_tokens: &[u32],
max_tokens: usize,
temperature: f32,
stop_tokens: &[u32],
spec: &RecurrentPassSpec,
) -> Result<Vec<u32>> {
let device = self.embed_tokens.embeddings().device();
let mut tokens = prompt_tokens.to_vec();
for _ in 0..max_tokens {
let input_tensor = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let effective_spec = if spec.sustained && !spec.feedback.is_empty() {
let mut step_spec = spec.clone();
let last_pos = tokens.len() - 1;
for entry in &spec.feedback {
step_spec.feedback.push(recurrent::RecurrentFeedbackEntry {
position: last_pos,
vector: entry.vector.clone(),
strength: entry.strength,
});
}
step_spec
} else {
spec.clone()
};
let hook_cache =
self.forward_recurrent(&input_tensor, &HookSpec::new(), &effective_spec)?;
let logits = hook_cache.output();
let seq_len = logits.dim(1)?;
let last_logits = logits.i((.., seq_len - 1, ..))?.squeeze(1)?;
let last_logits_flat = last_logits.flatten_all()?;
let next_token = backend::sample_token(&last_logits_flat, temperature)?;
if stop_tokens.contains(&next_token) {
break;
}
tokens.push(next_token);
}
Ok(tokens)
}
fn project_logits(&self, hidden: &Tensor) -> Result<Tensor> {
if let Some(head) = &self.lm_head {
Ok(head.forward(hidden)?)
} else {
let embed_weight = self.embed_tokens.embeddings();
let logits = hidden.broadcast_matmul(&embed_weight.t()?)?;
Ok(logits)
}
}
fn mask_for_layer(
&self,
layer_idx: usize,
seq_len: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let use_sliding = match (
self.config.sliding_window,
self.config.alternating_sliding_window,
) {
(Some(_), true) => layer_idx.is_multiple_of(2), (Some(_), false) => true, (None, _) => false,
};
if use_sliding && let Some(window) = self.config.sliding_window {
return create_sliding_window_mask(seq_len, window, device, dtype);
}
masks::create_causal_mask(seq_len, device, dtype)
}
}
impl MIBackend for GenericTransformer {
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn hidden_size(&self) -> usize {
self.config.hidden_size
}
fn vocab_size(&self) -> usize {
self.config.vocab_size
}
fn num_heads(&self) -> usize {
self.config.num_attention_heads
}
fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
let device = input_ids.device();
let (hidden, dtype, mut cache) = self.embed_with_hooks(input_ids, hooks)?;
let (_, seq_len, _) = hidden.dims3()?;
let hidden = self.forward_layer_range(
hidden,
0,
self.config.num_layers,
seq_len,
device,
dtype,
hooks,
&mut cache,
)?;
let logits = self.finalize_logits(hidden, hooks, &mut cache)?;
cache.set_output(logits);
Ok(cache)
}
fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor> {
let normed = self.final_norm.forward(hidden)?;
self.project_logits(&normed)
}
fn embedding_vector(&self, token_id: u32) -> Result<Tensor> {
let device = self.embed_tokens.embeddings().device();
let ids = Tensor::new(&[token_id], device)?;
let emb = self.embed_tokens.forward(&ids)?; Ok(emb.squeeze(0)?) }
}
fn inject_feedback_at_position(
hidden: &Tensor,
vector: &Tensor,
position: usize,
) -> Result<Tensor> {
let seq_len = hidden.dim(1)?;
let d_model = hidden.dim(2)?;
let mut delta_data = vec![0.0_f32; seq_len * d_model];
let vec_f32: Vec<f32> = vector
.to_dtype(candle_core::DType::F32)?
.flatten_all()?
.to_vec1()?;
let start = position * d_model;
let dest = delta_data.get_mut(start..start + d_model).ok_or_else(|| {
crate::error::MIError::Intervention(format!(
"feedback position {position} out of bounds (seq_len={seq_len})"
))
})?;
dest.copy_from_slice(&vec_f32);
let delta = Tensor::from_vec(delta_data, (1, seq_len, d_model), hidden.device())?
.to_dtype(hidden.dtype())?;
Ok(hidden.broadcast_add(&delta)?)
}
fn create_sliding_window_mask(
seq_len: usize,
window: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let mut mask_data = vec![0.0_f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
let idx = i * seq_len + j;
if j > i || i.saturating_sub(j) > window {
if let Some(cell) = mask_data.get_mut(idx) {
*cell = f32::NEG_INFINITY;
}
}
}
}
Ok(Tensor::from_vec(mask_data, (1, 1, seq_len, seq_len), device)?.to_dtype(dtype)?)
}