1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20#[cfg(feature = "cuda")]
21use axonml_autograd::functions::FusedAttentionBackward;
22#[cfg(feature = "cuda")]
23use axonml_autograd::grad_fn::GradFn;
24use axonml_tensor::Tensor;
25
26use crate::layers::Linear;
27use crate::module::Module;
28use crate::parameter::Parameter;
29
30pub struct MultiHeadAttention {
50 q_proj: Linear,
52 k_proj: Linear,
54 v_proj: Linear,
56 out_proj: Linear,
58 embed_dim: usize,
60 num_heads: usize,
62 head_dim: usize,
64 scale: f32,
66 batch_first: bool,
68}
69
70impl MultiHeadAttention {
71 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
73 Self::with_options(embed_dim, num_heads, 0.0, true)
74 }
75
76 pub fn with_options(
78 embed_dim: usize,
79 num_heads: usize,
80 _dropout: f32,
81 batch_first: bool,
82 ) -> Self {
83 assert!(
84 embed_dim % num_heads == 0,
85 "embed_dim must be divisible by num_heads"
86 );
87
88 let head_dim = embed_dim / num_heads;
89 let scale = (head_dim as f32).sqrt().recip();
90
91 Self {
92 q_proj: Linear::new(embed_dim, embed_dim),
93 k_proj: Linear::new(embed_dim, embed_dim),
94 v_proj: Linear::new(embed_dim, embed_dim),
95 out_proj: Linear::new(embed_dim, embed_dim),
96 embed_dim,
97 num_heads,
98 head_dim,
99 scale,
100 batch_first,
101 }
102 }
103
104 fn expand_mask(
113 mask: &Variable,
114 batch_size: usize,
115 num_heads: usize,
116 tgt_len: usize,
117 src_len: usize,
118 ) -> Variable {
119 let mask_shape = mask.shape();
120 let target = [batch_size, num_heads, tgt_len, src_len];
121
122 if mask_shape == target {
123 return mask.clone();
124 }
125
126 if mask_shape.len() == 2 {
128 let reshaped = mask.reshape(&[1, 1, tgt_len, src_len]);
129 return reshaped.expand(&target);
130 }
131
132 if mask_shape.len() == 4 && mask_shape[1] == 1 {
134 return mask.expand(&target);
135 }
136
137 if mask_shape.len() == 4 && mask_shape[0] == 1 && mask_shape[1] == 1 {
139 return mask.expand(&target);
140 }
141
142 mask.clone()
144 }
145
146 pub fn attention(
148 &self,
149 query: &Variable,
150 key: &Variable,
151 value: &Variable,
152 attn_mask: Option<&Variable>,
153 ) -> Variable {
154 let q_shape = query.shape();
155 let (batch_size, tgt_len, _) = if self.batch_first {
156 (q_shape[0], q_shape[1], q_shape[2])
157 } else {
158 (q_shape[1], q_shape[0], q_shape[2])
159 };
160 let src_len = if self.batch_first {
161 key.shape()[1]
162 } else {
163 key.shape()[0]
164 };
165
166 let q = self.q_proj.forward(query);
168 let k = self.k_proj.forward(key);
169 let v = self.v_proj.forward(value);
170
171 let q = q
174 .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
175 .transpose(1, 2);
176 let k = k
177 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
178 .transpose(1, 2);
179 let v = v
180 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
181 .transpose(1, 2);
182
183 #[cfg(feature = "cuda")]
189 if q.data().device().is_gpu() && attn_mask.is_none() {
190 let is_training = axonml_autograd::no_grad::is_grad_enabled();
191 let q_tensor = q.data();
192 let k_tensor = k.data();
193 let v_tensor = v.data();
194
195 if let Some(attn_out) = q_tensor.fused_attention_cuda(
196 &k_tensor, &v_tensor, self.scale,
197 false, ) {
199 let attn_output = if is_training
200 && (q.requires_grad() || k.requires_grad() || v.requires_grad())
201 {
202 let backward = FusedAttentionBackward::new(
204 q.grad_fn().cloned(),
205 k.grad_fn().cloned(),
206 v.grad_fn().cloned(),
207 q_tensor,
208 k_tensor,
209 v_tensor,
210 attn_out.clone(),
211 self.scale,
212 false,
213 );
214 Variable::from_operation(attn_out, GradFn::new(backward), true)
215 } else {
216 Variable::new(attn_out, false)
217 };
218 let attn_output =
219 attn_output
220 .transpose(1, 2)
221 .reshape(&[batch_size, tgt_len, self.embed_dim]);
222 return self.out_proj.forward(&attn_output);
223 }
224 }
226
227 let k_t = k.transpose(2, 3);
230 let scores = q.matmul(&k_t).mul_scalar(self.scale);
232
233 let scores = if let Some(mask) = attn_mask {
237 let mask_shape = mask.shape();
238 let mask_data = mask.data();
239 let scores_shape = scores.shape();
240 let total = scores_shape.iter().product::<usize>();
241
242 #[cfg(feature = "cuda")]
245 if scores.data().device().is_gpu() {
246 let mask_gpu = if mask_data.device().is_gpu() {
248 mask_data.clone()
249 } else {
250 mask_data.to_device(scores.data().device()).unwrap()
251 };
252
253 if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
254 &scores_shape,
255 batch_size,
256 self.num_heads,
257 tgt_len,
258 src_len,
259 ) {
260 let additive_mask = Variable::new(expanded_tensor, false);
261 return self.finish_attention(
262 scores.add_var(&additive_mask),
263 &v,
264 batch_size,
265 tgt_len,
266 );
267 }
268 }
270
271 let mask_vec = mask_data.to_vec();
273 let additive: Vec<f32> = mask_vec
274 .iter()
275 .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
276 .collect();
277
278 let mut expanded = vec![0.0f32; total];
279
280 if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
281 for b in 0..batch_size {
283 for h in 0..self.num_heads {
284 for i in 0..tgt_len {
285 for j in 0..src_len {
286 let idx = b * self.num_heads * tgt_len * src_len
287 + h * tgt_len * src_len
288 + i * src_len
289 + j;
290 expanded[idx] = additive[i * src_len + j];
291 }
292 }
293 }
294 }
295 } else if mask_shape.len() == 2
296 && mask_shape[0] == batch_size
297 && mask_shape[1] == src_len
298 {
299 for b in 0..batch_size {
301 for h in 0..self.num_heads {
302 for i in 0..tgt_len {
303 for j in 0..src_len {
304 let idx = b * self.num_heads * tgt_len * src_len
305 + h * tgt_len * src_len
306 + i * src_len
307 + j;
308 expanded[idx] = additive[b * src_len + j];
309 }
310 }
311 }
312 }
313 } else {
314 for (i, val) in expanded.iter_mut().enumerate() {
316 *val = additive[i % additive.len()];
317 }
318 }
319
320 let mut additive_tensor = Tensor::from_vec(expanded, &scores_shape).expect("tensor creation failed");
321 let scores_device = scores.data().device();
322 if scores_device.is_gpu() {
323 additive_tensor = additive_tensor.to_device(scores_device).expect("device transfer failed");
324 }
325 let additive_mask = Variable::new(additive_tensor, false);
326 scores.add_var(&additive_mask)
327 } else {
328 scores
329 };
330
331 self.finish_attention(scores, &v, batch_size, tgt_len)
332 }
333
334 fn finish_attention(
337 &self,
338 scores: Variable,
339 v: &Variable,
340 batch_size: usize,
341 tgt_len: usize,
342 ) -> Variable {
343 let attn_weights = scores.softmax(-1);
344 let attn_output = attn_weights.matmul(v);
345 let attn_output =
346 attn_output
347 .transpose(1, 2)
348 .reshape(&[batch_size, tgt_len, self.embed_dim]);
349 self.out_proj.forward(&attn_output)
350 }
351}
352
353impl Module for MultiHeadAttention {
354 fn forward(&self, input: &Variable) -> Variable {
355 self.attention(input, input, input, None)
357 }
358
359 fn parameters(&self) -> Vec<Parameter> {
360 let mut params = Vec::new();
361 params.extend(self.q_proj.parameters());
362 params.extend(self.k_proj.parameters());
363 params.extend(self.v_proj.parameters());
364 params.extend(self.out_proj.parameters());
365 params
366 }
367
368 fn named_parameters(&self) -> HashMap<String, Parameter> {
369 let mut params = HashMap::new();
370 for (name, param) in self.q_proj.named_parameters() {
371 params.insert(format!("q_proj.{name}"), param);
372 }
373 for (name, param) in self.k_proj.named_parameters() {
374 params.insert(format!("k_proj.{name}"), param);
375 }
376 for (name, param) in self.v_proj.named_parameters() {
377 params.insert(format!("v_proj.{name}"), param);
378 }
379 for (name, param) in self.out_proj.named_parameters() {
380 params.insert(format!("out_proj.{name}"), param);
381 }
382 params
383 }
384
385 fn name(&self) -> &'static str {
386 "MultiHeadAttention"
387 }
388}
389
390pub struct CrossAttention {
407 mha: MultiHeadAttention,
409}
410
411impl CrossAttention {
412 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
414 Self {
415 mha: MultiHeadAttention::new(embed_dim, num_heads),
416 }
417 }
418
419 pub fn with_options(
421 embed_dim: usize,
422 num_heads: usize,
423 dropout: f32,
424 batch_first: bool,
425 ) -> Self {
426 Self {
427 mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
428 }
429 }
430
431 pub fn cross_attention(
438 &self,
439 query: &Variable,
440 memory: &Variable,
441 attn_mask: Option<&Variable>,
442 ) -> Variable {
443 self.mha.attention(query, memory, memory, attn_mask)
444 }
445
446 pub fn embed_dim(&self) -> usize {
448 self.mha.embed_dim
449 }
450
451 pub fn num_heads(&self) -> usize {
453 self.mha.num_heads
454 }
455}
456
457impl Module for CrossAttention {
458 fn forward(&self, input: &Variable) -> Variable {
459 self.mha.forward(input)
462 }
463
464 fn parameters(&self) -> Vec<Parameter> {
465 self.mha.parameters()
466 }
467
468 fn named_parameters(&self) -> HashMap<String, Parameter> {
469 let mut params = HashMap::new();
470 for (name, param) in self.mha.named_parameters() {
471 params.insert(format!("mha.{name}"), param);
472 }
473 params
474 }
475
476 fn name(&self) -> &'static str {
477 "CrossAttention"
478 }
479}
480
481pub fn scaled_dot_product_attention_fused(
514 q: &Tensor<f32>,
515 k: &Tensor<f32>,
516 v: &Tensor<f32>,
517 scale: f32,
518 is_causal: bool,
519) -> Tensor<f32> {
520 #[cfg(feature = "cuda")]
522 if q.device().is_gpu() {
523 if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
524 return result;
525 }
526 }
527
528 let shape = q.shape();
530 let batch_size = shape[0];
531 let num_heads = shape[1];
532 let tgt_len = shape[2];
533 let head_dim = shape[3];
534 let src_len = k.shape()[2];
535
536 let q_data = q.to_vec();
537 let k_data = k.to_vec();
538 let v_data = v.to_vec();
539
540 let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
541
542 for b in 0..batch_size {
543 for h in 0..num_heads {
544 for i in 0..tgt_len {
545 let mut scores = vec![0.0f32; src_len];
547 let mut max_score = f32::NEG_INFINITY;
548
549 for j in 0..src_len {
550 if is_causal && j > i {
551 scores[j] = f32::NEG_INFINITY;
552 continue;
553 }
554 let mut score = 0.0f32;
555 for d in 0..head_dim {
556 let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
557 let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
558 score += q_data[q_idx] * k_data[k_idx];
559 }
560 score *= scale;
561 scores[j] = score;
562 if score > max_score {
563 max_score = score;
564 }
565 }
566
567 let mut sum_exp = 0.0f32;
569 for s in &mut scores {
570 if *s > f32::NEG_INFINITY {
571 *s = (*s - max_score).exp();
572 sum_exp += *s;
573 } else {
574 *s = 0.0;
575 }
576 }
577 let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
578
579 for d in 0..head_dim {
581 let mut val = 0.0f32;
582 for j in 0..src_len {
583 let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
584 val += scores[j] * v_data[v_idx];
585 }
586 let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
587 output[out_idx] = val * inv_sum;
588 }
589 }
590 }
591 }
592
593 Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim]).expect("tensor creation failed")
594}
595
596#[cfg(test)]
601mod tests {
602 use super::*;
603
604 #[test]
605 fn test_multihead_attention_creation() {
606 let mha = MultiHeadAttention::new(512, 8);
607 assert_eq!(mha.embed_dim, 512);
608 assert_eq!(mha.num_heads, 8);
609 assert_eq!(mha.head_dim, 64);
610 }
611
612 #[test]
613 fn test_multihead_attention_forward() {
614 let mha = MultiHeadAttention::new(64, 4);
615 let input = Variable::new(
616 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
617 false,
618 );
619 let output = mha.forward(&input);
620 assert_eq!(output.shape(), vec![2, 10, 64]);
621 }
622
623 #[test]
624 fn test_cross_attention() {
625 let mha = MultiHeadAttention::new(64, 4);
626 let query = Variable::new(
627 Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
628 false,
629 );
630 let key_value = Variable::new(
631 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
632 false,
633 );
634 let output = mha.attention(&query, &key_value, &key_value, None);
635 assert_eq!(output.shape(), vec![2, 5, 64]);
636 }
637
638 #[test]
639 fn test_multihead_attention_parameters() {
640 let mha = MultiHeadAttention::new(64, 4);
641 let params = mha.parameters();
642 assert_eq!(params.len(), 8);
644 }
645
646 #[test]
647 fn test_cross_attention_creation() {
648 let ca = CrossAttention::new(256, 8);
649 assert_eq!(ca.embed_dim(), 256);
650 assert_eq!(ca.num_heads(), 8);
651 }
652
653 #[test]
654 fn test_cross_attention_forward() {
655 let ca = CrossAttention::new(64, 4);
656 let query = Variable::new(
658 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
659 false,
660 );
661 let memory = Variable::new(
663 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
664 false,
665 );
666 let output = ca.cross_attention(&query, &memory, None);
667 assert_eq!(output.shape(), vec![2, 5, 64]);
668 }
669
670 #[test]
671 fn test_cross_attention_self_attention_fallback() {
672 let ca = CrossAttention::new(64, 4);
673 let input = Variable::new(
674 Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
675 false,
676 );
677 let output = ca.forward(&input);
679 assert_eq!(output.shape(), vec![2, 8, 64]);
680 }
681
682 #[test]
683 fn test_cross_attention_parameters() {
684 let ca = CrossAttention::new(64, 4);
685 let params = ca.parameters();
686 assert_eq!(params.len(), 8); let named = ca.named_parameters();
688 assert!(named.contains_key("mha.q_proj.weight"));
689 assert!(named.contains_key("mha.out_proj.bias"));
690 }
691
692 #[test]
693 fn test_fused_attention_cpu() {
694 let batch = 2;
696 let heads = 4;
697 let seq = 8;
698 let dim = 16;
699 let scale = 1.0 / (dim as f32).sqrt();
700
701 let q = Tensor::from_vec(
702 vec![0.1; batch * heads * seq * dim],
703 &[batch, heads, seq, dim],
704 )
705 .unwrap();
706 let k = Tensor::from_vec(
707 vec![0.1; batch * heads * seq * dim],
708 &[batch, heads, seq, dim],
709 )
710 .unwrap();
711 let v = Tensor::from_vec(
712 vec![0.5; batch * heads * seq * dim],
713 &[batch, heads, seq, dim],
714 )
715 .unwrap();
716
717 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
718 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
719
720 let out_vec = out.to_vec();
722 for val in &out_vec {
723 assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
724 }
725 }
726
727 #[test]
728 fn test_fused_attention_causal() {
729 let batch = 1;
730 let heads = 1;
731 let seq = 4;
732 let dim = 4;
733 let scale = 1.0 / (dim as f32).sqrt();
734
735 let q = Tensor::from_vec(
737 vec![0.1; batch * heads * seq * dim],
738 &[batch, heads, seq, dim],
739 )
740 .unwrap();
741 let k = Tensor::from_vec(
742 vec![0.1; batch * heads * seq * dim],
743 &[batch, heads, seq, dim],
744 )
745 .unwrap();
746 let v = Tensor::from_vec(
747 vec![
748 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
749 ],
750 &[batch, heads, seq, dim],
751 )
752 .unwrap();
753
754 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
755 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
756
757 let out_vec = out.to_vec();
759 assert!(
760 (out_vec[0] - 1.0).abs() < 1e-5,
761 "row 0, col 0 should be 1.0"
762 );
763 assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
764 }
765
766 #[test]
767 fn test_multihead_attention_backward_cpu() {
768 use axonml_autograd::backward;
770
771 let mha = MultiHeadAttention::new(32, 4);
772 let input = Variable::new(
773 Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
774 true,
775 );
776 let output = mha.forward(&input);
777 assert_eq!(output.shape(), vec![2, 4, 32]);
778
779 let loss = output.sum();
781 let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
782 backward(&loss, &ones);
783
784 let grad = input.grad();
786 assert!(grad.is_some(), "Input gradient should exist");
787 let grad_data = grad.unwrap();
788 assert_eq!(grad_data.shape(), &[2, 4, 32]);
789
790 let grad_vec = grad_data.to_vec();
792 let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
793 assert!(non_zero, "Gradients should be non-zero");
794 }
795
796 #[test]
797 fn test_fused_attention_backward_cpu() {
798 use axonml_autograd::functions::FusedAttentionBackward;
800 use axonml_autograd::grad_fn::GradientFunction;
801
802 let batch = 1;
803 let heads = 2;
804 let seq = 4;
805 let dim = 8;
806 let scale = 1.0 / (dim as f32).sqrt();
807
808 let q_data: Vec<f32> = (0..batch * heads * seq * dim)
810 .map(|i| ((i as f32) * 0.01).sin())
811 .collect();
812 let k_data: Vec<f32> = (0..batch * heads * seq * dim)
813 .map(|i| ((i as f32) * 0.02).cos())
814 .collect();
815 let v_data: Vec<f32> = (0..batch * heads * seq * dim)
816 .map(|i| ((i as f32) * 0.03).sin() + 0.5)
817 .collect();
818
819 let q = Tensor::from_vec(q_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
820 let k = Tensor::from_vec(k_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
821 let v = Tensor::from_vec(v_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
822
823 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
825 assert_eq!(output.shape(), &[batch, heads, seq, dim]);
826
827 let backward_fn = FusedAttentionBackward::new(
829 None,
830 None,
831 None,
832 q.clone(),
833 k.clone(),
834 v.clone(),
835 output.clone(),
836 scale,
837 false,
838 );
839
840 let grad_output = Tensor::from_vec(
842 vec![1.0f32; batch * heads * seq * dim],
843 &[batch, heads, seq, dim],
844 )
845 .unwrap();
846
847 let grads = backward_fn.apply(&grad_output);
848 assert_eq!(grads.len(), 3);
849
850 let gq = grads[0].as_ref().expect("grad_Q should exist");
851 let gk = grads[1].as_ref().expect("grad_K should exist");
852 let gv = grads[2].as_ref().expect("grad_V should exist");
853
854 assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
855 assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
856 assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
857
858 for val in gq
860 .to_vec()
861 .iter()
862 .chain(gk.to_vec().iter())
863 .chain(gv.to_vec().iter())
864 {
865 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
866 }
867
868 let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
870 assert!(gv_nonzero, "grad_V should be non-zero");
871 }
872
873 #[test]
874 fn test_fused_attention_backward_causal_cpu() {
875 use axonml_autograd::functions::FusedAttentionBackward;
877 use axonml_autograd::grad_fn::GradientFunction;
878
879 let batch = 1;
880 let heads = 1;
881 let seq = 4;
882 let dim = 4;
883 let scale = 1.0 / (dim as f32).sqrt();
884
885 let q = Tensor::from_vec(
886 vec![0.1f32; batch * heads * seq * dim],
887 &[batch, heads, seq, dim],
888 )
889 .unwrap();
890 let k = Tensor::from_vec(
891 vec![0.2f32; batch * heads * seq * dim],
892 &[batch, heads, seq, dim],
893 )
894 .unwrap();
895 let v = Tensor::from_vec(
896 vec![0.5f32; batch * heads * seq * dim],
897 &[batch, heads, seq, dim],
898 )
899 .unwrap();
900
901 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
902
903 let backward_fn = FusedAttentionBackward::new(
904 None,
905 None,
906 None,
907 q.clone(),
908 k.clone(),
909 v.clone(),
910 output.clone(),
911 scale,
912 true,
913 );
914
915 let grad_output = Tensor::from_vec(
916 vec![1.0f32; batch * heads * seq * dim],
917 &[batch, heads, seq, dim],
918 )
919 .unwrap();
920
921 let grads = backward_fn.apply(&grad_output);
922 assert_eq!(grads.len(), 3);
923
924 let gq = grads[0].as_ref().unwrap();
925 let gk = grads[1].as_ref().unwrap();
926 let gv = grads[2].as_ref().unwrap();
927
928 for val in gq
930 .to_vec()
931 .iter()
932 .chain(gk.to_vec().iter())
933 .chain(gv.to_vec().iter())
934 {
935 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
936 }
937 }
938}