Skip to main content

fib_quant/kv/
attention_ref.rs

1use crate::{metrics, FibQuantError, Result};
2
3use super::quality::{kl_divergence, topk_agreement, total_variation, KvAttentionQualityReportV1};
4
5/// Compute reference attention logits for one query and a flat key matrix.
6pub fn reference_attention_logits(
7    query: &[f32],
8    keys: &[f32],
9    head_dim: usize,
10) -> Result<Vec<f32>> {
11    if head_dim == 0 || query.len() != head_dim || keys.len() % head_dim != 0 {
12        return Err(FibQuantError::CorruptPayload(
13            "invalid attention logit dimensions".into(),
14        ));
15    }
16    check_finite(query)?;
17    check_finite(keys)?;
18    let scale = (head_dim as f64).sqrt();
19    let mut logits = Vec::with_capacity(keys.len() / head_dim);
20    for key in keys.chunks_exact(head_dim) {
21        let dot = query
22            .iter()
23            .zip(key)
24            .map(|(a, b)| f64::from(*a) * f64::from(*b))
25            .sum::<f64>();
26        logits.push((dot / scale) as f32);
27    }
28    Ok(logits)
29}
30
31/// Compute reference value aggregation from attention probabilities.
32pub fn reference_value_aggregation(
33    probabilities: &[f32],
34    values: &[f32],
35    head_dim: usize,
36) -> Result<Vec<f32>> {
37    if head_dim == 0
38        || values.len() % head_dim != 0
39        || values.len() / head_dim != probabilities.len()
40    {
41        return Err(FibQuantError::CorruptPayload(
42            "invalid value aggregation dimensions".into(),
43        ));
44    }
45    check_finite(probabilities)?;
46    check_finite(values)?;
47    let mut out = vec![0.0f64; head_dim];
48    for (prob, value) in probabilities.iter().zip(values.chunks_exact(head_dim)) {
49        for (idx, channel) in value.iter().enumerate() {
50            out[idx] += f64::from(*prob) * f64::from(*channel);
51        }
52    }
53    Ok(out.into_iter().map(|value| value as f32).collect())
54}
55
56/// Compare raw and decoded synthetic attention fixtures.
57pub fn compare_attention_fixture(
58    query: &[f32],
59    raw_keys: &[f32],
60    decoded_keys: &[f32],
61    raw_values: &[f32],
62    decoded_values: &[f32],
63    head_dim: usize,
64    top_k: usize,
65) -> Result<KvAttentionQualityReportV1> {
66    let raw_logits = reference_attention_logits(query, raw_keys, head_dim)?;
67    let decoded_logits = reference_attention_logits(query, decoded_keys, head_dim)?;
68    let raw_probs = softmax(&raw_logits)?;
69    let decoded_probs = softmax(&decoded_logits)?;
70    let raw_agg = reference_value_aggregation(&raw_probs, raw_values, head_dim)?;
71    let decoded_agg = reference_value_aggregation(&decoded_probs, decoded_values, head_dim)?;
72    let mut report = KvAttentionQualityReportV1::reconstruction_only(raw_keys, decoded_keys)?;
73    report.key_logit_mse = Some(metrics::mse(&raw_logits, &decoded_logits)?);
74    report.attention_tv = Some(total_variation(&raw_probs, &decoded_probs)?);
75    report.attention_kl = Some(kl_divergence(&raw_probs, &decoded_probs)?);
76    report.topk_attention_agreement = Some(topk_agreement(&raw_logits, &decoded_logits, top_k)?);
77    report.value_aggregation_mse = Some(metrics::mse(&raw_agg, &decoded_agg)?);
78    report.validate()?;
79    Ok(report)
80}
81
82fn softmax(logits: &[f32]) -> Result<Vec<f32>> {
83    if logits.is_empty() {
84        return Err(FibQuantError::ZeroDimension);
85    }
86    check_finite(logits)?;
87    let max = logits
88        .iter()
89        .copied()
90        .fold(f32::NEG_INFINITY, |acc, value| acc.max(value));
91    let mut sum = 0.0f64;
92    let mut out = Vec::with_capacity(logits.len());
93    for value in logits {
94        let exp = f64::from(*value - max).exp();
95        sum += exp;
96        out.push(exp);
97    }
98    if !sum.is_finite() || sum <= 0.0 {
99        return Err(FibQuantError::NumericalFailure(
100            "attention softmax underflow".into(),
101        ));
102    }
103    Ok(out.into_iter().map(|value| (value / sum) as f32).collect())
104}
105
106fn check_finite(values: &[f32]) -> Result<()> {
107    if values.iter().any(|value| !value.is_finite()) {
108        return Err(FibQuantError::CorruptPayload(
109            "attention input contains non-finite value".into(),
110        ));
111    }
112    Ok(())
113}