1use std::collections::HashMap;
28
29use axonml_autograd::Variable;
30use axonml_tensor::Tensor;
31
32use crate::layers::Linear;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36pub struct DifferentialAttention {
66 q_proj: Linear,
68 k_proj: Linear,
70 v_proj: Linear,
72 out_proj: Linear,
74 lambda: Parameter,
76 embed_dim: usize,
78 num_heads: usize,
80 head_dim: usize,
82 half_head_dim: usize,
84 scale: f32,
86}
87
88impl DifferentialAttention {
89 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
91 Self::with_lambda(embed_dim, num_heads, 0.05)
92 }
93
94 pub fn with_lambda(embed_dim: usize, num_heads: usize, lambda_init: f32) -> Self {
96 assert!(
97 embed_dim % num_heads == 0,
98 "embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
99 );
100
101 let head_dim = embed_dim / num_heads;
102 assert!(
103 head_dim % 2 == 0,
104 "head_dim ({head_dim}) must be even for Q/K splitting"
105 );
106
107 let half_head_dim = head_dim / 2;
108 let scale = (half_head_dim as f32).sqrt().recip();
109
110 let lambda_tensor =
112 Tensor::from_vec(vec![lambda_init], &[1]).expect("tensor creation failed");
113
114 Self {
115 q_proj: Linear::new(embed_dim, embed_dim),
116 k_proj: Linear::new(embed_dim, embed_dim),
117 v_proj: Linear::new(embed_dim, embed_dim),
118 out_proj: Linear::new(embed_dim, embed_dim),
119 lambda: Parameter::named("lambda", lambda_tensor, true),
120 embed_dim,
121 num_heads,
122 head_dim,
123 half_head_dim,
124 scale,
125 }
126 }
127
128 pub fn attention(
136 &self,
137 query: &Variable,
138 key: &Variable,
139 value: &Variable,
140 _attn_mask: Option<&Variable>,
141 ) -> Variable {
142 let q_shape = query.shape();
143 let batch_size = q_shape[0];
144 let tgt_len = q_shape[1];
145 let src_len = key.shape()[1];
146
147 let q = self.q_proj.forward(query);
149 let k = self.k_proj.forward(key);
150 let v = self.v_proj.forward(value);
151
152 let q = q
154 .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
155 .transpose(1, 2);
156 let k = k
157 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
158 .transpose(1, 2);
159 let v = v
160 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
161 .transpose(1, 2);
162
163 let q1 = q.narrow(3, 0, self.half_head_dim);
166 let q2 = q.narrow(3, self.half_head_dim, self.half_head_dim);
167
168 let k1 = k.narrow(3, 0, self.half_head_dim);
170 let k2 = k.narrow(3, self.half_head_dim, self.half_head_dim);
171
172 let k1_t = k1.transpose(2, 3);
175 let scores1 = q1.matmul(&k1_t).mul_scalar(self.scale);
176 let attn1 = scores1.softmax(-1);
177
178 let k2_t = k2.transpose(2, 3);
180 let scores2 = q2.matmul(&k2_t).mul_scalar(self.scale);
181 let attn2 = scores2.softmax(-1);
182
183 let lambda_var = self.lambda.variable();
185 let attn2_scaled = self.broadcast_mul_scalar(&attn2, &lambda_var);
188
189 let neg_attn2 = attn2_scaled.mul_scalar(-1.0);
191 let diff_attn = attn1.add_var(&neg_attn2);
192
193 let attn_output = diff_attn.matmul(&v);
195
196 let attn_output =
198 attn_output
199 .transpose(1, 2)
200 .reshape(&[batch_size, tgt_len, self.embed_dim]);
201
202 self.out_proj.forward(&attn_output)
204 }
205
206 fn broadcast_mul_scalar(&self, attn: &Variable, lambda: &Variable) -> Variable {
211 let lambda_val = lambda.data().to_vec()[0];
214 let attn_shape = attn.shape();
222 let total = attn_shape.iter().product::<usize>();
223 let lambda_expanded =
224 Tensor::from_vec(vec![lambda_val; total], &attn_shape).expect("tensor creation failed");
225 let lambda_var = Variable::new(lambda_expanded, false);
226 attn.mul_var(&lambda_var)
227 }
228
229 pub fn lambda_value(&self) -> f32 {
231 self.lambda.data().to_vec()[0]
232 }
233
234 pub fn embed_dim(&self) -> usize {
236 self.embed_dim
237 }
238
239 pub fn num_heads(&self) -> usize {
241 self.num_heads
242 }
243}
244
245impl Module for DifferentialAttention {
246 fn forward(&self, input: &Variable) -> Variable {
247 self.attention(input, input, input, None)
249 }
250
251 fn parameters(&self) -> Vec<Parameter> {
252 let mut params = Vec::new();
253 params.extend(self.q_proj.parameters());
254 params.extend(self.k_proj.parameters());
255 params.extend(self.v_proj.parameters());
256 params.extend(self.out_proj.parameters());
257 params.push(self.lambda.clone());
258 params
259 }
260
261 fn named_parameters(&self) -> HashMap<String, Parameter> {
262 let mut params = HashMap::new();
263 for (name, param) in self.q_proj.named_parameters() {
264 params.insert(format!("q_proj.{name}"), param);
265 }
266 for (name, param) in self.k_proj.named_parameters() {
267 params.insert(format!("k_proj.{name}"), param);
268 }
269 for (name, param) in self.v_proj.named_parameters() {
270 params.insert(format!("v_proj.{name}"), param);
271 }
272 for (name, param) in self.out_proj.named_parameters() {
273 params.insert(format!("out_proj.{name}"), param);
274 }
275 params.insert("lambda".to_string(), self.lambda.clone());
276 params
277 }
278
279 fn name(&self) -> &'static str {
280 "DifferentialAttention"
281 }
282}
283
284impl std::fmt::Debug for DifferentialAttention {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 f.debug_struct("DifferentialAttention")
287 .field("embed_dim", &self.embed_dim)
288 .field("num_heads", &self.num_heads)
289 .field("head_dim", &self.head_dim)
290 .field("half_head_dim", &self.half_head_dim)
291 .field("lambda", &self.lambda_value())
292 .finish()
293 }
294}
295
296#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_diff_attention_creation() {
306 let attn = DifferentialAttention::new(64, 4);
307 assert_eq!(attn.embed_dim(), 64);
308 assert_eq!(attn.num_heads(), 4);
309 assert_eq!(attn.head_dim, 16);
310 assert_eq!(attn.half_head_dim, 8);
311 assert!((attn.lambda_value() - 0.05).abs() < 1e-6);
312 }
313
314 #[test]
315 fn test_diff_attention_forward() {
316 let attn = DifferentialAttention::new(64, 4);
317 let input = Variable::new(
318 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
319 false,
320 );
321 let output = attn.forward(&input);
322 assert_eq!(output.shape(), vec![2, 10, 64]);
323 }
324
325 #[test]
326 fn test_diff_attention_cross() {
327 let attn = DifferentialAttention::new(64, 4);
328 let query = Variable::new(
329 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
330 false,
331 );
332 let kv = Variable::new(
333 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
334 false,
335 );
336 let output = attn.attention(&query, &kv, &kv, None);
337 assert_eq!(output.shape(), vec![2, 5, 64]);
338 }
339
340 #[test]
341 fn test_diff_attention_parameters() {
342 let attn = DifferentialAttention::new(64, 4);
343 let params = attn.parameters();
344 assert_eq!(params.len(), 9);
346 }
347
348 #[test]
349 fn test_diff_attention_lambda_in_named_params() {
350 let attn = DifferentialAttention::new(64, 4);
351 let named = attn.named_parameters();
352 assert!(named.contains_key("lambda"));
353 assert!(named.contains_key("q_proj.weight"));
354 assert!(named.contains_key("out_proj.bias"));
355 }
356
357 #[test]
358 fn test_diff_attention_backward() {
359 use axonml_autograd::backward;
360
361 let attn = DifferentialAttention::new(32, 2);
362 let input = Variable::new(
363 Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
364 true,
365 );
366 let output = attn.forward(&input);
367 assert_eq!(output.shape(), vec![2, 4, 32]);
368
369 let loss = output.sum();
370 let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
371 backward(&loss, &ones);
372
373 let grad = input.grad();
374 assert!(grad.is_some(), "Input gradient should exist");
375 let grad_data = grad.unwrap();
376 assert_eq!(grad_data.shape(), &[2, 4, 32]);
377
378 let grad_vec = grad_data.to_vec();
379 let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
380 assert!(non_zero, "Gradients should be non-zero");
381 }
382
383 #[test]
384 fn test_diff_attention_custom_lambda() {
385 let attn = DifferentialAttention::with_lambda(64, 4, 0.1);
386 assert!((attn.lambda_value() - 0.1).abs() < 1e-6);
387 }
388}