use burn::prelude::*;
use crate::attn_res_op::AttnResOp;
pub struct Phase1Result<B: Backend> {
pub outputs: Vec<Tensor<B, 3>>,
pub max_logits: Vec<Tensor<B, 2>>,
pub sum_exp: Vec<Tensor<B, 2>>,
}
pub fn phase1_batched<B: Backend>(
ops: &[&AttnResOp<B>],
blocks: &[Tensor<B, 3>],
) -> Phase1Result<B> {
if blocks.is_empty() {
let s = ops.len();
return Phase1Result {
outputs: Vec::with_capacity(s),
max_logits: Vec::with_capacity(s),
sum_exp: Vec::with_capacity(s),
};
}
let v = Tensor::stack(blocks.to_vec(), 0);
let mut outputs = Vec::with_capacity(ops.len());
let mut max_logits = Vec::with_capacity(ops.len());
let mut sum_exp = Vec::with_capacity(ops.len());
for op in ops {
let k = op.norm.forward_4d(v.clone());
let w = op
.pseudo_query
.val()
.unsqueeze_dim::<2>(0)
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0); let logits = (k * w).sum_dim(3).squeeze_dim::<3>(3);
let max_l = logits.clone().max_dim(0).squeeze_dim::<2>(0); let shifted = logits.clone() - max_l.clone().unsqueeze_dim::<3>(0); let exp_shifted = shifted.exp(); let sum_e = exp_shifted.clone().sum_dim(0).squeeze_dim::<2>(0);
let alpha = exp_shifted.unsqueeze_dim::<4>(3); let weighted = (v.clone() * alpha).sum_dim(0).squeeze_dim::<3>(0);
outputs.push(weighted);
max_logits.push(max_l);
sum_exp.push(sum_e);
}
Phase1Result {
outputs,
max_logits,
sum_exp,
}
}
pub fn normalize_inter_output<B: Backend>(
inter_output: Tensor<B, 3>,
inter_sum_exp: Tensor<B, 2>,
) -> Tensor<B, 3> {
inter_output / inter_sum_exp.unsqueeze_dim::<3>(2)
}
pub fn online_softmax_merge<B: Backend>(
inter_output: Tensor<B, 3>,
inter_max: Tensor<B, 2>,
inter_sum_exp: Tensor<B, 2>,
intra_logit: Tensor<B, 2>,
intra_value: Tensor<B, 3>,
) -> Tensor<B, 3> {
let m = inter_max.clone().max_pair(intra_logit.clone());
let inter_scale = (inter_max - m.clone()).exp(); let inter_scaled_sum = inter_sum_exp * inter_scale.clone(); let inter_scaled_out = inter_output * inter_scale.unsqueeze_dim::<3>(2);
let intra_scale = (intra_logit - m).exp(); let intra_scaled_out = intra_value * intra_scale.clone().unsqueeze_dim::<3>(2);
let total = inter_scaled_sum + intra_scale; (inter_scaled_out + intra_scaled_out) / total.unsqueeze_dim::<3>(2) }
pub fn compute_intra_logit<B: Backend>(op: &AttnResOp<B>, partial: &Tensor<B, 3>) -> Tensor<B, 2> {
let normed = op.norm.forward(partial.clone()); let w = op
.pseudo_query
.val()
.unsqueeze_dim::<2>(0)
.unsqueeze_dim::<3>(0); (normed * w).sum_dim(2).squeeze_dim::<2>(2) }
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::Distribution;
type TestBackend = NdArray;
#[test]
fn test_phase1_output_count() {
let device = Default::default();
let config = crate::config::AttnResConfig::new(32, 4, 2);
let op1 = config.init_op::<TestBackend>(&device);
let op2 = config.init_op::<TestBackend>(&device);
let blocks = vec![
Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device),
Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device),
];
let result = phase1_batched(&[&op1, &op2], &blocks);
assert_eq!(result.outputs.len(), 2);
assert_eq!(result.max_logits.len(), 2);
assert_eq!(result.sum_exp.len(), 2);
assert_eq!(result.outputs[0].dims(), [1, 8, 32]);
}
#[test]
fn test_compute_intra_logit_shape() {
let device = Default::default();
let config = crate::config::AttnResConfig::new(32, 4, 2);
let op = config.init_op::<TestBackend>(&device);
let partial = Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device);
let logit = compute_intra_logit(&op, &partial);
assert_eq!(logit.dims(), [1, 8]);
}
#[test]
fn test_phase1_inter_only_matches_blocks_only_forward() {
let device = Default::default();
let config = crate::config::AttnResConfig::new(32, 4, 2);
let op = config.init_op::<TestBackend>(&device);
let blocks = vec![
Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device),
Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device),
];
let standard = op.forward_optional_partial(&blocks, None);
let phase1 = phase1_batched(&[&op], &blocks);
let inter_only =
normalize_inter_output(phase1.outputs[0].clone(), phase1.sum_exp[0].clone());
let diff: f32 = (standard - inter_only).abs().max().into_scalar();
assert!(
diff < 1e-5,
"Phase 1 inter-only output should match blocks-only forward, diff={diff}"
);
}
}