Skip to main content

ferrotorch_nn/
identity.rs

1//! Identity and Flatten modules + small shape/distance modules.
2//!
3//! [`Identity`] passes input through unchanged — useful for model composition,
4//! conditional layers, and debugging.
5//!
6//! [`Flatten`] reshapes input by flattening contiguous dimensions from
7//! `start_dim` to `end_dim` into a single dimension. The default
8//! (`start_dim=1, end_dim=-1`) flattens everything except the batch dimension.
9//!
10//! [`Unflatten`] is Flatten's inverse; [`ChannelShuffle`] implements the
11//! ShuffleNet channel permutation; [`CosineSimilarity`] and
12//! [`PairwiseDistance`] are paired distance helpers.
13//!
14//! ## REQ status (per `.design/ferrotorch-nn/identity.md`)
15//!
16//! | REQ | Status | Evidence |
17//! |---|---|---|
18//! | REQ-1 | SHIPPED | `pub struct Identity` (`Default`, `Clone`, `Copy`) + `impl Module<T> for Identity` mirror `torch/nn/modules/linear.py:18-42`; consumed by `pub use identity::Identity` at `lib.rs:204` and downstream conditional-no-op CNN composition patterns. |
19//! | REQ-2 | SHIPPED | `pub struct Flatten` with `start_dim: usize`, `end_dim: isize`, `Default::default() == Flatten::new(1, -1)` mirror `torch/nn/modules/flatten.py:8-60`; consumed by `lib.rs:204` re-export and canonical CNN→FC transition patterns. |
20//! | REQ-3 | SHIPPED | `Flatten::forward` full edge-case ladder (0-D error, 1-D clone, negative-`end_dim` resolution, start>end check, no-op, final `grad_fns::shape::reshape` dispatch); consumed by every downstream model that flattens after pooling; `grad_fns::shape::reshape` is the autograd-aware production path. |
21//! | REQ-4 | SHIPPED | `pub struct Unflatten` with inherent typed forward + `impl Module<T> for Unflatten` delegating to it mirror `torch/nn/modules/flatten.py:62-167`; consumed by `lib.rs:204` re-export and downstream reshape-after-flatten decoder construction. |
22//! | REQ-5 | SHIPPED | `pub struct ChannelShuffle` with `[N, g, cpg, *] → [N, cpg, g, *]` permutation + CUDA `NotImplementedOnCuda` error mirrors `torch/nn/modules/channelshuffle.py`; consumed by `lib.rs:204` re-export and ShuffleNet-family vision architectures. |
23//! | REQ-6 | SHIPPED | `pub struct CosineSimilarity` with `dim`, `eps`, `Default::default() == (1, 1e-8)`, CPU forward, CUDA error mirrors `torch/nn/modules/distance.py:72-100`; consumed by `lib.rs:204` re-export and contrastive-learning / embedding-similarity code. |
24//! | REQ-7 | SHIPPED | `pub struct PairwiseDistance` with `p`, `eps`, `keepdim`, `Default::default() == (2.0, 1e-6, false)`, CPU forward, CUDA error mirrors `torch/nn/modules/distance.py:8-70`; consumed by `lib.rs:204` re-export and embedding-distance training drivers. |
25//! | REQ-8 | SHIPPED | Explicit `is_cuda()` guards returning `Err(FerrotorchError::NotImplementedOnCuda)` in `ChannelShuffle::forward`, `CosineSimilarity::forward`, `PairwiseDistance::forward` (R-CODE-4 — no silent CPU↔GPU round-trips); consumed by every CUDA-tensor invocation surfacing the error to the caller. |
26
27use ferrotorch_core::grad_fns::shape::reshape;
28use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
29
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33// ===========================================================================
34// Identity
35// ===========================================================================
36
37/// A module that returns its input unchanged.
38///
39/// Useful as a placeholder in model architectures where a layer is
40/// conditionally applied, or for debugging / hook attachment points.
41///
42/// Has zero learnable parameters.
43///
44/// # Examples
45///
46/// ```ignore
47/// let id = Identity;
48/// let output = id.forward(&input)?; // output == input
49/// ```
50#[derive(Debug, Clone, Copy)]
51pub struct Identity {
52    training: bool,
53}
54
55impl Identity {
56    /// Create a new `Identity` module.
57    pub fn new() -> Self {
58        Self { training: true }
59    }
60}
61
62impl Default for Identity {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl<T: Float> Module<T> for Identity {
69    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
70        Ok(input.clone())
71    }
72
73    fn parameters(&self) -> Vec<&Parameter<T>> {
74        vec![]
75    }
76
77    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
78        vec![]
79    }
80
81    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
82        vec![]
83    }
84
85    fn train(&mut self) {
86        self.training = true;
87    }
88
89    fn eval(&mut self) {
90        self.training = false;
91    }
92
93    fn is_training(&self) -> bool {
94        self.training
95    }
96}
97
98// ===========================================================================
99// Flatten
100// ===========================================================================
101
102/// Flattens a contiguous range of dimensions in a tensor.
103///
104/// By default, flattens all dimensions except the batch dimension
105/// (`start_dim=1, end_dim=-1`), producing output of shape `[B, *]`.
106///
107/// Negative `end_dim` values are resolved relative to the input's
108/// number of dimensions (`-1` = last dim).
109///
110/// # Examples
111///
112/// ```ignore
113/// // Input: [2, 3, 4, 5]
114/// let flatten = Flatten::new(1, -1);
115/// let output = flatten.forward(&input)?;
116/// // Output: [2, 60]
117///
118/// // Flatten specific range
119/// let flatten = Flatten::new(2, 3);
120/// let output = flatten.forward(&input)?;
121/// // Output: [2, 3, 20]
122/// ```
123#[derive(Debug, Clone, Copy)]
124pub struct Flatten {
125    /// First dimension to flatten (inclusive).
126    pub start_dim: usize,
127    /// Last dimension to flatten (inclusive). Negative values count from the end.
128    pub end_dim: isize,
129    training: bool,
130}
131
132impl Flatten {
133    /// Create a new `Flatten` module.
134    ///
135    /// # Arguments
136    ///
137    /// * `start_dim` - First dimension to flatten (inclusive, 0-indexed).
138    /// * `end_dim` - Last dimension to flatten (inclusive). Use `-1` for the
139    ///   last dimension, `-2` for second-to-last, etc.
140    pub fn new(start_dim: usize, end_dim: isize) -> Self {
141        Self {
142            start_dim,
143            end_dim,
144            training: true,
145        }
146    }
147
148    /// Resolve `end_dim` to a concrete dimension index.
149    fn resolve_end_dim(&self, ndim: usize) -> FerrotorchResult<usize> {
150        let resolved = if self.end_dim < 0 {
151            let d = ndim as isize + self.end_dim;
152            if d < 0 {
153                return Err(FerrotorchError::InvalidArgument {
154                    message: format!(
155                        "Flatten: end_dim {} is out of range for input with {} dims",
156                        self.end_dim, ndim
157                    ),
158                });
159            }
160            d as usize
161        } else {
162            self.end_dim as usize
163        };
164
165        if resolved >= ndim {
166            return Err(FerrotorchError::InvalidArgument {
167                message: format!(
168                    "Flatten: resolved end_dim {} is out of range for input with {} dims",
169                    resolved, ndim
170                ),
171            });
172        }
173
174        Ok(resolved)
175    }
176}
177
178impl Default for Flatten {
179    fn default() -> Self {
180        Self::new(1, -1)
181    }
182}
183
184impl<T: Float> Module<T> for Flatten {
185    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
186        let shape = input.shape();
187        let ndim = shape.len();
188
189        // 0-D tensor: nothing to flatten.
190        if ndim == 0 {
191            return Err(FerrotorchError::InvalidArgument {
192                message: "Flatten: cannot flatten a 0-D (scalar) tensor".into(),
193            });
194        }
195
196        // 1-D tensor: already flat.
197        if ndim == 1 {
198            return Ok(input.clone());
199        }
200
201        if self.start_dim >= ndim {
202            return Err(FerrotorchError::InvalidArgument {
203                message: format!(
204                    "Flatten: start_dim {} is out of range for input with {} dims",
205                    self.start_dim, ndim
206                ),
207            });
208        }
209
210        let end_dim = self.resolve_end_dim(ndim)?;
211
212        if self.start_dim > end_dim {
213            return Err(FerrotorchError::InvalidArgument {
214                message: format!(
215                    "Flatten: start_dim ({}) must be <= end_dim ({})",
216                    self.start_dim, end_dim
217                ),
218            });
219        }
220
221        // If start == end, no flattening needed.
222        if self.start_dim == end_dim {
223            return Ok(input.clone());
224        }
225
226        // Build new shape: [dims before start, flattened, dims after end].
227        let mut new_shape: Vec<isize> = Vec::with_capacity(ndim - (end_dim - self.start_dim));
228
229        for &d in &shape[..self.start_dim] {
230            new_shape.push(d as isize);
231        }
232
233        // Flatten the range [start_dim..=end_dim] into one dim.
234        let flattened: usize = shape[self.start_dim..=end_dim].iter().product();
235        new_shape.push(flattened as isize);
236
237        for &d in &shape[end_dim + 1..] {
238            new_shape.push(d as isize);
239        }
240
241        reshape(input, &new_shape)
242    }
243
244    fn parameters(&self) -> Vec<&Parameter<T>> {
245        vec![]
246    }
247
248    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
249        vec![]
250    }
251
252    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
253        vec![]
254    }
255
256    fn train(&mut self) {
257        self.training = true;
258    }
259
260    fn eval(&mut self) {
261        self.training = false;
262    }
263
264    fn is_training(&self) -> bool {
265        self.training
266    }
267}
268
269// ===========================================================================
270// Unflatten
271// ===========================================================================
272
273/// Unflattens a dimension, expanding it into multiple dimensions.
274///
275/// The inverse of [`Flatten`]. Given an input where dimension `dim` has
276/// size equal to the product of `unflattened_size`, reshapes that
277/// dimension into the specified shape.
278///
279/// Matches PyTorch's `nn.Unflatten`.
280#[derive(Debug, Clone)]
281pub struct Unflatten {
282    /// The dimension to unflatten.
283    pub dim: usize,
284    /// The target shape for the unflattened dimension.
285    pub unflattened_size: Vec<usize>,
286    training: bool,
287}
288
289impl Unflatten {
290    pub fn new(dim: usize, unflattened_size: Vec<usize>) -> Self {
291        Self {
292            dim,
293            unflattened_size,
294            training: true,
295        }
296    }
297
298    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
299        let shape = input.shape();
300        if self.dim >= shape.len() {
301            return Err(FerrotorchError::InvalidArgument {
302                message: format!(
303                    "Unflatten: dim {} out of range for input with {} dims",
304                    self.dim,
305                    shape.len()
306                ),
307            });
308        }
309
310        let expected_size: usize = self.unflattened_size.iter().product();
311        if expected_size != shape[self.dim] {
312            return Err(FerrotorchError::InvalidArgument {
313                message: format!(
314                    "Unflatten: unflattened_size {:?} (product={}) doesn't match dim {} size {}",
315                    self.unflattened_size, expected_size, self.dim, shape[self.dim]
316                ),
317            });
318        }
319
320        let mut new_shape = Vec::with_capacity(shape.len() - 1 + self.unflattened_size.len());
321        new_shape.extend_from_slice(&shape[..self.dim]);
322        new_shape.extend_from_slice(&self.unflattened_size);
323        new_shape.extend_from_slice(&shape[self.dim + 1..]);
324
325        input.view_reshape(new_shape)
326    }
327}
328
329impl<T: Float> Module<T> for Unflatten {
330    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
331        Unflatten::forward(self, input)
332    }
333
334    fn parameters(&self) -> Vec<&Parameter<T>> {
335        vec![]
336    }
337    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
338        vec![]
339    }
340    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
341        vec![]
342    }
343    fn train(&mut self) {
344        self.training = true;
345    }
346    fn eval(&mut self) {
347        self.training = false;
348    }
349    fn is_training(&self) -> bool {
350        self.training
351    }
352}
353
354// ===========================================================================
355// ChannelShuffle
356// ===========================================================================
357
358/// Rearranges channels in a [N, C, H, W] tensor by dividing them into
359/// groups and interleaving.
360///
361/// Used in ShuffleNet architectures. With `groups=g`, the channel
362/// dimension is reshaped to `[g, C/g]`, transposed to `[C/g, g]`,
363/// then flattened back to `[C]`.
364///
365/// Matches PyTorch's `nn.ChannelShuffle`.
366#[derive(Debug, Clone)]
367pub struct ChannelShuffle {
368    pub groups: usize,
369    training: bool,
370}
371
372impl ChannelShuffle {
373    pub fn new(groups: usize) -> Self {
374        Self {
375            groups,
376            training: true,
377        }
378    }
379
380    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
381        if input.ndim() < 2 {
382            return Err(FerrotorchError::InvalidArgument {
383                message: format!(
384                    "ChannelShuffle: input must have at least 2 dims, got {:?}",
385                    input.shape()
386                ),
387            });
388        }
389        if input.is_cuda() {
390            return Err(FerrotorchError::NotImplementedOnCuda {
391                op: "ChannelShuffle",
392            });
393        }
394
395        let shape = input.shape();
396        let channels = shape[1];
397        if channels % self.groups != 0 {
398            return Err(FerrotorchError::InvalidArgument {
399                message: format!(
400                    "ChannelShuffle: channels ({}) must be divisible by groups ({})",
401                    channels, self.groups
402                ),
403            });
404        }
405
406        let g = self.groups;
407        let cpg = channels / g; // channels per group
408        let batch = shape[0];
409        let spatial: usize = shape[2..].iter().product();
410        let data = input.data()?;
411
412        // Reshape [N, C, *] → [N, g, cpg, *] → transpose → [N, cpg, g, *] → [N, C, *]
413        let mut out = vec![<T as num_traits::Zero>::zero(); data.len()];
414        for n in 0..batch {
415            for c_out in 0..channels {
416                // c_out in the shuffled order: group index = c_out % g, within-group = c_out / g
417                let c_in = (c_out % g) * cpg + (c_out / g);
418                for s in 0..spatial {
419                    out[n * channels * spatial + c_out * spatial + s] =
420                        data[n * channels * spatial + c_in * spatial + s];
421                }
422            }
423        }
424
425        Tensor::from_storage(
426            ferrotorch_core::storage::TensorStorage::cpu(out),
427            shape.to_vec(),
428            false,
429        )
430    }
431}
432
433impl<T: Float> Module<T> for ChannelShuffle {
434    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
435        ChannelShuffle::forward(self, input)
436    }
437
438    fn parameters(&self) -> Vec<&Parameter<T>> {
439        vec![]
440    }
441    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
442        vec![]
443    }
444    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
445        vec![]
446    }
447    fn train(&mut self) {
448        self.training = true;
449    }
450    fn eval(&mut self) {
451        self.training = false;
452    }
453    fn is_training(&self) -> bool {
454        self.training
455    }
456}
457
458// ===========================================================================
459// CosineSimilarity
460// ===========================================================================
461
462/// Computes cosine similarity between two tensors along a dimension.
463///
464/// `cos(x1, x2) = (x1 . x2) / (||x1|| * ||x2||)`
465///
466/// Matches PyTorch's `nn.CosineSimilarity`.
467#[derive(Debug, Clone)]
468pub struct CosineSimilarity {
469    /// Dimension along which to compute cosine similarity.
470    pub dim: usize,
471    /// Small value to avoid division by zero.
472    pub eps: f64,
473}
474
475impl CosineSimilarity {
476    pub fn new(dim: usize, eps: f64) -> Self {
477        Self { dim, eps }
478    }
479
480    pub fn forward<T: Float>(&self, x1: &Tensor<T>, x2: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
481        if x1.shape() != x2.shape() {
482            return Err(FerrotorchError::ShapeMismatch {
483                message: format!(
484                    "CosineSimilarity: shapes must match, got {:?} and {:?}",
485                    x1.shape(),
486                    x2.shape()
487                ),
488            });
489        }
490        if x1.is_cuda() || x2.is_cuda() {
491            return Err(FerrotorchError::NotImplementedOnCuda {
492                op: "CosineSimilarity",
493            });
494        }
495
496        let shape = x1.shape();
497        if self.dim >= shape.len() {
498            return Err(FerrotorchError::InvalidArgument {
499                message: format!(
500                    "CosineSimilarity: dim {} out of range for shape {:?}",
501                    self.dim, shape
502                ),
503            });
504        }
505
506        let d1 = x1.data()?;
507        let d2 = x2.data()?;
508        let dim_size = shape[self.dim];
509        let outer: usize = shape[..self.dim].iter().product();
510        let inner: usize = shape[self.dim + 1..].iter().product();
511        let eps_t = T::from(self.eps).unwrap();
512
513        let out_numel = outer * inner;
514        let mut result = Vec::with_capacity(out_numel);
515
516        for o in 0..outer {
517            for i in 0..inner {
518                let mut dot = <T as num_traits::Zero>::zero();
519                let mut n1 = <T as num_traits::Zero>::zero();
520                let mut n2 = <T as num_traits::Zero>::zero();
521                for d in 0..dim_size {
522                    let idx = o * dim_size * inner + d * inner + i;
523                    dot += d1[idx] * d2[idx];
524                    n1 += d1[idx] * d1[idx];
525                    n2 += d2[idx] * d2[idx];
526                }
527                let denom = (n1.sqrt() * n2.sqrt()).max(eps_t);
528                result.push(dot / denom);
529            }
530        }
531
532        let mut out_shape = shape.to_vec();
533        out_shape.remove(self.dim);
534        if out_shape.is_empty() {
535            out_shape.push(1);
536        }
537        Tensor::from_storage(
538            ferrotorch_core::storage::TensorStorage::cpu(result),
539            out_shape,
540            false,
541        )
542    }
543}
544
545impl Default for CosineSimilarity {
546    fn default() -> Self {
547        Self::new(1, 1e-8)
548    }
549}
550
551// ===========================================================================
552// PairwiseDistance
553// ===========================================================================
554
555/// Computes the pairwise distance between two tensors using the p-norm.
556///
557/// `d(x1, x2) = ||x1 - x2||_p`
558///
559/// Matches PyTorch's `nn.PairwiseDistance`.
560#[derive(Debug, Clone)]
561pub struct PairwiseDistance {
562    /// The norm degree (default: 2.0 for Euclidean).
563    pub p: f64,
564    /// Small value to avoid division by zero.
565    pub eps: f64,
566    /// Whether to keep the output dimension.
567    pub keepdim: bool,
568}
569
570impl PairwiseDistance {
571    pub fn new(p: f64, eps: f64, keepdim: bool) -> Self {
572        Self { p, eps, keepdim }
573    }
574
575    pub fn forward<T: Float>(&self, x1: &Tensor<T>, x2: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
576        if x1.shape() != x2.shape() {
577            return Err(FerrotorchError::ShapeMismatch {
578                message: format!(
579                    "PairwiseDistance: shapes must match, got {:?} and {:?}",
580                    x1.shape(),
581                    x2.shape()
582                ),
583            });
584        }
585        if x1.is_cuda() || x2.is_cuda() {
586            return Err(FerrotorchError::NotImplementedOnCuda {
587                op: "PairwiseDistance",
588            });
589        }
590
591        let shape = x1.shape();
592        let ndim = shape.len();
593        if ndim == 0 {
594            return Err(FerrotorchError::InvalidArgument {
595                message: "PairwiseDistance: input must have at least 1 dimension".into(),
596            });
597        }
598
599        let d1 = x1.data()?;
600        let d2 = x2.data()?;
601        let last_dim = shape[ndim - 1];
602        let outer: usize = d1.len() / last_dim;
603        let p_t = T::from(self.p).unwrap();
604        let inv_p = T::from(1.0 / self.p).unwrap();
605        let eps_t = T::from(self.eps).unwrap();
606
607        let mut result = Vec::with_capacity(outer);
608        for o in 0..outer {
609            let mut norm = <T as num_traits::Zero>::zero();
610            for i in 0..last_dim {
611                let diff = d1[o * last_dim + i] - d2[o * last_dim + i];
612                let abs_diff = if diff < <T as num_traits::Zero>::zero() {
613                    <T as num_traits::Zero>::zero() - diff
614                } else {
615                    diff
616                };
617                norm += (abs_diff + eps_t).powf(p_t);
618            }
619            result.push(norm.powf(inv_p));
620        }
621
622        let mut out_shape: Vec<usize> = shape[..ndim - 1].to_vec();
623        if self.keepdim {
624            out_shape.push(1);
625        }
626        if out_shape.is_empty() {
627            out_shape.push(1);
628        }
629        Tensor::from_storage(
630            ferrotorch_core::storage::TensorStorage::cpu(result),
631            out_shape,
632            false,
633        )
634    }
635}
636
637impl Default for PairwiseDistance {
638    fn default() -> Self {
639        Self::new(2.0, 1e-6, false)
640    }
641}
642
643// ===========================================================================
644// Tests
645// ===========================================================================
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use ferrotorch_core::autograd::graph::backward;
651    use ferrotorch_core::storage::TensorStorage;
652
653    /// Helper: create a leaf tensor with given data, shape, and requires_grad.
654    fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
655        Tensor::from_storage(
656            TensorStorage::cpu(data.to_vec()),
657            shape.to_vec(),
658            requires_grad,
659        )
660        .unwrap()
661    }
662
663    // -----------------------------------------------------------------------
664    // Identity tests
665    // -----------------------------------------------------------------------
666
667    #[test]
668    fn test_identity_forward() {
669        let id = Identity::new();
670        let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
671        let output: Tensor<f64> = id.forward(&input).unwrap();
672        assert_eq!(output.shape(), input.shape());
673        assert_eq!(output.data_vec().unwrap(), input.data_vec().unwrap());
674    }
675
676    #[test]
677    fn test_identity_no_parameters() {
678        let id = Identity::new();
679        assert!(Module::<f64>::parameters(&id).is_empty());
680        assert!(Module::<f64>::named_parameters(&id).is_empty());
681    }
682
683    #[test]
684    fn test_identity_preserves_grad() {
685        let id = Identity::new();
686        let input = leaf(&[1.0, 2.0, 3.0], &[3], true);
687        let output: Tensor<f64> = id.forward(&input).unwrap();
688        assert!(output.requires_grad());
689    }
690
691    #[test]
692    fn test_identity_train_eval() {
693        let mut id = Identity::new();
694        assert!(Module::<f64>::is_training(&id));
695        Module::<f64>::eval(&mut id);
696        assert!(!Module::<f64>::is_training(&id));
697        Module::<f64>::train(&mut id);
698        assert!(Module::<f64>::is_training(&id));
699    }
700
701    #[test]
702    fn test_identity_empty_tensor() {
703        let id = Identity::new();
704        let input = leaf(&[], &[0], false);
705        let output: Tensor<f64> = id.forward(&input).unwrap();
706        assert_eq!(output.shape(), &[0]);
707        assert_eq!(output.numel(), 0);
708    }
709
710    #[test]
711    fn test_identity_is_send_sync() {
712        fn assert_send_sync<T: Send + Sync>() {}
713        assert_send_sync::<Identity>();
714    }
715
716    // -----------------------------------------------------------------------
717    // Flatten tests
718    // -----------------------------------------------------------------------
719
720    #[test]
721    fn test_flatten_default() {
722        // Default: start_dim=1, end_dim=-1 => flatten all but batch.
723        let flatten = Flatten::default();
724        let input = leaf(
725            &(0..120).map(|i| i as f64).collect::<Vec<_>>(),
726            &[2, 3, 4, 5],
727            false,
728        );
729        let output: Tensor<f64> = flatten.forward(&input).unwrap();
730        assert_eq!(output.shape(), &[2, 60]);
731    }
732
733    #[test]
734    fn test_flatten_specific_range() {
735        // Flatten dims 2..3 of [2, 3, 4, 5] => [2, 3, 20].
736        let flatten = Flatten::new(2, 3);
737        let input = leaf(
738            &(0..120).map(|i| i as f64).collect::<Vec<_>>(),
739            &[2, 3, 4, 5],
740            false,
741        );
742        let output: Tensor<f64> = flatten.forward(&input).unwrap();
743        assert_eq!(output.shape(), &[2, 3, 20]);
744    }
745
746    #[test]
747    fn test_flatten_all_dims() {
748        // start_dim=0, end_dim=-1 => flatten everything.
749        let flatten = Flatten::new(0, -1);
750        let input = leaf(
751            &(0..24).map(|i| i as f64).collect::<Vec<_>>(),
752            &[2, 3, 4],
753            false,
754        );
755        let output: Tensor<f64> = flatten.forward(&input).unwrap();
756        assert_eq!(output.shape(), &[24]);
757    }
758
759    #[test]
760    fn test_flatten_noop_single_dim() {
761        // start_dim == end_dim => no-op.
762        let flatten = Flatten::new(1, 1);
763        let input = leaf(
764            &(0..12).map(|i| i as f64).collect::<Vec<_>>(),
765            &[3, 4],
766            false,
767        );
768        let output: Tensor<f64> = flatten.forward(&input).unwrap();
769        assert_eq!(output.shape(), &[3, 4]);
770    }
771
772    #[test]
773    fn test_flatten_1d_input() {
774        // 1-D input: already flat, should return as-is.
775        let flatten = Flatten::new(0, -1);
776        let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
777        let output: Tensor<f64> = flatten.forward(&input).unwrap();
778        assert_eq!(output.shape(), &[3]);
779    }
780
781    #[test]
782    fn test_flatten_0d_error() {
783        // 0-D tensor should error.
784        let flatten = Flatten::new(0, -1);
785        let input = leaf(&[42.0], &[], false);
786        assert!(Module::<f64>::forward(&flatten, &input).is_err());
787    }
788
789    #[test]
790    fn test_flatten_start_dim_out_of_range() {
791        let flatten = Flatten::new(5, -1);
792        let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
793        assert!(Module::<f64>::forward(&flatten, &input).is_err());
794    }
795
796    #[test]
797    fn test_flatten_end_dim_out_of_range() {
798        let flatten = Flatten::new(0, 10);
799        let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
800        assert!(Module::<f64>::forward(&flatten, &input).is_err());
801    }
802
803    #[test]
804    fn test_flatten_start_gt_end_error() {
805        let flatten = Flatten::new(2, 1);
806        let input = leaf(
807            &(0..24).map(|i| i as f64).collect::<Vec<_>>(),
808            &[2, 3, 4],
809            false,
810        );
811        assert!(Module::<f64>::forward(&flatten, &input).is_err());
812    }
813
814    #[test]
815    fn test_flatten_preserves_data() {
816        let flatten = Flatten::default();
817        let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
818        let input = leaf(&data, &[2, 3, 4], false);
819        let output: Tensor<f64> = flatten.forward(&input).unwrap();
820        assert_eq!(output.data_vec().unwrap(), data);
821    }
822
823    #[test]
824    fn test_flatten_backward() {
825        use ferrotorch_core::tensor::GradFn;
826        use std::sync::Arc;
827
828        /// Sum backward helper that propagates gradients.
829        #[derive(Debug)]
830        struct SumBackwardHelper {
831            input: Tensor<f64>,
832        }
833
834        impl GradFn<f64> for SumBackwardHelper {
835            fn backward(
836                &self,
837                _grad_output: &Tensor<f64>,
838            ) -> FerrotorchResult<Vec<Option<Tensor<f64>>>> {
839                let ones_data = vec![1.0f64; self.input.numel()];
840                let ones = Tensor::from_storage(
841                    TensorStorage::cpu(ones_data),
842                    self.input.shape().to_vec(),
843                    false,
844                )?;
845                Ok(vec![Some(ones)])
846            }
847
848            fn inputs(&self) -> Vec<&Tensor<f64>> {
849                vec![&self.input]
850            }
851
852            fn name(&self) -> &'static str {
853                "SumBackwardHelper"
854            }
855        }
856
857        let flatten = Flatten::default();
858        let input = leaf(
859            &(0..24).map(|i| i as f64).collect::<Vec<_>>(),
860            &[2, 3, 4],
861            true,
862        );
863        let output: Tensor<f64> = flatten.forward(&input).unwrap();
864        assert_eq!(output.shape(), &[2, 12]);
865        assert!(output.requires_grad());
866
867        // Trigger backward through a differentiable sum.
868        let out_data = output.data().unwrap();
869        let total: f64 = out_data.iter().sum();
870        let sum_gf = Arc::new(SumBackwardHelper {
871            input: output.clone(),
872        });
873        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
874        backward(&loss).unwrap();
875
876        let grad = input.grad().unwrap().unwrap();
877        assert_eq!(grad.shape(), &[2, 3, 4]);
878        // Gradient of sum is all ones.
879        for &v in grad.data().unwrap().iter() {
880            assert!((v - 1.0).abs() < 1e-10);
881        }
882    }
883
884    #[test]
885    fn test_flatten_no_parameters() {
886        let flatten = Flatten::default();
887        assert!(Module::<f64>::parameters(&flatten).is_empty());
888        assert!(Module::<f64>::named_parameters(&flatten).is_empty());
889    }
890
891    #[test]
892    fn test_flatten_zero_size_dim() {
893        // Tensor with a zero-size dimension should still work.
894        let flatten = Flatten::default();
895        let input = leaf(&[], &[2, 0, 4], false);
896        let output: Tensor<f64> = flatten.forward(&input).unwrap();
897        assert_eq!(output.shape(), &[2, 0]);
898    }
899
900    #[test]
901    fn test_flatten_is_send_sync() {
902        fn assert_send_sync<T: Send + Sync>() {}
903        assert_send_sync::<Flatten>();
904    }
905}