Skip to main content

ferray_core/manipulation/
mod.rs

1// ferray-core: Shape manipulation functions (REQ-20, REQ-21, REQ-22)
2//
3// Mirrors numpy's shape manipulation routines: reshape, ravel, flatten,
4// concatenate, stack, transpose, flip, etc.
5
6pub mod extended;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13// ============================================================================
14// REQ-20: Shape methods
15// ============================================================================
16
17/// Reshape an array to a new shape (returns a new owned array).
18///
19/// The total number of elements must remain the same.
20///
21/// Analogous to `numpy.reshape()`. Delegates to ndarray's `to_shape`,
22/// which re-uses the existing buffer (zero copy on the data side) for
23/// contiguous inputs and only does a bulk memcpy via `to_owned` — the
24/// old path ran an element-by-element `iter().cloned().collect()`
25/// instead (issue #82).
26///
27/// # Errors
28/// Returns `FerrayError::ShapeMismatch` if the new shape has a different
29/// total number of elements.
30pub fn reshape<T: Element, D: Dimension>(
31    a: &Array<T, D>,
32    new_shape: &[usize],
33) -> FerrayResult<Array<T, IxDyn>> {
34    let old_size = a.size();
35    let new_size: usize = new_shape.iter().product();
36    if old_size != new_size {
37        return Err(FerrayError::shape_mismatch(format!(
38            "cannot reshape array of size {old_size} into shape {new_shape:?} (size {new_size})",
39        )));
40    }
41    let view = a.inner.view().into_dyn();
42    let reshaped = view
43        .to_shape(ndarray::IxDyn(new_shape))
44        .map_err(|e| FerrayError::shape_mismatch(e.to_string()))?;
45    // `as_standard_layout` ensures the final buffer is C-contiguous so
46    // `as_slice()` works for callers, even when the source view was
47    // F-contiguous / strided.
48    Ok(Array::from_ndarray(
49        reshaped.as_standard_layout().into_owned(),
50    ))
51}
52
53/// Return a flattened (1-D) copy of the array.
54///
55/// Analogous to `numpy.ravel()`. Uses ndarray's `to_shape` rather than
56/// the element-wise collect path (issue #82).
57pub fn ravel<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
58    let n = a.size();
59    let view = a.inner.view().into_dyn();
60    let reshaped = view
61        .to_shape(ndarray::IxDyn(&[n]))
62        .expect("1-D reshape always succeeds for a size-preserving target");
63    let standard = reshaped.as_standard_layout().into_owned();
64    let one_d = standard
65        .into_dimensionality::<ndarray::Ix1>()
66        .expect("reshape result has ndim == 1 by construction");
67    Ok(Array::from_ndarray(one_d))
68}
69
70/// Return a flattened (1-D) copy of the array.
71///
72/// Identical to `ravel()` — analogous to `ndarray.flatten()`.
73pub fn flatten<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
74    ravel(a)
75}
76
77/// Remove axes of length 1 from the shape.
78///
79/// If `axis` is `None`, all length-1 axes are removed.
80/// If `axis` is `Some(ax)`, only that axis is removed (errors if it is not length 1).
81///
82/// Analogous to `numpy.squeeze()`.
83///
84/// # Errors
85/// Returns `FerrayError::AxisOutOfBounds` if the axis is invalid, or
86/// `FerrayError::InvalidValue` if the specified axis has size != 1.
87pub fn squeeze<T: Element, D: Dimension>(
88    a: &Array<T, D>,
89    axis: Option<usize>,
90) -> FerrayResult<Array<T, IxDyn>> {
91    let shape = a.shape();
92    if let Some(ax) = axis {
93        if ax >= shape.len() {
94            return Err(FerrayError::axis_out_of_bounds(ax, shape.len()));
95        }
96        if shape[ax] != 1 {
97            return Err(FerrayError::invalid_value(format!(
98                "cannot select axis {} with size {} for squeeze (must be 1)",
99                ax, shape[ax],
100            )));
101        }
102        let new_shape: Vec<usize> = shape
103            .iter()
104            .enumerate()
105            .filter(|&(i, _)| i != ax)
106            .map(|(_, &s)| s)
107            .collect();
108        let data: Vec<T> = a.iter().cloned().collect();
109        Array::from_vec(IxDyn::new(&new_shape), data)
110    } else {
111        let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
112        // If all dims are 1, the result is a scalar (0-D is tricky), so
113        // make it at least 1-D with a single element.
114        let new_shape = if new_shape.is_empty() && !shape.is_empty() {
115            vec![1]
116        } else if new_shape.is_empty() {
117            vec![]
118        } else {
119            new_shape
120        };
121        let data: Vec<T> = a.iter().cloned().collect();
122        Array::from_vec(IxDyn::new(&new_shape), data)
123    }
124}
125
126/// Insert a new axis of length 1 at the given position.
127///
128/// Analogous to `numpy.expand_dims()`.
129///
130/// # Errors
131/// Returns `FerrayError::AxisOutOfBounds` if `axis > ndim`.
132pub fn expand_dims<T: Element, D: Dimension>(
133    a: &Array<T, D>,
134    axis: usize,
135) -> FerrayResult<Array<T, IxDyn>> {
136    let ndim = a.ndim();
137    if axis > ndim {
138        return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
139    }
140    let mut new_shape: Vec<usize> = a.shape().to_vec();
141    new_shape.insert(axis, 1);
142    let data: Vec<T> = a.iter().cloned().collect();
143    Array::from_vec(IxDyn::new(&new_shape), data)
144}
145
146/// Broadcast an array to a new shape (returns a new owned array).
147///
148/// The array is replicated along size-1 dimensions to match the target shape.
149///
150/// Analogous to `numpy.broadcast_to()`. Uses ndarray's `broadcast` to
151/// produce a zero-copy stride-0 view and then materializes it with a
152/// single `to_owned` — the prior implementation allocated a source
153/// buffer via `iter().cloned().collect()` and walked every output
154/// index in a hand-rolled nested loop (issue #81).
155///
156/// # Errors
157/// Returns `FerrayError::BroadcastFailure` if the shapes are incompatible.
158pub fn broadcast_to<T: Element, D: Dimension>(
159    a: &Array<T, D>,
160    new_shape: &[usize],
161) -> FerrayResult<Array<T, IxDyn>> {
162    let src_shape = a.shape();
163    let dyn_view = a.inner.view().into_dyn();
164    let broadcast_view = dyn_view
165        .broadcast(ndarray::IxDyn(new_shape))
166        .ok_or_else(|| FerrayError::BroadcastFailure {
167            shape_a: src_shape.to_vec(),
168            shape_b: new_shape.to_vec(),
169        })?;
170    // `as_standard_layout` turns the stride-0 broadcast view into a
171    // proper C-contiguous owned buffer in one pass; the previous
172    // hand-rolled loop walked every destination index separately.
173    Ok(Array::from_ndarray(
174        broadcast_view.as_standard_layout().into_owned(),
175    ))
176}
177
178// ============================================================================
179// REQ-21: Join/split
180// ============================================================================
181
182/// Join a sequence of arrays along an existing axis.
183///
184/// Analogous to `numpy.concatenate()`.
185///
186/// # Errors
187/// Returns `FerrayError::InvalidValue` if the array list is empty.
188/// Returns `FerrayError::ShapeMismatch` if shapes differ on non-concatenation axes.
189/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
190pub fn concatenate<T: Element>(
191    arrays: &[Array<T, IxDyn>],
192    axis: usize,
193) -> FerrayResult<Array<T, IxDyn>> {
194    if arrays.is_empty() {
195        return Err(FerrayError::invalid_value(
196            "concatenate: need at least one array",
197        ));
198    }
199    let ndim = arrays[0].ndim();
200    if axis >= ndim {
201        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
202    }
203    let base_shape = arrays[0].shape();
204
205    // Validate all arrays have same ndim and matching shapes on non-concat axes
206    let mut total_along_axis = 0usize;
207    for arr in arrays {
208        if arr.ndim() != ndim {
209            return Err(FerrayError::shape_mismatch(format!(
210                "all arrays must have same ndim; got {} and {}",
211                ndim,
212                arr.ndim(),
213            )));
214        }
215        for (i, (&s, &base)) in arr.shape().iter().zip(base_shape.iter()).enumerate() {
216            if i != axis && s != base {
217                return Err(FerrayError::shape_mismatch(format!(
218                    "shape mismatch on axis {i}: {s} vs {base}",
219                )));
220            }
221        }
222        total_along_axis += arr.shape()[axis];
223    }
224
225    // Build new shape
226    let mut new_shape = base_shape.to_vec();
227    new_shape[axis] = total_along_axis;
228    let total: usize = new_shape.iter().product();
229    let mut data = Vec::with_capacity(total);
230
231    // Pre-collect each source array into a contiguous Vec to avoid O(n) iter().nth() per element
232    let src_vecs: Vec<Vec<T>> = arrays.iter().map(|a| a.iter().cloned().collect()).collect();
233
234    // Compute strides for the output array (C-order)
235    let mut out_strides = vec![1usize; ndim];
236    for i in (0..ndim - 1).rev() {
237        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
238    }
239
240    // For each position in the output, figure out which source array and offset
241    for flat_idx in 0..total {
242        // Convert flat index to nd-index
243        let mut rem = flat_idx;
244        let mut nd_idx = vec![0usize; ndim];
245        for i in 0..ndim {
246            nd_idx[i] = rem / out_strides[i];
247            rem %= out_strides[i];
248        }
249
250        // Find which source array this position belongs to
251        let concat_idx = nd_idx[axis];
252        let mut offset = 0;
253        let mut src_arr_idx = 0;
254        for (k, arr) in arrays.iter().enumerate() {
255            let len_along = arr.shape()[axis];
256            if concat_idx < offset + len_along {
257                src_arr_idx = k;
258                break;
259            }
260            offset += len_along;
261        }
262        let local_concat_idx = concat_idx - offset;
263
264        // Build source flat index (C-order)
265        let src_shape = arrays[src_arr_idx].shape();
266        let mut src_flat = 0usize;
267        let mut src_mul = 1usize;
268        for i in (0..ndim).rev() {
269            let idx = if i == axis {
270                local_concat_idx
271            } else {
272                nd_idx[i]
273            };
274            src_flat += idx * src_mul;
275            src_mul *= src_shape[i];
276        }
277
278        let elem = src_vecs[src_arr_idx].get(src_flat).ok_or_else(|| {
279            FerrayError::invalid_value(format!(
280                "concatenate: internal index {} out of range for source array of length {}",
281                src_flat,
282                src_vecs[src_arr_idx].len(),
283            ))
284        })?;
285        data.push(elem.clone());
286    }
287
288    Array::from_vec(IxDyn::new(&new_shape), data)
289}
290
291/// Join a sequence of arrays along a **new** axis.
292///
293/// All arrays must have the same shape. The result has one more dimension
294/// than the inputs.
295///
296/// Analogous to `numpy.stack()`.
297///
298/// # Errors
299/// Returns `FerrayError::InvalidValue` if the array list is empty.
300/// Returns `FerrayError::ShapeMismatch` if shapes differ.
301/// Returns `FerrayError::AxisOutOfBounds` if axis > ndim.
302pub fn stack<T: Element>(arrays: &[Array<T, IxDyn>], axis: usize) -> FerrayResult<Array<T, IxDyn>> {
303    if arrays.is_empty() {
304        return Err(FerrayError::invalid_value("stack: need at least one array"));
305    }
306    let base_shape = arrays[0].shape();
307    let ndim = base_shape.len();
308
309    if axis > ndim {
310        return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
311    }
312
313    for arr in &arrays[1..] {
314        if arr.shape() != base_shape {
315            return Err(FerrayError::shape_mismatch(format!(
316                "all input arrays must have the same shape; got {:?} and {:?}",
317                base_shape,
318                arr.shape(),
319            )));
320        }
321    }
322
323    // Expand each array along the new axis, then concatenate
324    let mut expanded = Vec::with_capacity(arrays.len());
325    for arr in arrays {
326        expanded.push(expand_dims(arr, axis)?);
327    }
328    concatenate(&expanded, axis)
329}
330
331/// Stack arrays vertically (row-wise). Equivalent to `concatenate` along axis 0
332/// for 2-D+ arrays, or equivalent to stacking 1-D arrays as rows.
333///
334/// Analogous to `numpy.vstack()`.
335pub fn vstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
336    if arrays.is_empty() {
337        return Err(FerrayError::invalid_value(
338            "vstack: need at least one array",
339        ));
340    }
341    // For 1-D arrays, reshape to (1, N) then concatenate along axis 0
342    let ndim = arrays[0].ndim();
343    if ndim == 1 {
344        let mut reshaped = Vec::with_capacity(arrays.len());
345        for arr in arrays {
346            let n = arr.shape()[0];
347            reshaped.push(reshape(arr, &[1, n])?);
348        }
349        concatenate(&reshaped, 0)
350    } else {
351        concatenate(arrays, 0)
352    }
353}
354
355/// Stack arrays horizontally (column-wise). Equivalent to `concatenate` along
356/// axis 1 for 2-D+ arrays, or along axis 0 for 1-D arrays.
357///
358/// Analogous to `numpy.hstack()`.
359pub fn hstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
360    if arrays.is_empty() {
361        return Err(FerrayError::invalid_value(
362            "hstack: need at least one array",
363        ));
364    }
365    let ndim = arrays[0].ndim();
366    if ndim == 1 {
367        concatenate(arrays, 0)
368    } else {
369        concatenate(arrays, 1)
370    }
371}
372
373/// Stack arrays along the third axis (depth-wise).
374///
375/// For 1-D arrays of shape `(N,)`, reshapes to `(1, N, 1)`.
376/// For 2-D arrays of shape `(M, N)`, reshapes to `(M, N, 1)`.
377/// Then concatenates along axis 2.
378///
379/// Analogous to `numpy.dstack()`.
380pub fn dstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
381    if arrays.is_empty() {
382        return Err(FerrayError::invalid_value(
383            "dstack: need at least one array",
384        ));
385    }
386    let mut expanded = Vec::with_capacity(arrays.len());
387    for arr in arrays {
388        let shape = arr.shape();
389        match shape.len() {
390            1 => {
391                let n = shape[0];
392                expanded.push(reshape(arr, &[1, n, 1])?);
393            }
394            2 => {
395                let (m, n) = (shape[0], shape[1]);
396                expanded.push(reshape(arr, &[m, n, 1])?);
397            }
398            _ => {
399                // Already 3-D+, just use as-is
400                let data: Vec<T> = arr.iter().cloned().collect();
401                expanded.push(Array::from_vec(IxDyn::new(shape), data)?);
402            }
403        }
404    }
405    concatenate(&expanded, 2)
406}
407
408/// Stack 1-D arrays as columns into a 2-D array.
409///
410/// Each input becomes one column of the output. For 2-D+ inputs, this is
411/// equivalent to [`hstack`].
412///
413/// Analogous to `numpy.column_stack()`.
414///
415/// # Errors
416/// Returns `FerrayError::InvalidValue` if the input is empty,
417/// or `FerrayError::ShapeMismatch` if 1-D inputs have different lengths.
418pub fn column_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
419    if arrays.is_empty() {
420        return Err(FerrayError::invalid_value(
421            "column_stack: need at least one array",
422        ));
423    }
424    let first_ndim = arrays[0].ndim();
425    if first_ndim == 1 {
426        // Convert each 1-D array of length N into a (N, 1) column, then hstack.
427        let n = arrays[0].shape()[0];
428        let mut reshaped = Vec::with_capacity(arrays.len());
429        for arr in arrays {
430            if arr.ndim() != 1 {
431                return Err(FerrayError::shape_mismatch(
432                    "column_stack: all inputs must have the same ndim",
433                ));
434            }
435            if arr.shape()[0] != n {
436                return Err(FerrayError::shape_mismatch(format!(
437                    "column_stack: 1-D inputs must have the same length; got {} and {}",
438                    n,
439                    arr.shape()[0],
440                )));
441            }
442            reshaped.push(reshape(arr, &[n, 1])?);
443        }
444        concatenate(&reshaped, 1)
445    } else {
446        // 2-D+: same as hstack
447        hstack(arrays)
448    }
449}
450
451/// Stack arrays in sequence vertically (row-wise). Alias for [`vstack`].
452///
453/// Analogous to `numpy.row_stack()` (deprecated alias for `vstack` in `NumPy` 2.0
454/// but still widely used).
455pub fn row_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
456    vstack(arrays)
457}
458
459/// Assemble an array from nested blocks.
460///
461/// Simplified version: takes a 2-D grid of arrays (as Vec<Vec<...>>)
462/// and assembles them by stacking rows horizontally, then all rows vertically.
463///
464/// Analogous to `numpy.block()`.
465///
466/// # Errors
467/// Returns errors on shape mismatches.
468pub fn block<T: Element>(blocks: &[Vec<Array<T, IxDyn>>]) -> FerrayResult<Array<T, IxDyn>> {
469    if blocks.is_empty() {
470        return Err(FerrayError::invalid_value("block: empty input"));
471    }
472    let mut rows = Vec::with_capacity(blocks.len());
473    for row in blocks {
474        if row.is_empty() {
475            return Err(FerrayError::invalid_value("block: empty row"));
476        }
477        // Concatenate along axis 1 (columns within each row)
478        let row_arr = if row.len() == 1 {
479            let data: Vec<T> = row[0].iter().cloned().collect();
480            Array::from_vec(IxDyn::new(row[0].shape()), data)?
481        } else {
482            hstack(row)?
483        };
484        rows.push(row_arr);
485    }
486    if rows.len() == 1 {
487        // SAFETY: just checked len() == 1, so pop() always returns Some
488        Ok(rows.pop().unwrap_or_else(|| unreachable!()))
489    } else {
490        vstack(&rows)
491    }
492}
493
494/// Split an array into equal-sized sub-arrays.
495///
496/// `n_sections` must evenly divide the size along `axis`.
497///
498/// Analogous to `numpy.split()`.
499///
500/// # Errors
501/// Returns `FerrayError::InvalidValue` if the axis cannot be evenly split.
502pub fn split<T: Element>(
503    a: &Array<T, IxDyn>,
504    n_sections: usize,
505    axis: usize,
506) -> FerrayResult<Vec<Array<T, IxDyn>>> {
507    let shape = a.shape();
508    if axis >= shape.len() {
509        return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
510    }
511    let axis_len = shape[axis];
512    if n_sections == 0 {
513        return Err(FerrayError::invalid_value("split: n_sections must be > 0"));
514    }
515    if axis_len % n_sections != 0 {
516        return Err(FerrayError::invalid_value(format!(
517            "array of size {axis_len} along axis {axis} cannot be evenly split into {n_sections} sections",
518        )));
519    }
520    let chunk_size = axis_len / n_sections;
521    let indices: Vec<usize> = (1..n_sections).map(|i| i * chunk_size).collect();
522    array_split(a, &indices, axis)
523}
524
525/// Split an array into sub-arrays at the given indices along `axis`.
526///
527/// Unlike `split()`, this does not require even division.
528///
529/// Analogous to `numpy.array_split()` (with explicit split points).
530///
531/// # Errors
532/// Returns `FerrayError::AxisOutOfBounds` if axis is invalid.
533pub fn array_split<T: Element>(
534    a: &Array<T, IxDyn>,
535    indices: &[usize],
536    axis: usize,
537) -> FerrayResult<Vec<Array<T, IxDyn>>> {
538    let shape = a.shape();
539    let ndim = shape.len();
540    if axis >= ndim {
541        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
542    }
543    let axis_len = shape[axis];
544    let src_data: Vec<T> = a.iter().cloned().collect();
545
546    // Build split points including 0 and axis_len
547    let mut splits = Vec::with_capacity(indices.len() + 2);
548    splits.push(0);
549    for &idx in indices {
550        splits.push(idx.min(axis_len));
551    }
552    splits.push(axis_len);
553
554    // Compute source strides (C-order)
555    let mut src_strides = vec![1usize; ndim];
556    for i in (0..ndim - 1).rev() {
557        src_strides[i] = src_strides[i + 1] * shape[i + 1];
558    }
559
560    let mut result = Vec::with_capacity(splits.len() - 1);
561    for w in splits.windows(2) {
562        let start = w[0];
563        let end = w[1];
564        let chunk_len = end - start;
565
566        let mut sub_shape = shape.to_vec();
567        sub_shape[axis] = chunk_len;
568        let sub_total: usize = sub_shape.iter().product();
569
570        // Compute sub strides
571        let mut sub_strides = vec![1usize; ndim];
572        for i in (0..ndim - 1).rev() {
573            sub_strides[i] = sub_strides[i + 1] * sub_shape[i + 1];
574        }
575
576        let mut sub_data = Vec::with_capacity(sub_total);
577        for flat in 0..sub_total {
578            // Convert to nd-index in sub array
579            let mut rem = flat;
580            let mut src_flat = 0usize;
581            for i in 0..ndim {
582                let idx = rem / sub_strides[i];
583                rem %= sub_strides[i];
584                let src_idx = if i == axis { idx + start } else { idx };
585                src_flat += src_idx * src_strides[i];
586            }
587            sub_data.push(src_data[src_flat].clone());
588        }
589        result.push(Array::from_vec(IxDyn::new(&sub_shape), sub_data)?);
590    }
591
592    Ok(result)
593}
594
595/// Split an array into `n` sub-arrays along `axis`, allowing uneven sections.
596///
597/// Unlike [`split`], this never errors on uneven division: the first
598/// `axis_len % n` sections have one extra element. This matches `NumPy`'s
599/// `numpy.array_split(ary, n, axis)` (integer-section variant).
600///
601/// # Errors
602/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
603/// Returns `FerrayError::InvalidValue` if `n == 0`.
604pub fn array_split_n<T: Element>(
605    a: &Array<T, IxDyn>,
606    n: usize,
607    axis: usize,
608) -> FerrayResult<Vec<Array<T, IxDyn>>> {
609    if n == 0 {
610        return Err(FerrayError::invalid_value("array_split_n: n must be > 0"));
611    }
612    let shape = a.shape();
613    if axis >= shape.len() {
614        return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
615    }
616    let axis_len = shape[axis];
617
618    // Build split indices following NumPy's array_split:
619    // First (axis_len % n) sections get (axis_len / n + 1) elements,
620    // remaining sections get (axis_len / n) elements.
621    let base = axis_len / n;
622    let extra = axis_len % n;
623    let mut indices = Vec::with_capacity(n.saturating_sub(1));
624    let mut cum = 0usize;
625    for i in 0..n - 1 {
626        cum += if i < extra { base + 1 } else { base };
627        indices.push(cum);
628    }
629    array_split(a, &indices, axis)
630}
631
632/// Split array along axis 0 (vertical split). Equivalent to `split(a, n, 0)`.
633///
634/// Analogous to `numpy.vsplit()`.
635pub fn vsplit<T: Element>(
636    a: &Array<T, IxDyn>,
637    n_sections: usize,
638) -> FerrayResult<Vec<Array<T, IxDyn>>> {
639    split(a, n_sections, 0)
640}
641
642/// Split array along axis 1 (horizontal split). Equivalent to `split(a, n, 1)`.
643///
644/// Analogous to `numpy.hsplit()`.
645pub fn hsplit<T: Element>(
646    a: &Array<T, IxDyn>,
647    n_sections: usize,
648) -> FerrayResult<Vec<Array<T, IxDyn>>> {
649    split(a, n_sections, 1)
650}
651
652/// Split array along axis 2 (depth split). Equivalent to `split(a, n, 2)`.
653///
654/// Analogous to `numpy.dsplit()`.
655pub fn dsplit<T: Element>(
656    a: &Array<T, IxDyn>,
657    n_sections: usize,
658) -> FerrayResult<Vec<Array<T, IxDyn>>> {
659    split(a, n_sections, 2)
660}
661
662// ============================================================================
663// REQ-22: Transpose / reorder
664// ============================================================================
665
666/// Permute the axes of an array.
667///
668/// `axes` specifies the new ordering. For a 2-D array, `[1, 0]` transposes.
669/// If `axes` is `None`, reverses the order of all axes.
670///
671/// Analogous to `numpy.transpose()`.
672///
673/// # Errors
674/// Returns `FerrayError::InvalidValue` if `axes` is the wrong length or
675/// contains invalid/duplicate axis indices.
676pub fn transpose<T: Element, D: Dimension>(
677    a: &Array<T, D>,
678    axes: Option<&[usize]>,
679) -> FerrayResult<Array<T, IxDyn>> {
680    let ndim = a.ndim();
681    let perm: Vec<usize> = match axes {
682        Some(ax) => {
683            if ax.len() != ndim {
684                return Err(FerrayError::invalid_value(format!(
685                    "axes must have length {} but got {}",
686                    ndim,
687                    ax.len(),
688                )));
689            }
690            // Validate: each axis appears exactly once
691            let mut seen = vec![false; ndim];
692            for &a in ax {
693                if a >= ndim {
694                    return Err(FerrayError::axis_out_of_bounds(a, ndim));
695                }
696                if seen[a] {
697                    return Err(FerrayError::invalid_value(format!(
698                        "duplicate axis {a} in transpose",
699                    )));
700                }
701                seen[a] = true;
702            }
703            ax.to_vec()
704        }
705        None => (0..ndim).rev().collect(),
706    };
707
708    // Permute axes is a zero-copy stride rearrangement in ndarray; we
709    // then call `as_standard_layout` which returns a borrowed view when
710    // already C-contiguous or materializes a single standard-layout
711    // owned buffer otherwise. The old implementation walked every
712    // output index in a hand-rolled scatter loop (issue #82). Plain
713    // `to_owned()` would preserve the F-contiguous stride pattern from
714    // the permutation, which downstream callers don't want because
715    // `as_slice()` only succeeds on standard layouts.
716    let permuted = a
717        .inner
718        .view()
719        .into_dyn()
720        .permuted_axes(ndarray::IxDyn(&perm));
721    Ok(Array::from_ndarray(
722        permuted.as_standard_layout().into_owned(),
723    ))
724}
725
726/// Swap two axes of an array.
727///
728/// Analogous to `numpy.swapaxes()`.
729///
730/// # Errors
731/// Returns `FerrayError::AxisOutOfBounds` if either axis is out of bounds.
732pub fn swapaxes<T: Element, D: Dimension>(
733    a: &Array<T, D>,
734    axis1: usize,
735    axis2: usize,
736) -> FerrayResult<Array<T, IxDyn>> {
737    let ndim = a.ndim();
738    if axis1 >= ndim {
739        return Err(FerrayError::axis_out_of_bounds(axis1, ndim));
740    }
741    if axis2 >= ndim {
742        return Err(FerrayError::axis_out_of_bounds(axis2, ndim));
743    }
744    let mut perm: Vec<usize> = (0..ndim).collect();
745    perm.swap(axis1, axis2);
746    transpose(a, Some(&perm))
747}
748
749/// Move an axis to a new position.
750///
751/// Analogous to `numpy.moveaxis()`.
752///
753/// # Errors
754/// Returns `FerrayError::AxisOutOfBounds` if either axis is out of bounds.
755pub fn moveaxis<T: Element, D: Dimension>(
756    a: &Array<T, D>,
757    source: usize,
758    destination: usize,
759) -> FerrayResult<Array<T, IxDyn>> {
760    let ndim = a.ndim();
761    if source >= ndim {
762        return Err(FerrayError::axis_out_of_bounds(source, ndim));
763    }
764    if destination >= ndim {
765        return Err(FerrayError::axis_out_of_bounds(destination, ndim));
766    }
767    // Build permutation by removing source and inserting at destination
768    let mut order: Vec<usize> = (0..ndim).filter(|&x| x != source).collect();
769    order.insert(destination, source);
770    transpose(a, Some(&order))
771}
772
773/// Roll an axis to a new position (similar to moveaxis).
774///
775/// Analogous to `numpy.rollaxis()`.
776///
777/// # Errors
778/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim` or `start > ndim`.
779pub fn rollaxis<T: Element, D: Dimension>(
780    a: &Array<T, D>,
781    axis: usize,
782    start: usize,
783) -> FerrayResult<Array<T, IxDyn>> {
784    let ndim = a.ndim();
785    if axis >= ndim {
786        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
787    }
788    if start > ndim {
789        return Err(FerrayError::axis_out_of_bounds(start, ndim + 1));
790    }
791    let dst = if start > axis { start - 1 } else { start };
792    if axis == dst {
793        // No-op: return a copy
794        let data: Vec<T> = a.iter().cloned().collect();
795        return Array::from_vec(IxDyn::new(a.shape()), data);
796    }
797    moveaxis(a, axis, dst)
798}
799
800/// Reverse the order of elements along the given axis.
801///
802/// Analogous to `numpy.flip()`.
803///
804/// # Errors
805/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
806pub fn flip<T: Element, D: Dimension>(
807    a: &Array<T, D>,
808    axis: usize,
809) -> FerrayResult<Array<T, IxDyn>> {
810    let shape = a.shape();
811    let ndim = shape.len();
812    if axis >= ndim {
813        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
814    }
815    let src_data: Vec<T> = a.iter().cloned().collect();
816    let total = src_data.len();
817
818    // Compute strides (C-order)
819    let mut strides = vec![1usize; ndim];
820    for i in (0..ndim.saturating_sub(1)).rev() {
821        strides[i] = strides[i + 1] * shape[i + 1];
822    }
823
824    let mut data = Vec::with_capacity(total);
825    for flat in 0..total {
826        let mut rem = flat;
827        let mut src_flat = 0usize;
828        for i in 0..ndim {
829            let idx = rem / strides[i];
830            rem %= strides[i];
831            let src_idx = if i == axis { shape[i] - 1 - idx } else { idx };
832            src_flat += src_idx * strides[i];
833        }
834        data.push(src_data[src_flat].clone());
835    }
836    Array::from_vec(IxDyn::new(shape), data)
837}
838
839/// Flip array left-right (reverse axis 1).
840///
841/// Analogous to `numpy.fliplr()`.
842///
843/// # Errors
844/// Returns `FerrayError::InvalidValue` if the array has fewer than 2 dimensions.
845pub fn fliplr<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
846    if a.ndim() < 2 {
847        return Err(FerrayError::invalid_value(
848            "fliplr: array must be at least 2-D",
849        ));
850    }
851    flip(a, 1)
852}
853
854/// Flip array up-down (reverse axis 0).
855///
856/// Analogous to `numpy.flipud()`.
857///
858/// # Errors
859/// Returns `FerrayError::InvalidValue` if the array has 0 dimensions.
860pub fn flipud<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
861    if a.ndim() < 1 {
862        return Err(FerrayError::invalid_value(
863            "flipud: array must be at least 1-D",
864        ));
865    }
866    flip(a, 0)
867}
868
869/// Rotate array 90 degrees counterclockwise in the plane defined by axes (0, 1).
870///
871/// `k` specifies the number of 90-degree rotations (can be negative).
872///
873/// Analogous to `numpy.rot90()`.
874///
875/// # Errors
876/// Returns `FerrayError::InvalidValue` if the array has fewer than 2 dimensions.
877pub fn rot90<T: Element, D: Dimension>(a: &Array<T, D>, k: i32) -> FerrayResult<Array<T, IxDyn>> {
878    if a.ndim() < 2 {
879        return Err(FerrayError::invalid_value(
880            "rot90: array must be at least 2-D",
881        ));
882    }
883    // Normalize k to [0, 4)
884    let k = k.rem_euclid(4);
885    let shape = a.shape();
886    let data: Vec<T> = a.iter().cloned().collect();
887
888    // We work with the IxDyn representation
889    let as_dyn = Array::from_vec(IxDyn::new(shape), data)?;
890
891    match k {
892        0 => Ok(as_dyn),
893        1 => {
894            // rot90 once: flip axis 1, then transpose axes 0,1
895            let flipped = flip(&as_dyn, 1)?;
896            swapaxes(&flipped, 0, 1)
897        }
898        2 => {
899            // rot180: flip both axes
900            let f1 = flip(&as_dyn, 0)?;
901            flip(&f1, 1)
902        }
903        3 => {
904            // rot270: transpose, then flip axis 1
905            let transposed = swapaxes(&as_dyn, 0, 1)?;
906            flip(&transposed, 1)
907        }
908        _ => unreachable!(),
909    }
910}
911
912/// Roll elements along an axis. Elements that roll past the end
913/// are re-introduced at the beginning.
914///
915/// If `axis` is `None`, the array is flattened first, then rolled.
916///
917/// Analogous to `numpy.roll()`.
918///
919/// # Errors
920/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
921pub fn roll<T: Element, D: Dimension>(
922    a: &Array<T, D>,
923    shift: isize,
924    axis: Option<usize>,
925) -> FerrayResult<Array<T, IxDyn>> {
926    match axis {
927        None => {
928            // Flatten, roll, reshape back
929            let data: Vec<T> = a.iter().cloned().collect();
930            let n = data.len();
931            if n == 0 {
932                return Array::from_vec(IxDyn::new(a.shape()), data);
933            }
934            let shift = ((shift % n as isize) + n as isize) as usize % n;
935            let mut rolled = Vec::with_capacity(n);
936            for i in 0..n {
937                rolled.push(data[(n + i - shift) % n].clone());
938            }
939            Array::from_vec(IxDyn::new(a.shape()), rolled)
940        }
941        Some(ax) => {
942            let shape = a.shape();
943            let ndim = shape.len();
944            if ax >= ndim {
945                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
946            }
947            let axis_len = shape[ax];
948            if axis_len == 0 {
949                let data: Vec<T> = a.iter().cloned().collect();
950                return Array::from_vec(IxDyn::new(shape), data);
951            }
952            let shift = ((shift % axis_len as isize) + axis_len as isize) as usize % axis_len;
953            let src_data: Vec<T> = a.iter().cloned().collect();
954            let total = src_data.len();
955
956            // Compute strides (C-order)
957            let mut strides = vec![1usize; ndim];
958            for i in (0..ndim.saturating_sub(1)).rev() {
959                strides[i] = strides[i + 1] * shape[i + 1];
960            }
961
962            let mut data = Vec::with_capacity(total);
963            for flat in 0..total {
964                let mut rem = flat;
965                let mut src_flat = 0usize;
966                #[allow(clippy::needless_range_loop)]
967                for i in 0..ndim {
968                    let idx = rem / strides[i];
969                    rem %= strides[i];
970                    let src_idx = if i == ax {
971                        (axis_len + idx - shift) % axis_len
972                    } else {
973                        idx
974                    };
975                    src_flat += src_idx * strides[i];
976                }
977                data.push(src_data[src_flat].clone());
978            }
979            Array::from_vec(IxDyn::new(shape), data)
980        }
981    }
982}
983
984// ============================================================================
985// Tests
986// ============================================================================
987
988#[cfg(test)]
989mod tests {
990    use super::*;
991
992    fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
993        Array::from_vec(IxDyn::new(shape), data).unwrap()
994    }
995
996    // -- REQ-20 --
997
998    #[test]
999    fn test_reshape() {
1000        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1001        let b = reshape(&a, &[3, 2]).unwrap();
1002        assert_eq!(b.shape(), &[3, 2]);
1003        let data: Vec<f64> = b.iter().copied().collect();
1004        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1005    }
1006
1007    #[test]
1008    fn test_reshape_size_mismatch() {
1009        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1010        assert!(reshape(&a, &[2, 4]).is_err());
1011    }
1012
1013    #[test]
1014    fn test_ravel() {
1015        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1016        let b = ravel(&a).unwrap();
1017        assert_eq!(b.shape(), &[6]);
1018        assert_eq!(b.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1019    }
1020
1021    #[test]
1022    fn test_flatten() {
1023        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1024        let b = flatten(&a).unwrap();
1025        assert_eq!(b.shape(), &[6]);
1026    }
1027
1028    #[test]
1029    fn test_squeeze() {
1030        let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1031        let b = squeeze(&a, None).unwrap();
1032        assert_eq!(b.shape(), &[3]);
1033    }
1034
1035    #[test]
1036    fn test_squeeze_specific_axis() {
1037        let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1038        let b = squeeze(&a, Some(0)).unwrap();
1039        assert_eq!(b.shape(), &[3, 1]);
1040    }
1041
1042    #[test]
1043    fn test_squeeze_not_size_1() {
1044        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1045        assert!(squeeze(&a, Some(0)).is_err());
1046    }
1047
1048    #[test]
1049    fn test_expand_dims() {
1050        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1051        let b = expand_dims(&a, 0).unwrap();
1052        assert_eq!(b.shape(), &[1, 3]);
1053        let c = expand_dims(&a, 1).unwrap();
1054        assert_eq!(c.shape(), &[3, 1]);
1055    }
1056
1057    #[test]
1058    fn test_expand_dims_oob() {
1059        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1060        assert!(expand_dims(&a, 3).is_err());
1061    }
1062
1063    #[test]
1064    fn test_broadcast_to() {
1065        let a = dyn_arr(&[1, 3], vec![1.0, 2.0, 3.0]);
1066        let b = broadcast_to(&a, &[3, 3]).unwrap();
1067        assert_eq!(b.shape(), &[3, 3]);
1068        let data: Vec<f64> = b.iter().copied().collect();
1069        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1070    }
1071
1072    #[test]
1073    fn test_broadcast_to_1d_to_2d() {
1074        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1075        let b = broadcast_to(&a, &[2, 3]).unwrap();
1076        assert_eq!(b.shape(), &[2, 3]);
1077    }
1078
1079    #[test]
1080    fn test_broadcast_to_incompatible() {
1081        let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1082        assert!(broadcast_to(&a, &[3]).is_err());
1083    }
1084
1085    #[test]
1086    fn test_broadcast_to_from_non_contiguous_source() {
1087        // Issue #133: broadcast_to's source may itself be non-contiguous
1088        // (e.g., transposed). Since broadcast_to now delegates to
1089        // ndarray's `broadcast` which handles any stride pattern, the
1090        // result should materialize correctly.
1091        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1092        // Transpose gives us a 3x2 logical view — then broadcast to (2, 3, 2).
1093        let t = transpose(&a, None).unwrap();
1094        let b = broadcast_to(&t, &[2, 3, 2]).unwrap();
1095        assert_eq!(b.shape(), &[2, 3, 2]);
1096        // Both outer slices should be identical.
1097        let data: Vec<f64> = b.iter().copied().collect();
1098        assert_eq!(&data[..6], &data[6..12]);
1099    }
1100
1101    // -- REQ-21 --
1102
1103    #[test]
1104    fn test_concatenate() {
1105        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1106        let b = dyn_arr(&[2, 3], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1107        let c = concatenate(&[a, b], 0).unwrap();
1108        assert_eq!(c.shape(), &[4, 3]);
1109    }
1110
1111    #[test]
1112    fn test_concatenate_axis1() {
1113        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1114        let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1115        let c = concatenate(&[a, b], 1).unwrap();
1116        assert_eq!(c.shape(), &[2, 5]);
1117    }
1118
1119    #[test]
1120    fn test_concatenate_shape_mismatch() {
1121        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1122        let b = dyn_arr(&[3, 3], vec![1.0; 9]);
1123        // Axis 0: different sizes on axis 1? No — axis 1 is same (3).
1124        // But axis 0 concat: shapes are [2,3] and [3,3], axis 0 can differ.
1125        // Non-concat axis (1) matches.
1126        let c = concatenate(&[a, b], 0).unwrap();
1127        assert_eq!(c.shape(), &[5, 3]);
1128    }
1129
1130    #[test]
1131    fn test_concatenate_empty() {
1132        let v: Vec<Array<f64, IxDyn>> = vec![];
1133        assert!(concatenate(&v, 0).is_err());
1134    }
1135
1136    #[test]
1137    fn test_stack() {
1138        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1139        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1140        let c = stack(&[a, b], 0).unwrap();
1141        assert_eq!(c.shape(), &[2, 3]);
1142        let data: Vec<f64> = c.iter().copied().collect();
1143        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1144    }
1145
1146    #[test]
1147    fn test_stack_axis1() {
1148        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1149        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1150        let c = stack(&[a, b], 1).unwrap();
1151        assert_eq!(c.shape(), &[3, 2]);
1152        let data: Vec<f64> = c.iter().copied().collect();
1153        assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1154    }
1155
1156    #[test]
1157    fn test_vstack() {
1158        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1159        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1160        let c = vstack(&[a, b]).unwrap();
1161        assert_eq!(c.shape(), &[2, 3]);
1162    }
1163
1164    #[test]
1165    fn test_hstack() {
1166        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1167        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1168        let c = hstack(&[a, b]).unwrap();
1169        assert_eq!(c.shape(), &[6]);
1170    }
1171
1172    #[test]
1173    fn test_hstack_2d() {
1174        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1175        let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1176        let c = hstack(&[a, b]).unwrap();
1177        assert_eq!(c.shape(), &[2, 5]);
1178    }
1179
1180    #[test]
1181    fn test_dstack() {
1182        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1183        let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1184        let c = dstack(&[a, b]).unwrap();
1185        assert_eq!(c.shape(), &[2, 2, 2]);
1186    }
1187
1188    #[test]
1189    fn test_block() {
1190        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1191        let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1192        let c = dyn_arr(&[2, 2], vec![9.0, 10.0, 11.0, 12.0]);
1193        let d = dyn_arr(&[2, 2], vec![13.0, 14.0, 15.0, 16.0]);
1194        let result = block(&[vec![a, b], vec![c, d]]).unwrap();
1195        assert_eq!(result.shape(), &[4, 4]);
1196    }
1197
1198    #[test]
1199    fn test_split() {
1200        let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1201        let parts = split(&a, 3, 0).unwrap();
1202        assert_eq!(parts.len(), 3);
1203        assert_eq!(parts[0].shape(), &[2]);
1204        assert_eq!(parts[1].shape(), &[2]);
1205        assert_eq!(parts[2].shape(), &[2]);
1206    }
1207
1208    #[test]
1209    fn test_split_uneven() {
1210        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1211        assert!(split(&a, 3, 0).is_err()); // 5 not divisible by 3
1212    }
1213
1214    #[test]
1215    fn test_array_split() {
1216        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1217        let parts = array_split(&a, &[2, 4], 0).unwrap();
1218        assert_eq!(parts.len(), 3);
1219        assert_eq!(parts[0].shape(), &[2]); // [1,2]
1220        assert_eq!(parts[1].shape(), &[2]); // [3,4]
1221        assert_eq!(parts[2].shape(), &[1]); // [5]
1222    }
1223
1224    #[test]
1225    fn test_vsplit() {
1226        let a = dyn_arr(&[4, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1227        let parts = vsplit(&a, 2).unwrap();
1228        assert_eq!(parts.len(), 2);
1229        assert_eq!(parts[0].shape(), &[2, 2]);
1230    }
1231
1232    #[test]
1233    fn test_hsplit() {
1234        let a = dyn_arr(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1235        let parts = hsplit(&a, 2).unwrap();
1236        assert_eq!(parts.len(), 2);
1237        assert_eq!(parts[0].shape(), &[2, 2]);
1238    }
1239
1240    // -- REQ-22 --
1241
1242    #[test]
1243    fn test_transpose_2d() {
1244        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1245        let b = transpose(&a, None).unwrap();
1246        assert_eq!(b.shape(), &[3, 2]);
1247        let data: Vec<f64> = b.iter().copied().collect();
1248        assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1249    }
1250
1251    #[test]
1252    fn test_transpose_explicit() {
1253        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1254        let b = transpose(&a, Some(&[1, 0])).unwrap();
1255        assert_eq!(b.shape(), &[3, 2]);
1256    }
1257
1258    #[test]
1259    fn test_transpose_bad_axes() {
1260        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1261        assert!(transpose(&a, Some(&[0])).is_err()); // wrong length
1262    }
1263
1264    #[test]
1265    fn test_swapaxes() {
1266        let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1267        let b = swapaxes(&a, 0, 2).unwrap();
1268        assert_eq!(b.shape(), &[4, 3, 2]);
1269    }
1270
1271    #[test]
1272    fn test_moveaxis() {
1273        let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1274        let b = moveaxis(&a, 0, 2).unwrap();
1275        assert_eq!(b.shape(), &[3, 4, 2]);
1276    }
1277
1278    #[test]
1279    fn test_rollaxis() {
1280        let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1281        let b = rollaxis(&a, 2, 0).unwrap();
1282        assert_eq!(b.shape(), &[4, 2, 3]);
1283    }
1284
1285    #[test]
1286    fn test_flip() {
1287        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1288        let b = flip(&a, 0).unwrap();
1289        let data: Vec<f64> = b.iter().copied().collect();
1290        assert_eq!(data, vec![3.0, 2.0, 1.0]);
1291    }
1292
1293    #[test]
1294    fn test_flip_2d() {
1295        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1296        let b = flip(&a, 0).unwrap();
1297        let data: Vec<f64> = b.iter().copied().collect();
1298        assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1299
1300        let c = flip(&a, 1).unwrap();
1301        let data2: Vec<f64> = c.iter().copied().collect();
1302        assert_eq!(data2, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1303    }
1304
1305    #[test]
1306    fn test_fliplr() {
1307        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1308        let b = fliplr(&a).unwrap();
1309        let data: Vec<f64> = b.iter().copied().collect();
1310        assert_eq!(data, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1311    }
1312
1313    #[test]
1314    fn test_flipud() {
1315        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1316        let b = flipud(&a).unwrap();
1317        let data: Vec<f64> = b.iter().copied().collect();
1318        assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1319    }
1320
1321    #[test]
1322    fn test_fliplr_1d_err() {
1323        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1324        assert!(fliplr(&a).is_err());
1325    }
1326
1327    #[test]
1328    fn test_rot90_once() {
1329        // [[1, 2], [3, 4]] -> [[2, 4], [1, 3]]
1330        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1331        let b = rot90(&a, 1).unwrap();
1332        assert_eq!(b.shape(), &[2, 2]);
1333        let data: Vec<f64> = b.iter().copied().collect();
1334        assert_eq!(data, vec![2.0, 4.0, 1.0, 3.0]);
1335    }
1336
1337    #[test]
1338    fn test_rot90_twice() {
1339        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1340        let b = rot90(&a, 2).unwrap();
1341        let data: Vec<f64> = b.iter().copied().collect();
1342        assert_eq!(data, vec![4.0, 3.0, 2.0, 1.0]);
1343    }
1344
1345    #[test]
1346    fn test_rot90_four_is_identity() {
1347        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1348        let b = rot90(&a, 4).unwrap();
1349        let data_a: Vec<f64> = a.iter().copied().collect();
1350        let data_b: Vec<f64> = b.iter().copied().collect();
1351        assert_eq!(data_a, data_b);
1352        assert_eq!(a.shape(), b.shape());
1353    }
1354
1355    #[test]
1356    fn test_roll_flat() {
1357        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1358        let b = roll(&a, 2, None).unwrap();
1359        let data: Vec<f64> = b.iter().copied().collect();
1360        assert_eq!(data, vec![4.0, 5.0, 1.0, 2.0, 3.0]);
1361    }
1362
1363    #[test]
1364    fn test_roll_negative() {
1365        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1366        let b = roll(&a, -2, None).unwrap();
1367        let data: Vec<f64> = b.iter().copied().collect();
1368        assert_eq!(data, vec![3.0, 4.0, 5.0, 1.0, 2.0]);
1369    }
1370
1371    #[test]
1372    fn test_roll_axis() {
1373        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1374        let b = roll(&a, 1, Some(1)).unwrap();
1375        let data: Vec<f64> = b.iter().copied().collect();
1376        assert_eq!(data, vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1377    }
1378
1379    // -----------------------------------------------------------------------
1380    // column_stack / row_stack / array_split_n / to_dyn (issue #362)
1381    // -----------------------------------------------------------------------
1382
1383    #[test]
1384    fn test_column_stack_1d() {
1385        // 3 1-D arrays of length 4 -> (4, 3)
1386        let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1387        let b = dyn_arr(&[4], vec![10.0, 20.0, 30.0, 40.0]);
1388        let c = dyn_arr(&[4], vec![100.0, 200.0, 300.0, 400.0]);
1389        let result = column_stack(&[a, b, c]).unwrap();
1390        assert_eq!(result.shape(), &[4, 3]);
1391        assert_eq!(
1392            result.iter().copied().collect::<Vec<_>>(),
1393            vec![
1394                1.0, 10.0, 100.0, // row 0
1395                2.0, 20.0, 200.0, // row 1
1396                3.0, 30.0, 300.0, // row 2
1397                4.0, 40.0, 400.0, // row 3
1398            ]
1399        );
1400    }
1401
1402    #[test]
1403    fn test_column_stack_2d_same_as_hstack() {
1404        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1405        let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1406        let result = column_stack(&[a, b]).unwrap();
1407        assert_eq!(result.shape(), &[2, 4]);
1408        assert_eq!(
1409            result.iter().copied().collect::<Vec<_>>(),
1410            vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
1411        );
1412    }
1413
1414    #[test]
1415    fn test_column_stack_length_mismatch() {
1416        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1417        let b = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1418        assert!(column_stack(&[a, b]).is_err());
1419    }
1420
1421    #[test]
1422    fn test_row_stack_is_vstack() {
1423        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1424        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1425        let row = row_stack(&[a.clone(), b.clone()]).unwrap();
1426        let v = vstack(&[a, b]).unwrap();
1427        assert_eq!(row.shape(), v.shape());
1428        assert_eq!(
1429            row.iter().copied().collect::<Vec<_>>(),
1430            v.iter().copied().collect::<Vec<_>>()
1431        );
1432    }
1433
1434    #[test]
1435    fn test_array_split_n_uneven() {
1436        // 7 elements split into 3 sections -> [3, 2, 2]
1437        let a = dyn_arr(&[7], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
1438        let parts = array_split_n(&a, 3, 0).unwrap();
1439        assert_eq!(parts.len(), 3);
1440        assert_eq!(
1441            parts[0].iter().copied().collect::<Vec<_>>(),
1442            vec![1.0, 2.0, 3.0]
1443        );
1444        assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0]);
1445        assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![6.0, 7.0]);
1446    }
1447
1448    #[test]
1449    fn test_array_split_n_even() {
1450        let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1451        let parts = array_split_n(&a, 3, 0).unwrap();
1452        assert_eq!(parts.len(), 3);
1453        for (i, expected) in [vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]
1454            .iter()
1455            .enumerate()
1456        {
1457            assert_eq!(&parts[i].iter().copied().collect::<Vec<_>>(), expected);
1458        }
1459    }
1460
1461    #[test]
1462    fn test_array_split_n_more_sections_than_elements() {
1463        // NumPy's behavior: 3 elements split into 5 sections gives 5 parts,
1464        // first 3 have 1 element each, last 2 are empty.
1465        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1466        let parts = array_split_n(&a, 5, 0).unwrap();
1467        assert_eq!(parts.len(), 5);
1468        assert_eq!(parts[0].iter().copied().collect::<Vec<_>>(), vec![1.0]);
1469        assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![2.0]);
1470        assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![3.0]);
1471        assert_eq!(
1472            parts[3].iter().copied().collect::<Vec<_>>(),
1473            Vec::<f64>::new()
1474        );
1475        assert_eq!(
1476            parts[4].iter().copied().collect::<Vec<_>>(),
1477            Vec::<f64>::new()
1478        );
1479    }
1480
1481    #[test]
1482    fn test_to_dyn_from_typed() {
1483        use crate::Array;
1484        use crate::dimension::Ix2;
1485        let typed =
1486            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1487                .unwrap();
1488        let dy = typed.to_dyn();
1489        assert_eq!(dy.shape(), &[2, 3]);
1490        assert_eq!(
1491            dy.iter().copied().collect::<Vec<_>>(),
1492            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
1493        );
1494    }
1495
1496    #[test]
1497    fn test_concatenate_typed_via_to_dyn() {
1498        // Demonstrates the typical end-user flow: have Array<T, Ix2>, want to
1499        // concatenate, route through to_dyn().
1500        use crate::Array;
1501        use crate::dimension::Ix2;
1502        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1503        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
1504        let result = concatenate(&[a.to_dyn(), b.to_dyn()], 0).unwrap();
1505        assert_eq!(result.shape(), &[4, 2]);
1506        assert_eq!(
1507            result.iter().copied().collect::<Vec<_>>(),
1508            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
1509        );
1510    }
1511}