fib_quant/kv/
attention_ref.rs1use crate::{metrics, FibQuantError, Result};
2
3use super::quality::{kl_divergence, topk_agreement, total_variation, KvAttentionQualityReportV1};
4
5pub fn reference_attention_logits(
7 query: &[f32],
8 keys: &[f32],
9 head_dim: usize,
10) -> Result<Vec<f32>> {
11 if head_dim == 0 || query.len() != head_dim || keys.len() % head_dim != 0 {
12 return Err(FibQuantError::CorruptPayload(
13 "invalid attention logit dimensions".into(),
14 ));
15 }
16 check_finite(query)?;
17 check_finite(keys)?;
18 let scale = (head_dim as f64).sqrt();
19 let mut logits = Vec::with_capacity(keys.len() / head_dim);
20 for key in keys.chunks_exact(head_dim) {
21 let dot = query
22 .iter()
23 .zip(key)
24 .map(|(a, b)| f64::from(*a) * f64::from(*b))
25 .sum::<f64>();
26 logits.push((dot / scale) as f32);
27 }
28 Ok(logits)
29}
30
31pub fn reference_value_aggregation(
33 probabilities: &[f32],
34 values: &[f32],
35 head_dim: usize,
36) -> Result<Vec<f32>> {
37 if head_dim == 0
38 || values.len() % head_dim != 0
39 || values.len() / head_dim != probabilities.len()
40 {
41 return Err(FibQuantError::CorruptPayload(
42 "invalid value aggregation dimensions".into(),
43 ));
44 }
45 check_finite(probabilities)?;
46 check_finite(values)?;
47 let mut out = vec![0.0f64; head_dim];
48 for (prob, value) in probabilities.iter().zip(values.chunks_exact(head_dim)) {
49 for (idx, channel) in value.iter().enumerate() {
50 out[idx] += f64::from(*prob) * f64::from(*channel);
51 }
52 }
53 Ok(out.into_iter().map(|value| value as f32).collect())
54}
55
56pub fn compare_attention_fixture(
58 query: &[f32],
59 raw_keys: &[f32],
60 decoded_keys: &[f32],
61 raw_values: &[f32],
62 decoded_values: &[f32],
63 head_dim: usize,
64 top_k: usize,
65) -> Result<KvAttentionQualityReportV1> {
66 let raw_logits = reference_attention_logits(query, raw_keys, head_dim)?;
67 let decoded_logits = reference_attention_logits(query, decoded_keys, head_dim)?;
68 let raw_probs = softmax(&raw_logits)?;
69 let decoded_probs = softmax(&decoded_logits)?;
70 let raw_agg = reference_value_aggregation(&raw_probs, raw_values, head_dim)?;
71 let decoded_agg = reference_value_aggregation(&decoded_probs, decoded_values, head_dim)?;
72 let mut report = KvAttentionQualityReportV1::reconstruction_only(raw_keys, decoded_keys)?;
73 report.key_logit_mse = Some(metrics::mse(&raw_logits, &decoded_logits)?);
74 report.attention_tv = Some(total_variation(&raw_probs, &decoded_probs)?);
75 report.attention_kl = Some(kl_divergence(&raw_probs, &decoded_probs)?);
76 report.topk_attention_agreement = Some(topk_agreement(&raw_logits, &decoded_logits, top_k)?);
77 report.value_aggregation_mse = Some(metrics::mse(&raw_agg, &decoded_agg)?);
78 report.validate()?;
79 Ok(report)
80}
81
82fn softmax(logits: &[f32]) -> Result<Vec<f32>> {
83 if logits.is_empty() {
84 return Err(FibQuantError::ZeroDimension);
85 }
86 check_finite(logits)?;
87 let max = logits
88 .iter()
89 .copied()
90 .fold(f32::NEG_INFINITY, |acc, value| acc.max(value));
91 let mut sum = 0.0f64;
92 let mut out = Vec::with_capacity(logits.len());
93 for value in logits {
94 let exp = f64::from(*value - max).exp();
95 sum += exp;
96 out.push(exp);
97 }
98 if !sum.is_finite() || sum <= 0.0 {
99 return Err(FibQuantError::NumericalFailure(
100 "attention softmax underflow".into(),
101 ));
102 }
103 Ok(out.into_iter().map(|value| (value / sum) as f32).collect())
104}
105
106fn check_finite(values: &[f32]) -> Result<()> {
107 if values.iter().any(|value| !value.is_finite()) {
108 return Err(FibQuantError::CorruptPayload(
109 "attention input contains non-finite value".into(),
110 ));
111 }
112 Ok(())
113}