numrs/ops/shape/
flatten.rs

1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4
5/// Flattens a contiguous range of dims into a tensor.
6/// 
7/// For use with batches, typically `start_dim` = 1.
8/// 
9/// # Arguments
10/// * `input` - Input tensor
11/// * `start_dim` - First dim to flatten (0-indexed)
12/// * `end_dim` - Last dim to flatten (inclusive). Use -1 (or huge number) for "until end".
13pub fn flatten(input: &Array, start_dim: usize, end_dim: usize) -> Result<Array> {
14    let ndim = input.shape.len();
15    
16    // Clamp end_dim
17    let end_dim = if end_dim >= ndim { ndim - 1 } else { end_dim };
18    
19    if start_dim > end_dim {
20        return Err(anyhow!("flatten: start_dim ({}) cannot come after end_dim ({})", start_dim, end_dim));
21    }
22    
23    let mut new_shape = Vec::new();
24    
25    // 1. Dims before start_dim
26    for i in 0..start_dim {
27        new_shape.push(input.shape[i]);
28    }
29    
30    // 2. Flattened dim
31    let mut flat_size = 1;
32    for i in start_dim..=end_dim {
33        flat_size *= input.shape[i];
34    }
35    new_shape.push(flat_size);
36    
37    // 3. Dims after end_dim
38    for i in (end_dim + 1)..ndim {
39        new_shape.push(input.shape[i]);
40    }
41    
42    // Delegate to reshape logic (requires isize shape)
43    let new_shape_isize: Vec<isize> = new_shape.iter().map(|&x| x as isize).collect();
44    crate::ops::reshape(input, &new_shape_isize)
45}