1use ferrotorch_core::grad_fns::linalg::linear_fused;
35use ferrotorch_core::grad_fns::shape::reshape;
36use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
37
38use crate::init::{NonLinearity, kaiming_uniform};
39use crate::module::Module;
40use crate::parameter::Parameter;
41
42#[derive(Debug)]
62pub struct Linear<T: Float> {
63 pub weight: Parameter<T>,
65 pub bias: Option<Parameter<T>>,
67 in_features: usize,
69 out_features: usize,
71 training: bool,
73}
74
75impl<T: Float> Linear<T> {
76 pub fn new(in_features: usize, out_features: usize, bias: bool) -> FerrotorchResult<Self> {
89 if in_features == 0 {
90 return Err(FerrotorchError::InvalidArgument {
91 message: "Linear: in_features must be > 0".into(),
92 });
93 }
94 if out_features == 0 {
95 return Err(FerrotorchError::InvalidArgument {
96 message: "Linear: out_features must be > 0".into(),
97 });
98 }
99
100 let mut weight = Parameter::zeros(&[out_features, in_features])?;
102 kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
103
104 let bias_param = if bias {
110 let mut b = Parameter::zeros(&[out_features])?;
111 let bound = if in_features > 0 {
112 1.0 / (in_features as f64).sqrt()
113 } else {
114 0.0
115 };
116 crate::init::uniform(&mut b, -bound, bound)?;
117 Some(b)
118 } else {
119 None
120 };
121
122 Ok(Self {
123 weight,
124 bias: bias_param,
125 in_features,
126 out_features,
127 training: true,
128 })
129 }
130
131 #[inline]
133 pub fn in_features(&self) -> usize {
134 self.in_features
135 }
136
137 #[inline]
139 pub fn out_features(&self) -> usize {
140 self.out_features
141 }
142}
143
144impl<T: Float> Module<T> for Linear<T> {
145 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
167 if input.ndim() == 0 {
168 return Err(FerrotorchError::ShapeMismatch {
169 message: "Linear: scalar input not supported".into(),
170 });
171 }
172
173 let last_dim = input.shape()[input.ndim() - 1];
175 if last_dim != self.in_features {
176 return Err(FerrotorchError::ShapeMismatch {
177 message: format!(
178 "Linear: input has {} features but layer expects {}",
179 last_dim, self.in_features
180 ),
181 });
182 }
183
184 let input_shape = input.shape().to_vec();
187 let batch_shape = &input_shape[..input_shape.len() - 1];
188 let n: usize = batch_shape.iter().product::<usize>().max(1);
189 let needs_reshape = input.ndim() != 2;
190
191 let input_2d = if needs_reshape {
192 reshape(input, &[n as isize, self.in_features as isize])?
193 } else {
194 input.clone()
195 };
196
197 let output_2d = linear_fused(
199 &input_2d,
200 self.weight.tensor(),
201 self.bias.as_ref().map(|b| b.tensor()),
202 )?;
203
204 if needs_reshape {
206 let mut out_shape: Vec<isize> = batch_shape.iter().map(|&d| d as isize).collect();
207 out_shape.push(self.out_features as isize);
208 reshape(&output_2d, &out_shape)
209 } else {
210 Ok(output_2d)
211 }
212 }
213
214 fn parameters(&self) -> Vec<&Parameter<T>> {
215 let mut params = vec![&self.weight];
216 if let Some(ref b) = self.bias {
217 params.push(b);
218 }
219 params
220 }
221
222 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
223 let mut params = vec![&mut self.weight];
224 if let Some(ref mut b) = self.bias {
225 params.push(b);
226 }
227 params
228 }
229
230 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
231 let mut params = vec![("weight".to_string(), &self.weight)];
232 if let Some(ref b) = self.bias {
233 params.push(("bias".to_string(), b));
234 }
235 params
236 }
237
238 fn train(&mut self) {
239 self.training = true;
240 }
241
242 fn eval(&mut self) {
243 self.training = false;
244 }
245
246 fn is_training(&self) -> bool {
247 self.training
248 }
249}
250
251impl<T: Float> std::fmt::Display for Linear<T> {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 write!(
258 f,
259 "Linear(in_features={}, out_features={}, bias={})",
260 self.in_features,
261 self.out_features,
262 self.bias.is_some()
263 )
264 }
265}
266
267#[derive(Debug)]
290pub struct Bilinear<T: Float> {
291 pub weight: Parameter<T>,
293 pub bias: Option<Parameter<T>>,
295 in1_features: usize,
296 in2_features: usize,
297 out_features: usize,
298 training: bool,
299}
300
301impl<T: Float> Bilinear<T> {
302 pub fn new(
315 in1_features: usize,
316 in2_features: usize,
317 out_features: usize,
318 bias: bool,
319 ) -> FerrotorchResult<Self> {
320 if in1_features == 0 || in2_features == 0 || out_features == 0 {
321 return Err(FerrotorchError::InvalidArgument {
322 message: format!(
323 "Bilinear: in1/in2/out_features must all be > 0, got ({in1_features}, {in2_features}, {out_features})"
324 ),
325 });
326 }
327
328 let bound = if in1_features > 0 {
330 1.0 / (in1_features as f64).sqrt()
331 } else {
332 0.0
333 };
334
335 let mut weight = Parameter::zeros(&[out_features, in1_features, in2_features])?;
336 crate::init::uniform(&mut weight, -bound, bound)?;
337
338 let bias_param = if bias {
339 let mut b = Parameter::zeros(&[out_features])?;
340 crate::init::uniform(&mut b, -bound, bound)?;
341 Some(b)
342 } else {
343 None
344 };
345
346 Ok(Self {
347 weight,
348 bias: bias_param,
349 in1_features,
350 in2_features,
351 out_features,
352 training: true,
353 })
354 }
355
356 #[inline]
358 pub fn in1_features(&self) -> usize {
359 self.in1_features
360 }
361
362 #[inline]
364 pub fn in2_features(&self) -> usize {
365 self.in2_features
366 }
367
368 #[inline]
370 pub fn out_features(&self) -> usize {
371 self.out_features
372 }
373
374 pub fn forward_pair(&self, x1: &Tensor<T>, x2: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
392 if x1.ndim() == 0 || x2.ndim() == 0 {
396 return Err(FerrotorchError::ShapeMismatch {
397 message: "Bilinear: scalar (0-D) inputs not supported; expected (*, features)"
398 .into(),
399 });
400 }
401 if x1.ndim() != x2.ndim() {
402 return Err(FerrotorchError::ShapeMismatch {
403 message: format!(
404 "Bilinear: input dimensions do not match: got {} and {}",
405 x1.ndim(),
406 x2.ndim(),
407 ),
408 });
409 }
410
411 let x1_shape = x1.shape().to_vec();
412 let x2_shape = x2.shape().to_vec();
413
414 let lead_len = x1_shape.len() - 1;
417 for d in 0..lead_len {
418 if x1_shape[d] != x2_shape[d] {
419 return Err(FerrotorchError::ShapeMismatch {
420 message: format!(
421 "Bilinear: input batch dimensions do not match at dim {}: got {} and {}",
422 d, x1_shape[d], x2_shape[d],
423 ),
424 });
425 }
426 }
427
428 if x1_shape[lead_len] != self.in1_features {
430 return Err(FerrotorchError::ShapeMismatch {
431 message: format!(
432 "Bilinear: x1 last dim {} != in1_features {}",
433 x1_shape[lead_len], self.in1_features,
434 ),
435 });
436 }
437 if x2_shape[lead_len] != self.in2_features {
438 return Err(FerrotorchError::ShapeMismatch {
439 message: format!(
440 "Bilinear: x2 last dim {} != in2_features {}",
441 x2_shape[lead_len], self.in2_features,
442 ),
443 });
444 }
445
446 let batch_shape = &x1_shape[..lead_len];
453 let n: usize = batch_shape.iter().product();
454 let x1_2d = ferrotorch_core::grad_fns::shape::reshape(
455 x1,
456 &[n as isize, self.in1_features as isize],
457 )?;
458 let x2_2d = ferrotorch_core::grad_fns::shape::reshape(
459 x2,
460 &[n as isize, self.in2_features as isize],
461 )?;
462
463 let boj = ferrotorch_core::einsum::einsum_differentiable(
469 "bi,oij->boj",
470 &[&x1_2d, self.weight.tensor()],
471 )?;
472 let bo = ferrotorch_core::einsum::einsum_differentiable("boj,bj->bo", &[&boj, &x2_2d])?;
473
474 let out_2d = if let Some(ref bias) = self.bias {
479 let bias_2d = ferrotorch_core::grad_fns::shape::reshape(
480 bias.tensor(),
481 &[1, self.out_features as isize],
482 )?;
483 ferrotorch_core::grad_fns::arithmetic::add(&bo, &bias_2d)?
484 } else {
485 bo
486 };
487
488 let mut out_shape: Vec<isize> = batch_shape.iter().map(|&d| d as isize).collect();
492 out_shape.push(self.out_features as isize);
493 ferrotorch_core::grad_fns::shape::reshape(&out_2d, &out_shape)
494 }
495}
496
497impl<T: Float> Module<T> for Bilinear<T> {
498 fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
503 Err(FerrotorchError::InvalidArgument {
504 message: "Bilinear requires two inputs; call `forward_pair(x1, x2)` instead of \
505 `Module::forward`."
506 .into(),
507 })
508 }
509
510 fn parameters(&self) -> Vec<&Parameter<T>> {
511 let mut params = vec![&self.weight];
512 if let Some(ref b) = self.bias {
513 params.push(b);
514 }
515 params
516 }
517
518 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
519 let mut params = vec![&mut self.weight];
520 if let Some(ref mut b) = self.bias {
521 params.push(b);
522 }
523 params
524 }
525
526 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
527 let mut params = vec![("weight".to_string(), &self.weight)];
528 if let Some(ref b) = self.bias {
529 params.push(("bias".to_string(), b));
530 }
531 params
532 }
533
534 fn train(&mut self) {
535 self.training = true;
536 }
537
538 fn eval(&mut self) {
539 self.training = false;
540 }
541
542 fn is_training(&self) -> bool {
543 self.training
544 }
545}
546
547impl<T: Float> std::fmt::Display for Bilinear<T> {
548 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549 write!(
550 f,
551 "Bilinear(in1_features={}, in2_features={}, out_features={}, bias={})",
552 self.in1_features,
553 self.in2_features,
554 self.out_features,
555 self.bias.is_some()
556 )
557 }
558}
559
560#[cfg(test)]
565mod tests {
566 use super::*;
567 use ferrotorch_core::{Tensor, TensorStorage};
568
569 fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
571 Tensor::from_storage(
572 TensorStorage::cpu(data.to_vec()),
573 shape.to_vec(),
574 requires_grad,
575 )
576 .unwrap()
577 }
578
579 fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
581 assert_eq!(
582 actual.len(),
583 expected.len(),
584 "length mismatch: {} vs {}",
585 actual.len(),
586 expected.len()
587 );
588 for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
589 assert!(
590 (a - e).abs() < tol,
591 "index {i}: actual={a} expected={e} diff={}",
592 (a - e).abs()
593 );
594 }
595 }
596
597 #[test]
602 fn test_construction_with_bias() {
603 let layer = Linear::<f32>::new(10, 5, true).unwrap();
604 assert_eq!(layer.in_features(), 10);
605 assert_eq!(layer.out_features(), 5);
606 assert_eq!(layer.weight.shape(), &[5, 10]);
607 assert!(layer.bias.is_some());
608 assert_eq!(layer.bias.as_ref().unwrap().shape(), &[5]);
609 }
610
611 #[test]
612 fn test_construction_without_bias() {
613 let layer = Linear::<f32>::new(8, 4, false).unwrap();
614 assert_eq!(layer.weight.shape(), &[4, 8]);
615 assert!(layer.bias.is_none());
616 }
617
618 #[test]
619 fn test_construction_zero_in_features() {
620 assert!(Linear::<f32>::new(0, 5, true).is_err());
621 }
622
623 #[test]
624 fn test_construction_zero_out_features() {
625 assert!(Linear::<f32>::new(5, 0, true).is_err());
626 }
627
628 #[test]
629 fn test_weight_requires_grad() {
630 let layer = Linear::<f32>::new(4, 3, true).unwrap();
631 assert!(layer.weight.requires_grad());
632 assert!(layer.bias.as_ref().unwrap().requires_grad());
633 }
634
635 #[test]
640 fn test_forward_shape() {
641 let layer = Linear::<f32>::new(4, 3, true).unwrap();
642 let input = leaf(&[0.0; 8], &[2, 4], false);
643 let output = layer.forward(&input).unwrap();
644 assert_eq!(output.shape(), &[2, 3]);
645 }
646
647 #[test]
648 fn test_forward_shape_no_bias() {
649 let layer = Linear::<f32>::new(6, 2, false).unwrap();
650 let input = leaf(&[0.0; 18], &[3, 6], false);
651 let output = layer.forward(&input).unwrap();
652 assert_eq!(output.shape(), &[3, 2]);
653 }
654
655 #[test]
656 fn test_forward_wrong_input_features() {
657 let layer = Linear::<f32>::new(4, 3, true).unwrap();
658 let input = leaf(&[0.0; 15], &[3, 5], false);
659 assert!(layer.forward(&input).is_err());
660 }
661
662 #[test]
663 fn test_forward_1d_input_accepted() {
664 let mut layer = Linear::<f32>::new(3, 2, false).unwrap();
666 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
667 let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
668 let output = layer.forward(&input).unwrap();
669 assert_eq!(output.shape(), &[2]);
670 assert_close(output.data().unwrap(), &[1.0, 2.0], 1e-6);
671 }
672
673 #[test]
678 fn test_forward_3d_input_shape() {
679 let layer = Linear::<f32>::new(4, 3, true).unwrap();
681 let input = leaf(&[0.0; 2 * 5 * 4], &[2, 5, 4], false);
682 let output = layer.forward(&input).unwrap();
683 assert_eq!(output.shape(), &[2, 5, 3]);
684 }
685
686 #[test]
687 fn test_forward_4d_input_shape() {
688 let layer = Linear::<f32>::new(8, 4, false).unwrap();
690 let input = leaf(&[0.0; 2 * 3 * 4 * 8], &[2, 3, 4, 8], false);
691 let output = layer.forward(&input).unwrap();
692 assert_eq!(output.shape(), &[2, 3, 4, 4]);
693 }
694
695 #[test]
696 fn test_forward_3d_correctness() {
697 let mut layer = Linear::<f32>::new(3, 2, false).unwrap();
699 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
700
701 let data = [
703 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
704 ];
705 let input_3d = leaf(&data, &[2, 2, 3], false);
706 let out_3d = layer.forward(&input_3d).unwrap();
707 assert_eq!(out_3d.shape(), &[2, 2, 2]);
708
709 let input_2d = leaf(&data, &[4, 3], false);
711 let out_2d = layer.forward(&input_2d).unwrap();
712 assert_eq!(out_2d.shape(), &[4, 2]);
713
714 assert_close(out_3d.data().unwrap(), out_2d.data().unwrap(), 1e-6);
716 }
717
718 #[test]
723 fn test_forward_correctness_no_bias() {
724 let mut layer = Linear::<f32>::new(3, 2, false).unwrap();
726
727 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
730
731 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
733 let output = layer.forward(&input).unwrap();
734
735 assert_eq!(output.shape(), &[2, 2]);
738 assert_close(output.data().unwrap(), &[1.0, 2.0, 4.0, 5.0], 1e-6);
739 }
740
741 #[test]
742 fn test_forward_correctness_with_bias() {
743 let mut layer = Linear::<f32>::new(2, 2, true).unwrap();
744
745 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
747 *layer.bias.as_mut().unwrap() = Parameter::from_slice(&[10.0, 20.0], &[2]).unwrap();
749
750 let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
751 let output = layer.forward(&input).unwrap();
752
753 assert_close(output.data().unwrap(), &[11.0, 22.0, 13.0, 24.0], 1e-6);
755 }
756
757 #[test]
762 fn test_backward_gradients_no_bias() {
763 let mut layer = Linear::<f32>::new(2, 2, false).unwrap();
776 layer.weight = Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
777
778 let input = leaf(&[1.0, 0.0, 0.0, 1.0], &[2, 2], true);
779 let output = layer.forward(&input).unwrap();
780
781 let loss = ferrotorch_core::grad_fns::reduction::sum(&output).unwrap();
783 loss.backward().unwrap();
784
785 let input_grad = input.grad().unwrap().expect("input should have grad");
787 assert_eq!(input_grad.shape(), &[2, 2]);
788 assert_close(input_grad.data().unwrap(), &[4.0, 6.0, 4.0, 6.0], 1e-5);
789 }
790
791 #[test]
792 fn test_backward_weight_grad() {
793 let mut layer = Linear::<f32>::new(2, 2, false).unwrap();
806 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
807
808 let input = leaf(&[2.0, 3.0], &[1, 2], false);
809 let output = layer.forward(&input).unwrap();
810 let loss = ferrotorch_core::grad_fns::reduction::sum(&output).unwrap();
811 loss.backward().unwrap();
812
813 let w_grad = layer
819 .weight
820 .grad()
821 .unwrap()
822 .expect("weight should have grad");
823 assert_eq!(w_grad.shape(), &[2, 2]);
824 assert_close(w_grad.data().unwrap(), &[2.0, 3.0, 2.0, 3.0], 1e-5);
825 }
826
827 #[test]
828 fn test_backward_numerical_gradient() {
829 let eps = 1e-4f32;
833
834 let mut layer = Linear::<f32>::new(2, 2, false).unwrap();
835 layer.weight = Parameter::from_slice(&[0.5, -0.3, 0.2, 0.8], &[2, 2]).unwrap();
836
837 let input_data = [1.0f32, 2.0, 3.0, 4.0];
838 let input = leaf(&input_data, &[2, 2], false);
839
840 let output = layer.forward(&input).unwrap();
842 let loss = ferrotorch_core::grad_fns::reduction::sum(&output).unwrap();
843 loss.backward().unwrap();
844
845 let analytic_grad = layer.weight.grad().unwrap().unwrap();
846 let analytic = analytic_grad.data().unwrap().to_vec();
847
848 let base_weight = [0.5f32, -0.3, 0.2, 0.8];
850 for idx in 0..4 {
851 let mut w_plus = base_weight;
852 w_plus[idx] += eps;
853 let mut layer_plus = Linear::<f32>::new(2, 2, false).unwrap();
854 layer_plus.weight = Parameter::from_slice(&w_plus, &[2, 2]).unwrap();
855 let input_ng = leaf(&input_data, &[2, 2], false);
856 let out_plus = ferrotorch_core::no_grad(|| {
857 let o = layer_plus.forward(&input_ng).unwrap();
858 ferrotorch_core::grad_fns::reduction::sum(&o).unwrap()
859 });
860 let loss_plus = out_plus.item().unwrap();
861
862 let mut w_minus = base_weight;
863 w_minus[idx] -= eps;
864 let mut layer_minus = Linear::<f32>::new(2, 2, false).unwrap();
865 layer_minus.weight = Parameter::from_slice(&w_minus, &[2, 2]).unwrap();
866 let input_ng2 = leaf(&input_data, &[2, 2], false);
867 let out_minus = ferrotorch_core::no_grad(|| {
868 let o = layer_minus.forward(&input_ng2).unwrap();
869 ferrotorch_core::grad_fns::reduction::sum(&o).unwrap()
870 });
871 let loss_minus = out_minus.item().unwrap();
872
873 let numerical = (loss_plus - loss_minus) / (2.0 * eps);
874 assert!(
875 (numerical - analytic[idx]).abs() < 1e-2,
876 "weight[{idx}]: numerical={numerical}, analytic={}, diff={}",
877 analytic[idx],
878 (numerical - analytic[idx]).abs()
879 );
880 }
881 }
882
883 #[test]
888 fn test_parameter_count_with_bias() {
889 let layer = Linear::<f32>::new(10, 5, true).unwrap();
890 let params = layer.parameters();
891 assert_eq!(params.len(), 2);
892 let total: usize = params.iter().map(|p| p.numel()).sum();
894 assert_eq!(total, 55);
895 }
896
897 #[test]
898 fn test_parameter_count_without_bias() {
899 let layer = Linear::<f32>::new(10, 5, false).unwrap();
900 let params = layer.parameters();
901 assert_eq!(params.len(), 1);
902 let total: usize = params.iter().map(|p| p.numel()).sum();
903 assert_eq!(total, 50);
904 }
905
906 #[test]
911 fn test_state_dict_roundtrip_with_bias() {
912 let layer = Linear::<f32>::new(4, 3, true).unwrap();
913 let sd = layer.state_dict();
914 assert!(sd.contains_key("weight"));
915 assert!(sd.contains_key("bias"));
916 assert_eq!(sd["weight"].shape(), &[3, 4]);
917 assert_eq!(sd["bias"].shape(), &[3]);
918
919 let mut layer2 = Linear::<f32>::new(4, 3, true).unwrap();
920 layer2.load_state_dict(&sd, true).unwrap();
921
922 assert_close(
924 layer2.weight.data().unwrap(),
925 layer.weight.data().unwrap(),
926 1e-7,
927 );
928 assert_close(
929 layer2.bias.as_ref().unwrap().data().unwrap(),
930 layer.bias.as_ref().unwrap().data().unwrap(),
931 1e-7,
932 );
933 }
934
935 #[test]
936 fn test_state_dict_roundtrip_without_bias() {
937 let layer = Linear::<f32>::new(6, 2, false).unwrap();
938 let sd = layer.state_dict();
939 assert!(sd.contains_key("weight"));
940 assert!(!sd.contains_key("bias"));
941
942 let mut layer2 = Linear::<f32>::new(6, 2, false).unwrap();
943 layer2.load_state_dict(&sd, true).unwrap();
944
945 assert_close(
946 layer2.weight.data().unwrap(),
947 layer.weight.data().unwrap(),
948 1e-7,
949 );
950 }
951
952 #[test]
953 fn test_state_dict_shape_mismatch_rejected() {
954 let layer_a = Linear::<f32>::new(4, 3, true).unwrap();
955 let sd = layer_a.state_dict();
956
957 let mut layer_b = Linear::<f32>::new(4, 5, true).unwrap();
958 assert!(layer_b.load_state_dict(&sd, true).is_err());
959 }
960
961 #[test]
966 fn test_named_parameters_with_bias() {
967 let layer = Linear::<f32>::new(3, 2, true).unwrap();
968 let named = layer.named_parameters();
969 assert_eq!(named.len(), 2);
970 assert_eq!(named[0].0, "weight");
971 assert_eq!(named[1].0, "bias");
972 }
973
974 #[test]
975 fn test_named_parameters_without_bias() {
976 let layer = Linear::<f32>::new(3, 2, false).unwrap();
977 let named = layer.named_parameters();
978 assert_eq!(named.len(), 1);
979 assert_eq!(named[0].0, "weight");
980 }
981
982 #[test]
987 fn test_train_eval() {
988 let mut layer = Linear::<f32>::new(4, 3, true).unwrap();
989 assert!(layer.is_training());
990 layer.eval();
991 assert!(!layer.is_training());
992 layer.train();
993 assert!(layer.is_training());
994 }
995
996 #[test]
1001 fn test_display() {
1002 let layer = Linear::<f32>::new(10, 5, true).unwrap();
1003 let s = format!("{layer}");
1004 assert_eq!(s, "Linear(in_features=10, out_features=5, bias=true)");
1005 }
1006
1007 #[test]
1008 fn test_display_no_bias() {
1009 let layer = Linear::<f32>::new(10, 5, false).unwrap();
1010 let s = format!("{layer}");
1011 assert_eq!(s, "Linear(in_features=10, out_features=5, bias=false)");
1012 }
1013
1014 #[test]
1019 fn test_linear_is_send_sync() {
1020 fn assert_send_sync<T: Send + Sync>() {}
1021 assert_send_sync::<Linear<f32>>();
1022 assert_send_sync::<Linear<f64>>();
1023 }
1024
1025 #[test]
1034 fn test_linear_bias_init_bounded_uniform() {
1035 let in_features = 64usize;
1036 let out_features = 128usize;
1037 let layer = Linear::<f32>::new(in_features, out_features, true).unwrap();
1038 let bias = layer.bias.as_ref().expect("bias requested");
1039 let bias_data = bias.tensor().data_vec().unwrap();
1040 let bound = 1.0_f32 / (in_features as f32).sqrt();
1041 let mut nonzero = 0usize;
1042 for &b in &bias_data {
1043 assert!(
1044 b.abs() <= bound + 1e-6,
1045 "bias element {b} exceeds bound {bound}"
1046 );
1047 if b != 0.0 {
1048 nonzero += 1;
1049 }
1050 }
1051 assert!(
1052 nonzero > out_features / 2,
1053 "expected most bias entries to be nonzero (got {nonzero}/{out_features}); \
1054 would FAIL pre-fix when bias was zeros_init"
1055 );
1056 }
1057
1058 #[test]
1063 fn test_to_device_cpu_preserves_weights() {
1064 let mut layer = Linear::<f32>::new(4, 3, true).unwrap();
1065 layer.weight = Parameter::from_slice(
1066 &[
1067 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1068 ],
1069 &[3, 4],
1070 )
1071 .unwrap();
1072 *layer.bias.as_mut().unwrap() = Parameter::from_slice(&[0.1, 0.2, 0.3], &[3]).unwrap();
1073
1074 layer.to_device(ferrotorch_core::Device::Cpu).unwrap();
1075
1076 assert_eq!(layer.weight.shape(), &[3, 4]);
1077 assert_close(
1078 layer.weight.data().unwrap(),
1079 &[
1080 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1081 ],
1082 1e-7,
1083 );
1084 assert_close(
1085 layer.bias.as_ref().unwrap().data().unwrap(),
1086 &[0.1, 0.2, 0.3],
1087 1e-7,
1088 );
1089 assert!(layer.weight.requires_grad());
1090 assert!(layer.bias.as_ref().unwrap().requires_grad());
1091 }
1092
1093 #[test]
1094 fn test_to_device_cuda_returns_device_unavailable() {
1095 let mut layer = Linear::<f32>::new(4, 3, true).unwrap();
1096 let result = layer.to_device(ferrotorch_core::Device::Cuda(0));
1097 assert!(result.is_err());
1098 }
1099
1100 fn bilinear_3d_layer() -> Bilinear<f32> {
1118 let mut layer = Bilinear::<f32>::new(3, 2, 2, true).unwrap();
1119 layer.weight = Parameter::from_slice(
1121 &[
1122 0.1, 0.2, 0.3, -0.1, -0.2, 0.05, 0.0, 0.4, -0.3, 0.2, 0.1, -0.15, ],
1125 &[2, 3, 2],
1126 )
1127 .unwrap();
1128 *layer.bias.as_mut().unwrap() = Parameter::from_slice(&[0.5, -0.25], &[2]).unwrap();
1129 layer
1130 }
1131
1132 #[test]
1133 fn test_bilinear_3d_forward_matches_torch() {
1134 let layer = bilinear_3d_layer();
1139 let x1 = leaf(
1140 &[
1141 1.0, 2.0, 3.0, -1.0, 0.5, 2.0, 0.0, 1.0, -1.0, 2.0, -2.0, 1.0,
1142 ],
1143 &[2, 2, 3],
1144 false,
1145 );
1146 let x2 = leaf(
1147 &[1.0, -1.0, 0.5, 2.0, -1.0, 1.0, 3.0, 0.0],
1148 &[2, 2, 2],
1149 false,
1150 );
1151 let y = layer.forward_pair(&x1, &x2).unwrap();
1152 assert_eq!(y.shape(), &[2, 2, 2]);
1153 assert_close(
1155 y.data().unwrap(),
1156 &[0.45, -0.9, 0.025, -1.425, -0.15, 0.5, -1.3, 1.85],
1157 1e-5,
1158 );
1159 }
1160
1161 #[test]
1162 fn test_bilinear_3d_backward_matches_torch() {
1163 let layer = bilinear_3d_layer();
1166 let x1 = leaf(
1167 &[
1168 1.0, 2.0, 3.0, -1.0, 0.5, 2.0, 0.0, 1.0, -1.0, 2.0, -2.0, 1.0,
1169 ],
1170 &[2, 2, 3],
1171 true,
1172 );
1173 let x2 = leaf(
1174 &[1.0, -1.0, 0.5, 2.0, -1.0, 1.0, 3.0, 0.0],
1175 &[2, 2, 2],
1176 true,
1177 );
1178 let y = layer.forward_pair(&x1, &x2).unwrap();
1179 let loss = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
1180 loss.backward().unwrap();
1181
1182 let g_x1 = x1.grad().unwrap().expect("x1 should have grad");
1183 assert_eq!(g_x1.shape(), &[2, 2, 3]);
1184 assert_close(
1185 g_x1.data().unwrap(),
1186 &[
1187 -0.5, -0.1, 0.0, 1.25, 0.2, -0.25, 0.5, 0.1, 0.0, 0.3, 0.0, -0.3,
1188 ],
1189 1e-5,
1190 );
1191
1192 let g_x2 = x2.grad().unwrap().expect("x2 should have grad");
1193 assert_eq!(g_x2.shape(), &[2, 2, 2]);
1194 assert_close(
1195 g_x2.data().unwrap(),
1196 &[-0.2, 0.5, -0.3, -0.75, 0.1, 0.2, 0.1, 0.9],
1197 1e-5,
1198 );
1199
1200 let g_w = layer.weight.grad().unwrap().expect("W should have grad");
1201 assert_eq!(g_w.shape(), &[2, 3, 2]);
1202 assert_close(
1203 g_w.data().unwrap(),
1204 &[
1205 6.5, -3.0, -4.75, 0.0, 8.0, 0.0, 6.5, -3.0, -4.75, 0.0, 8.0, 0.0,
1206 ],
1207 1e-5,
1208 );
1209
1210 let g_b = layer
1211 .bias
1212 .as_ref()
1213 .unwrap()
1214 .grad()
1215 .unwrap()
1216 .expect("bias should have grad");
1217 assert_eq!(g_b.shape(), &[2]);
1218 assert_close(g_b.data().unwrap(), &[4.0, 4.0], 1e-5);
1219 }
1220
1221 #[test]
1222 fn test_bilinear_4d_forward_matches_torch() {
1223 let mut layer = Bilinear::<f32>::new(2, 2, 1, false).unwrap();
1229 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2]).unwrap();
1230 let x1 = leaf(
1231 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
1232 &[2, 1, 2, 2],
1233 false,
1234 );
1235 let x2 = leaf(
1236 &[1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0],
1237 &[2, 1, 2, 2],
1238 false,
1239 );
1240 let y = layer.forward_pair(&x1, &x2).unwrap();
1241 assert_eq!(y.shape(), &[2, 1, 2, 1]);
1242 assert_close(y.data().unwrap(), &[3.0, 7.0, 22.0, 30.0], 1e-5);
1243 }
1244
1245 #[test]
1246 fn test_bilinear_2d_still_matches_torch() {
1247 let mut layer = Bilinear::<f32>::new(2, 2, 1, false).unwrap();
1252 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2]).unwrap();
1253 let x1 = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
1254 let x2 = leaf(&[1.0, 1.0, 1.0, 1.0], &[2, 2], false);
1255 let y = layer.forward_pair(&x1, &x2).unwrap();
1256 assert_eq!(y.shape(), &[2, 1]);
1257 assert_close(y.data().unwrap(), &[3.0, 7.0], 1e-5);
1258 }
1259
1260 #[test]
1261 fn test_bilinear_1d_still_matches_torch() {
1262 let mut layer = Bilinear::<f32>::new(2, 2, 1, false).unwrap();
1265 layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2]).unwrap();
1266 let x1 = leaf(&[2.0, 3.0], &[2], false);
1267 let x2 = leaf(&[1.0, 1.0], &[2], false);
1268 let y = layer.forward_pair(&x1, &x2).unwrap();
1269 assert_eq!(y.shape(), &[1]);
1270 assert_close(y.data().unwrap(), &[5.0], 1e-5);
1271 }
1272
1273 #[test]
1274 fn test_bilinear_empty_leading_dim_2d() {
1275 let layer = bilinear_3d_layer();
1277 let x1 = leaf(&[], &[0, 3], false);
1278 let x2 = leaf(&[], &[0, 2], false);
1279 let y = layer.forward_pair(&x1, &x2).unwrap();
1280 assert_eq!(y.shape(), &[0, 2]);
1281 assert_eq!(y.numel(), 0);
1282 }
1283
1284 #[test]
1285 fn test_bilinear_empty_leading_dim_3d() {
1286 let layer = bilinear_3d_layer();
1288 let x1 = leaf(&[], &[0, 4, 3], false);
1289 let x2 = leaf(&[], &[0, 4, 2], false);
1290 let y = layer.forward_pair(&x1, &x2).unwrap();
1291 assert_eq!(y.shape(), &[0, 4, 2]);
1292 assert_eq!(y.numel(), 0);
1293 }
1294
1295 #[test]
1296 fn test_bilinear_zero_middle_dim_3d() {
1297 let layer = bilinear_3d_layer();
1299 let x1 = leaf(&[], &[2, 0, 3], false);
1300 let x2 = leaf(&[], &[2, 0, 2], false);
1301 let y = layer.forward_pair(&x1, &x2).unwrap();
1302 assert_eq!(y.shape(), &[2, 0, 2]);
1303 assert_eq!(y.numel(), 0);
1304 }
1305
1306 #[test]
1307 fn test_bilinear_mismatched_ndim_rejected() {
1308 let layer = bilinear_3d_layer();
1310 let x1 = leaf(&[0.0; 2 * 2 * 3], &[2, 2, 3], false);
1311 let x2 = leaf(&[0.0; 2 * 2], &[2, 2], false);
1312 assert!(layer.forward_pair(&x1, &x2).is_err());
1313 }
1314
1315 #[test]
1316 fn test_bilinear_mismatched_leading_dim_rejected() {
1317 let layer = bilinear_3d_layer();
1320 let x1 = leaf(&[0.0; 2 * 3 * 3], &[2, 3, 3], false);
1321 let x2 = leaf(&[0.0; 2 * 4 * 2], &[2, 4, 2], false);
1322 assert!(layer.forward_pair(&x1, &x2).is_err());
1323 }
1324
1325 #[test]
1326 fn test_bilinear_wrong_feature_dim_rejected() {
1327 let layer = bilinear_3d_layer(); let bad_x1 = leaf(&[0.0; 2 * 2 * 4], &[2, 2, 4], false); let x2 = leaf(&[0.0; 2 * 2 * 2], &[2, 2, 2], false);
1331 assert!(layer.forward_pair(&bad_x1, &x2).is_err());
1332 }
1333}