impl AprTransformer {
pub fn forward_traced(&self, token_ids: &[u32]) -> Result<ForwardTrace> {
if token_ids.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Token sequence cannot be empty".to_string(),
});
}
let hidden_dim = self.config.hidden_dim;
let intermediate_dim = self.config.intermediate_dim;
let mut hidden = self.embed(token_ids);
let embed_stats = ActivationStats::from_slice(&hidden);
let mut layer_activations = Vec::with_capacity(self.layers.len());
for (layer_idx, layer) in self.layers.iter().enumerate() {
let _q4k_layer = self.q4k_layers.as_ref().and_then(|l| l.get(layer_idx));
let normed = self.layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.config.eps,
);
let attn_norm_stats = ActivationStats::from_slice(&normed);
let qkv_dim = layer.qkv_weight.len() / hidden_dim;
let mut qkv = self.matmul(&normed, &layer.qkv_weight, hidden_dim, qkv_dim);
if let Some(ref bias) = layer.qkv_bias {
self.add_bias(&mut qkv, bias);
}
let qkv_stats = ActivationStats::from_slice(&qkv);
let seq_len = token_ids.len();
let head_dim = hidden_dim / self.config.num_heads;
let num_kv_heads = self.config.num_kv_heads;
let kv_dim = num_kv_heads * head_dim;
let group_size = self.config.num_heads / num_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut q_all = Vec::with_capacity(seq_len * hidden_dim);
let mut k_all = Vec::with_capacity(seq_len * kv_dim);
let mut v_all = Vec::with_capacity(seq_len * kv_dim);
for s in 0..seq_len {
let qkv_start = s * qkv_dim;
let mut q_pos = qkv[qkv_start..qkv_start + hidden_dim].to_vec();
let mut k_pos =
qkv[qkv_start + hidden_dim..qkv_start + hidden_dim + kv_dim].to_vec();
let v_pos =
&qkv[qkv_start + hidden_dim + kv_dim..qkv_start + hidden_dim + 2 * kv_dim];
self.apply_rope_f32(&mut q_pos, s, self.config.num_heads, head_dim);
self.apply_rope_f32(&mut k_pos, s, num_kv_heads, head_dim);
q_all.extend_from_slice(&q_pos);
k_all.extend_from_slice(&k_pos);
v_all.extend_from_slice(v_pos);
}
let mut attn_out = vec![0.0f32; seq_len * hidden_dim];
for head in 0..self.config.num_heads {
let kv_head = head / group_size;
let q_head_offset = head * head_dim;
let kv_head_offset = kv_head * head_dim;
for i in 0..seq_len {
let mut scores = Vec::with_capacity(i + 1);
let q_start = i * hidden_dim + q_head_offset;
for j in 0..=i {
let k_start = j * kv_dim + kv_head_offset;
let mut score = 0.0f32;
for d in 0..head_dim {
score += q_all[q_start + d] * k_all[k_start + d];
}
scores.push(score * scale);
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> =
scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let probs: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
let out_start = i * hidden_dim + q_head_offset;
for (j, &p) in probs.iter().enumerate() {
let v_start = j * kv_dim + kv_head_offset;
for d in 0..head_dim {
attn_out[out_start + d] += p * v_all[v_start + d];
}
}
}
}
let mut attn_output =
self.matmul(&attn_out, &layer.attn_output_weight, hidden_dim, hidden_dim);
if let Some(ref bias) = layer.attn_output_bias {
self.add_bias(&mut attn_output, bias);
}
let attn_out_stats = ActivationStats::from_slice(&attn_output);
for i in 0..hidden.len() {
hidden[i] += attn_output[i];
}
let ffn_input = if let Some(ref norm_weight) = layer.ffn_norm_weight {
let normed = self.layer_norm(
&hidden,
norm_weight,
layer.ffn_norm_bias.as_deref(),
self.config.eps,
);
normed
} else {
hidden.clone()
};
let ffn_norm_stats = ActivationStats::from_slice(&ffn_input);
let ffn_output = if let Some(ref gate_weight) = layer.ffn_gate_weight {
let gate = self.matmul(&ffn_input, gate_weight, hidden_dim, intermediate_dim);
let up = self.matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
let mut ffn_hidden = Vec::with_capacity(gate.len());
for (g, u) in gate.iter().zip(up.iter()) {
let silu_g = g / (1.0 + (-g).exp());
ffn_hidden.push(silu_g * u);
}
let mut out = self.matmul(
&ffn_hidden,
&layer.ffn_down_weight,
intermediate_dim,
hidden_dim,
);
if let Some(ref bias) = layer.ffn_down_bias {
self.add_bias(&mut out, bias);
}
out
} else {
let mut ffn_hidden = self.matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
if let Some(ref bias) = layer.ffn_up_bias {
self.add_bias(&mut ffn_hidden, bias);
}
for h in &mut ffn_hidden {
let gelu_approx =
0.5 * *h * (1.0 + (0.797_884_6 * (*h + 0.044_715 * *h * *h * *h)).tanh());
*h = gelu_approx;
}
let mut out = self.matmul(
&ffn_hidden,
&layer.ffn_down_weight,
intermediate_dim,
hidden_dim,
);
if let Some(ref bias) = layer.ffn_down_bias {
self.add_bias(&mut out, bias);
}
out
};
let ffn_out_stats = ActivationStats::from_slice(&ffn_output);
for i in 0..hidden.len() {
hidden[i] += ffn_output[i];
}
let output_stats = ActivationStats::from_slice(&hidden);
layer_activations.push(LayerActivation {
layer_idx,
attn_norm_stats,
qkv_stats,
attn_out_stats,
ffn_norm_stats,
ffn_out_stats,
output_stats,
});
}
let normed = self.layer_norm(
&hidden,
&self.output_norm_weight,
self.output_norm_bias.as_deref(),
self.config.eps,
);
let final_norm_stats = ActivationStats::from_slice(&normed);
let seq_len = token_ids.len();
let last_hidden_start = (seq_len - 1) * hidden_dim;
let last_hidden = &normed[last_hidden_start..last_hidden_start + hidden_dim];
let mut logits = self.matmul(
last_hidden,
&self.lm_head_weight,
hidden_dim,
self.config.vocab_size,
);
if let Some(ref bias) = self.lm_head_bias {
self.add_bias(&mut logits, bias);
}
let logits_stats = ActivationStats::from_slice(&logits);
Ok(ForwardTrace {
input_tokens: token_ids.to_vec(),
embed_stats,
layer_activations,
final_norm_stats,
logits_stats,
logits,
})
}
pub fn predict_next(&self, token_ids: &[u32]) -> Result<u32> {
let logits = self.forward(token_ids)?;
let (max_idx, _) = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| RealizarError::InvalidShape {
reason: "Empty logits".to_string(),
})?;
Ok(max_idx as u32)
}
}