1use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28#[cfg(feature = "cuda")]
29use axonml_autograd::functions::FusedAttentionBackward;
30#[cfg(feature = "cuda")]
31use axonml_autograd::grad_fn::GradFn;
32use axonml_tensor::Tensor;
33
34use crate::layers::Linear;
35use crate::module::Module;
36use crate::parameter::Parameter;
37
38pub struct MultiHeadAttention {
58 q_proj: Linear,
60 k_proj: Linear,
62 v_proj: Linear,
64 out_proj: Linear,
66 embed_dim: usize,
68 num_heads: usize,
70 head_dim: usize,
72 scale: f32,
74 batch_first: bool,
76}
77
78impl MultiHeadAttention {
79 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
81 Self::with_options(embed_dim, num_heads, 0.0, true)
82 }
83
84 pub fn with_options(
86 embed_dim: usize,
87 num_heads: usize,
88 _dropout: f32,
89 batch_first: bool,
90 ) -> Self {
91 assert!(
92 embed_dim % num_heads == 0,
93 "embed_dim must be divisible by num_heads"
94 );
95
96 let head_dim = embed_dim / num_heads;
97 let scale = (head_dim as f32).sqrt().recip();
98
99 Self {
100 q_proj: Linear::new(embed_dim, embed_dim),
101 k_proj: Linear::new(embed_dim, embed_dim),
102 v_proj: Linear::new(embed_dim, embed_dim),
103 out_proj: Linear::new(embed_dim, embed_dim),
104 embed_dim,
105 num_heads,
106 head_dim,
107 scale,
108 batch_first,
109 }
110 }
111
112 #[allow(dead_code)]
121 fn expand_mask(
122 mask: &Variable,
123 batch_size: usize,
124 num_heads: usize,
125 tgt_len: usize,
126 src_len: usize,
127 ) -> Variable {
128 let mask_shape = mask.shape();
129 let target = [batch_size, num_heads, tgt_len, src_len];
130
131 if mask_shape == target {
132 return mask.clone();
133 }
134
135 if mask_shape.len() == 2 {
137 let reshaped = mask.reshape(&[1, 1, tgt_len, src_len]);
138 return reshaped.expand(&target);
139 }
140
141 if mask_shape.len() == 4 && mask_shape[1] == 1 {
143 return mask.expand(&target);
144 }
145
146 if mask_shape.len() == 4 && mask_shape[0] == 1 && mask_shape[1] == 1 {
148 return mask.expand(&target);
149 }
150
151 mask.clone()
153 }
154
155 pub fn attention(
157 &self,
158 query: &Variable,
159 key: &Variable,
160 value: &Variable,
161 attn_mask: Option<&Variable>,
162 ) -> Variable {
163 let q_shape = query.shape();
164 let (batch_size, tgt_len, _) = if self.batch_first {
165 (q_shape[0], q_shape[1], q_shape[2])
166 } else {
167 (q_shape[1], q_shape[0], q_shape[2])
168 };
169 let src_len = if self.batch_first {
170 key.shape()[1]
171 } else {
172 key.shape()[0]
173 };
174
175 let q = self.q_proj.forward(query);
177 let k = self.k_proj.forward(key);
178 let v = self.v_proj.forward(value);
179
180 let q = q
183 .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
184 .transpose(1, 2);
185 let k = k
186 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
187 .transpose(1, 2);
188 let v = v
189 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
190 .transpose(1, 2);
191
192 #[cfg(feature = "cuda")]
198 if q.data().device().is_gpu() && attn_mask.is_none() {
199 let is_training = axonml_autograd::no_grad::is_grad_enabled();
200 let q_tensor = q.data();
201 let k_tensor = k.data();
202 let v_tensor = v.data();
203
204 if let Some(attn_out) = q_tensor.fused_attention_cuda(
205 &k_tensor, &v_tensor, self.scale,
206 false, ) {
208 let attn_output = if is_training
209 && (q.requires_grad() || k.requires_grad() || v.requires_grad())
210 {
211 let backward = FusedAttentionBackward::new(
213 q.grad_fn().cloned(),
214 k.grad_fn().cloned(),
215 v.grad_fn().cloned(),
216 q_tensor,
217 k_tensor,
218 v_tensor,
219 attn_out.clone(),
220 self.scale,
221 false,
222 );
223 Variable::from_operation(attn_out, GradFn::new(backward), true)
224 } else {
225 Variable::new(attn_out, false)
226 };
227 let attn_output =
228 attn_output
229 .transpose(1, 2)
230 .reshape(&[batch_size, tgt_len, self.embed_dim]);
231 return self.out_proj.forward(&attn_output);
232 }
233 }
235
236 let k_t = k.transpose(2, 3);
239 let scores = q.matmul(&k_t).mul_scalar(self.scale);
241
242 let scores = if let Some(mask) = attn_mask {
246 let mask_shape = mask.shape();
247 let mask_data = mask.data();
248 let scores_shape = scores.shape();
249 let total = scores_shape.iter().product::<usize>();
250
251 #[cfg(feature = "cuda")]
254 if scores.data().device().is_gpu() {
255 let mask_gpu = if mask_data.device().is_gpu() {
257 mask_data.clone()
258 } else {
259 mask_data.to_device(scores.data().device()).unwrap()
260 };
261
262 if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
263 &scores_shape,
264 batch_size,
265 self.num_heads,
266 tgt_len,
267 src_len,
268 ) {
269 let additive_mask = Variable::new(expanded_tensor, false);
270 return self.finish_attention(
271 scores.add_var(&additive_mask),
272 &v,
273 batch_size,
274 tgt_len,
275 );
276 }
277 }
279
280 let mask_vec = mask_data.to_vec();
282 let additive: Vec<f32> = mask_vec
283 .iter()
284 .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
285 .collect();
286
287 let mut expanded = vec![0.0f32; total];
288
289 if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
290 for b in 0..batch_size {
292 for h in 0..self.num_heads {
293 for i in 0..tgt_len {
294 for j in 0..src_len {
295 let idx = b * self.num_heads * tgt_len * src_len
296 + h * tgt_len * src_len
297 + i * src_len
298 + j;
299 expanded[idx] = additive[i * src_len + j];
300 }
301 }
302 }
303 }
304 } else if mask_shape.len() == 2
305 && mask_shape[0] == batch_size
306 && mask_shape[1] == src_len
307 {
308 for b in 0..batch_size {
310 for h in 0..self.num_heads {
311 for i in 0..tgt_len {
312 for j in 0..src_len {
313 let idx = b * self.num_heads * tgt_len * src_len
314 + h * tgt_len * src_len
315 + i * src_len
316 + j;
317 expanded[idx] = additive[b * src_len + j];
318 }
319 }
320 }
321 }
322 } else {
323 for (i, val) in expanded.iter_mut().enumerate() {
325 *val = additive[i % additive.len()];
326 }
327 }
328
329 let mut additive_tensor =
330 Tensor::from_vec(expanded, &scores_shape).expect("tensor creation failed");
331 let scores_device = scores.data().device();
332 if scores_device.is_gpu() {
333 additive_tensor = additive_tensor
334 .to_device(scores_device)
335 .expect("device transfer failed");
336 }
337 let additive_mask = Variable::new(additive_tensor, false);
338 scores.add_var(&additive_mask)
339 } else {
340 scores
341 };
342
343 self.finish_attention(scores, &v, batch_size, tgt_len)
344 }
345
346 fn finish_attention(
349 &self,
350 scores: Variable,
351 v: &Variable,
352 batch_size: usize,
353 tgt_len: usize,
354 ) -> Variable {
355 let attn_weights = scores.softmax(-1);
356 let attn_output = attn_weights.matmul(v);
357 let attn_output =
358 attn_output
359 .transpose(1, 2)
360 .reshape(&[batch_size, tgt_len, self.embed_dim]);
361 self.out_proj.forward(&attn_output)
362 }
363}
364
365impl Module for MultiHeadAttention {
366 fn forward(&self, input: &Variable) -> Variable {
367 self.attention(input, input, input, None)
369 }
370
371 fn parameters(&self) -> Vec<Parameter> {
372 let mut params = Vec::new();
373 params.extend(self.q_proj.parameters());
374 params.extend(self.k_proj.parameters());
375 params.extend(self.v_proj.parameters());
376 params.extend(self.out_proj.parameters());
377 params
378 }
379
380 fn named_parameters(&self) -> HashMap<String, Parameter> {
381 let mut params = HashMap::new();
382 for (name, param) in self.q_proj.named_parameters() {
383 params.insert(format!("q_proj.{name}"), param);
384 }
385 for (name, param) in self.k_proj.named_parameters() {
386 params.insert(format!("k_proj.{name}"), param);
387 }
388 for (name, param) in self.v_proj.named_parameters() {
389 params.insert(format!("v_proj.{name}"), param);
390 }
391 for (name, param) in self.out_proj.named_parameters() {
392 params.insert(format!("out_proj.{name}"), param);
393 }
394 params
395 }
396
397 fn name(&self) -> &'static str {
398 "MultiHeadAttention"
399 }
400}
401
402pub struct CrossAttention {
419 mha: MultiHeadAttention,
421}
422
423impl CrossAttention {
424 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
426 Self {
427 mha: MultiHeadAttention::new(embed_dim, num_heads),
428 }
429 }
430
431 pub fn with_options(
433 embed_dim: usize,
434 num_heads: usize,
435 dropout: f32,
436 batch_first: bool,
437 ) -> Self {
438 Self {
439 mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
440 }
441 }
442
443 pub fn cross_attention(
450 &self,
451 query: &Variable,
452 memory: &Variable,
453 attn_mask: Option<&Variable>,
454 ) -> Variable {
455 self.mha.attention(query, memory, memory, attn_mask)
456 }
457
458 pub fn embed_dim(&self) -> usize {
460 self.mha.embed_dim
461 }
462
463 pub fn num_heads(&self) -> usize {
465 self.mha.num_heads
466 }
467}
468
469impl Module for CrossAttention {
470 fn forward(&self, input: &Variable) -> Variable {
471 self.mha.forward(input)
474 }
475
476 fn parameters(&self) -> Vec<Parameter> {
477 self.mha.parameters()
478 }
479
480 fn named_parameters(&self) -> HashMap<String, Parameter> {
481 let mut params = HashMap::new();
482 for (name, param) in self.mha.named_parameters() {
483 params.insert(format!("mha.{name}"), param);
484 }
485 params
486 }
487
488 fn name(&self) -> &'static str {
489 "CrossAttention"
490 }
491}
492
493pub fn scaled_dot_product_attention_fused(
526 q: &Tensor<f32>,
527 k: &Tensor<f32>,
528 v: &Tensor<f32>,
529 scale: f32,
530 is_causal: bool,
531) -> Tensor<f32> {
532 #[cfg(feature = "cuda")]
534 if q.device().is_gpu() {
535 if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
536 return result;
537 }
538 }
539
540 let shape = q.shape();
542 let batch_size = shape[0];
543 let num_heads = shape[1];
544 let tgt_len = shape[2];
545 let head_dim = shape[3];
546 let src_len = k.shape()[2];
547
548 let q_data = q.to_vec();
549 let k_data = k.to_vec();
550 let v_data = v.to_vec();
551
552 let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
553
554 for b in 0..batch_size {
555 for h in 0..num_heads {
556 for i in 0..tgt_len {
557 let mut scores = vec![0.0f32; src_len];
559 let mut max_score = f32::NEG_INFINITY;
560
561 for j in 0..src_len {
562 if is_causal && j > i {
563 scores[j] = f32::NEG_INFINITY;
564 continue;
565 }
566 let mut score = 0.0f32;
567 for d in 0..head_dim {
568 let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
569 let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
570 score += q_data[q_idx] * k_data[k_idx];
571 }
572 score *= scale;
573 scores[j] = score;
574 if score > max_score {
575 max_score = score;
576 }
577 }
578
579 let mut sum_exp = 0.0f32;
581 for s in &mut scores {
582 if *s > f32::NEG_INFINITY {
583 *s = (*s - max_score).exp();
584 sum_exp += *s;
585 } else {
586 *s = 0.0;
587 }
588 }
589 let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
590
591 for d in 0..head_dim {
593 let mut val = 0.0f32;
594 for j in 0..src_len {
595 let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
596 val += scores[j] * v_data[v_idx];
597 }
598 let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
599 output[out_idx] = val * inv_sum;
600 }
601 }
602 }
603 }
604
605 Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim])
606 .expect("tensor creation failed")
607}
608
609#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_multihead_attention_creation() {
619 let mha = MultiHeadAttention::new(512, 8);
620 assert_eq!(mha.embed_dim, 512);
621 assert_eq!(mha.num_heads, 8);
622 assert_eq!(mha.head_dim, 64);
623 }
624
625 #[test]
626 fn test_multihead_attention_forward() {
627 let mha = MultiHeadAttention::new(64, 4);
628 let input = Variable::new(
629 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
630 false,
631 );
632 let output = mha.forward(&input);
633 assert_eq!(output.shape(), vec![2, 10, 64]);
634 }
635
636 #[test]
637 fn test_cross_attention() {
638 let mha = MultiHeadAttention::new(64, 4);
639 let query = Variable::new(
640 Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
641 false,
642 );
643 let key_value = Variable::new(
644 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
645 false,
646 );
647 let output = mha.attention(&query, &key_value, &key_value, None);
648 assert_eq!(output.shape(), vec![2, 5, 64]);
649 }
650
651 #[test]
652 fn test_multihead_attention_parameters() {
653 let mha = MultiHeadAttention::new(64, 4);
654 let params = mha.parameters();
655 assert_eq!(params.len(), 8);
657 }
658
659 #[test]
660 fn test_cross_attention_creation() {
661 let ca = CrossAttention::new(256, 8);
662 assert_eq!(ca.embed_dim(), 256);
663 assert_eq!(ca.num_heads(), 8);
664 }
665
666 #[test]
667 fn test_cross_attention_forward() {
668 let ca = CrossAttention::new(64, 4);
669 let query = Variable::new(
671 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
672 false,
673 );
674 let memory = Variable::new(
676 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
677 false,
678 );
679 let output = ca.cross_attention(&query, &memory, None);
680 assert_eq!(output.shape(), vec![2, 5, 64]);
681 }
682
683 #[test]
684 fn test_cross_attention_self_attention_fallback() {
685 let ca = CrossAttention::new(64, 4);
686 let input = Variable::new(
687 Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
688 false,
689 );
690 let output = ca.forward(&input);
692 assert_eq!(output.shape(), vec![2, 8, 64]);
693 }
694
695 #[test]
696 fn test_cross_attention_parameters() {
697 let ca = CrossAttention::new(64, 4);
698 let params = ca.parameters();
699 assert_eq!(params.len(), 8); let named = ca.named_parameters();
701 assert!(named.contains_key("mha.q_proj.weight"));
702 assert!(named.contains_key("mha.out_proj.bias"));
703 }
704
705 #[test]
706 fn test_fused_attention_cpu() {
707 let batch = 2;
709 let heads = 4;
710 let seq = 8;
711 let dim = 16;
712 let scale = 1.0 / (dim as f32).sqrt();
713
714 let q = Tensor::from_vec(
715 vec![0.1; batch * heads * seq * dim],
716 &[batch, heads, seq, dim],
717 )
718 .unwrap();
719 let k = Tensor::from_vec(
720 vec![0.1; batch * heads * seq * dim],
721 &[batch, heads, seq, dim],
722 )
723 .unwrap();
724 let v = Tensor::from_vec(
725 vec![0.5; batch * heads * seq * dim],
726 &[batch, heads, seq, dim],
727 )
728 .unwrap();
729
730 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
731 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
732
733 let out_vec = out.to_vec();
735 for val in &out_vec {
736 assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
737 }
738 }
739
740 #[test]
741 fn test_fused_attention_causal() {
742 let batch = 1;
743 let heads = 1;
744 let seq = 4;
745 let dim = 4;
746 let scale = 1.0 / (dim as f32).sqrt();
747
748 let q = Tensor::from_vec(
750 vec![0.1; batch * heads * seq * dim],
751 &[batch, heads, seq, dim],
752 )
753 .unwrap();
754 let k = Tensor::from_vec(
755 vec![0.1; batch * heads * seq * dim],
756 &[batch, heads, seq, dim],
757 )
758 .unwrap();
759 let v = Tensor::from_vec(
760 vec![
761 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
762 ],
763 &[batch, heads, seq, dim],
764 )
765 .unwrap();
766
767 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
768 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
769
770 let out_vec = out.to_vec();
772 assert!(
773 (out_vec[0] - 1.0).abs() < 1e-5,
774 "row 0, col 0 should be 1.0"
775 );
776 assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
777 }
778
779 #[test]
780 fn test_multihead_attention_backward_cpu() {
781 use axonml_autograd::backward;
783
784 let mha = MultiHeadAttention::new(32, 4);
785 let input = Variable::new(
786 Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
787 true,
788 );
789 let output = mha.forward(&input);
790 assert_eq!(output.shape(), vec![2, 4, 32]);
791
792 let loss = output.sum();
794 let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
795 backward(&loss, &ones);
796
797 let grad = input.grad();
799 assert!(grad.is_some(), "Input gradient should exist");
800 let grad_data = grad.unwrap();
801 assert_eq!(grad_data.shape(), &[2, 4, 32]);
802
803 let grad_vec = grad_data.to_vec();
805 let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
806 assert!(non_zero, "Gradients should be non-zero");
807 }
808
809 #[test]
810 fn test_fused_attention_backward_cpu() {
811 use axonml_autograd::functions::FusedAttentionBackward;
813 use axonml_autograd::grad_fn::GradientFunction;
814
815 let batch = 1;
816 let heads = 2;
817 let seq = 4;
818 let dim = 8;
819 let scale = 1.0 / (dim as f32).sqrt();
820
821 let q_data: Vec<f32> = (0..batch * heads * seq * dim)
823 .map(|i| ((i as f32) * 0.01).sin())
824 .collect();
825 let k_data: Vec<f32> = (0..batch * heads * seq * dim)
826 .map(|i| ((i as f32) * 0.02).cos())
827 .collect();
828 let v_data: Vec<f32> = (0..batch * heads * seq * dim)
829 .map(|i| ((i as f32) * 0.03).sin() + 0.5)
830 .collect();
831
832 let q =
833 Tensor::from_vec(q_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
834 let k =
835 Tensor::from_vec(k_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
836 let v =
837 Tensor::from_vec(v_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
838
839 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
841 assert_eq!(output.shape(), &[batch, heads, seq, dim]);
842
843 let backward_fn = FusedAttentionBackward::new(
845 None,
846 None,
847 None,
848 q.clone(),
849 k.clone(),
850 v.clone(),
851 output.clone(),
852 scale,
853 false,
854 );
855
856 let grad_output = Tensor::from_vec(
858 vec![1.0f32; batch * heads * seq * dim],
859 &[batch, heads, seq, dim],
860 )
861 .unwrap();
862
863 let grads = backward_fn.apply(&grad_output);
864 assert_eq!(grads.len(), 3);
865
866 let gq = grads[0].as_ref().expect("grad_Q should exist");
867 let gk = grads[1].as_ref().expect("grad_K should exist");
868 let gv = grads[2].as_ref().expect("grad_V should exist");
869
870 assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
871 assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
872 assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
873
874 for val in gq
876 .to_vec()
877 .iter()
878 .chain(gk.to_vec().iter())
879 .chain(gv.to_vec().iter())
880 {
881 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
882 }
883
884 let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
886 assert!(gv_nonzero, "grad_V should be non-zero");
887 }
888
889 #[test]
890 fn test_fused_attention_backward_causal_cpu() {
891 use axonml_autograd::functions::FusedAttentionBackward;
893 use axonml_autograd::grad_fn::GradientFunction;
894
895 let batch = 1;
896 let heads = 1;
897 let seq = 4;
898 let dim = 4;
899 let scale = 1.0 / (dim as f32).sqrt();
900
901 let q = Tensor::from_vec(
902 vec![0.1f32; batch * heads * seq * dim],
903 &[batch, heads, seq, dim],
904 )
905 .unwrap();
906 let k = Tensor::from_vec(
907 vec![0.2f32; batch * heads * seq * dim],
908 &[batch, heads, seq, dim],
909 )
910 .unwrap();
911 let v = Tensor::from_vec(
912 vec![0.5f32; batch * heads * seq * dim],
913 &[batch, heads, seq, dim],
914 )
915 .unwrap();
916
917 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
918
919 let backward_fn = FusedAttentionBackward::new(
920 None,
921 None,
922 None,
923 q.clone(),
924 k.clone(),
925 v.clone(),
926 output.clone(),
927 scale,
928 true,
929 );
930
931 let grad_output = Tensor::from_vec(
932 vec![1.0f32; batch * heads * seq * dim],
933 &[batch, heads, seq, dim],
934 )
935 .unwrap();
936
937 let grads = backward_fn.apply(&grad_output);
938 assert_eq!(grads.len(), 3);
939
940 let gq = grads[0].as_ref().unwrap();
941 let gk = grads[1].as_ref().unwrap();
942 let gv = grads[2].as_ref().unwrap();
943
944 for val in gq
946 .to_vec()
947 .iter()
948 .chain(gk.to_vec().iter())
949 .chain(gv.to_vec().iter())
950 {
951 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
952 }
953 }
954}