1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20#[cfg(feature = "cuda")]
21use axonml_autograd::functions::FusedAttentionBackward;
22#[cfg(feature = "cuda")]
23use axonml_autograd::grad_fn::GradFn;
24use axonml_tensor::Tensor;
25
26use crate::layers::Linear;
27use crate::module::Module;
28use crate::parameter::Parameter;
29
30pub struct MultiHeadAttention {
50 q_proj: Linear,
52 k_proj: Linear,
54 v_proj: Linear,
56 out_proj: Linear,
58 embed_dim: usize,
60 num_heads: usize,
62 head_dim: usize,
64 scale: f32,
66 batch_first: bool,
68}
69
70impl MultiHeadAttention {
71 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
73 Self::with_options(embed_dim, num_heads, 0.0, true)
74 }
75
76 pub fn with_options(
78 embed_dim: usize,
79 num_heads: usize,
80 _dropout: f32,
81 batch_first: bool,
82 ) -> Self {
83 assert!(
84 embed_dim % num_heads == 0,
85 "embed_dim must be divisible by num_heads"
86 );
87
88 let head_dim = embed_dim / num_heads;
89 let scale = (head_dim as f32).sqrt().recip();
90
91 Self {
92 q_proj: Linear::new(embed_dim, embed_dim),
93 k_proj: Linear::new(embed_dim, embed_dim),
94 v_proj: Linear::new(embed_dim, embed_dim),
95 out_proj: Linear::new(embed_dim, embed_dim),
96 embed_dim,
97 num_heads,
98 head_dim,
99 scale,
100 batch_first,
101 }
102 }
103
104 #[allow(dead_code)]
113 fn expand_mask(
114 mask: &Variable,
115 batch_size: usize,
116 num_heads: usize,
117 tgt_len: usize,
118 src_len: usize,
119 ) -> Variable {
120 let mask_shape = mask.shape();
121 let target = [batch_size, num_heads, tgt_len, src_len];
122
123 if mask_shape == target {
124 return mask.clone();
125 }
126
127 if mask_shape.len() == 2 {
129 let reshaped = mask.reshape(&[1, 1, tgt_len, src_len]);
130 return reshaped.expand(&target);
131 }
132
133 if mask_shape.len() == 4 && mask_shape[1] == 1 {
135 return mask.expand(&target);
136 }
137
138 if mask_shape.len() == 4 && mask_shape[0] == 1 && mask_shape[1] == 1 {
140 return mask.expand(&target);
141 }
142
143 mask.clone()
145 }
146
147 pub fn attention(
149 &self,
150 query: &Variable,
151 key: &Variable,
152 value: &Variable,
153 attn_mask: Option<&Variable>,
154 ) -> Variable {
155 let q_shape = query.shape();
156 let (batch_size, tgt_len, _) = if self.batch_first {
157 (q_shape[0], q_shape[1], q_shape[2])
158 } else {
159 (q_shape[1], q_shape[0], q_shape[2])
160 };
161 let src_len = if self.batch_first {
162 key.shape()[1]
163 } else {
164 key.shape()[0]
165 };
166
167 let q = self.q_proj.forward(query);
169 let k = self.k_proj.forward(key);
170 let v = self.v_proj.forward(value);
171
172 let q = q
175 .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
176 .transpose(1, 2);
177 let k = k
178 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
179 .transpose(1, 2);
180 let v = v
181 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
182 .transpose(1, 2);
183
184 #[cfg(feature = "cuda")]
190 if q.data().device().is_gpu() && attn_mask.is_none() {
191 let is_training = axonml_autograd::no_grad::is_grad_enabled();
192 let q_tensor = q.data();
193 let k_tensor = k.data();
194 let v_tensor = v.data();
195
196 if let Some(attn_out) = q_tensor.fused_attention_cuda(
197 &k_tensor, &v_tensor, self.scale,
198 false, ) {
200 let attn_output = if is_training
201 && (q.requires_grad() || k.requires_grad() || v.requires_grad())
202 {
203 let backward = FusedAttentionBackward::new(
205 q.grad_fn().cloned(),
206 k.grad_fn().cloned(),
207 v.grad_fn().cloned(),
208 q_tensor,
209 k_tensor,
210 v_tensor,
211 attn_out.clone(),
212 self.scale,
213 false,
214 );
215 Variable::from_operation(attn_out, GradFn::new(backward), true)
216 } else {
217 Variable::new(attn_out, false)
218 };
219 let attn_output =
220 attn_output
221 .transpose(1, 2)
222 .reshape(&[batch_size, tgt_len, self.embed_dim]);
223 return self.out_proj.forward(&attn_output);
224 }
225 }
227
228 let k_t = k.transpose(2, 3);
231 let scores = q.matmul(&k_t).mul_scalar(self.scale);
233
234 let scores = if let Some(mask) = attn_mask {
238 let mask_shape = mask.shape();
239 let mask_data = mask.data();
240 let scores_shape = scores.shape();
241 let total = scores_shape.iter().product::<usize>();
242
243 #[cfg(feature = "cuda")]
246 if scores.data().device().is_gpu() {
247 let mask_gpu = if mask_data.device().is_gpu() {
249 mask_data.clone()
250 } else {
251 mask_data.to_device(scores.data().device()).unwrap()
252 };
253
254 if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
255 &scores_shape,
256 batch_size,
257 self.num_heads,
258 tgt_len,
259 src_len,
260 ) {
261 let additive_mask = Variable::new(expanded_tensor, false);
262 return self.finish_attention(
263 scores.add_var(&additive_mask),
264 &v,
265 batch_size,
266 tgt_len,
267 );
268 }
269 }
271
272 let mask_vec = mask_data.to_vec();
274 let additive: Vec<f32> = mask_vec
275 .iter()
276 .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
277 .collect();
278
279 let mut expanded = vec![0.0f32; total];
280
281 if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
282 for b in 0..batch_size {
284 for h in 0..self.num_heads {
285 for i in 0..tgt_len {
286 for j in 0..src_len {
287 let idx = b * self.num_heads * tgt_len * src_len
288 + h * tgt_len * src_len
289 + i * src_len
290 + j;
291 expanded[idx] = additive[i * src_len + j];
292 }
293 }
294 }
295 }
296 } else if mask_shape.len() == 2
297 && mask_shape[0] == batch_size
298 && mask_shape[1] == src_len
299 {
300 for b in 0..batch_size {
302 for h in 0..self.num_heads {
303 for i in 0..tgt_len {
304 for j in 0..src_len {
305 let idx = b * self.num_heads * tgt_len * src_len
306 + h * tgt_len * src_len
307 + i * src_len
308 + j;
309 expanded[idx] = additive[b * src_len + j];
310 }
311 }
312 }
313 }
314 } else {
315 for (i, val) in expanded.iter_mut().enumerate() {
317 *val = additive[i % additive.len()];
318 }
319 }
320
321 let mut additive_tensor =
322 Tensor::from_vec(expanded, &scores_shape).expect("tensor creation failed");
323 let scores_device = scores.data().device();
324 if scores_device.is_gpu() {
325 additive_tensor = additive_tensor
326 .to_device(scores_device)
327 .expect("device transfer failed");
328 }
329 let additive_mask = Variable::new(additive_tensor, false);
330 scores.add_var(&additive_mask)
331 } else {
332 scores
333 };
334
335 self.finish_attention(scores, &v, batch_size, tgt_len)
336 }
337
338 fn finish_attention(
341 &self,
342 scores: Variable,
343 v: &Variable,
344 batch_size: usize,
345 tgt_len: usize,
346 ) -> Variable {
347 let attn_weights = scores.softmax(-1);
348 let attn_output = attn_weights.matmul(v);
349 let attn_output =
350 attn_output
351 .transpose(1, 2)
352 .reshape(&[batch_size, tgt_len, self.embed_dim]);
353 self.out_proj.forward(&attn_output)
354 }
355}
356
357impl Module for MultiHeadAttention {
358 fn forward(&self, input: &Variable) -> Variable {
359 self.attention(input, input, input, None)
361 }
362
363 fn parameters(&self) -> Vec<Parameter> {
364 let mut params = Vec::new();
365 params.extend(self.q_proj.parameters());
366 params.extend(self.k_proj.parameters());
367 params.extend(self.v_proj.parameters());
368 params.extend(self.out_proj.parameters());
369 params
370 }
371
372 fn named_parameters(&self) -> HashMap<String, Parameter> {
373 let mut params = HashMap::new();
374 for (name, param) in self.q_proj.named_parameters() {
375 params.insert(format!("q_proj.{name}"), param);
376 }
377 for (name, param) in self.k_proj.named_parameters() {
378 params.insert(format!("k_proj.{name}"), param);
379 }
380 for (name, param) in self.v_proj.named_parameters() {
381 params.insert(format!("v_proj.{name}"), param);
382 }
383 for (name, param) in self.out_proj.named_parameters() {
384 params.insert(format!("out_proj.{name}"), param);
385 }
386 params
387 }
388
389 fn name(&self) -> &'static str {
390 "MultiHeadAttention"
391 }
392}
393
394pub struct CrossAttention {
411 mha: MultiHeadAttention,
413}
414
415impl CrossAttention {
416 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
418 Self {
419 mha: MultiHeadAttention::new(embed_dim, num_heads),
420 }
421 }
422
423 pub fn with_options(
425 embed_dim: usize,
426 num_heads: usize,
427 dropout: f32,
428 batch_first: bool,
429 ) -> Self {
430 Self {
431 mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
432 }
433 }
434
435 pub fn cross_attention(
442 &self,
443 query: &Variable,
444 memory: &Variable,
445 attn_mask: Option<&Variable>,
446 ) -> Variable {
447 self.mha.attention(query, memory, memory, attn_mask)
448 }
449
450 pub fn embed_dim(&self) -> usize {
452 self.mha.embed_dim
453 }
454
455 pub fn num_heads(&self) -> usize {
457 self.mha.num_heads
458 }
459}
460
461impl Module for CrossAttention {
462 fn forward(&self, input: &Variable) -> Variable {
463 self.mha.forward(input)
466 }
467
468 fn parameters(&self) -> Vec<Parameter> {
469 self.mha.parameters()
470 }
471
472 fn named_parameters(&self) -> HashMap<String, Parameter> {
473 let mut params = HashMap::new();
474 for (name, param) in self.mha.named_parameters() {
475 params.insert(format!("mha.{name}"), param);
476 }
477 params
478 }
479
480 fn name(&self) -> &'static str {
481 "CrossAttention"
482 }
483}
484
485pub fn scaled_dot_product_attention_fused(
518 q: &Tensor<f32>,
519 k: &Tensor<f32>,
520 v: &Tensor<f32>,
521 scale: f32,
522 is_causal: bool,
523) -> Tensor<f32> {
524 #[cfg(feature = "cuda")]
526 if q.device().is_gpu() {
527 if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
528 return result;
529 }
530 }
531
532 let shape = q.shape();
534 let batch_size = shape[0];
535 let num_heads = shape[1];
536 let tgt_len = shape[2];
537 let head_dim = shape[3];
538 let src_len = k.shape()[2];
539
540 let q_data = q.to_vec();
541 let k_data = k.to_vec();
542 let v_data = v.to_vec();
543
544 let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
545
546 for b in 0..batch_size {
547 for h in 0..num_heads {
548 for i in 0..tgt_len {
549 let mut scores = vec![0.0f32; src_len];
551 let mut max_score = f32::NEG_INFINITY;
552
553 for j in 0..src_len {
554 if is_causal && j > i {
555 scores[j] = f32::NEG_INFINITY;
556 continue;
557 }
558 let mut score = 0.0f32;
559 for d in 0..head_dim {
560 let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
561 let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
562 score += q_data[q_idx] * k_data[k_idx];
563 }
564 score *= scale;
565 scores[j] = score;
566 if score > max_score {
567 max_score = score;
568 }
569 }
570
571 let mut sum_exp = 0.0f32;
573 for s in &mut scores {
574 if *s > f32::NEG_INFINITY {
575 *s = (*s - max_score).exp();
576 sum_exp += *s;
577 } else {
578 *s = 0.0;
579 }
580 }
581 let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
582
583 for d in 0..head_dim {
585 let mut val = 0.0f32;
586 for j in 0..src_len {
587 let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
588 val += scores[j] * v_data[v_idx];
589 }
590 let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
591 output[out_idx] = val * inv_sum;
592 }
593 }
594 }
595 }
596
597 Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim])
598 .expect("tensor creation failed")
599}
600
601#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_multihead_attention_creation() {
611 let mha = MultiHeadAttention::new(512, 8);
612 assert_eq!(mha.embed_dim, 512);
613 assert_eq!(mha.num_heads, 8);
614 assert_eq!(mha.head_dim, 64);
615 }
616
617 #[test]
618 fn test_multihead_attention_forward() {
619 let mha = MultiHeadAttention::new(64, 4);
620 let input = Variable::new(
621 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
622 false,
623 );
624 let output = mha.forward(&input);
625 assert_eq!(output.shape(), vec![2, 10, 64]);
626 }
627
628 #[test]
629 fn test_cross_attention() {
630 let mha = MultiHeadAttention::new(64, 4);
631 let query = Variable::new(
632 Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
633 false,
634 );
635 let key_value = Variable::new(
636 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
637 false,
638 );
639 let output = mha.attention(&query, &key_value, &key_value, None);
640 assert_eq!(output.shape(), vec![2, 5, 64]);
641 }
642
643 #[test]
644 fn test_multihead_attention_parameters() {
645 let mha = MultiHeadAttention::new(64, 4);
646 let params = mha.parameters();
647 assert_eq!(params.len(), 8);
649 }
650
651 #[test]
652 fn test_cross_attention_creation() {
653 let ca = CrossAttention::new(256, 8);
654 assert_eq!(ca.embed_dim(), 256);
655 assert_eq!(ca.num_heads(), 8);
656 }
657
658 #[test]
659 fn test_cross_attention_forward() {
660 let ca = CrossAttention::new(64, 4);
661 let query = Variable::new(
663 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
664 false,
665 );
666 let memory = Variable::new(
668 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
669 false,
670 );
671 let output = ca.cross_attention(&query, &memory, None);
672 assert_eq!(output.shape(), vec![2, 5, 64]);
673 }
674
675 #[test]
676 fn test_cross_attention_self_attention_fallback() {
677 let ca = CrossAttention::new(64, 4);
678 let input = Variable::new(
679 Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
680 false,
681 );
682 let output = ca.forward(&input);
684 assert_eq!(output.shape(), vec![2, 8, 64]);
685 }
686
687 #[test]
688 fn test_cross_attention_parameters() {
689 let ca = CrossAttention::new(64, 4);
690 let params = ca.parameters();
691 assert_eq!(params.len(), 8); let named = ca.named_parameters();
693 assert!(named.contains_key("mha.q_proj.weight"));
694 assert!(named.contains_key("mha.out_proj.bias"));
695 }
696
697 #[test]
698 fn test_fused_attention_cpu() {
699 let batch = 2;
701 let heads = 4;
702 let seq = 8;
703 let dim = 16;
704 let scale = 1.0 / (dim as f32).sqrt();
705
706 let q = Tensor::from_vec(
707 vec![0.1; batch * heads * seq * dim],
708 &[batch, heads, seq, dim],
709 )
710 .unwrap();
711 let k = Tensor::from_vec(
712 vec![0.1; batch * heads * seq * dim],
713 &[batch, heads, seq, dim],
714 )
715 .unwrap();
716 let v = Tensor::from_vec(
717 vec![0.5; batch * heads * seq * dim],
718 &[batch, heads, seq, dim],
719 )
720 .unwrap();
721
722 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
723 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
724
725 let out_vec = out.to_vec();
727 for val in &out_vec {
728 assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
729 }
730 }
731
732 #[test]
733 fn test_fused_attention_causal() {
734 let batch = 1;
735 let heads = 1;
736 let seq = 4;
737 let dim = 4;
738 let scale = 1.0 / (dim as f32).sqrt();
739
740 let q = Tensor::from_vec(
742 vec![0.1; batch * heads * seq * dim],
743 &[batch, heads, seq, dim],
744 )
745 .unwrap();
746 let k = Tensor::from_vec(
747 vec![0.1; batch * heads * seq * dim],
748 &[batch, heads, seq, dim],
749 )
750 .unwrap();
751 let v = Tensor::from_vec(
752 vec![
753 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
754 ],
755 &[batch, heads, seq, dim],
756 )
757 .unwrap();
758
759 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
760 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
761
762 let out_vec = out.to_vec();
764 assert!(
765 (out_vec[0] - 1.0).abs() < 1e-5,
766 "row 0, col 0 should be 1.0"
767 );
768 assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
769 }
770
771 #[test]
772 fn test_multihead_attention_backward_cpu() {
773 use axonml_autograd::backward;
775
776 let mha = MultiHeadAttention::new(32, 4);
777 let input = Variable::new(
778 Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
779 true,
780 );
781 let output = mha.forward(&input);
782 assert_eq!(output.shape(), vec![2, 4, 32]);
783
784 let loss = output.sum();
786 let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
787 backward(&loss, &ones);
788
789 let grad = input.grad();
791 assert!(grad.is_some(), "Input gradient should exist");
792 let grad_data = grad.unwrap();
793 assert_eq!(grad_data.shape(), &[2, 4, 32]);
794
795 let grad_vec = grad_data.to_vec();
797 let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
798 assert!(non_zero, "Gradients should be non-zero");
799 }
800
801 #[test]
802 fn test_fused_attention_backward_cpu() {
803 use axonml_autograd::functions::FusedAttentionBackward;
805 use axonml_autograd::grad_fn::GradientFunction;
806
807 let batch = 1;
808 let heads = 2;
809 let seq = 4;
810 let dim = 8;
811 let scale = 1.0 / (dim as f32).sqrt();
812
813 let q_data: Vec<f32> = (0..batch * heads * seq * dim)
815 .map(|i| ((i as f32) * 0.01).sin())
816 .collect();
817 let k_data: Vec<f32> = (0..batch * heads * seq * dim)
818 .map(|i| ((i as f32) * 0.02).cos())
819 .collect();
820 let v_data: Vec<f32> = (0..batch * heads * seq * dim)
821 .map(|i| ((i as f32) * 0.03).sin() + 0.5)
822 .collect();
823
824 let q =
825 Tensor::from_vec(q_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
826 let k =
827 Tensor::from_vec(k_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
828 let v =
829 Tensor::from_vec(v_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
830
831 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
833 assert_eq!(output.shape(), &[batch, heads, seq, dim]);
834
835 let backward_fn = FusedAttentionBackward::new(
837 None,
838 None,
839 None,
840 q.clone(),
841 k.clone(),
842 v.clone(),
843 output.clone(),
844 scale,
845 false,
846 );
847
848 let grad_output = Tensor::from_vec(
850 vec![1.0f32; batch * heads * seq * dim],
851 &[batch, heads, seq, dim],
852 )
853 .unwrap();
854
855 let grads = backward_fn.apply(&grad_output);
856 assert_eq!(grads.len(), 3);
857
858 let gq = grads[0].as_ref().expect("grad_Q should exist");
859 let gk = grads[1].as_ref().expect("grad_K should exist");
860 let gv = grads[2].as_ref().expect("grad_V should exist");
861
862 assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
863 assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
864 assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
865
866 for val in gq
868 .to_vec()
869 .iter()
870 .chain(gk.to_vec().iter())
871 .chain(gv.to_vec().iter())
872 {
873 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
874 }
875
876 let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
878 assert!(gv_nonzero, "grad_V should be non-zero");
879 }
880
881 #[test]
882 fn test_fused_attention_backward_causal_cpu() {
883 use axonml_autograd::functions::FusedAttentionBackward;
885 use axonml_autograd::grad_fn::GradientFunction;
886
887 let batch = 1;
888 let heads = 1;
889 let seq = 4;
890 let dim = 4;
891 let scale = 1.0 / (dim as f32).sqrt();
892
893 let q = Tensor::from_vec(
894 vec![0.1f32; batch * heads * seq * dim],
895 &[batch, heads, seq, dim],
896 )
897 .unwrap();
898 let k = Tensor::from_vec(
899 vec![0.2f32; batch * heads * seq * dim],
900 &[batch, heads, seq, dim],
901 )
902 .unwrap();
903 let v = Tensor::from_vec(
904 vec![0.5f32; batch * heads * seq * dim],
905 &[batch, heads, seq, dim],
906 )
907 .unwrap();
908
909 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
910
911 let backward_fn = FusedAttentionBackward::new(
912 None,
913 None,
914 None,
915 q.clone(),
916 k.clone(),
917 v.clone(),
918 output.clone(),
919 scale,
920 true,
921 );
922
923 let grad_output = Tensor::from_vec(
924 vec![1.0f32; batch * heads * seq * dim],
925 &[batch, heads, seq, dim],
926 )
927 .unwrap();
928
929 let grads = backward_fn.apply(&grad_output);
930 assert_eq!(grads.len(), 3);
931
932 let gq = grads[0].as_ref().unwrap();
933 let gk = grads[1].as_ref().unwrap();
934 let gv = grads[2].as_ref().unwrap();
935
936 for val in gq
938 .to_vec()
939 .iter()
940 .chain(gk.to_vec().iter())
941 .chain(gv.to_vec().iter())
942 {
943 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
944 }
945 }
946}