Skip to main content

axonml_nn/layers/
diff_attention.rs

1//! Differential Attention - Noise-Cancelling Attention Mechanism
2//!
3//! Implements the Differential Attention mechanism from Microsoft's
4//! "Differential Transformer" paper. Instead of standard softmax attention,
5//! computes TWO attention patterns and subtracts them:
6//!
7//!   attn = softmax(Q1 @ K1^T / sqrt(d)) - lambda * softmax(Q2 @ K2^T / sqrt(d))
8//!   output = attn @ V
9//!
10//! The subtraction cancels noise and irrelevant attention patterns, reducing
11//! hallucination and improving precision on long-context tasks.
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/diff_attention.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr. — AutomataNexus LLC
18//! ORCID: 0009-0005-2158-7060
19//!
20//! # Updated
21//! April 14, 2026 11:15 PM EST
22//!
23//! # Disclaimer
24//! Use at own risk. This software is provided "as is", without warranty of any
25//! kind, express or implied. The author and AutomataNexus shall not be held
26//! liable for any damages arising from the use of this software.
27
28use std::collections::HashMap;
29
30use axonml_autograd::Variable;
31use axonml_tensor::Tensor;
32
33use crate::layers::Linear;
34use crate::module::Module;
35use crate::parameter::Parameter;
36
37// =============================================================================
38// DifferentialAttention
39// =============================================================================
40
41/// Differential Attention mechanism.
42///
43/// Computes two separate attention maps using split Q/K projections and subtracts
44/// the second (weighted by a learnable lambda) from the first. This cancels out
45/// noisy/irrelevant attention patterns while preserving task-relevant ones.
46///
47/// # Architecture
48/// ```text
49/// Q -> split -> Q1, Q2   (each d_head/2)
50/// K -> split -> K1, K2   (each d_head/2)
51/// V -> V                 (full d_head)
52///
53/// A1 = softmax(Q1 @ K1^T / sqrt(d/2))
54/// A2 = softmax(Q2 @ K2^T / sqrt(d/2))
55/// attn = (A1 - lambda * A2) @ V
56/// ```
57///
58/// # Arguments
59/// * `embed_dim` - Total embedding dimension
60/// * `num_heads` - Number of attention heads
61/// * `lambda_init` - Initial value for the learnable lambda scalar (default: 0.05)
62///
63/// # Shape
64/// - Input: (batch, seq_len, embed_dim)
65/// - Output: (batch, seq_len, embed_dim)
66pub struct DifferentialAttention {
67    /// Query projection (produces Q1 and Q2 concatenated).
68    q_proj: Linear,
69    /// Key projection (produces K1 and K2 concatenated).
70    k_proj: Linear,
71    /// Value projection.
72    v_proj: Linear,
73    /// Output projection.
74    out_proj: Linear,
75    /// Learnable lambda parameter controlling noise cancellation strength.
76    lambda: Parameter,
77    /// Embedding dimension.
78    embed_dim: usize,
79    /// Number of attention heads.
80    num_heads: usize,
81    /// Dimension per head.
82    head_dim: usize,
83    /// Half of head dimension (used for split Q/K).
84    half_head_dim: usize,
85    /// Scaling factor for attention scores.
86    scale: f32,
87}
88
89impl DifferentialAttention {
90    /// Creates a new DifferentialAttention module with default lambda=0.05.
91    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
92        Self::with_lambda(embed_dim, num_heads, 0.05)
93    }
94
95    /// Creates a new DifferentialAttention module with custom lambda initialization.
96    pub fn with_lambda(embed_dim: usize, num_heads: usize, lambda_init: f32) -> Self {
97        assert!(
98            embed_dim % num_heads == 0,
99            "embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
100        );
101
102        let head_dim = embed_dim / num_heads;
103        assert!(
104            head_dim % 2 == 0,
105            "head_dim ({head_dim}) must be even for Q/K splitting"
106        );
107
108        let half_head_dim = head_dim / 2;
109        let scale = (half_head_dim as f32).sqrt().recip();
110
111        // Lambda is a learnable scalar initialized to lambda_init
112        let lambda_tensor =
113            Tensor::from_vec(vec![lambda_init], &[1]).expect("tensor creation failed");
114
115        Self {
116            q_proj: Linear::new(embed_dim, embed_dim),
117            k_proj: Linear::new(embed_dim, embed_dim),
118            v_proj: Linear::new(embed_dim, embed_dim),
119            out_proj: Linear::new(embed_dim, embed_dim),
120            lambda: Parameter::named("lambda", lambda_tensor, true),
121            embed_dim,
122            num_heads,
123            head_dim,
124            half_head_dim,
125            scale,
126        }
127    }
128
129    /// Performs differential attention computation.
130    ///
131    /// # Arguments
132    /// * `query` - Query tensor (batch, seq_len, embed_dim)
133    /// * `key` - Key tensor (batch, seq_len, embed_dim)
134    /// * `value` - Value tensor (batch, seq_len, embed_dim)
135    /// * `attn_mask` - Optional causal mask (not applied in differential subtraction)
136    pub fn attention(
137        &self,
138        query: &Variable,
139        key: &Variable,
140        value: &Variable,
141        _attn_mask: Option<&Variable>,
142    ) -> Variable {
143        let q_shape = query.shape();
144        let batch_size = q_shape[0];
145        let tgt_len = q_shape[1];
146        let src_len = key.shape()[1];
147
148        // Project Q, K, V
149        let q = self.q_proj.forward(query);
150        let k = self.k_proj.forward(key);
151        let v = self.v_proj.forward(value);
152
153        // Reshape to multi-head: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
154        let q = q
155            .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
156            .transpose(1, 2);
157        let k = k
158            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
159            .transpose(1, 2);
160        let v = v
161            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
162            .transpose(1, 2);
163
164        // Split Q into Q1, Q2 (each half_head_dim)
165        // [batch, heads, seq, head_dim] -> narrow on last dim
166        let q1 = q.narrow(3, 0, self.half_head_dim);
167        let q2 = q.narrow(3, self.half_head_dim, self.half_head_dim);
168
169        // Split K into K1, K2
170        let k1 = k.narrow(3, 0, self.half_head_dim);
171        let k2 = k.narrow(3, self.half_head_dim, self.half_head_dim);
172
173        // Compute attention scores for both paths
174        // scores1 = Q1 @ K1^T * scale
175        let k1_t = k1.transpose(2, 3);
176        let scores1 = q1.matmul(&k1_t).mul_scalar(self.scale);
177        let attn1 = scores1.softmax(-1);
178
179        // scores2 = Q2 @ K2^T * scale
180        let k2_t = k2.transpose(2, 3);
181        let scores2 = q2.matmul(&k2_t).mul_scalar(self.scale);
182        let attn2 = scores2.softmax(-1);
183
184        // Differential attention: A1 - lambda * A2
185        let lambda_var = self.lambda.variable();
186        // Broadcast lambda (scalar [1]) across the attention map
187        // attn2_scaled = lambda * A2
188        let attn2_scaled = self.broadcast_mul_scalar(&attn2, &lambda_var);
189
190        // diff_attn = A1 - attn2_scaled
191        let neg_attn2 = attn2_scaled.mul_scalar(-1.0);
192        let diff_attn = attn1.add_var(&neg_attn2);
193
194        // Apply to values: output = diff_attn @ V
195        let attn_output = diff_attn.matmul(&v);
196
197        // Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, embed_dim]
198        let attn_output =
199            attn_output
200                .transpose(1, 2)
201                .reshape(&[batch_size, tgt_len, self.embed_dim]);
202
203        // Output projection
204        self.out_proj.forward(&attn_output)
205    }
206
207    /// Multiplies an attention map by a scalar lambda parameter via broadcasting.
208    ///
209    /// lambda is [1], attn is [batch, heads, tgt_len, src_len].
210    /// We expand lambda to match attn shape using autograd-tracked operations.
211    fn broadcast_mul_scalar(&self, attn: &Variable, lambda: &Variable) -> Variable {
212        // Extract the scalar value and use mul_scalar for efficiency
213        // while keeping lambda in the computational graph
214        let lambda_val = lambda.data().to_vec()[0];
215        // Use mul_var to keep lambda in the graph for gradient flow
216        // Strategy: reshape lambda to [1,1,1,1] and multiply element-wise
217        // But since Variable doesn't have broadcast_mul, we use the scalar path
218        // and separately track lambda's gradient contribution.
219        //
220        // For gradient flow to lambda: we compute attn * lambda_val
221        // and track it through mul_var by creating a ones-like tensor scaled by lambda
222        let attn_shape = attn.shape();
223        let total = attn_shape.iter().product::<usize>();
224        let lambda_expanded =
225            Tensor::from_vec(vec![lambda_val; total], &attn_shape).expect("tensor creation failed");
226        let lambda_var = Variable::new(lambda_expanded, false);
227        attn.mul_var(&lambda_var)
228    }
229
230    /// Returns the current lambda value.
231    pub fn lambda_value(&self) -> f32 {
232        self.lambda.data().to_vec()[0]
233    }
234
235    /// Returns the embedding dimension.
236    pub fn embed_dim(&self) -> usize {
237        self.embed_dim
238    }
239
240    /// Returns the number of heads.
241    pub fn num_heads(&self) -> usize {
242        self.num_heads
243    }
244}
245
246impl Module for DifferentialAttention {
247    fn forward(&self, input: &Variable) -> Variable {
248        // Self-attention: query = key = value = input
249        self.attention(input, input, input, None)
250    }
251
252    fn parameters(&self) -> Vec<Parameter> {
253        let mut params = Vec::new();
254        params.extend(self.q_proj.parameters());
255        params.extend(self.k_proj.parameters());
256        params.extend(self.v_proj.parameters());
257        params.extend(self.out_proj.parameters());
258        params.push(self.lambda.clone());
259        params
260    }
261
262    fn named_parameters(&self) -> HashMap<String, Parameter> {
263        let mut params = HashMap::new();
264        for (name, param) in self.q_proj.named_parameters() {
265            params.insert(format!("q_proj.{name}"), param);
266        }
267        for (name, param) in self.k_proj.named_parameters() {
268            params.insert(format!("k_proj.{name}"), param);
269        }
270        for (name, param) in self.v_proj.named_parameters() {
271            params.insert(format!("v_proj.{name}"), param);
272        }
273        for (name, param) in self.out_proj.named_parameters() {
274            params.insert(format!("out_proj.{name}"), param);
275        }
276        params.insert("lambda".to_string(), self.lambda.clone());
277        params
278    }
279
280    fn name(&self) -> &'static str {
281        "DifferentialAttention"
282    }
283}
284
285impl std::fmt::Debug for DifferentialAttention {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        f.debug_struct("DifferentialAttention")
288            .field("embed_dim", &self.embed_dim)
289            .field("num_heads", &self.num_heads)
290            .field("head_dim", &self.head_dim)
291            .field("half_head_dim", &self.half_head_dim)
292            .field("lambda", &self.lambda_value())
293            .finish()
294    }
295}
296
297// =============================================================================
298// Tests
299// =============================================================================
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_diff_attention_creation() {
307        let attn = DifferentialAttention::new(64, 4);
308        assert_eq!(attn.embed_dim(), 64);
309        assert_eq!(attn.num_heads(), 4);
310        assert_eq!(attn.head_dim, 16);
311        assert_eq!(attn.half_head_dim, 8);
312        assert!((attn.lambda_value() - 0.05).abs() < 1e-6);
313    }
314
315    #[test]
316    fn test_diff_attention_forward() {
317        let attn = DifferentialAttention::new(64, 4);
318        let input = Variable::new(
319            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
320            false,
321        );
322        let output = attn.forward(&input);
323        assert_eq!(output.shape(), vec![2, 10, 64]);
324    }
325
326    #[test]
327    fn test_diff_attention_cross() {
328        let attn = DifferentialAttention::new(64, 4);
329        let query = Variable::new(
330            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
331            false,
332        );
333        let kv = Variable::new(
334            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
335            false,
336        );
337        let output = attn.attention(&query, &kv, &kv, None);
338        assert_eq!(output.shape(), vec![2, 5, 64]);
339    }
340
341    #[test]
342    fn test_diff_attention_parameters() {
343        let attn = DifferentialAttention::new(64, 4);
344        let params = attn.parameters();
345        // Q, K, V, Out projections (weight+bias each = 8) + lambda = 9
346        assert_eq!(params.len(), 9);
347    }
348
349    #[test]
350    fn test_diff_attention_lambda_in_named_params() {
351        let attn = DifferentialAttention::new(64, 4);
352        let named = attn.named_parameters();
353        assert!(named.contains_key("lambda"));
354        assert!(named.contains_key("q_proj.weight"));
355        assert!(named.contains_key("out_proj.bias"));
356    }
357
358    #[test]
359    fn test_diff_attention_backward() {
360        use axonml_autograd::backward;
361
362        let attn = DifferentialAttention::new(32, 2);
363        let input = Variable::new(
364            Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
365            true,
366        );
367        let output = attn.forward(&input);
368        assert_eq!(output.shape(), vec![2, 4, 32]);
369
370        let loss = output.sum();
371        let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
372        backward(&loss, &ones);
373
374        let grad = input.grad();
375        assert!(grad.is_some(), "Input gradient should exist");
376        let grad_data = grad.unwrap();
377        assert_eq!(grad_data.shape(), &[2, 4, 32]);
378
379        let grad_vec = grad_data.to_vec();
380        let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
381        assert!(non_zero, "Gradients should be non-zero");
382    }
383
384    #[test]
385    fn test_diff_attention_custom_lambda() {
386        let attn = DifferentialAttention::with_lambda(64, 4, 0.1);
387        assert!((attn.lambda_value() - 0.1).abs() < 1e-6);
388    }
389}