impl AprTransformer {
pub fn forward_traced(&self, token_ids: &[u32]) -> Result<ForwardTrace> {
self.forward_traced_with_plan(token_ids, None)
}
pub fn forward_traced_with_plan(
&self,
token_ids: &[u32],
plan: Option<&crate::inference_trace::save_tensor_plan::SaveTensorPlan>,
) -> Result<ForwardTrace> {
use crate::inference_trace::save_tensor::WHOLE_MODEL_LAYER;
use crate::inference_trace::save_tensor_emit::maybe_save_stage;
use crate::inference_trace::save_tensor_stage::SaveTensorStage;
let emit = |stage: SaveTensorStage, layer: u32, values: &[f32]| -> Result<()> {
maybe_save_stage(plan, stage, layer, values).map_err(|e| RealizarError::IoError {
message: format!("save_tensor::{stage:?} L{layer}: {e}"),
})
};
if token_ids.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Token sequence cannot be empty".to_string(),
});
}
let hidden_dim = self.config.hidden_dim;
let intermediate_dim = self.config.intermediate_dim;
let mut hidden = self.embed(token_ids);
emit(SaveTensorStage::Embedding, 0, &hidden)?;
let embed_stats = ActivationStats::from_slice(&hidden);
let mut layer_activations = Vec::with_capacity(self.layers.len());
for (layer_idx, layer) in self.layers.iter().enumerate() {
let _q4k_layer = self.q4k_layers.as_ref().and_then(|l| l.get(layer_idx));
let normed = self.layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.config.eps,
);
emit(SaveTensorStage::AttnNorm, layer_idx as u32, &normed)?;
let attn_norm_stats = ActivationStats::from_slice(&normed);
let seq_len_for_last = token_ids.len();
let last_token_attn_norm_stats = ActivationStats::from_slice(
&normed[(seq_len_for_last - 1) * hidden_dim..],
);
let qkv_dim = layer.qkv_weight.len() / hidden_dim;
let mut qkv = self.matmul(&normed, &layer.qkv_weight, hidden_dim, qkv_dim);
emit(SaveTensorStage::QkvMatmul, layer_idx as u32, &qkv)?;
if let Some(ref bias) = layer.qkv_bias {
self.add_bias(&mut qkv, bias);
emit(SaveTensorStage::QkvBias, layer_idx as u32, &qkv)?;
}
let qkv_stats = ActivationStats::from_slice(&qkv);
let last_token_qkv_stats =
ActivationStats::from_slice(&qkv[(seq_len_for_last - 1) * qkv_dim..]);
let seq_len = token_ids.len();
let head_dim = hidden_dim / self.config.num_heads;
let num_kv_heads = self.config.num_kv_heads;
let kv_dim = num_kv_heads * head_dim;
let group_size = self.config.num_heads / num_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut q_all = Vec::with_capacity(seq_len * hidden_dim);
let mut k_all = Vec::with_capacity(seq_len * kv_dim);
let mut v_all = Vec::with_capacity(seq_len * kv_dim);
for s in 0..seq_len {
let qkv_start = s * qkv_dim;
let mut q_pos = qkv[qkv_start..qkv_start + hidden_dim].to_vec();
let mut k_pos =
qkv[qkv_start + hidden_dim..qkv_start + hidden_dim + kv_dim].to_vec();
let v_pos =
&qkv[qkv_start + hidden_dim + kv_dim..qkv_start + hidden_dim + 2 * kv_dim];
self.apply_rope_f32(&mut q_pos, s, self.config.num_heads, head_dim);
self.apply_rope_f32(&mut k_pos, s, num_kv_heads, head_dim);
q_all.extend_from_slice(&q_pos);
k_all.extend_from_slice(&k_pos);
v_all.extend_from_slice(v_pos);
}
emit(SaveTensorStage::QPostRope, layer_idx as u32, &q_all)?;
emit(SaveTensorStage::KPostRope, layer_idx as u32, &k_all)?;
let want_scores = plan
.is_some_and(|p| p.should_save(SaveTensorStage::AttnScores, layer_idx as u32));
let want_softmax = plan
.is_some_and(|p| p.should_save(SaveTensorStage::AttnSoftmax, layer_idx as u32));
let mut scores_all: Option<Vec<f32>> = if want_scores {
Some(vec![0.0f32; self.config.num_heads * seq_len * seq_len])
} else {
None
};
let mut softmax_all: Option<Vec<f32>> = if want_softmax {
Some(vec![0.0f32; self.config.num_heads * seq_len * seq_len])
} else {
None
};
let mut attn_out = vec![0.0f32; seq_len * hidden_dim];
for head in 0..self.config.num_heads {
let kv_head = head / group_size;
let q_head_offset = head * head_dim;
let kv_head_offset = kv_head * head_dim;
for i in 0..seq_len {
let mut scores = Vec::with_capacity(i + 1);
let q_start = i * hidden_dim + q_head_offset;
for j in 0..=i {
let k_start = j * kv_dim + kv_head_offset;
let mut score = 0.0f32;
for d in 0..head_dim {
score += q_all[q_start + d] * k_all[k_start + d];
}
scores.push(score * scale);
}
if let Some(ref mut buf) = scores_all {
let row_base = head * seq_len * seq_len + i * seq_len;
for (j, &s) in scores.iter().enumerate() {
buf[row_base + j] = s;
}
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> =
scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let probs: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
if let Some(ref mut buf) = softmax_all {
let row_base = head * seq_len * seq_len + i * seq_len;
for (j, &p) in probs.iter().enumerate() {
buf[row_base + j] = p;
}
}
let out_start = i * hidden_dim + q_head_offset;
for (j, &p) in probs.iter().enumerate() {
let v_start = j * kv_dim + kv_head_offset;
for d in 0..head_dim {
attn_out[out_start + d] += p * v_all[v_start + d];
}
}
}
}
if let Some(buf) = scores_all.as_deref() {
emit(SaveTensorStage::AttnScores, layer_idx as u32, buf)?;
}
if let Some(buf) = softmax_all.as_deref() {
emit(SaveTensorStage::AttnSoftmax, layer_idx as u32, buf)?;
}
emit(SaveTensorStage::Attention, layer_idx as u32, &attn_out)?;
let mut attn_output =
self.matmul(&attn_out, &layer.attn_output_weight, hidden_dim, hidden_dim);
if let Some(ref bias) = layer.attn_output_bias {
self.add_bias(&mut attn_output, bias);
}
emit(SaveTensorStage::AttnOut, layer_idx as u32, &attn_output)?;
let attn_out_stats = ActivationStats::from_slice(&attn_output);
let last_token_attn_out_stats = ActivationStats::from_slice(
&attn_output[(seq_len_for_last - 1) * hidden_dim..],
);
for i in 0..hidden.len() {
hidden[i] += attn_output[i];
}
emit(SaveTensorStage::PostAttnResidual, layer_idx as u32, &hidden)?;
let ffn_input = if let Some(ref norm_weight) = layer.ffn_norm_weight {
let normed = self.layer_norm(
&hidden,
norm_weight,
layer.ffn_norm_bias.as_deref(),
self.config.eps,
);
normed
} else {
hidden.clone()
};
emit(SaveTensorStage::FfnNorm, layer_idx as u32, &ffn_input)?;
let ffn_norm_stats = ActivationStats::from_slice(&ffn_input);
let last_token_ffn_norm_stats = ActivationStats::from_slice(
&ffn_input[(seq_len_for_last - 1) * hidden_dim..],
);
let mut ffn_gate_stats = ActivationStats::default();
let mut ffn_up_stats = ActivationStats::default();
let mut ffn_silu_gate_stats = ActivationStats::default();
let mut ffn_swiglu_inner_stats = ActivationStats::default();
let mut last_token_ffn_gate_stats = ActivationStats::default();
let mut last_token_ffn_up_stats = ActivationStats::default();
let mut last_token_ffn_silu_gate_stats = ActivationStats::default();
let mut last_token_ffn_swiglu_inner_stats = ActivationStats::default();
let ffn_output = if let Some(ref gate_weight) = layer.ffn_gate_weight {
let gate = self.matmul(&ffn_input, gate_weight, hidden_dim, intermediate_dim);
emit(SaveTensorStage::FfnGate, layer_idx as u32, &gate)?;
let up = self.matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
emit(SaveTensorStage::FfnUp, layer_idx as u32, &up)?;
ffn_gate_stats = ActivationStats::from_slice(&gate);
ffn_up_stats = ActivationStats::from_slice(&up);
last_token_ffn_gate_stats = ActivationStats::from_slice(
&gate[(seq_len_for_last - 1) * intermediate_dim..],
);
last_token_ffn_up_stats = ActivationStats::from_slice(
&up[(seq_len_for_last - 1) * intermediate_dim..],
);
let mut silu_gate = Vec::with_capacity(gate.len());
let mut ffn_hidden = Vec::with_capacity(gate.len());
for (g, u) in gate.iter().zip(up.iter()) {
let silu_g = g / (1.0 + (-g).exp());
silu_gate.push(silu_g);
ffn_hidden.push(silu_g * u);
}
emit(SaveTensorStage::FfnSilu, layer_idx as u32, &silu_gate)?;
emit(SaveTensorStage::FfnSwigl, layer_idx as u32, &ffn_hidden)?;
ffn_silu_gate_stats = ActivationStats::from_slice(&silu_gate);
ffn_swiglu_inner_stats = ActivationStats::from_slice(&ffn_hidden);
last_token_ffn_silu_gate_stats = ActivationStats::from_slice(
&silu_gate[(seq_len_for_last - 1) * intermediate_dim..],
);
last_token_ffn_swiglu_inner_stats = ActivationStats::from_slice(
&ffn_hidden[(seq_len_for_last - 1) * intermediate_dim..],
);
let mut out = self.matmul(
&ffn_hidden,
&layer.ffn_down_weight,
intermediate_dim,
hidden_dim,
);
if let Some(ref bias) = layer.ffn_down_bias {
self.add_bias(&mut out, bias);
}
out
} else {
let mut ffn_hidden = self.matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
if let Some(ref bias) = layer.ffn_up_bias {
self.add_bias(&mut ffn_hidden, bias);
}
ffn_up_stats = ActivationStats::from_slice(&ffn_hidden);
last_token_ffn_up_stats = ActivationStats::from_slice(
&ffn_hidden[(seq_len_for_last - 1) * intermediate_dim..],
);
for h in &mut ffn_hidden {
let gelu_approx =
0.5 * *h * (1.0 + (0.797_884_6 * (*h + 0.044_715 * *h * *h * *h)).tanh());
*h = gelu_approx;
}
let mut out = self.matmul(
&ffn_hidden,
&layer.ffn_down_weight,
intermediate_dim,
hidden_dim,
);
if let Some(ref bias) = layer.ffn_down_bias {
self.add_bias(&mut out, bias);
}
out
};
emit(SaveTensorStage::FfnOut, layer_idx as u32, &ffn_output)?;
let ffn_out_stats = ActivationStats::from_slice(&ffn_output);
let last_token_ffn_out_stats = ActivationStats::from_slice(
&ffn_output[(seq_len_for_last - 1) * hidden_dim..],
);
for i in 0..hidden.len() {
hidden[i] += ffn_output[i];
}
emit(SaveTensorStage::PostFfnResidual, layer_idx as u32, &hidden)?;
let output_stats = ActivationStats::from_slice(&hidden);
let last_token_output_stats = ActivationStats::from_slice(
&hidden[(seq_len_for_last - 1) * hidden_dim..],
);
let last_token = Some(crate::apr_transformer::LastTokenStats {
attn_norm_stats: last_token_attn_norm_stats,
qkv_stats: last_token_qkv_stats,
attn_out_stats: last_token_attn_out_stats,
ffn_norm_stats: last_token_ffn_norm_stats,
ffn_gate_stats: last_token_ffn_gate_stats,
ffn_up_stats: last_token_ffn_up_stats,
ffn_silu_gate_stats: last_token_ffn_silu_gate_stats,
ffn_swiglu_inner_stats: last_token_ffn_swiglu_inner_stats,
ffn_out_stats: last_token_ffn_out_stats,
output_stats: last_token_output_stats,
});
layer_activations.push(LayerActivation {
layer_idx,
attn_norm_stats,
qkv_stats,
attn_out_stats,
ffn_norm_stats,
ffn_gate_stats,
ffn_up_stats,
ffn_silu_gate_stats,
ffn_swiglu_inner_stats,
ffn_out_stats,
output_stats,
last_token,
});
}
let normed = self.layer_norm(
&hidden,
&self.output_norm_weight,
self.output_norm_bias.as_deref(),
self.config.eps,
);
emit(SaveTensorStage::FinalNorm, WHOLE_MODEL_LAYER, &normed)?;
let final_norm_stats = ActivationStats::from_slice(&normed);
let seq_len = token_ids.len();
let last_hidden_start = (seq_len - 1) * hidden_dim;
let last_hidden = &normed[last_hidden_start..last_hidden_start + hidden_dim];
let mut logits = self.matmul(
last_hidden,
&self.lm_head_weight,
hidden_dim,
self.config.vocab_size,
);
if let Some(ref bias) = self.lm_head_bias {
self.add_bias(&mut logits, bias);
}
emit(SaveTensorStage::LmHead, WHOLE_MODEL_LAYER, &logits)?;
let logits_stats = ActivationStats::from_slice(&logits);
Ok(ForwardTrace {
input_tokens: token_ids.to_vec(),
embed_stats,
layer_activations,
final_norm_stats,
logits_stats,
logits,
})
}
pub fn predict_next(&self, token_ids: &[u32]) -> Result<u32> {
let logits = self.forward(token_ids)?;
let (max_idx, _) = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| RealizarError::InvalidShape {
reason: "Empty logits".to_string(),
})?;
Ok(max_idx as u32)
}
}