1use std::collections::HashMap;
29
30use axonml_autograd::Variable;
31use axonml_tensor::Tensor;
32
33use crate::layers::Linear;
34use crate::module::Module;
35use crate::parameter::Parameter;
36
37pub struct DifferentialAttention {
67 q_proj: Linear,
69 k_proj: Linear,
71 v_proj: Linear,
73 out_proj: Linear,
75 lambda: Parameter,
77 embed_dim: usize,
79 num_heads: usize,
81 head_dim: usize,
83 half_head_dim: usize,
85 scale: f32,
87}
88
89impl DifferentialAttention {
90 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
92 Self::with_lambda(embed_dim, num_heads, 0.05)
93 }
94
95 pub fn with_lambda(embed_dim: usize, num_heads: usize, lambda_init: f32) -> Self {
97 assert!(
98 embed_dim % num_heads == 0,
99 "embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
100 );
101
102 let head_dim = embed_dim / num_heads;
103 assert!(
104 head_dim % 2 == 0,
105 "head_dim ({head_dim}) must be even for Q/K splitting"
106 );
107
108 let half_head_dim = head_dim / 2;
109 let scale = (half_head_dim as f32).sqrt().recip();
110
111 let lambda_tensor =
113 Tensor::from_vec(vec![lambda_init], &[1]).expect("tensor creation failed");
114
115 Self {
116 q_proj: Linear::new(embed_dim, embed_dim),
117 k_proj: Linear::new(embed_dim, embed_dim),
118 v_proj: Linear::new(embed_dim, embed_dim),
119 out_proj: Linear::new(embed_dim, embed_dim),
120 lambda: Parameter::named("lambda", lambda_tensor, true),
121 embed_dim,
122 num_heads,
123 head_dim,
124 half_head_dim,
125 scale,
126 }
127 }
128
129 pub fn attention(
137 &self,
138 query: &Variable,
139 key: &Variable,
140 value: &Variable,
141 _attn_mask: Option<&Variable>,
142 ) -> Variable {
143 let q_shape = query.shape();
144 let batch_size = q_shape[0];
145 let tgt_len = q_shape[1];
146 let src_len = key.shape()[1];
147
148 let q = self.q_proj.forward(query);
150 let k = self.k_proj.forward(key);
151 let v = self.v_proj.forward(value);
152
153 let q = q
155 .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
156 .transpose(1, 2);
157 let k = k
158 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
159 .transpose(1, 2);
160 let v = v
161 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
162 .transpose(1, 2);
163
164 let q1 = q.narrow(3, 0, self.half_head_dim);
167 let q2 = q.narrow(3, self.half_head_dim, self.half_head_dim);
168
169 let k1 = k.narrow(3, 0, self.half_head_dim);
171 let k2 = k.narrow(3, self.half_head_dim, self.half_head_dim);
172
173 let k1_t = k1.transpose(2, 3);
176 let scores1 = q1.matmul(&k1_t).mul_scalar(self.scale);
177 let attn1 = scores1.softmax(-1);
178
179 let k2_t = k2.transpose(2, 3);
181 let scores2 = q2.matmul(&k2_t).mul_scalar(self.scale);
182 let attn2 = scores2.softmax(-1);
183
184 let lambda_var = self.lambda.variable();
186 let attn2_scaled = self.broadcast_mul_scalar(&attn2, &lambda_var);
189
190 let neg_attn2 = attn2_scaled.mul_scalar(-1.0);
192 let diff_attn = attn1.add_var(&neg_attn2);
193
194 let attn_output = diff_attn.matmul(&v);
196
197 let attn_output =
199 attn_output
200 .transpose(1, 2)
201 .reshape(&[batch_size, tgt_len, self.embed_dim]);
202
203 self.out_proj.forward(&attn_output)
205 }
206
207 fn broadcast_mul_scalar(&self, attn: &Variable, lambda: &Variable) -> Variable {
212 let lambda_val = lambda.data().to_vec()[0];
215 let attn_shape = attn.shape();
223 let total = attn_shape.iter().product::<usize>();
224 let lambda_expanded =
225 Tensor::from_vec(vec![lambda_val; total], &attn_shape).expect("tensor creation failed");
226 let lambda_var = Variable::new(lambda_expanded, false);
227 attn.mul_var(&lambda_var)
228 }
229
230 pub fn lambda_value(&self) -> f32 {
232 self.lambda.data().to_vec()[0]
233 }
234
235 pub fn embed_dim(&self) -> usize {
237 self.embed_dim
238 }
239
240 pub fn num_heads(&self) -> usize {
242 self.num_heads
243 }
244}
245
246impl Module for DifferentialAttention {
247 fn forward(&self, input: &Variable) -> Variable {
248 self.attention(input, input, input, None)
250 }
251
252 fn parameters(&self) -> Vec<Parameter> {
253 let mut params = Vec::new();
254 params.extend(self.q_proj.parameters());
255 params.extend(self.k_proj.parameters());
256 params.extend(self.v_proj.parameters());
257 params.extend(self.out_proj.parameters());
258 params.push(self.lambda.clone());
259 params
260 }
261
262 fn named_parameters(&self) -> HashMap<String, Parameter> {
263 let mut params = HashMap::new();
264 for (name, param) in self.q_proj.named_parameters() {
265 params.insert(format!("q_proj.{name}"), param);
266 }
267 for (name, param) in self.k_proj.named_parameters() {
268 params.insert(format!("k_proj.{name}"), param);
269 }
270 for (name, param) in self.v_proj.named_parameters() {
271 params.insert(format!("v_proj.{name}"), param);
272 }
273 for (name, param) in self.out_proj.named_parameters() {
274 params.insert(format!("out_proj.{name}"), param);
275 }
276 params.insert("lambda".to_string(), self.lambda.clone());
277 params
278 }
279
280 fn name(&self) -> &'static str {
281 "DifferentialAttention"
282 }
283}
284
285impl std::fmt::Debug for DifferentialAttention {
286 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287 f.debug_struct("DifferentialAttention")
288 .field("embed_dim", &self.embed_dim)
289 .field("num_heads", &self.num_heads)
290 .field("head_dim", &self.head_dim)
291 .field("half_head_dim", &self.half_head_dim)
292 .field("lambda", &self.lambda_value())
293 .finish()
294 }
295}
296
297#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_diff_attention_creation() {
307 let attn = DifferentialAttention::new(64, 4);
308 assert_eq!(attn.embed_dim(), 64);
309 assert_eq!(attn.num_heads(), 4);
310 assert_eq!(attn.head_dim, 16);
311 assert_eq!(attn.half_head_dim, 8);
312 assert!((attn.lambda_value() - 0.05).abs() < 1e-6);
313 }
314
315 #[test]
316 fn test_diff_attention_forward() {
317 let attn = DifferentialAttention::new(64, 4);
318 let input = Variable::new(
319 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
320 false,
321 );
322 let output = attn.forward(&input);
323 assert_eq!(output.shape(), vec![2, 10, 64]);
324 }
325
326 #[test]
327 fn test_diff_attention_cross() {
328 let attn = DifferentialAttention::new(64, 4);
329 let query = Variable::new(
330 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
331 false,
332 );
333 let kv = Variable::new(
334 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
335 false,
336 );
337 let output = attn.attention(&query, &kv, &kv, None);
338 assert_eq!(output.shape(), vec![2, 5, 64]);
339 }
340
341 #[test]
342 fn test_diff_attention_parameters() {
343 let attn = DifferentialAttention::new(64, 4);
344 let params = attn.parameters();
345 assert_eq!(params.len(), 9);
347 }
348
349 #[test]
350 fn test_diff_attention_lambda_in_named_params() {
351 let attn = DifferentialAttention::new(64, 4);
352 let named = attn.named_parameters();
353 assert!(named.contains_key("lambda"));
354 assert!(named.contains_key("q_proj.weight"));
355 assert!(named.contains_key("out_proj.bias"));
356 }
357
358 #[test]
359 fn test_diff_attention_backward() {
360 use axonml_autograd::backward;
361
362 let attn = DifferentialAttention::new(32, 2);
363 let input = Variable::new(
364 Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
365 true,
366 );
367 let output = attn.forward(&input);
368 assert_eq!(output.shape(), vec![2, 4, 32]);
369
370 let loss = output.sum();
371 let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
372 backward(&loss, &ones);
373
374 let grad = input.grad();
375 assert!(grad.is_some(), "Input gradient should exist");
376 let grad_data = grad.unwrap();
377 assert_eq!(grad_data.shape(), &[2, 4, 32]);
378
379 let grad_vec = grad_data.to_vec();
380 let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
381 assert!(non_zero, "Gradients should be non-zero");
382 }
383
384 #[test]
385 fn test_diff_attention_custom_lambda() {
386 let attn = DifferentialAttention::with_lambda(64, 4, 0.1);
387 assert!((attn.lambda_value() - 0.1).abs() < 1e-6);
388 }
389}