Skip to main content

ferrotorch_nn/
linear.rs

1//! Fully connected (dense) linear layer: `y = input @ weight^T + bias`.
2//!
3//! This is the fundamental building block for feedforward networks. The
4//! weight matrix has shape `[out_features, in_features]` (same convention
5//! as PyTorch) and the optional bias has shape `[out_features]`.
6//!
7//! # Autograd
8//!
9//! The forward pass is built from composable differentiable operations
10//! (`mm_differentiable`, `add`), so the backward graph is constructed
11//! automatically:
12//!
13//! - `grad_weight` is accumulated through `MmBackward`
14//! - `grad_bias` is accumulated through `AddBackward` (broadcast reduction)
15//! - `grad_input` is accumulated through `MmBackward`
16//!
17//! ## REQ status (per `.design/ferrotorch-nn/linear.md`)
18//!
19//! | REQ | Status | Evidence |
20//! |---|---|---|
21//! | REQ-1 | SHIPPED | impl: `pub struct Linear<T: Float>` here, mirroring `torch/nn/modules/linear.py:91-115`; non-test consumer: `pub use linear::Linear` in `lib.rs` exposes the type to `ferrotorch_llama::mlp::FeedForward::gate_proj` and similar fields. |
22//! | REQ-2 | SHIPPED | impl: the `Linear::new` constructor here, mirroring `linear.py:96-115`; non-test consumer: `Linear::new(cfg.hidden_size, cfg.intermediate_size, false)?` in `ferrotorch-llama/src/mlp.rs`. |
23//! | REQ-3 | SHIPPED | impl: shape flatten/reshape pre/post `linear_fused` inside `<Linear as Module>::forward` here, mirroring `linear.py:67-70`; non-test consumer: transformer blocks in `ferrotorch-nn/src/transformer.rs` and `ferrotorch-llama/src/attention.rs` feed 3-D `[B, T, H]` tensors through `Linear::forward` for QKV projection. |
24//! | REQ-4 | SHIPPED | impl: the `linear_fused(&input_2d, weight.tensor(), bias_opt)` call inside `<Linear as Module>::forward` mirroring `linear.py:130-134`'s `F.linear`; non-test consumer: every model in `ferrotorch-vision/src/models/` invokes `Linear::forward` through its classifier head. |
25//! | REQ-5 | SHIPPED | impl: `kaiming_uniform(&mut weight, NonLinearity::ReLU)` call inside `Linear::new` here; non-test consumer: `Linear::new` is the construction path used by every consumer above. NOTE: gain divergence from upstream `linear.py:124`. |
26//! | REQ-6 | SHIPPED | impl: `crate::init::uniform(&mut b, -bound, bound)?` with `bound = 1/sqrt(in_features)` call inside `Linear::new` here mirroring `torch/nn/modules/linear.py:124-128`; non-test consumer: same as REQ-5. |
27//! | REQ-7 | SHIPPED | impl: `impl<T: Float> Module<T> for Linear<T>` block here providing `forward`/`parameters`/`parameters_mut`/`named_parameters`/`train`/`eval`/`is_training`; non-test consumer: `ferrotorch_optim::Optimizer` consumes `Module::parameters_mut()` to apply updates. |
28//! | REQ-8 | SHIPPED | impl: `impl<T: Float> Display for Linear<T>` block here matching upstream `linear.py:136-140`'s `extra_repr`; non-test consumer: `format!("{layer}")` in model summary printing (e.g. `ferrotorch_train` learner emits module displays in logs). |
29//! | REQ-9 | SHIPPED | `Linear` carries only `Parameter<T>` fields which are `Send + Sync`; verified at compile time via `assert_send_sync::<Linear<f32>>()` in tests; non-test consumer: any multi-threaded `DataParallel`-style training scaffolding in `ferrotorch-train` requires `Send + Sync`. |
30//! | REQ-10 | SHIPPED | impl: `last_dim != self.in_features` guard inside `<Linear as Module>::forward` here; non-test consumer: every production caller is shielded from silent shape mismatches by this guard. |
31//! | REQ-11 | SHIPPED | impl: `pub struct Bilinear<T: Float>` here with `weight` `[out, in1, in2]` + optional `bias` `[out]`. `forward_pair` accepts arbitrary leading batch dims `(*, in)` -> `(*, out)`: flattens all-but-last to `[N, in]` (explicit batch product, handles `N == 0`), runs two `einsum_differentiable` contractions (`"bi,oij->boj"` then `"boj,bj->bo"`) + bias broadcast, then reshapes back to `(*, out)`, mirroring `torch/nn/modules/linear.py:162-256` + `aten/src/ATen/native/Linear.cpp:792-802`; non-test consumer: `pub use linear::Bilinear` in `lib.rs` re-export so downstream model crates (e.g. attention-fusion and FiLM-style conditioning) can construct it directly. Closes #1442, #1603. |
32//! | REQ-12 | NOT-STARTED | blocker #1441 — parity-sweep runner has no arm for `nn.functional.linear`; sweep reports `0/144 passed, 144 skipped`. The forward path itself is end-to-end verified by 22 lib tests; only the runner-arm wiring is missing. |
33
34use ferrotorch_core::grad_fns::linalg::linear_fused;
35use ferrotorch_core::grad_fns::shape::reshape;
36use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
37
38use crate::init::{NonLinearity, kaiming_uniform};
39use crate::module::Module;
40use crate::parameter::Parameter;
41
42/// A fully connected (dense) linear layer.
43///
44/// Applies the transformation `y = x @ W^T + b` where `W` has shape
45/// `[out_features, in_features]` and `b` (if present) has shape
46/// `[out_features]`.
47///
48/// # Initialization
49///
50/// - **Weight**: Kaiming uniform with `gain = sqrt(2)` (ReLU). This is
51///   the PyTorch default for `nn.Linear`.
52/// - **Bias**: Uniform `U(-bound, bound)` with `bound = 1/sqrt(in_features)`,
53///   mirroring `torch/nn/modules/linear.py:124-128`.
54///
55/// # Examples
56///
57/// ```ignore
58/// let layer = Linear::<f32>::new(784, 256, true)?;
59/// let output = layer.forward(&input)?; // input: [batch, 784] -> output: [batch, 256]
60/// ```
61#[derive(Debug)]
62pub struct Linear<T: Float> {
63    /// Weight matrix of shape `[out_features, in_features]`.
64    pub weight: Parameter<T>,
65    /// Optional bias vector of shape `[out_features]`.
66    pub bias: Option<Parameter<T>>,
67    /// Number of input features.
68    in_features: usize,
69    /// Number of output features.
70    out_features: usize,
71    /// Whether the module is in training mode.
72    training: bool,
73}
74
75impl<T: Float> Linear<T> {
76    /// Create a new linear layer.
77    ///
78    /// # Arguments
79    ///
80    /// - `in_features` — Size of each input sample.
81    /// - `out_features` — Size of each output sample.
82    /// - `bias` — If `true`, adds a learnable bias to the output.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if `in_features` or `out_features` is zero, or if
87    /// parameter allocation fails.
88    pub fn new(in_features: usize, out_features: usize, bias: bool) -> FerrotorchResult<Self> {
89        if in_features == 0 {
90            return Err(FerrotorchError::InvalidArgument {
91                message: "Linear: in_features must be > 0".into(),
92            });
93        }
94        if out_features == 0 {
95            return Err(FerrotorchError::InvalidArgument {
96                message: "Linear: out_features must be > 0".into(),
97            });
98        }
99
100        // Initialize weight with Kaiming uniform (fan_in mode, ReLU gain).
101        let mut weight = Parameter::zeros(&[out_features, in_features])?;
102        kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
103
104        // Initialize bias U(-bound, bound) with bound = 1/sqrt(fan_in),
105        // fan_in = in_features. Mirrors `torch/nn/modules/linear.py:124-128`:
106        //   `fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)`
107        //   `bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0`
108        //   `init.uniform_(self.bias, -bound, bound)`
109        let bias_param = if bias {
110            let mut b = Parameter::zeros(&[out_features])?;
111            let bound = if in_features > 0 {
112                1.0 / (in_features as f64).sqrt()
113            } else {
114                0.0
115            };
116            crate::init::uniform(&mut b, -bound, bound)?;
117            Some(b)
118        } else {
119            None
120        };
121
122        Ok(Self {
123            weight,
124            bias: bias_param,
125            in_features,
126            out_features,
127            training: true,
128        })
129    }
130
131    /// Number of input features.
132    #[inline]
133    pub fn in_features(&self) -> usize {
134        self.in_features
135    }
136
137    /// Number of output features.
138    #[inline]
139    pub fn out_features(&self) -> usize {
140        self.out_features
141    }
142}
143
144impl<T: Float> Module<T> for Linear<T> {
145    /// Forward pass: `y = input @ weight^T + bias`.
146    ///
147    /// # Input shape
148    ///
149    /// Accepts any input with shape `(*batch, in_features)`:
150    /// - 1D: `[in_features]` — single sample, no batch dim.
151    /// - 2D: `[batch, in_features]` — standard batched forward.
152    /// - 3D: `[batch, seq_len, in_features]` — e.g. transformer inputs.
153    /// - ND: `[d0, d1, ..., in_features]` — arbitrary leading dimensions.
154    ///
155    /// # Output shape
156    ///
157    /// - `(*batch, out_features)` — same leading dimensions as input.
158    ///
159    /// # Autograd
160    ///
161    /// When gradient tracking is enabled, the returned tensor participates
162    /// in the computation graph through the composed differentiable
163    /// operations (`mm_differentiable` + `add` + `reshape`). Calling
164    /// `.backward()` on a downstream scalar loss will propagate gradients
165    /// to `weight` and `bias` automatically.
166    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
167        if input.ndim() == 0 {
168            return Err(FerrotorchError::ShapeMismatch {
169                message: "Linear: scalar input not supported".into(),
170            });
171        }
172
173        // Validate the last dimension is in_features.
174        let last_dim = input.shape()[input.ndim() - 1];
175        if last_dim != self.in_features {
176            return Err(FerrotorchError::ShapeMismatch {
177                message: format!(
178                    "Linear: input has {} features but layer expects {}",
179                    last_dim, self.in_features
180                ),
181            });
182        }
183
184        // For inputs with ndim != 2, flatten leading dims to get [N, in_features],
185        // apply the fused linear transform, then reshape back to (*batch, out_features).
186        let input_shape = input.shape().to_vec();
187        let batch_shape = &input_shape[..input_shape.len() - 1];
188        let n: usize = batch_shape.iter().product::<usize>().max(1);
189        let needs_reshape = input.ndim() != 2;
190
191        let input_2d = if needs_reshape {
192            reshape(input, &[n as isize, self.in_features as isize])?
193        } else {
194            input.clone()
195        };
196
197        // Fused linear: input @ weight^T + bias in a single operation.
198        let output_2d = linear_fused(
199            &input_2d,
200            self.weight.tensor(),
201            self.bias.as_ref().map(|b| b.tensor()),
202        )?;
203
204        // Reshape back to (*batch, out_features).
205        if needs_reshape {
206            let mut out_shape: Vec<isize> = batch_shape.iter().map(|&d| d as isize).collect();
207            out_shape.push(self.out_features as isize);
208            reshape(&output_2d, &out_shape)
209        } else {
210            Ok(output_2d)
211        }
212    }
213
214    fn parameters(&self) -> Vec<&Parameter<T>> {
215        let mut params = vec![&self.weight];
216        if let Some(ref b) = self.bias {
217            params.push(b);
218        }
219        params
220    }
221
222    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
223        let mut params = vec![&mut self.weight];
224        if let Some(ref mut b) = self.bias {
225            params.push(b);
226        }
227        params
228    }
229
230    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
231        let mut params = vec![("weight".to_string(), &self.weight)];
232        if let Some(ref b) = self.bias {
233            params.push(("bias".to_string(), b));
234        }
235        params
236    }
237
238    fn train(&mut self) {
239        self.training = true;
240    }
241
242    fn eval(&mut self) {
243        self.training = false;
244    }
245
246    fn is_training(&self) -> bool {
247        self.training
248    }
249}
250
251// ---------------------------------------------------------------------------
252// Display
253// ---------------------------------------------------------------------------
254
255impl<T: Float> std::fmt::Display for Linear<T> {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        write!(
258            f,
259            "Linear(in_features={}, out_features={}, bias={})",
260            self.in_features,
261            self.out_features,
262            self.bias.is_some()
263        )
264    }
265}
266
267// ---------------------------------------------------------------------------
268// Bilinear — closes #1442
269// ---------------------------------------------------------------------------
270
271/// Bilinear layer: `y = x1^T @ W @ x2 + b`.
272///
273/// Applies a learnable bilinear transformation to two input vectors,
274/// mirroring `torch.nn.Bilinear` (`torch/nn/modules/linear.py:162-260`).
275/// The weight tensor has shape `[out_features, in1_features, in2_features]`
276/// and bias (if present) has shape `[out_features]`. For a 2-D batched input
277/// pair `(x1, x2)` of shape `[B, in1]` and `[B, in2]`, the output has shape
278/// `[B, out]`:
279///
280/// ```text
281/// y[b, o] = sum_i sum_j x1[b, i] * W[o, i, j] * x2[b, j]  + b[o]
282/// ```
283///
284/// # Initialization
285///
286/// - **Weight**: `U(-bound, bound)` with `bound = 1/sqrt(in1_features)`,
287///   matching `torch/nn/modules/linear.py:191-194`.
288/// - **Bias**: `U(-bound, bound)` with the same bound.
289#[derive(Debug)]
290pub struct Bilinear<T: Float> {
291    /// Weight tensor of shape `[out_features, in1_features, in2_features]`.
292    pub weight: Parameter<T>,
293    /// Optional bias of shape `[out_features]`.
294    pub bias: Option<Parameter<T>>,
295    in1_features: usize,
296    in2_features: usize,
297    out_features: usize,
298    training: bool,
299}
300
301impl<T: Float> Bilinear<T> {
302    /// Create a new bilinear layer.
303    ///
304    /// # Arguments
305    ///
306    /// - `in1_features` — size of each `x1` sample.
307    /// - `in2_features` — size of each `x2` sample.
308    /// - `out_features` — size of the output sample.
309    /// - `bias` — if `true`, adds a learnable bias.
310    ///
311    /// # Errors
312    ///
313    /// Returns an error if any feature count is zero, or allocation fails.
314    pub fn new(
315        in1_features: usize,
316        in2_features: usize,
317        out_features: usize,
318        bias: bool,
319    ) -> FerrotorchResult<Self> {
320        if in1_features == 0 || in2_features == 0 || out_features == 0 {
321            return Err(FerrotorchError::InvalidArgument {
322                message: format!(
323                    "Bilinear: in1/in2/out_features must all be > 0, got ({in1_features}, {in2_features}, {out_features})"
324                ),
325            });
326        }
327
328        // bound = 1/sqrt(in1_features) per `torch/nn/modules/linear.py:191-194`.
329        let bound = if in1_features > 0 {
330            1.0 / (in1_features as f64).sqrt()
331        } else {
332            0.0
333        };
334
335        let mut weight = Parameter::zeros(&[out_features, in1_features, in2_features])?;
336        crate::init::uniform(&mut weight, -bound, bound)?;
337
338        let bias_param = if bias {
339            let mut b = Parameter::zeros(&[out_features])?;
340            crate::init::uniform(&mut b, -bound, bound)?;
341            Some(b)
342        } else {
343            None
344        };
345
346        Ok(Self {
347            weight,
348            bias: bias_param,
349            in1_features,
350            in2_features,
351            out_features,
352            training: true,
353        })
354    }
355
356    /// Number of features in the first input.
357    #[inline]
358    pub fn in1_features(&self) -> usize {
359        self.in1_features
360    }
361
362    /// Number of features in the second input.
363    #[inline]
364    pub fn in2_features(&self) -> usize {
365        self.in2_features
366    }
367
368    /// Number of features in the output.
369    #[inline]
370    pub fn out_features(&self) -> usize {
371        self.out_features
372    }
373
374    /// Bilinear forward pass: `y = x1 W x2 + b`.
375    ///
376    /// Accepts arbitrary leading batch dims, mirroring `torch.nn.Bilinear`'s
377    /// `(*, H_in)` shape contract (`torch/nn/modules/linear.py:172-178`):
378    ///
379    /// - `x1`: `(*, in1_features)`, `x2`: `(*, in2_features)` where `*` is
380    ///   any number of additional dimensions (including none, i.e. 1-D).
381    /// - Both inputs must share the **same** leading shape `*`.
382    /// - Returns `(*, out_features)`.
383    ///
384    /// The contraction is `y[*, o] = sum_{i,j} x1[*, i] * W[o, i, j] *
385    /// x2[*, j] + b[o]`. Following the upstream ATen implementation
386    /// (`aten/src/ATen/native/Linear.cpp:792-802`), the leading dims are
387    /// flattened into a single batch axis `N`, the bilinear contraction
388    /// runs on `[N, in]`, and the output `[N, out]` is reshaped back to
389    /// `(*, out_features)`. A zero-size leading dim (`N == 0`) yields the
390    /// correctly-shaped empty output.
391    pub fn forward_pair(&self, x1: &Tensor<T>, x2: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
392        // Both inputs must have the same rank, and at least 1-D (the last
393        // axis is the feature axis). Mirrors `Linear.cpp:777` (`input1.dim()
394        // == input2.dim()`).
395        if x1.ndim() == 0 || x2.ndim() == 0 {
396            return Err(FerrotorchError::ShapeMismatch {
397                message: "Bilinear: scalar (0-D) inputs not supported; expected (*, features)"
398                    .into(),
399            });
400        }
401        if x1.ndim() != x2.ndim() {
402            return Err(FerrotorchError::ShapeMismatch {
403                message: format!(
404                    "Bilinear: input dimensions do not match: got {} and {}",
405                    x1.ndim(),
406                    x2.ndim(),
407                ),
408            });
409        }
410
411        let x1_shape = x1.shape().to_vec();
412        let x2_shape = x2.shape().to_vec();
413
414        // All but the last dimension (the leading shape `*`) must match.
415        // Mirrors `Linear.cpp:778-781` (per-dim batch-shape equality).
416        let lead_len = x1_shape.len() - 1;
417        for d in 0..lead_len {
418            if x1_shape[d] != x2_shape[d] {
419                return Err(FerrotorchError::ShapeMismatch {
420                    message: format!(
421                        "Bilinear: input batch dimensions do not match at dim {}: got {} and {}",
422                        d, x1_shape[d], x2_shape[d],
423                    ),
424                });
425            }
426        }
427
428        // Feature-axis (last dim) checks. Mirrors `Linear.cpp:782-787`.
429        if x1_shape[lead_len] != self.in1_features {
430            return Err(FerrotorchError::ShapeMismatch {
431                message: format!(
432                    "Bilinear: x1 last dim {} != in1_features {}",
433                    x1_shape[lead_len], self.in1_features,
434                ),
435            });
436        }
437        if x2_shape[lead_len] != self.in2_features {
438            return Err(FerrotorchError::ShapeMismatch {
439                message: format!(
440                    "Bilinear: x2 last dim {} != in2_features {}",
441                    x2_shape[lead_len], self.in2_features,
442                ),
443            });
444        }
445
446        // Flatten the leading `*` dims into a single batch axis `N`.
447        // `N` is the explicit product of the leading dims (NOT `-1`), so a
448        // zero-size leading dim flattens to `N == 0` correctly — the einsum
449        // empty-output path (`einsum.rs`, #1605) then returns the right
450        // empty shape rather than panicking. Mirrors `Linear.cpp:796-797`
451        // (`input1.reshape({-1, input1.size(-1)})`).
452        let batch_shape = &x1_shape[..lead_len];
453        let n: usize = batch_shape.iter().product();
454        let x1_2d = ferrotorch_core::grad_fns::shape::reshape(
455            x1,
456            &[n as isize, self.in1_features as isize],
457        )?;
458        let x2_2d = ferrotorch_core::grad_fns::shape::reshape(
459            x2,
460            &[n as isize, self.in2_features as isize],
461        )?;
462
463        // y = einsum("bi,oij,bj->bo", x1, W, x2). Decompose via two
464        // 2-tensor einsums (the workspace einsum primitive supports up to
465        // two operands per call): first contract `i` to get
466        // `boj = sum_i x1[b,i] * W[o,i,j]`, then contract `j` with x2 to
467        // get `bo = sum_j boj * x2[b,j]`.
468        let boj = ferrotorch_core::einsum::einsum_differentiable(
469            "bi,oij->boj",
470            &[&x1_2d, self.weight.tensor()],
471        )?;
472        let bo = ferrotorch_core::einsum::einsum_differentiable("boj,bj->bo", &[&boj, &x2_2d])?;
473
474        // Add bias (broadcast `[out]` over `[N, out]`). Upstream adds the
475        // bias AFTER the reshape-back (`Linear.cpp:799-801`); broadcasting
476        // `[out]` over the flattened `[N, out]` is equivalent and keeps the
477        // add in the 2-D regime the einsum primitive already produced.
478        let out_2d = if let Some(ref bias) = self.bias {
479            let bias_2d = ferrotorch_core::grad_fns::shape::reshape(
480                bias.tensor(),
481                &[1, self.out_features as isize],
482            )?;
483            ferrotorch_core::grad_fns::arithmetic::add(&bo, &bias_2d)?
484        } else {
485            bo
486        };
487
488        // Reshape the output's batch axis back to the original leading
489        // shape `(*, out_features)`. Mirrors `Linear.cpp:792-798`
490        // (`output_size = size1[:-1] + [weight.size(0)]`).
491        let mut out_shape: Vec<isize> = batch_shape.iter().map(|&d| d as isize).collect();
492        out_shape.push(self.out_features as isize);
493        ferrotorch_core::grad_fns::shape::reshape(&out_2d, &out_shape)
494    }
495}
496
497impl<T: Float> Module<T> for Bilinear<T> {
498    /// `Module::forward` for `Bilinear` requires both inputs. The single-
499    /// tensor `Module` trait can't carry the second operand; use
500    /// [`Bilinear::forward_pair`] directly for the bilinear contraction.
501    /// Calling this `forward` returns an error to flag the misuse.
502    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
503        Err(FerrotorchError::InvalidArgument {
504            message: "Bilinear requires two inputs; call `forward_pair(x1, x2)` instead of \
505                      `Module::forward`."
506                .into(),
507        })
508    }
509
510    fn parameters(&self) -> Vec<&Parameter<T>> {
511        let mut params = vec![&self.weight];
512        if let Some(ref b) = self.bias {
513            params.push(b);
514        }
515        params
516    }
517
518    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
519        let mut params = vec![&mut self.weight];
520        if let Some(ref mut b) = self.bias {
521            params.push(b);
522        }
523        params
524    }
525
526    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
527        let mut params = vec![("weight".to_string(), &self.weight)];
528        if let Some(ref b) = self.bias {
529            params.push(("bias".to_string(), b));
530        }
531        params
532    }
533
534    fn train(&mut self) {
535        self.training = true;
536    }
537
538    fn eval(&mut self) {
539        self.training = false;
540    }
541
542    fn is_training(&self) -> bool {
543        self.training
544    }
545}
546
547impl<T: Float> std::fmt::Display for Bilinear<T> {
548    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549        write!(
550            f,
551            "Bilinear(in1_features={}, in2_features={}, out_features={}, bias={})",
552            self.in1_features,
553            self.in2_features,
554            self.out_features,
555            self.bias.is_some()
556        )
557    }
558}
559
560// ---------------------------------------------------------------------------
561// Tests
562// ---------------------------------------------------------------------------
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use ferrotorch_core::{Tensor, TensorStorage};
568
569    /// Create a leaf tensor with given data and shape, optionally with grad.
570    fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
571        Tensor::from_storage(
572            TensorStorage::cpu(data.to_vec()),
573            shape.to_vec(),
574            requires_grad,
575        )
576        .unwrap()
577    }
578
579    /// Assert two float slices are element-wise close.
580    fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
581        assert_eq!(
582            actual.len(),
583            expected.len(),
584            "length mismatch: {} vs {}",
585            actual.len(),
586            expected.len()
587        );
588        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
589            assert!(
590                (a - e).abs() < tol,
591                "index {i}: actual={a} expected={e} diff={}",
592                (a - e).abs()
593            );
594        }
595    }
596
597    // -----------------------------------------------------------------------
598    // Construction
599    // -----------------------------------------------------------------------
600
601    #[test]
602    fn test_construction_with_bias() {
603        let layer = Linear::<f32>::new(10, 5, true).unwrap();
604        assert_eq!(layer.in_features(), 10);
605        assert_eq!(layer.out_features(), 5);
606        assert_eq!(layer.weight.shape(), &[5, 10]);
607        assert!(layer.bias.is_some());
608        assert_eq!(layer.bias.as_ref().unwrap().shape(), &[5]);
609    }
610
611    #[test]
612    fn test_construction_without_bias() {
613        let layer = Linear::<f32>::new(8, 4, false).unwrap();
614        assert_eq!(layer.weight.shape(), &[4, 8]);
615        assert!(layer.bias.is_none());
616    }
617
618    #[test]
619    fn test_construction_zero_in_features() {
620        assert!(Linear::<f32>::new(0, 5, true).is_err());
621    }
622
623    #[test]
624    fn test_construction_zero_out_features() {
625        assert!(Linear::<f32>::new(5, 0, true).is_err());
626    }
627
628    #[test]
629    fn test_weight_requires_grad() {
630        let layer = Linear::<f32>::new(4, 3, true).unwrap();
631        assert!(layer.weight.requires_grad());
632        assert!(layer.bias.as_ref().unwrap().requires_grad());
633    }
634
635    // -----------------------------------------------------------------------
636    // Forward shape
637    // -----------------------------------------------------------------------
638
639    #[test]
640    fn test_forward_shape() {
641        let layer = Linear::<f32>::new(4, 3, true).unwrap();
642        let input = leaf(&[0.0; 8], &[2, 4], false);
643        let output = layer.forward(&input).unwrap();
644        assert_eq!(output.shape(), &[2, 3]);
645    }
646
647    #[test]
648    fn test_forward_shape_no_bias() {
649        let layer = Linear::<f32>::new(6, 2, false).unwrap();
650        let input = leaf(&[0.0; 18], &[3, 6], false);
651        let output = layer.forward(&input).unwrap();
652        assert_eq!(output.shape(), &[3, 2]);
653    }
654
655    #[test]
656    fn test_forward_wrong_input_features() {
657        let layer = Linear::<f32>::new(4, 3, true).unwrap();
658        let input = leaf(&[0.0; 15], &[3, 5], false);
659        assert!(layer.forward(&input).is_err());
660    }
661
662    #[test]
663    fn test_forward_1d_input_accepted() {
664        // PyTorch accepts 1D input: (in_features,) -> (out_features,).
665        let mut layer = Linear::<f32>::new(3, 2, false).unwrap();
666        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
667        let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
668        let output = layer.forward(&input).unwrap();
669        assert_eq!(output.shape(), &[2]);
670        assert_close(output.data().unwrap(), &[1.0, 2.0], 1e-6);
671    }
672
673    // -----------------------------------------------------------------------
674    // Forward shape — multi-dimensional inputs
675    // -----------------------------------------------------------------------
676
677    #[test]
678    fn test_forward_3d_input_shape() {
679        // (batch, seq_len, in_features) -> (batch, seq_len, out_features)
680        let layer = Linear::<f32>::new(4, 3, true).unwrap();
681        let input = leaf(&[0.0; 2 * 5 * 4], &[2, 5, 4], false);
682        let output = layer.forward(&input).unwrap();
683        assert_eq!(output.shape(), &[2, 5, 3]);
684    }
685
686    #[test]
687    fn test_forward_4d_input_shape() {
688        // (batch, x, y, features) -> (batch, x, y, out_features)
689        let layer = Linear::<f32>::new(8, 4, false).unwrap();
690        let input = leaf(&[0.0; 2 * 3 * 4 * 8], &[2, 3, 4, 8], false);
691        let output = layer.forward(&input).unwrap();
692        assert_eq!(output.shape(), &[2, 3, 4, 4]);
693    }
694
695    #[test]
696    fn test_forward_3d_correctness() {
697        // Verify 3D gives same results as manually flattening to 2D.
698        let mut layer = Linear::<f32>::new(3, 2, false).unwrap();
699        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
700
701        // 3D input: (2, 2, 3)
702        let data = [
703            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
704        ];
705        let input_3d = leaf(&data, &[2, 2, 3], false);
706        let out_3d = layer.forward(&input_3d).unwrap();
707        assert_eq!(out_3d.shape(), &[2, 2, 2]);
708
709        // Equivalent 2D input.
710        let input_2d = leaf(&data, &[4, 3], false);
711        let out_2d = layer.forward(&input_2d).unwrap();
712        assert_eq!(out_2d.shape(), &[4, 2]);
713
714        // Data should be identical.
715        assert_close(out_3d.data().unwrap(), out_2d.data().unwrap(), 1e-6);
716    }
717
718    // -----------------------------------------------------------------------
719    // Forward correctness (manual weight/bias)
720    // -----------------------------------------------------------------------
721
722    #[test]
723    fn test_forward_correctness_no_bias() {
724        // Build a layer then manually set the weight.
725        let mut layer = Linear::<f32>::new(3, 2, false).unwrap();
726
727        // weight = [[1, 0, 0], [0, 1, 0]]  (2x3)
728        // This selects the first two features.
729        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
730
731        // input = [[1, 2, 3], [4, 5, 6]]  (2x3)
732        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
733        let output = layer.forward(&input).unwrap();
734
735        // output = input @ weight^T = [[1,2,3],[4,5,6]] @ [[1,0],[0,1],[0,0]]
736        //        = [[1, 2], [4, 5]]
737        assert_eq!(output.shape(), &[2, 2]);
738        assert_close(output.data().unwrap(), &[1.0, 2.0, 4.0, 5.0], 1e-6);
739    }
740
741    #[test]
742    fn test_forward_correctness_with_bias() {
743        let mut layer = Linear::<f32>::new(2, 2, true).unwrap();
744
745        // weight = [[1, 0], [0, 1]]  (identity)
746        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
747        // bias = [10, 20]
748        *layer.bias.as_mut().unwrap() = Parameter::from_slice(&[10.0, 20.0], &[2]).unwrap();
749
750        let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
751        let output = layer.forward(&input).unwrap();
752
753        // output = input @ I + [10, 20] = [[11, 22], [13, 24]]
754        assert_close(output.data().unwrap(), &[11.0, 22.0, 13.0, 24.0], 1e-6);
755    }
756
757    // -----------------------------------------------------------------------
758    // Backward gradients
759    // -----------------------------------------------------------------------
760
761    #[test]
762    fn test_backward_gradients_no_bias() {
763        // Linear: y = input @ W^T, loss = sum(y)
764        // W = [[1, 2], [3, 4]]  (out=2, in=2)
765        // input = [[1, 0], [0, 1]]  (batch=2, in=2)
766        //
767        // W^T = [[1, 3], [2, 4]]
768        // y = input @ W^T = [[1, 3], [2, 4]]  shape [2, 2]
769        // loss = 1 + 3 + 2 + 4 = 10
770        //
771        // dL/dy = ones(2, 2)
772        // dL/d(input) = dL/dy @ W = [[1,1],[1,1]] @ [[1,2],[3,4]] = [[4,6],[4,6]]
773        // dL/d(W^T) = input^T @ dL/dy = [[1,0],[0,1]] @ [[1,1],[1,1]] = [[1,1],[1,1]]
774        // => dL/d(W) = [[1,1],[1,1]]^T = [[1,1],[1,1]]
775        let mut layer = Linear::<f32>::new(2, 2, false).unwrap();
776        layer.weight = Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
777
778        let input = leaf(&[1.0, 0.0, 0.0, 1.0], &[2, 2], true);
779        let output = layer.forward(&input).unwrap();
780
781        // Reduce to scalar via differentiable sum.
782        let loss = ferrotorch_core::grad_fns::reduction::sum(&output).unwrap();
783        loss.backward().unwrap();
784
785        // Check input grad.
786        let input_grad = input.grad().unwrap().expect("input should have grad");
787        assert_eq!(input_grad.shape(), &[2, 2]);
788        assert_close(input_grad.data().unwrap(), &[4.0, 6.0, 4.0, 6.0], 1e-5);
789    }
790
791    #[test]
792    fn test_backward_weight_grad() {
793        // Use a known configuration to verify weight gradients.
794        // W = [[1, 0], [0, 1]]  (out=2, in=2) — identity
795        // input = [[2, 3]]  (batch=1, in=2)
796        // y = [[2, 3]] @ I = [[2, 3]]
797        // loss = sum(y) = 5
798        // dL/dy = ones(1, 2) = [[1, 1]]
799        //
800        // For mm(input, W^T):
801        //   dL/d(W^T) = input^T @ dL/dy = [[2],[3]] @ [[1,1]] = [[2,2],[3,3]]
802        //   => dL/d(W) by chain through transpose
803        //
804        // PyTorch reference: W.grad = dL/dy^T @ input = [[1],[1]] @ [[2,3]] = [[2,3],[2,3]]
805        let mut layer = Linear::<f32>::new(2, 2, false).unwrap();
806        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
807
808        let input = leaf(&[2.0, 3.0], &[1, 2], false);
809        let output = layer.forward(&input).unwrap();
810        let loss = ferrotorch_core::grad_fns::reduction::sum(&output).unwrap();
811        loss.backward().unwrap();
812
813        // The weight gradient flows through mm(input, W^T):
814        // dL/d(W^T) = input^T @ dL/dy = [[2],[3]] @ [[1,1]] = [[2,2],[3,3]]
815        // Since W^T was created via transpose(W), the gradient accumulates on
816        // the original W parameter through the transpose operation.
817        // The transpose of [[2,2],[3,3]] is [[2,3],[2,3]], matching W's shape.
818        let w_grad = layer
819            .weight
820            .grad()
821            .unwrap()
822            .expect("weight should have grad");
823        assert_eq!(w_grad.shape(), &[2, 2]);
824        assert_close(w_grad.data().unwrap(), &[2.0, 3.0, 2.0, 3.0], 1e-5);
825    }
826
827    #[test]
828    fn test_backward_numerical_gradient() {
829        // Numerical gradient check for a small Linear layer.
830        // Perturb each weight element by eps and check finite-difference
831        // gradient matches autograd gradient.
832        let eps = 1e-4f32;
833
834        let mut layer = Linear::<f32>::new(2, 2, false).unwrap();
835        layer.weight = Parameter::from_slice(&[0.5, -0.3, 0.2, 0.8], &[2, 2]).unwrap();
836
837        let input_data = [1.0f32, 2.0, 3.0, 4.0];
838        let input = leaf(&input_data, &[2, 2], false);
839
840        // Forward + backward for analytic gradient.
841        let output = layer.forward(&input).unwrap();
842        let loss = ferrotorch_core::grad_fns::reduction::sum(&output).unwrap();
843        loss.backward().unwrap();
844
845        let analytic_grad = layer.weight.grad().unwrap().unwrap();
846        let analytic = analytic_grad.data().unwrap().to_vec();
847
848        // Numerical gradient for each weight element.
849        let base_weight = [0.5f32, -0.3, 0.2, 0.8];
850        for idx in 0..4 {
851            let mut w_plus = base_weight;
852            w_plus[idx] += eps;
853            let mut layer_plus = Linear::<f32>::new(2, 2, false).unwrap();
854            layer_plus.weight = Parameter::from_slice(&w_plus, &[2, 2]).unwrap();
855            let input_ng = leaf(&input_data, &[2, 2], false);
856            let out_plus = ferrotorch_core::no_grad(|| {
857                let o = layer_plus.forward(&input_ng).unwrap();
858                ferrotorch_core::grad_fns::reduction::sum(&o).unwrap()
859            });
860            let loss_plus = out_plus.item().unwrap();
861
862            let mut w_minus = base_weight;
863            w_minus[idx] -= eps;
864            let mut layer_minus = Linear::<f32>::new(2, 2, false).unwrap();
865            layer_minus.weight = Parameter::from_slice(&w_minus, &[2, 2]).unwrap();
866            let input_ng2 = leaf(&input_data, &[2, 2], false);
867            let out_minus = ferrotorch_core::no_grad(|| {
868                let o = layer_minus.forward(&input_ng2).unwrap();
869                ferrotorch_core::grad_fns::reduction::sum(&o).unwrap()
870            });
871            let loss_minus = out_minus.item().unwrap();
872
873            let numerical = (loss_plus - loss_minus) / (2.0 * eps);
874            assert!(
875                (numerical - analytic[idx]).abs() < 1e-2,
876                "weight[{idx}]: numerical={numerical}, analytic={}, diff={}",
877                analytic[idx],
878                (numerical - analytic[idx]).abs()
879            );
880        }
881    }
882
883    // -----------------------------------------------------------------------
884    // Parameter count
885    // -----------------------------------------------------------------------
886
887    #[test]
888    fn test_parameter_count_with_bias() {
889        let layer = Linear::<f32>::new(10, 5, true).unwrap();
890        let params = layer.parameters();
891        assert_eq!(params.len(), 2);
892        // weight: 10 * 5 = 50 elements, bias: 5 elements
893        let total: usize = params.iter().map(|p| p.numel()).sum();
894        assert_eq!(total, 55);
895    }
896
897    #[test]
898    fn test_parameter_count_without_bias() {
899        let layer = Linear::<f32>::new(10, 5, false).unwrap();
900        let params = layer.parameters();
901        assert_eq!(params.len(), 1);
902        let total: usize = params.iter().map(|p| p.numel()).sum();
903        assert_eq!(total, 50);
904    }
905
906    // -----------------------------------------------------------------------
907    // State dict roundtrip
908    // -----------------------------------------------------------------------
909
910    #[test]
911    fn test_state_dict_roundtrip_with_bias() {
912        let layer = Linear::<f32>::new(4, 3, true).unwrap();
913        let sd = layer.state_dict();
914        assert!(sd.contains_key("weight"));
915        assert!(sd.contains_key("bias"));
916        assert_eq!(sd["weight"].shape(), &[3, 4]);
917        assert_eq!(sd["bias"].shape(), &[3]);
918
919        let mut layer2 = Linear::<f32>::new(4, 3, true).unwrap();
920        layer2.load_state_dict(&sd, true).unwrap();
921
922        // Verify loaded weights match.
923        assert_close(
924            layer2.weight.data().unwrap(),
925            layer.weight.data().unwrap(),
926            1e-7,
927        );
928        assert_close(
929            layer2.bias.as_ref().unwrap().data().unwrap(),
930            layer.bias.as_ref().unwrap().data().unwrap(),
931            1e-7,
932        );
933    }
934
935    #[test]
936    fn test_state_dict_roundtrip_without_bias() {
937        let layer = Linear::<f32>::new(6, 2, false).unwrap();
938        let sd = layer.state_dict();
939        assert!(sd.contains_key("weight"));
940        assert!(!sd.contains_key("bias"));
941
942        let mut layer2 = Linear::<f32>::new(6, 2, false).unwrap();
943        layer2.load_state_dict(&sd, true).unwrap();
944
945        assert_close(
946            layer2.weight.data().unwrap(),
947            layer.weight.data().unwrap(),
948            1e-7,
949        );
950    }
951
952    #[test]
953    fn test_state_dict_shape_mismatch_rejected() {
954        let layer_a = Linear::<f32>::new(4, 3, true).unwrap();
955        let sd = layer_a.state_dict();
956
957        let mut layer_b = Linear::<f32>::new(4, 5, true).unwrap();
958        assert!(layer_b.load_state_dict(&sd, true).is_err());
959    }
960
961    // -----------------------------------------------------------------------
962    // Named parameters
963    // -----------------------------------------------------------------------
964
965    #[test]
966    fn test_named_parameters_with_bias() {
967        let layer = Linear::<f32>::new(3, 2, true).unwrap();
968        let named = layer.named_parameters();
969        assert_eq!(named.len(), 2);
970        assert_eq!(named[0].0, "weight");
971        assert_eq!(named[1].0, "bias");
972    }
973
974    #[test]
975    fn test_named_parameters_without_bias() {
976        let layer = Linear::<f32>::new(3, 2, false).unwrap();
977        let named = layer.named_parameters();
978        assert_eq!(named.len(), 1);
979        assert_eq!(named[0].0, "weight");
980    }
981
982    // -----------------------------------------------------------------------
983    // Train / Eval
984    // -----------------------------------------------------------------------
985
986    #[test]
987    fn test_train_eval() {
988        let mut layer = Linear::<f32>::new(4, 3, true).unwrap();
989        assert!(layer.is_training());
990        layer.eval();
991        assert!(!layer.is_training());
992        layer.train();
993        assert!(layer.is_training());
994    }
995
996    // -----------------------------------------------------------------------
997    // Display
998    // -----------------------------------------------------------------------
999
1000    #[test]
1001    fn test_display() {
1002        let layer = Linear::<f32>::new(10, 5, true).unwrap();
1003        let s = format!("{layer}");
1004        assert_eq!(s, "Linear(in_features=10, out_features=5, bias=true)");
1005    }
1006
1007    #[test]
1008    fn test_display_no_bias() {
1009        let layer = Linear::<f32>::new(10, 5, false).unwrap();
1010        let s = format!("{layer}");
1011        assert_eq!(s, "Linear(in_features=10, out_features=5, bias=false)");
1012    }
1013
1014    // -----------------------------------------------------------------------
1015    // Send + Sync
1016    // -----------------------------------------------------------------------
1017
1018    #[test]
1019    fn test_linear_is_send_sync() {
1020        fn assert_send_sync<T: Send + Sync>() {}
1021        assert_send_sync::<Linear<f32>>();
1022        assert_send_sync::<Linear<f64>>();
1023    }
1024
1025    // -----------------------------------------------------------------------
1026    // Bias init bounds — REQ-6 / closes #1450
1027    // -----------------------------------------------------------------------
1028
1029    /// Verifies bias is initialized within `U(-bound, bound)` where
1030    /// `bound = 1/sqrt(in_features)` per `torch/nn/modules/linear.py:124-128`.
1031    /// Pre-fix the bias was identically 0.0 (zeros_init), which would FAIL
1032    /// the `nonzero` assertion below with overwhelming probability.
1033    #[test]
1034    fn test_linear_bias_init_bounded_uniform() {
1035        let in_features = 64usize;
1036        let out_features = 128usize;
1037        let layer = Linear::<f32>::new(in_features, out_features, true).unwrap();
1038        let bias = layer.bias.as_ref().expect("bias requested");
1039        let bias_data = bias.tensor().data_vec().unwrap();
1040        let bound = 1.0_f32 / (in_features as f32).sqrt();
1041        let mut nonzero = 0usize;
1042        for &b in &bias_data {
1043            assert!(
1044                b.abs() <= bound + 1e-6,
1045                "bias element {b} exceeds bound {bound}"
1046            );
1047            if b != 0.0 {
1048                nonzero += 1;
1049            }
1050        }
1051        assert!(
1052            nonzero > out_features / 2,
1053            "expected most bias entries to be nonzero (got {nonzero}/{out_features}); \
1054             would FAIL pre-fix when bias was zeros_init"
1055        );
1056    }
1057
1058    // -----------------------------------------------------------------------
1059    // Device transfer
1060    // -----------------------------------------------------------------------
1061
1062    #[test]
1063    fn test_to_device_cpu_preserves_weights() {
1064        let mut layer = Linear::<f32>::new(4, 3, true).unwrap();
1065        layer.weight = Parameter::from_slice(
1066            &[
1067                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1068            ],
1069            &[3, 4],
1070        )
1071        .unwrap();
1072        *layer.bias.as_mut().unwrap() = Parameter::from_slice(&[0.1, 0.2, 0.3], &[3]).unwrap();
1073
1074        layer.to_device(ferrotorch_core::Device::Cpu).unwrap();
1075
1076        assert_eq!(layer.weight.shape(), &[3, 4]);
1077        assert_close(
1078            layer.weight.data().unwrap(),
1079            &[
1080                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1081            ],
1082            1e-7,
1083        );
1084        assert_close(
1085            layer.bias.as_ref().unwrap().data().unwrap(),
1086            &[0.1, 0.2, 0.3],
1087            1e-7,
1088        );
1089        assert!(layer.weight.requires_grad());
1090        assert!(layer.bias.as_ref().unwrap().requires_grad());
1091    }
1092
1093    #[test]
1094    fn test_to_device_cuda_returns_device_unavailable() {
1095        let mut layer = Linear::<f32>::new(4, 3, true).unwrap();
1096        let result = layer.to_device(ferrotorch_core::Device::Cuda(0));
1097        assert!(result.is_err());
1098    }
1099
1100    // -----------------------------------------------------------------------
1101    // Bilinear N-D input — closes #1603
1102    //
1103    // Oracle values constructed by live-calling PyTorch 2.11 (R-CHAR-3):
1104    //   import torch
1105    //   y = torch.nn.functional.bilinear(x1, x2, W, b)
1106    //   y.sum().backward()  # for the four gradients
1107    // Each test documents the exact torch invocation that produced its
1108    // expected tensor. The bilinear contract is
1109    // `torch/nn/modules/linear.py:172-178` (shape `(*, H_in)`) and
1110    // `aten/src/ATen/native/Linear.cpp:792-802` (flatten-2D-then-reshape).
1111    // -----------------------------------------------------------------------
1112
1113    /// Build the shared deterministic weight `[out=2, in1=3, in2=2]` and bias
1114    /// `[out=2]` used by the 3-D forward/backward oracle tests below. These
1115    /// exact values are what was fed to `torch.nn.functional.bilinear` to
1116    /// produce the expected outputs/gradients.
1117    fn bilinear_3d_layer() -> Bilinear<f32> {
1118        let mut layer = Bilinear::<f32>::new(3, 2, 2, true).unwrap();
1119        // W[o,i,j], row-major flatten of the [2,3,2] tensor.
1120        layer.weight = Parameter::from_slice(
1121            &[
1122                0.1, 0.2, 0.3, -0.1, -0.2, 0.05, // o=0
1123                0.0, 0.4, -0.3, 0.2, 0.1, -0.15, // o=1
1124            ],
1125            &[2, 3, 2],
1126        )
1127        .unwrap();
1128        *layer.bias.as_mut().unwrap() = Parameter::from_slice(&[0.5, -0.25], &[2]).unwrap();
1129        layer
1130    }
1131
1132    #[test]
1133    fn test_bilinear_3d_forward_matches_torch() {
1134        // torch:
1135        //   x1 = [[[1,2,3],[-1,0.5,2]],[[0,1,-1],[2,-2,1]]]  # (2,2,3)
1136        //   x2 = [[[1,-1],[0.5,2]],[[-1,1],[3,0]]]            # (2,2,2)
1137        //   F.bilinear(x1, x2, W, b).shape == (2,2,2)
1138        let layer = bilinear_3d_layer();
1139        let x1 = leaf(
1140            &[
1141                1.0, 2.0, 3.0, -1.0, 0.5, 2.0, 0.0, 1.0, -1.0, 2.0, -2.0, 1.0,
1142            ],
1143            &[2, 2, 3],
1144            false,
1145        );
1146        let x2 = leaf(
1147            &[1.0, -1.0, 0.5, 2.0, -1.0, 1.0, 3.0, 0.0],
1148            &[2, 2, 2],
1149            false,
1150        );
1151        let y = layer.forward_pair(&x1, &x2).unwrap();
1152        assert_eq!(y.shape(), &[2, 2, 2]);
1153        // FWD3D_out from torch oracle.
1154        assert_close(
1155            y.data().unwrap(),
1156            &[0.45, -0.9, 0.025, -1.425, -0.15, 0.5, -1.3, 1.85],
1157            1e-5,
1158        );
1159    }
1160
1161    #[test]
1162    fn test_bilinear_3d_backward_matches_torch() {
1163        // Same inputs as the forward test; loss = y.sum().
1164        // Expected grads are GRAD_x1 / GRAD_x2 / GRAD_W / GRAD_b from torch.
1165        let layer = bilinear_3d_layer();
1166        let x1 = leaf(
1167            &[
1168                1.0, 2.0, 3.0, -1.0, 0.5, 2.0, 0.0, 1.0, -1.0, 2.0, -2.0, 1.0,
1169            ],
1170            &[2, 2, 3],
1171            true,
1172        );
1173        let x2 = leaf(
1174            &[1.0, -1.0, 0.5, 2.0, -1.0, 1.0, 3.0, 0.0],
1175            &[2, 2, 2],
1176            true,
1177        );
1178        let y = layer.forward_pair(&x1, &x2).unwrap();
1179        let loss = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
1180        loss.backward().unwrap();
1181
1182        let g_x1 = x1.grad().unwrap().expect("x1 should have grad");
1183        assert_eq!(g_x1.shape(), &[2, 2, 3]);
1184        assert_close(
1185            g_x1.data().unwrap(),
1186            &[
1187                -0.5, -0.1, 0.0, 1.25, 0.2, -0.25, 0.5, 0.1, 0.0, 0.3, 0.0, -0.3,
1188            ],
1189            1e-5,
1190        );
1191
1192        let g_x2 = x2.grad().unwrap().expect("x2 should have grad");
1193        assert_eq!(g_x2.shape(), &[2, 2, 2]);
1194        assert_close(
1195            g_x2.data().unwrap(),
1196            &[-0.2, 0.5, -0.3, -0.75, 0.1, 0.2, 0.1, 0.9],
1197            1e-5,
1198        );
1199
1200        let g_w = layer.weight.grad().unwrap().expect("W should have grad");
1201        assert_eq!(g_w.shape(), &[2, 3, 2]);
1202        assert_close(
1203            g_w.data().unwrap(),
1204            &[
1205                6.5, -3.0, -4.75, 0.0, 8.0, 0.0, 6.5, -3.0, -4.75, 0.0, 8.0, 0.0,
1206            ],
1207            1e-5,
1208        );
1209
1210        let g_b = layer
1211            .bias
1212            .as_ref()
1213            .unwrap()
1214            .grad()
1215            .unwrap()
1216            .expect("bias should have grad");
1217        assert_eq!(g_b.shape(), &[2]);
1218        assert_close(g_b.data().unwrap(), &[4.0, 4.0], 1e-5);
1219    }
1220
1221    #[test]
1222    fn test_bilinear_4d_forward_matches_torch() {
1223        // torch:
1224        //   W = [[[1,0],[0,1]]]  (out=1,in1=2,in2=2 -> identity contraction)
1225        //   x1 = [[[[1,2],[3,4]]],[[[5,6],[7,8]]]]  # (2,1,2,2)
1226        //   x2 = [[[[1,1],[1,1]]],[[[2,2],[2,2]]]]  # (2,1,2,2)
1227        //   F.bilinear(x1,x2,W).shape == (2,1,2,1); data == [3,7,22,30]
1228        let mut layer = Bilinear::<f32>::new(2, 2, 1, false).unwrap();
1229        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2]).unwrap();
1230        let x1 = leaf(
1231            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
1232            &[2, 1, 2, 2],
1233            false,
1234        );
1235        let x2 = leaf(
1236            &[1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0],
1237            &[2, 1, 2, 2],
1238            false,
1239        );
1240        let y = layer.forward_pair(&x1, &x2).unwrap();
1241        assert_eq!(y.shape(), &[2, 1, 2, 1]);
1242        assert_close(y.data().unwrap(), &[3.0, 7.0, 22.0, 30.0], 1e-5);
1243    }
1244
1245    #[test]
1246    fn test_bilinear_2d_still_matches_torch() {
1247        // Regression guard: the pre-existing 2-D path must keep working.
1248        // torch:
1249        //   W = [[[1,0],[0,1]]] (out=1,in1=2,in2=2), x1=[[1,2],[3,4]],
1250        //   x2=[[1,1],[1,1]] -> y = [[1*1+2*1],[3*1+4*1]] = [[3],[7]]
1251        let mut layer = Bilinear::<f32>::new(2, 2, 1, false).unwrap();
1252        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2]).unwrap();
1253        let x1 = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
1254        let x2 = leaf(&[1.0, 1.0, 1.0, 1.0], &[2, 2], false);
1255        let y = layer.forward_pair(&x1, &x2).unwrap();
1256        assert_eq!(y.shape(), &[2, 1]);
1257        assert_close(y.data().unwrap(), &[3.0, 7.0], 1e-5);
1258    }
1259
1260    #[test]
1261    fn test_bilinear_1d_still_matches_torch() {
1262        // Regression guard: a 1-D pair (no batch dim) -> (out,).
1263        // torch: W=[[[1,0],[0,1]]], x1=[2,3], x2=[1,1] -> y=[2*1+3*1]=[5]
1264        let mut layer = Bilinear::<f32>::new(2, 2, 1, false).unwrap();
1265        layer.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2]).unwrap();
1266        let x1 = leaf(&[2.0, 3.0], &[2], false);
1267        let x2 = leaf(&[1.0, 1.0], &[2], false);
1268        let y = layer.forward_pair(&x1, &x2).unwrap();
1269        assert_eq!(y.shape(), &[1]);
1270        assert_close(y.data().unwrap(), &[5.0], 1e-5);
1271    }
1272
1273    #[test]
1274    fn test_bilinear_empty_leading_dim_2d() {
1275        // torch: F.bilinear(zeros(0,3), zeros(0,2), W, b).shape == (0,2)
1276        let layer = bilinear_3d_layer();
1277        let x1 = leaf(&[], &[0, 3], false);
1278        let x2 = leaf(&[], &[0, 2], false);
1279        let y = layer.forward_pair(&x1, &x2).unwrap();
1280        assert_eq!(y.shape(), &[0, 2]);
1281        assert_eq!(y.numel(), 0);
1282    }
1283
1284    #[test]
1285    fn test_bilinear_empty_leading_dim_3d() {
1286        // torch: F.bilinear(zeros(0,4,3), zeros(0,4,2), W, b).shape == (0,4,2)
1287        let layer = bilinear_3d_layer();
1288        let x1 = leaf(&[], &[0, 4, 3], false);
1289        let x2 = leaf(&[], &[0, 4, 2], false);
1290        let y = layer.forward_pair(&x1, &x2).unwrap();
1291        assert_eq!(y.shape(), &[0, 4, 2]);
1292        assert_eq!(y.numel(), 0);
1293    }
1294
1295    #[test]
1296    fn test_bilinear_zero_middle_dim_3d() {
1297        // torch: F.bilinear(zeros(2,0,3), zeros(2,0,2), W, b).shape == (2,0,2)
1298        let layer = bilinear_3d_layer();
1299        let x1 = leaf(&[], &[2, 0, 3], false);
1300        let x2 = leaf(&[], &[2, 0, 2], false);
1301        let y = layer.forward_pair(&x1, &x2).unwrap();
1302        assert_eq!(y.shape(), &[2, 0, 2]);
1303        assert_eq!(y.numel(), 0);
1304    }
1305
1306    #[test]
1307    fn test_bilinear_mismatched_ndim_rejected() {
1308        // torch raises: "bilinear(): input dimensions do not match: got 3 and 2"
1309        let layer = bilinear_3d_layer();
1310        let x1 = leaf(&[0.0; 2 * 2 * 3], &[2, 2, 3], false);
1311        let x2 = leaf(&[0.0; 2 * 2], &[2, 2], false);
1312        assert!(layer.forward_pair(&x1, &x2).is_err());
1313    }
1314
1315    #[test]
1316    fn test_bilinear_mismatched_leading_dim_rejected() {
1317        // torch raises: "bilinear(): input batch dimensions do not match at
1318        // dim 1: got 3 and 4"
1319        let layer = bilinear_3d_layer();
1320        let x1 = leaf(&[0.0; 2 * 3 * 3], &[2, 3, 3], false);
1321        let x2 = leaf(&[0.0; 2 * 4 * 2], &[2, 4, 2], false);
1322        assert!(layer.forward_pair(&x1, &x2).is_err());
1323    }
1324
1325    #[test]
1326    fn test_bilinear_wrong_feature_dim_rejected() {
1327        // torch raises: "input1 size does not match weight size".
1328        let layer = bilinear_3d_layer(); // in1=3, in2=2
1329        let bad_x1 = leaf(&[0.0; 2 * 2 * 4], &[2, 2, 4], false); // last dim 4 != 3
1330        let x2 = leaf(&[0.0; 2 * 2 * 2], &[2, 2, 2], false);
1331        assert!(layer.forward_pair(&bad_x1, &x2).is_err());
1332    }
1333}