numrs/ops/shape/
flatten.rs1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4
5pub fn flatten(input: &Array, start_dim: usize, end_dim: usize) -> Result<Array> {
14 let ndim = input.shape.len();
15
16 let end_dim = if end_dim >= ndim { ndim - 1 } else { end_dim };
18
19 if start_dim > end_dim {
20 return Err(anyhow!("flatten: start_dim ({}) cannot come after end_dim ({})", start_dim, end_dim));
21 }
22
23 let mut new_shape = Vec::new();
24
25 for i in 0..start_dim {
27 new_shape.push(input.shape[i]);
28 }
29
30 let mut flat_size = 1;
32 for i in start_dim..=end_dim {
33 flat_size *= input.shape[i];
34 }
35 new_shape.push(flat_size);
36
37 for i in (end_dim + 1)..ndim {
39 new_shape.push(input.shape[i]);
40 }
41
42 let new_shape_isize: Vec<isize> = new_shape.iter().map(|&x| x as isize).collect();
44 crate::ops::reshape(input, &new_shape_isize)
45}