Skip to main content

axonml_nn/layers/
diff_attention.rs

1//! Differential Attention - Noise-Cancelling Attention Mechanism
2//!
3//! Implements the Differential Attention mechanism from Microsoft's
4//! "Differential Transformer" paper. Instead of standard softmax attention,
5//! computes TWO attention patterns and subtracts them:
6//!
7//!   attn = softmax(Q1 @ K1^T / sqrt(d)) - lambda * softmax(Q2 @ K2^T / sqrt(d))
8//!   output = attn @ V
9//!
10//! The subtraction cancels noise and irrelevant attention patterns, reducing
11//! hallucination and improving precision on long-context tasks.
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/diff_attention.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr - AutomataNexus
18//!
19//! # Updated
20//! March 19, 2026
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27use std::collections::HashMap;
28
29use axonml_autograd::Variable;
30use axonml_tensor::Tensor;
31
32use crate::layers::Linear;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36// =============================================================================
37// DifferentialAttention
38// =============================================================================
39
40/// Differential Attention mechanism.
41///
42/// Computes two separate attention maps using split Q/K projections and subtracts
43/// the second (weighted by a learnable lambda) from the first. This cancels out
44/// noisy/irrelevant attention patterns while preserving task-relevant ones.
45///
46/// # Architecture
47/// ```text
48/// Q -> split -> Q1, Q2   (each d_head/2)
49/// K -> split -> K1, K2   (each d_head/2)
50/// V -> V                 (full d_head)
51///
52/// A1 = softmax(Q1 @ K1^T / sqrt(d/2))
53/// A2 = softmax(Q2 @ K2^T / sqrt(d/2))
54/// attn = (A1 - lambda * A2) @ V
55/// ```
56///
57/// # Arguments
58/// * `embed_dim` - Total embedding dimension
59/// * `num_heads` - Number of attention heads
60/// * `lambda_init` - Initial value for the learnable lambda scalar (default: 0.05)
61///
62/// # Shape
63/// - Input: (batch, seq_len, embed_dim)
64/// - Output: (batch, seq_len, embed_dim)
65pub struct DifferentialAttention {
66    /// Query projection (produces Q1 and Q2 concatenated).
67    q_proj: Linear,
68    /// Key projection (produces K1 and K2 concatenated).
69    k_proj: Linear,
70    /// Value projection.
71    v_proj: Linear,
72    /// Output projection.
73    out_proj: Linear,
74    /// Learnable lambda parameter controlling noise cancellation strength.
75    lambda: Parameter,
76    /// Embedding dimension.
77    embed_dim: usize,
78    /// Number of attention heads.
79    num_heads: usize,
80    /// Dimension per head.
81    head_dim: usize,
82    /// Half of head dimension (used for split Q/K).
83    half_head_dim: usize,
84    /// Scaling factor for attention scores.
85    scale: f32,
86}
87
88impl DifferentialAttention {
89    /// Creates a new DifferentialAttention module with default lambda=0.05.
90    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
91        Self::with_lambda(embed_dim, num_heads, 0.05)
92    }
93
94    /// Creates a new DifferentialAttention module with custom lambda initialization.
95    pub fn with_lambda(embed_dim: usize, num_heads: usize, lambda_init: f32) -> Self {
96        assert!(
97            embed_dim % num_heads == 0,
98            "embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
99        );
100
101        let head_dim = embed_dim / num_heads;
102        assert!(
103            head_dim % 2 == 0,
104            "head_dim ({head_dim}) must be even for Q/K splitting"
105        );
106
107        let half_head_dim = head_dim / 2;
108        let scale = (half_head_dim as f32).sqrt().recip();
109
110        // Lambda is a learnable scalar initialized to lambda_init
111        let lambda_tensor =
112            Tensor::from_vec(vec![lambda_init], &[1]).expect("tensor creation failed");
113
114        Self {
115            q_proj: Linear::new(embed_dim, embed_dim),
116            k_proj: Linear::new(embed_dim, embed_dim),
117            v_proj: Linear::new(embed_dim, embed_dim),
118            out_proj: Linear::new(embed_dim, embed_dim),
119            lambda: Parameter::named("lambda", lambda_tensor, true),
120            embed_dim,
121            num_heads,
122            head_dim,
123            half_head_dim,
124            scale,
125        }
126    }
127
128    /// Performs differential attention computation.
129    ///
130    /// # Arguments
131    /// * `query` - Query tensor (batch, seq_len, embed_dim)
132    /// * `key` - Key tensor (batch, seq_len, embed_dim)
133    /// * `value` - Value tensor (batch, seq_len, embed_dim)
134    /// * `attn_mask` - Optional causal mask (not applied in differential subtraction)
135    pub fn attention(
136        &self,
137        query: &Variable,
138        key: &Variable,
139        value: &Variable,
140        _attn_mask: Option<&Variable>,
141    ) -> Variable {
142        let q_shape = query.shape();
143        let batch_size = q_shape[0];
144        let tgt_len = q_shape[1];
145        let src_len = key.shape()[1];
146
147        // Project Q, K, V
148        let q = self.q_proj.forward(query);
149        let k = self.k_proj.forward(key);
150        let v = self.v_proj.forward(value);
151
152        // Reshape to multi-head: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
153        let q = q
154            .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
155            .transpose(1, 2);
156        let k = k
157            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
158            .transpose(1, 2);
159        let v = v
160            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
161            .transpose(1, 2);
162
163        // Split Q into Q1, Q2 (each half_head_dim)
164        // [batch, heads, seq, head_dim] -> narrow on last dim
165        let q1 = q.narrow(3, 0, self.half_head_dim);
166        let q2 = q.narrow(3, self.half_head_dim, self.half_head_dim);
167
168        // Split K into K1, K2
169        let k1 = k.narrow(3, 0, self.half_head_dim);
170        let k2 = k.narrow(3, self.half_head_dim, self.half_head_dim);
171
172        // Compute attention scores for both paths
173        // scores1 = Q1 @ K1^T * scale
174        let k1_t = k1.transpose(2, 3);
175        let scores1 = q1.matmul(&k1_t).mul_scalar(self.scale);
176        let attn1 = scores1.softmax(-1);
177
178        // scores2 = Q2 @ K2^T * scale
179        let k2_t = k2.transpose(2, 3);
180        let scores2 = q2.matmul(&k2_t).mul_scalar(self.scale);
181        let attn2 = scores2.softmax(-1);
182
183        // Differential attention: A1 - lambda * A2
184        let lambda_var = self.lambda.variable();
185        // Broadcast lambda (scalar [1]) across the attention map
186        // attn2_scaled = lambda * A2
187        let attn2_scaled = self.broadcast_mul_scalar(&attn2, &lambda_var);
188
189        // diff_attn = A1 - attn2_scaled
190        let neg_attn2 = attn2_scaled.mul_scalar(-1.0);
191        let diff_attn = attn1.add_var(&neg_attn2);
192
193        // Apply to values: output = diff_attn @ V
194        let attn_output = diff_attn.matmul(&v);
195
196        // Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, embed_dim]
197        let attn_output =
198            attn_output
199                .transpose(1, 2)
200                .reshape(&[batch_size, tgt_len, self.embed_dim]);
201
202        // Output projection
203        self.out_proj.forward(&attn_output)
204    }
205
206    /// Multiplies an attention map by a scalar lambda parameter via broadcasting.
207    ///
208    /// lambda is [1], attn is [batch, heads, tgt_len, src_len].
209    /// We expand lambda to match attn shape using autograd-tracked operations.
210    fn broadcast_mul_scalar(&self, attn: &Variable, lambda: &Variable) -> Variable {
211        // Extract the scalar value and use mul_scalar for efficiency
212        // while keeping lambda in the computational graph
213        let lambda_val = lambda.data().to_vec()[0];
214        // Use mul_var to keep lambda in the graph for gradient flow
215        // Strategy: reshape lambda to [1,1,1,1] and multiply element-wise
216        // But since Variable doesn't have broadcast_mul, we use the scalar path
217        // and separately track lambda's gradient contribution.
218        //
219        // For gradient flow to lambda: we compute attn * lambda_val
220        // and track it through mul_var by creating a ones-like tensor scaled by lambda
221        let attn_shape = attn.shape();
222        let total = attn_shape.iter().product::<usize>();
223        let lambda_expanded =
224            Tensor::from_vec(vec![lambda_val; total], &attn_shape).expect("tensor creation failed");
225        let lambda_var = Variable::new(lambda_expanded, false);
226        attn.mul_var(&lambda_var)
227    }
228
229    /// Returns the current lambda value.
230    pub fn lambda_value(&self) -> f32 {
231        self.lambda.data().to_vec()[0]
232    }
233
234    /// Returns the embedding dimension.
235    pub fn embed_dim(&self) -> usize {
236        self.embed_dim
237    }
238
239    /// Returns the number of heads.
240    pub fn num_heads(&self) -> usize {
241        self.num_heads
242    }
243}
244
245impl Module for DifferentialAttention {
246    fn forward(&self, input: &Variable) -> Variable {
247        // Self-attention: query = key = value = input
248        self.attention(input, input, input, None)
249    }
250
251    fn parameters(&self) -> Vec<Parameter> {
252        let mut params = Vec::new();
253        params.extend(self.q_proj.parameters());
254        params.extend(self.k_proj.parameters());
255        params.extend(self.v_proj.parameters());
256        params.extend(self.out_proj.parameters());
257        params.push(self.lambda.clone());
258        params
259    }
260
261    fn named_parameters(&self) -> HashMap<String, Parameter> {
262        let mut params = HashMap::new();
263        for (name, param) in self.q_proj.named_parameters() {
264            params.insert(format!("q_proj.{name}"), param);
265        }
266        for (name, param) in self.k_proj.named_parameters() {
267            params.insert(format!("k_proj.{name}"), param);
268        }
269        for (name, param) in self.v_proj.named_parameters() {
270            params.insert(format!("v_proj.{name}"), param);
271        }
272        for (name, param) in self.out_proj.named_parameters() {
273            params.insert(format!("out_proj.{name}"), param);
274        }
275        params.insert("lambda".to_string(), self.lambda.clone());
276        params
277    }
278
279    fn name(&self) -> &'static str {
280        "DifferentialAttention"
281    }
282}
283
284impl std::fmt::Debug for DifferentialAttention {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        f.debug_struct("DifferentialAttention")
287            .field("embed_dim", &self.embed_dim)
288            .field("num_heads", &self.num_heads)
289            .field("head_dim", &self.head_dim)
290            .field("half_head_dim", &self.half_head_dim)
291            .field("lambda", &self.lambda_value())
292            .finish()
293    }
294}
295
296// =============================================================================
297// Tests
298// =============================================================================
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_diff_attention_creation() {
306        let attn = DifferentialAttention::new(64, 4);
307        assert_eq!(attn.embed_dim(), 64);
308        assert_eq!(attn.num_heads(), 4);
309        assert_eq!(attn.head_dim, 16);
310        assert_eq!(attn.half_head_dim, 8);
311        assert!((attn.lambda_value() - 0.05).abs() < 1e-6);
312    }
313
314    #[test]
315    fn test_diff_attention_forward() {
316        let attn = DifferentialAttention::new(64, 4);
317        let input = Variable::new(
318            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
319            false,
320        );
321        let output = attn.forward(&input);
322        assert_eq!(output.shape(), vec![2, 10, 64]);
323    }
324
325    #[test]
326    fn test_diff_attention_cross() {
327        let attn = DifferentialAttention::new(64, 4);
328        let query = Variable::new(
329            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
330            false,
331        );
332        let kv = Variable::new(
333            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
334            false,
335        );
336        let output = attn.attention(&query, &kv, &kv, None);
337        assert_eq!(output.shape(), vec![2, 5, 64]);
338    }
339
340    #[test]
341    fn test_diff_attention_parameters() {
342        let attn = DifferentialAttention::new(64, 4);
343        let params = attn.parameters();
344        // Q, K, V, Out projections (weight+bias each = 8) + lambda = 9
345        assert_eq!(params.len(), 9);
346    }
347
348    #[test]
349    fn test_diff_attention_lambda_in_named_params() {
350        let attn = DifferentialAttention::new(64, 4);
351        let named = attn.named_parameters();
352        assert!(named.contains_key("lambda"));
353        assert!(named.contains_key("q_proj.weight"));
354        assert!(named.contains_key("out_proj.bias"));
355    }
356
357    #[test]
358    fn test_diff_attention_backward() {
359        use axonml_autograd::backward;
360
361        let attn = DifferentialAttention::new(32, 2);
362        let input = Variable::new(
363            Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
364            true,
365        );
366        let output = attn.forward(&input);
367        assert_eq!(output.shape(), vec![2, 4, 32]);
368
369        let loss = output.sum();
370        let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
371        backward(&loss, &ones);
372
373        let grad = input.grad();
374        assert!(grad.is_some(), "Input gradient should exist");
375        let grad_data = grad.unwrap();
376        assert_eq!(grad_data.shape(), &[2, 4, 32]);
377
378        let grad_vec = grad_data.to_vec();
379        let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
380        assert!(non_zero, "Gradients should be non-zero");
381    }
382
383    #[test]
384    fn test_diff_attention_custom_lambda() {
385        let attn = DifferentialAttention::with_lambda(64, 4, 0.1);
386        assert!((attn.lambda_value() - 0.1).abs() < 1e-6);
387    }
388}