Skip to main content

ferray_core/manipulation/
mod.rs

1// ferray-core: Shape manipulation functions (REQ-20, REQ-21, REQ-22)
2//
3// Mirrors numpy's shape manipulation routines: reshape, ravel, flatten,
4// concatenate, stack, transpose, flip, etc.
5
6pub mod extended;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13// ============================================================================
14// REQ-20: Shape methods
15// ============================================================================
16
17/// Reshape an array to a new shape (returns a new owned array).
18///
19/// The total number of elements must remain the same.
20///
21/// Analogous to `numpy.reshape()`. Delegates to ndarray's `to_shape`,
22/// which re-uses the existing buffer (zero copy on the data side) for
23/// contiguous inputs and only does a bulk memcpy via `to_owned` — the
24/// old path ran an element-by-element `iter().cloned().collect()`
25/// instead (issue #82).
26///
27/// # Errors
28/// Returns `FerrayError::ShapeMismatch` if the new shape has a different
29/// total number of elements.
30pub fn reshape<T: Element, D: Dimension>(
31    a: &Array<T, D>,
32    new_shape: &[usize],
33) -> FerrayResult<Array<T, IxDyn>> {
34    let old_size = a.size();
35    let new_size: usize = new_shape.iter().product();
36    if old_size != new_size {
37        return Err(FerrayError::shape_mismatch(format!(
38            "cannot reshape array of size {} into shape {:?} (size {})",
39            old_size, new_shape, new_size,
40        )));
41    }
42    let view = a.inner.view().into_dyn();
43    let reshaped = view
44        .to_shape(ndarray::IxDyn(new_shape))
45        .map_err(|e| FerrayError::shape_mismatch(e.to_string()))?;
46    // `as_standard_layout` ensures the final buffer is C-contiguous so
47    // `as_slice()` works for callers, even when the source view was
48    // F-contiguous / strided.
49    Ok(Array::from_ndarray(
50        reshaped.as_standard_layout().into_owned(),
51    ))
52}
53
54/// Return a flattened (1-D) copy of the array.
55///
56/// Analogous to `numpy.ravel()`. Uses ndarray's `to_shape` rather than
57/// the element-wise collect path (issue #82).
58pub fn ravel<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
59    let n = a.size();
60    let view = a.inner.view().into_dyn();
61    let reshaped = view
62        .to_shape(ndarray::IxDyn(&[n]))
63        .expect("1-D reshape always succeeds for a size-preserving target");
64    let standard = reshaped.as_standard_layout().into_owned();
65    let one_d = standard
66        .into_dimensionality::<ndarray::Ix1>()
67        .expect("reshape result has ndim == 1 by construction");
68    Ok(Array::from_ndarray(one_d))
69}
70
71/// Return a flattened (1-D) copy of the array.
72///
73/// Identical to `ravel()` — analogous to `ndarray.flatten()`.
74pub fn flatten<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
75    ravel(a)
76}
77
78/// Remove axes of length 1 from the shape.
79///
80/// If `axis` is `None`, all length-1 axes are removed.
81/// If `axis` is `Some(ax)`, only that axis is removed (errors if it is not length 1).
82///
83/// Analogous to `numpy.squeeze()`.
84///
85/// # Errors
86/// Returns `FerrayError::AxisOutOfBounds` if the axis is invalid, or
87/// `FerrayError::InvalidValue` if the specified axis has size != 1.
88pub fn squeeze<T: Element, D: Dimension>(
89    a: &Array<T, D>,
90    axis: Option<usize>,
91) -> FerrayResult<Array<T, IxDyn>> {
92    let shape = a.shape();
93    match axis {
94        Some(ax) => {
95            if ax >= shape.len() {
96                return Err(FerrayError::axis_out_of_bounds(ax, shape.len()));
97            }
98            if shape[ax] != 1 {
99                return Err(FerrayError::invalid_value(format!(
100                    "cannot select axis {} with size {} for squeeze (must be 1)",
101                    ax, shape[ax],
102                )));
103            }
104            let new_shape: Vec<usize> = shape
105                .iter()
106                .enumerate()
107                .filter(|&(i, _)| i != ax)
108                .map(|(_, &s)| s)
109                .collect();
110            let data: Vec<T> = a.iter().cloned().collect();
111            Array::from_vec(IxDyn::new(&new_shape), data)
112        }
113        None => {
114            let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
115            // If all dims are 1, the result is a scalar (0-D is tricky), so
116            // make it at least 1-D with a single element.
117            let new_shape = if new_shape.is_empty() && !shape.is_empty() {
118                vec![1]
119            } else if new_shape.is_empty() {
120                vec![]
121            } else {
122                new_shape
123            };
124            let data: Vec<T> = a.iter().cloned().collect();
125            Array::from_vec(IxDyn::new(&new_shape), data)
126        }
127    }
128}
129
130/// Insert a new axis of length 1 at the given position.
131///
132/// Analogous to `numpy.expand_dims()`.
133///
134/// # Errors
135/// Returns `FerrayError::AxisOutOfBounds` if `axis > ndim`.
136pub fn expand_dims<T: Element, D: Dimension>(
137    a: &Array<T, D>,
138    axis: usize,
139) -> FerrayResult<Array<T, IxDyn>> {
140    let ndim = a.ndim();
141    if axis > ndim {
142        return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
143    }
144    let mut new_shape: Vec<usize> = a.shape().to_vec();
145    new_shape.insert(axis, 1);
146    let data: Vec<T> = a.iter().cloned().collect();
147    Array::from_vec(IxDyn::new(&new_shape), data)
148}
149
150/// Broadcast an array to a new shape (returns a new owned array).
151///
152/// The array is replicated along size-1 dimensions to match the target shape.
153///
154/// Analogous to `numpy.broadcast_to()`. Uses ndarray's `broadcast` to
155/// produce a zero-copy stride-0 view and then materializes it with a
156/// single `to_owned` — the prior implementation allocated a source
157/// buffer via `iter().cloned().collect()` and walked every output
158/// index in a hand-rolled nested loop (issue #81).
159///
160/// # Errors
161/// Returns `FerrayError::BroadcastFailure` if the shapes are incompatible.
162pub fn broadcast_to<T: Element, D: Dimension>(
163    a: &Array<T, D>,
164    new_shape: &[usize],
165) -> FerrayResult<Array<T, IxDyn>> {
166    let src_shape = a.shape();
167    let dyn_view = a.inner.view().into_dyn();
168    let broadcast_view = dyn_view
169        .broadcast(ndarray::IxDyn(new_shape))
170        .ok_or_else(|| FerrayError::BroadcastFailure {
171            shape_a: src_shape.to_vec(),
172            shape_b: new_shape.to_vec(),
173        })?;
174    // `as_standard_layout` turns the stride-0 broadcast view into a
175    // proper C-contiguous owned buffer in one pass; the previous
176    // hand-rolled loop walked every destination index separately.
177    Ok(Array::from_ndarray(
178        broadcast_view.as_standard_layout().into_owned(),
179    ))
180}
181
182// ============================================================================
183// REQ-21: Join/split
184// ============================================================================
185
186/// Join a sequence of arrays along an existing axis.
187///
188/// Analogous to `numpy.concatenate()`.
189///
190/// # Errors
191/// Returns `FerrayError::InvalidValue` if the array list is empty.
192/// Returns `FerrayError::ShapeMismatch` if shapes differ on non-concatenation axes.
193/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
194pub fn concatenate<T: Element>(
195    arrays: &[Array<T, IxDyn>],
196    axis: usize,
197) -> FerrayResult<Array<T, IxDyn>> {
198    if arrays.is_empty() {
199        return Err(FerrayError::invalid_value(
200            "concatenate: need at least one array",
201        ));
202    }
203    let ndim = arrays[0].ndim();
204    if axis >= ndim {
205        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
206    }
207    let base_shape = arrays[0].shape();
208
209    // Validate all arrays have same ndim and matching shapes on non-concat axes
210    let mut total_along_axis = 0usize;
211    for arr in arrays {
212        if arr.ndim() != ndim {
213            return Err(FerrayError::shape_mismatch(format!(
214                "all arrays must have same ndim; got {} and {}",
215                ndim,
216                arr.ndim(),
217            )));
218        }
219        for (i, (&s, &base)) in arr.shape().iter().zip(base_shape.iter()).enumerate() {
220            if i != axis && s != base {
221                return Err(FerrayError::shape_mismatch(format!(
222                    "shape mismatch on axis {}: {} vs {}",
223                    i, s, base,
224                )));
225            }
226        }
227        total_along_axis += arr.shape()[axis];
228    }
229
230    // Build new shape
231    let mut new_shape = base_shape.to_vec();
232    new_shape[axis] = total_along_axis;
233    let total: usize = new_shape.iter().product();
234    let mut data = Vec::with_capacity(total);
235
236    // Pre-collect each source array into a contiguous Vec to avoid O(n) iter().nth() per element
237    let src_vecs: Vec<Vec<T>> = arrays.iter().map(|a| a.iter().cloned().collect()).collect();
238
239    // Compute strides for the output array (C-order)
240    let mut out_strides = vec![1usize; ndim];
241    for i in (0..ndim - 1).rev() {
242        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
243    }
244
245    // For each position in the output, figure out which source array and offset
246    for flat_idx in 0..total {
247        // Convert flat index to nd-index
248        let mut rem = flat_idx;
249        let mut nd_idx = vec![0usize; ndim];
250        for i in 0..ndim {
251            nd_idx[i] = rem / out_strides[i];
252            rem %= out_strides[i];
253        }
254
255        // Find which source array this position belongs to
256        let concat_idx = nd_idx[axis];
257        let mut offset = 0;
258        let mut src_arr_idx = 0;
259        for (k, arr) in arrays.iter().enumerate() {
260            let len_along = arr.shape()[axis];
261            if concat_idx < offset + len_along {
262                src_arr_idx = k;
263                break;
264            }
265            offset += len_along;
266        }
267        let local_concat_idx = concat_idx - offset;
268
269        // Build source flat index (C-order)
270        let src_shape = arrays[src_arr_idx].shape();
271        let mut src_flat = 0usize;
272        let mut src_mul = 1usize;
273        for i in (0..ndim).rev() {
274            let idx = if i == axis {
275                local_concat_idx
276            } else {
277                nd_idx[i]
278            };
279            src_flat += idx * src_mul;
280            src_mul *= src_shape[i];
281        }
282
283        let elem = src_vecs[src_arr_idx].get(src_flat).ok_or_else(|| {
284            FerrayError::invalid_value(format!(
285                "concatenate: internal index {} out of range for source array of length {}",
286                src_flat,
287                src_vecs[src_arr_idx].len(),
288            ))
289        })?;
290        data.push(elem.clone());
291    }
292
293    Array::from_vec(IxDyn::new(&new_shape), data)
294}
295
296/// Join a sequence of arrays along a **new** axis.
297///
298/// All arrays must have the same shape. The result has one more dimension
299/// than the inputs.
300///
301/// Analogous to `numpy.stack()`.
302///
303/// # Errors
304/// Returns `FerrayError::InvalidValue` if the array list is empty.
305/// Returns `FerrayError::ShapeMismatch` if shapes differ.
306/// Returns `FerrayError::AxisOutOfBounds` if axis > ndim.
307pub fn stack<T: Element>(arrays: &[Array<T, IxDyn>], axis: usize) -> FerrayResult<Array<T, IxDyn>> {
308    if arrays.is_empty() {
309        return Err(FerrayError::invalid_value("stack: need at least one array"));
310    }
311    let base_shape = arrays[0].shape();
312    let ndim = base_shape.len();
313
314    if axis > ndim {
315        return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
316    }
317
318    for arr in &arrays[1..] {
319        if arr.shape() != base_shape {
320            return Err(FerrayError::shape_mismatch(format!(
321                "all input arrays must have the same shape; got {:?} and {:?}",
322                base_shape,
323                arr.shape(),
324            )));
325        }
326    }
327
328    // Expand each array along the new axis, then concatenate
329    let mut expanded = Vec::with_capacity(arrays.len());
330    for arr in arrays {
331        expanded.push(expand_dims(arr, axis)?);
332    }
333    concatenate(&expanded, axis)
334}
335
336/// Stack arrays vertically (row-wise). Equivalent to `concatenate` along axis 0
337/// for 2-D+ arrays, or equivalent to stacking 1-D arrays as rows.
338///
339/// Analogous to `numpy.vstack()`.
340pub fn vstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
341    if arrays.is_empty() {
342        return Err(FerrayError::invalid_value(
343            "vstack: need at least one array",
344        ));
345    }
346    // For 1-D arrays, reshape to (1, N) then concatenate along axis 0
347    let ndim = arrays[0].ndim();
348    if ndim == 1 {
349        let mut reshaped = Vec::with_capacity(arrays.len());
350        for arr in arrays {
351            let n = arr.shape()[0];
352            reshaped.push(reshape(arr, &[1, n])?);
353        }
354        concatenate(&reshaped, 0)
355    } else {
356        concatenate(arrays, 0)
357    }
358}
359
360/// Stack arrays horizontally (column-wise). Equivalent to `concatenate` along
361/// axis 1 for 2-D+ arrays, or along axis 0 for 1-D arrays.
362///
363/// Analogous to `numpy.hstack()`.
364pub fn hstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
365    if arrays.is_empty() {
366        return Err(FerrayError::invalid_value(
367            "hstack: need at least one array",
368        ));
369    }
370    let ndim = arrays[0].ndim();
371    if ndim == 1 {
372        concatenate(arrays, 0)
373    } else {
374        concatenate(arrays, 1)
375    }
376}
377
378/// Stack arrays along the third axis (depth-wise).
379///
380/// For 1-D arrays of shape `(N,)`, reshapes to `(1, N, 1)`.
381/// For 2-D arrays of shape `(M, N)`, reshapes to `(M, N, 1)`.
382/// Then concatenates along axis 2.
383///
384/// Analogous to `numpy.dstack()`.
385pub fn dstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
386    if arrays.is_empty() {
387        return Err(FerrayError::invalid_value(
388            "dstack: need at least one array",
389        ));
390    }
391    let mut expanded = Vec::with_capacity(arrays.len());
392    for arr in arrays {
393        let shape = arr.shape();
394        match shape.len() {
395            1 => {
396                let n = shape[0];
397                expanded.push(reshape(arr, &[1, n, 1])?);
398            }
399            2 => {
400                let (m, n) = (shape[0], shape[1]);
401                expanded.push(reshape(arr, &[m, n, 1])?);
402            }
403            _ => {
404                // Already 3-D+, just use as-is
405                let data: Vec<T> = arr.iter().cloned().collect();
406                expanded.push(Array::from_vec(IxDyn::new(shape), data)?);
407            }
408        }
409    }
410    concatenate(&expanded, 2)
411}
412
413/// Stack 1-D arrays as columns into a 2-D array.
414///
415/// Each input becomes one column of the output. For 2-D+ inputs, this is
416/// equivalent to [`hstack`].
417///
418/// Analogous to `numpy.column_stack()`.
419///
420/// # Errors
421/// Returns `FerrayError::InvalidValue` if the input is empty,
422/// or `FerrayError::ShapeMismatch` if 1-D inputs have different lengths.
423pub fn column_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
424    if arrays.is_empty() {
425        return Err(FerrayError::invalid_value(
426            "column_stack: need at least one array",
427        ));
428    }
429    let first_ndim = arrays[0].ndim();
430    if first_ndim == 1 {
431        // Convert each 1-D array of length N into a (N, 1) column, then hstack.
432        let n = arrays[0].shape()[0];
433        let mut reshaped = Vec::with_capacity(arrays.len());
434        for arr in arrays {
435            if arr.ndim() != 1 {
436                return Err(FerrayError::shape_mismatch(
437                    "column_stack: all inputs must have the same ndim",
438                ));
439            }
440            if arr.shape()[0] != n {
441                return Err(FerrayError::shape_mismatch(format!(
442                    "column_stack: 1-D inputs must have the same length; got {} and {}",
443                    n,
444                    arr.shape()[0],
445                )));
446            }
447            reshaped.push(reshape(arr, &[n, 1])?);
448        }
449        concatenate(&reshaped, 1)
450    } else {
451        // 2-D+: same as hstack
452        hstack(arrays)
453    }
454}
455
456/// Stack arrays in sequence vertically (row-wise). Alias for [`vstack`].
457///
458/// Analogous to `numpy.row_stack()` (deprecated alias for `vstack` in NumPy 2.0
459/// but still widely used).
460pub fn row_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
461    vstack(arrays)
462}
463
464/// Assemble an array from nested blocks.
465///
466/// Simplified version: takes a 2-D grid of arrays (as Vec<Vec<...>>)
467/// and assembles them by stacking rows horizontally, then all rows vertically.
468///
469/// Analogous to `numpy.block()`.
470///
471/// # Errors
472/// Returns errors on shape mismatches.
473pub fn block<T: Element>(blocks: &[Vec<Array<T, IxDyn>>]) -> FerrayResult<Array<T, IxDyn>> {
474    if blocks.is_empty() {
475        return Err(FerrayError::invalid_value("block: empty input"));
476    }
477    let mut rows = Vec::with_capacity(blocks.len());
478    for row in blocks {
479        if row.is_empty() {
480            return Err(FerrayError::invalid_value("block: empty row"));
481        }
482        // Concatenate along axis 1 (columns within each row)
483        let row_arr = if row.len() == 1 {
484            let data: Vec<T> = row[0].iter().cloned().collect();
485            Array::from_vec(IxDyn::new(row[0].shape()), data)?
486        } else {
487            hstack(row)?
488        };
489        rows.push(row_arr);
490    }
491    if rows.len() == 1 {
492        // SAFETY: just checked len() == 1, so pop() always returns Some
493        Ok(rows.pop().unwrap_or_else(|| unreachable!()))
494    } else {
495        vstack(&rows)
496    }
497}
498
499/// Split an array into equal-sized sub-arrays.
500///
501/// `n_sections` must evenly divide the size along `axis`.
502///
503/// Analogous to `numpy.split()`.
504///
505/// # Errors
506/// Returns `FerrayError::InvalidValue` if the axis cannot be evenly split.
507pub fn split<T: Element>(
508    a: &Array<T, IxDyn>,
509    n_sections: usize,
510    axis: usize,
511) -> FerrayResult<Vec<Array<T, IxDyn>>> {
512    let shape = a.shape();
513    if axis >= shape.len() {
514        return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
515    }
516    let axis_len = shape[axis];
517    if n_sections == 0 {
518        return Err(FerrayError::invalid_value("split: n_sections must be > 0"));
519    }
520    if axis_len % n_sections != 0 {
521        return Err(FerrayError::invalid_value(format!(
522            "array of size {} along axis {} cannot be evenly split into {} sections",
523            axis_len, axis, n_sections,
524        )));
525    }
526    let chunk_size = axis_len / n_sections;
527    let indices: Vec<usize> = (1..n_sections).map(|i| i * chunk_size).collect();
528    array_split(a, &indices, axis)
529}
530
531/// Split an array into sub-arrays at the given indices along `axis`.
532///
533/// Unlike `split()`, this does not require even division.
534///
535/// Analogous to `numpy.array_split()` (with explicit split points).
536///
537/// # Errors
538/// Returns `FerrayError::AxisOutOfBounds` if axis is invalid.
539pub fn array_split<T: Element>(
540    a: &Array<T, IxDyn>,
541    indices: &[usize],
542    axis: usize,
543) -> FerrayResult<Vec<Array<T, IxDyn>>> {
544    let shape = a.shape();
545    let ndim = shape.len();
546    if axis >= ndim {
547        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
548    }
549    let axis_len = shape[axis];
550    let src_data: Vec<T> = a.iter().cloned().collect();
551
552    // Build split points including 0 and axis_len
553    let mut splits = Vec::with_capacity(indices.len() + 2);
554    splits.push(0);
555    for &idx in indices {
556        splits.push(idx.min(axis_len));
557    }
558    splits.push(axis_len);
559
560    // Compute source strides (C-order)
561    let mut src_strides = vec![1usize; ndim];
562    for i in (0..ndim - 1).rev() {
563        src_strides[i] = src_strides[i + 1] * shape[i + 1];
564    }
565
566    let mut result = Vec::with_capacity(splits.len() - 1);
567    for w in splits.windows(2) {
568        let start = w[0];
569        let end = w[1];
570        let chunk_len = end - start;
571
572        let mut sub_shape = shape.to_vec();
573        sub_shape[axis] = chunk_len;
574        let sub_total: usize = sub_shape.iter().product();
575
576        // Compute sub strides
577        let mut sub_strides = vec![1usize; ndim];
578        for i in (0..ndim - 1).rev() {
579            sub_strides[i] = sub_strides[i + 1] * sub_shape[i + 1];
580        }
581
582        let mut sub_data = Vec::with_capacity(sub_total);
583        for flat in 0..sub_total {
584            // Convert to nd-index in sub array
585            let mut rem = flat;
586            let mut src_flat = 0usize;
587            for i in 0..ndim {
588                let idx = rem / sub_strides[i];
589                rem %= sub_strides[i];
590                let src_idx = if i == axis { idx + start } else { idx };
591                src_flat += src_idx * src_strides[i];
592            }
593            sub_data.push(src_data[src_flat].clone());
594        }
595        result.push(Array::from_vec(IxDyn::new(&sub_shape), sub_data)?);
596    }
597
598    Ok(result)
599}
600
601/// Split an array into `n` sub-arrays along `axis`, allowing uneven sections.
602///
603/// Unlike [`split`], this never errors on uneven division: the first
604/// `axis_len % n` sections have one extra element. This matches NumPy's
605/// `numpy.array_split(ary, n, axis)` (integer-section variant).
606///
607/// # Errors
608/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
609/// Returns `FerrayError::InvalidValue` if `n == 0`.
610pub fn array_split_n<T: Element>(
611    a: &Array<T, IxDyn>,
612    n: usize,
613    axis: usize,
614) -> FerrayResult<Vec<Array<T, IxDyn>>> {
615    if n == 0 {
616        return Err(FerrayError::invalid_value("array_split_n: n must be > 0"));
617    }
618    let shape = a.shape();
619    if axis >= shape.len() {
620        return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
621    }
622    let axis_len = shape[axis];
623
624    // Build split indices following NumPy's array_split:
625    // First (axis_len % n) sections get (axis_len / n + 1) elements,
626    // remaining sections get (axis_len / n) elements.
627    let base = axis_len / n;
628    let extra = axis_len % n;
629    let mut indices = Vec::with_capacity(n.saturating_sub(1));
630    let mut cum = 0usize;
631    for i in 0..n - 1 {
632        cum += if i < extra { base + 1 } else { base };
633        indices.push(cum);
634    }
635    array_split(a, &indices, axis)
636}
637
638/// Split array along axis 0 (vertical split). Equivalent to `split(a, n, 0)`.
639///
640/// Analogous to `numpy.vsplit()`.
641pub fn vsplit<T: Element>(
642    a: &Array<T, IxDyn>,
643    n_sections: usize,
644) -> FerrayResult<Vec<Array<T, IxDyn>>> {
645    split(a, n_sections, 0)
646}
647
648/// Split array along axis 1 (horizontal split). Equivalent to `split(a, n, 1)`.
649///
650/// Analogous to `numpy.hsplit()`.
651pub fn hsplit<T: Element>(
652    a: &Array<T, IxDyn>,
653    n_sections: usize,
654) -> FerrayResult<Vec<Array<T, IxDyn>>> {
655    split(a, n_sections, 1)
656}
657
658/// Split array along axis 2 (depth split). Equivalent to `split(a, n, 2)`.
659///
660/// Analogous to `numpy.dsplit()`.
661pub fn dsplit<T: Element>(
662    a: &Array<T, IxDyn>,
663    n_sections: usize,
664) -> FerrayResult<Vec<Array<T, IxDyn>>> {
665    split(a, n_sections, 2)
666}
667
668// ============================================================================
669// REQ-22: Transpose / reorder
670// ============================================================================
671
672/// Permute the axes of an array.
673///
674/// `axes` specifies the new ordering. For a 2-D array, `[1, 0]` transposes.
675/// If `axes` is `None`, reverses the order of all axes.
676///
677/// Analogous to `numpy.transpose()`.
678///
679/// # Errors
680/// Returns `FerrayError::InvalidValue` if `axes` is the wrong length or
681/// contains invalid/duplicate axis indices.
682pub fn transpose<T: Element, D: Dimension>(
683    a: &Array<T, D>,
684    axes: Option<&[usize]>,
685) -> FerrayResult<Array<T, IxDyn>> {
686    let ndim = a.ndim();
687    let perm: Vec<usize> = match axes {
688        Some(ax) => {
689            if ax.len() != ndim {
690                return Err(FerrayError::invalid_value(format!(
691                    "axes must have length {} but got {}",
692                    ndim,
693                    ax.len(),
694                )));
695            }
696            // Validate: each axis appears exactly once
697            let mut seen = vec![false; ndim];
698            for &a in ax {
699                if a >= ndim {
700                    return Err(FerrayError::axis_out_of_bounds(a, ndim));
701                }
702                if seen[a] {
703                    return Err(FerrayError::invalid_value(format!(
704                        "duplicate axis {} in transpose",
705                        a,
706                    )));
707                }
708                seen[a] = true;
709            }
710            ax.to_vec()
711        }
712        None => (0..ndim).rev().collect(),
713    };
714
715    // Permute axes is a zero-copy stride rearrangement in ndarray; we
716    // then call `as_standard_layout` which returns a borrowed view when
717    // already C-contiguous or materializes a single standard-layout
718    // owned buffer otherwise. The old implementation walked every
719    // output index in a hand-rolled scatter loop (issue #82). Plain
720    // `to_owned()` would preserve the F-contiguous stride pattern from
721    // the permutation, which downstream callers don't want because
722    // `as_slice()` only succeeds on standard layouts.
723    let permuted = a
724        .inner
725        .view()
726        .into_dyn()
727        .permuted_axes(ndarray::IxDyn(&perm));
728    Ok(Array::from_ndarray(
729        permuted.as_standard_layout().into_owned(),
730    ))
731}
732
733/// Swap two axes of an array.
734///
735/// Analogous to `numpy.swapaxes()`.
736///
737/// # Errors
738/// Returns `FerrayError::AxisOutOfBounds` if either axis is out of bounds.
739pub fn swapaxes<T: Element, D: Dimension>(
740    a: &Array<T, D>,
741    axis1: usize,
742    axis2: usize,
743) -> FerrayResult<Array<T, IxDyn>> {
744    let ndim = a.ndim();
745    if axis1 >= ndim {
746        return Err(FerrayError::axis_out_of_bounds(axis1, ndim));
747    }
748    if axis2 >= ndim {
749        return Err(FerrayError::axis_out_of_bounds(axis2, ndim));
750    }
751    let mut perm: Vec<usize> = (0..ndim).collect();
752    perm.swap(axis1, axis2);
753    transpose(a, Some(&perm))
754}
755
756/// Move an axis to a new position.
757///
758/// Analogous to `numpy.moveaxis()`.
759///
760/// # Errors
761/// Returns `FerrayError::AxisOutOfBounds` if either axis is out of bounds.
762pub fn moveaxis<T: Element, D: Dimension>(
763    a: &Array<T, D>,
764    source: usize,
765    destination: usize,
766) -> FerrayResult<Array<T, IxDyn>> {
767    let ndim = a.ndim();
768    if source >= ndim {
769        return Err(FerrayError::axis_out_of_bounds(source, ndim));
770    }
771    if destination >= ndim {
772        return Err(FerrayError::axis_out_of_bounds(destination, ndim));
773    }
774    // Build permutation by removing source and inserting at destination
775    let mut order: Vec<usize> = (0..ndim).filter(|&x| x != source).collect();
776    order.insert(destination, source);
777    transpose(a, Some(&order))
778}
779
780/// Roll an axis to a new position (similar to moveaxis).
781///
782/// Analogous to `numpy.rollaxis()`.
783///
784/// # Errors
785/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim` or `start > ndim`.
786pub fn rollaxis<T: Element, D: Dimension>(
787    a: &Array<T, D>,
788    axis: usize,
789    start: usize,
790) -> FerrayResult<Array<T, IxDyn>> {
791    let ndim = a.ndim();
792    if axis >= ndim {
793        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
794    }
795    if start > ndim {
796        return Err(FerrayError::axis_out_of_bounds(start, ndim + 1));
797    }
798    let dst = if start > axis { start - 1 } else { start };
799    if axis == dst {
800        // No-op: return a copy
801        let data: Vec<T> = a.iter().cloned().collect();
802        return Array::from_vec(IxDyn::new(a.shape()), data);
803    }
804    moveaxis(a, axis, dst)
805}
806
807/// Reverse the order of elements along the given axis.
808///
809/// Analogous to `numpy.flip()`.
810///
811/// # Errors
812/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
813pub fn flip<T: Element, D: Dimension>(
814    a: &Array<T, D>,
815    axis: usize,
816) -> FerrayResult<Array<T, IxDyn>> {
817    let shape = a.shape();
818    let ndim = shape.len();
819    if axis >= ndim {
820        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
821    }
822    let src_data: Vec<T> = a.iter().cloned().collect();
823    let total = src_data.len();
824
825    // Compute strides (C-order)
826    let mut strides = vec![1usize; ndim];
827    for i in (0..ndim.saturating_sub(1)).rev() {
828        strides[i] = strides[i + 1] * shape[i + 1];
829    }
830
831    let mut data = Vec::with_capacity(total);
832    for flat in 0..total {
833        let mut rem = flat;
834        let mut src_flat = 0usize;
835        for i in 0..ndim {
836            let idx = rem / strides[i];
837            rem %= strides[i];
838            let src_idx = if i == axis { shape[i] - 1 - idx } else { idx };
839            src_flat += src_idx * strides[i];
840        }
841        data.push(src_data[src_flat].clone());
842    }
843    Array::from_vec(IxDyn::new(shape), data)
844}
845
846/// Flip array left-right (reverse axis 1).
847///
848/// Analogous to `numpy.fliplr()`.
849///
850/// # Errors
851/// Returns `FerrayError::InvalidValue` if the array has fewer than 2 dimensions.
852pub fn fliplr<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
853    if a.ndim() < 2 {
854        return Err(FerrayError::invalid_value(
855            "fliplr: array must be at least 2-D",
856        ));
857    }
858    flip(a, 1)
859}
860
861/// Flip array up-down (reverse axis 0).
862///
863/// Analogous to `numpy.flipud()`.
864///
865/// # Errors
866/// Returns `FerrayError::InvalidValue` if the array has 0 dimensions.
867pub fn flipud<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
868    if a.ndim() < 1 {
869        return Err(FerrayError::invalid_value(
870            "flipud: array must be at least 1-D",
871        ));
872    }
873    flip(a, 0)
874}
875
876/// Rotate array 90 degrees counterclockwise in the plane defined by axes (0, 1).
877///
878/// `k` specifies the number of 90-degree rotations (can be negative).
879///
880/// Analogous to `numpy.rot90()`.
881///
882/// # Errors
883/// Returns `FerrayError::InvalidValue` if the array has fewer than 2 dimensions.
884pub fn rot90<T: Element, D: Dimension>(a: &Array<T, D>, k: i32) -> FerrayResult<Array<T, IxDyn>> {
885    if a.ndim() < 2 {
886        return Err(FerrayError::invalid_value(
887            "rot90: array must be at least 2-D",
888        ));
889    }
890    // Normalize k to [0, 4)
891    let k = k.rem_euclid(4);
892    let shape = a.shape();
893    let data: Vec<T> = a.iter().cloned().collect();
894
895    // We work with the IxDyn representation
896    let as_dyn = Array::from_vec(IxDyn::new(shape), data)?;
897
898    match k {
899        0 => Ok(as_dyn),
900        1 => {
901            // rot90 once: flip axis 1, then transpose axes 0,1
902            let flipped = flip(&as_dyn, 1)?;
903            swapaxes(&flipped, 0, 1)
904        }
905        2 => {
906            // rot180: flip both axes
907            let f1 = flip(&as_dyn, 0)?;
908            flip(&f1, 1)
909        }
910        3 => {
911            // rot270: transpose, then flip axis 1
912            let transposed = swapaxes(&as_dyn, 0, 1)?;
913            flip(&transposed, 1)
914        }
915        _ => unreachable!(),
916    }
917}
918
919/// Roll elements along an axis. Elements that roll past the end
920/// are re-introduced at the beginning.
921///
922/// If `axis` is `None`, the array is flattened first, then rolled.
923///
924/// Analogous to `numpy.roll()`.
925///
926/// # Errors
927/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
928pub fn roll<T: Element, D: Dimension>(
929    a: &Array<T, D>,
930    shift: isize,
931    axis: Option<usize>,
932) -> FerrayResult<Array<T, IxDyn>> {
933    match axis {
934        None => {
935            // Flatten, roll, reshape back
936            let data: Vec<T> = a.iter().cloned().collect();
937            let n = data.len();
938            if n == 0 {
939                return Array::from_vec(IxDyn::new(a.shape()), data);
940            }
941            let shift = ((shift % n as isize) + n as isize) as usize % n;
942            let mut rolled = Vec::with_capacity(n);
943            for i in 0..n {
944                rolled.push(data[(n + i - shift) % n].clone());
945            }
946            Array::from_vec(IxDyn::new(a.shape()), rolled)
947        }
948        Some(ax) => {
949            let shape = a.shape();
950            let ndim = shape.len();
951            if ax >= ndim {
952                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
953            }
954            let axis_len = shape[ax];
955            if axis_len == 0 {
956                let data: Vec<T> = a.iter().cloned().collect();
957                return Array::from_vec(IxDyn::new(shape), data);
958            }
959            let shift = ((shift % axis_len as isize) + axis_len as isize) as usize % axis_len;
960            let src_data: Vec<T> = a.iter().cloned().collect();
961            let total = src_data.len();
962
963            // Compute strides (C-order)
964            let mut strides = vec![1usize; ndim];
965            for i in (0..ndim.saturating_sub(1)).rev() {
966                strides[i] = strides[i + 1] * shape[i + 1];
967            }
968
969            let mut data = Vec::with_capacity(total);
970            for flat in 0..total {
971                let mut rem = flat;
972                let mut src_flat = 0usize;
973                #[allow(clippy::needless_range_loop)]
974                for i in 0..ndim {
975                    let idx = rem / strides[i];
976                    rem %= strides[i];
977                    let src_idx = if i == ax {
978                        (axis_len + idx - shift) % axis_len
979                    } else {
980                        idx
981                    };
982                    src_flat += src_idx * strides[i];
983                }
984                data.push(src_data[src_flat].clone());
985            }
986            Array::from_vec(IxDyn::new(shape), data)
987        }
988    }
989}
990
991// ============================================================================
992// Tests
993// ============================================================================
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998
999    fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
1000        Array::from_vec(IxDyn::new(shape), data).unwrap()
1001    }
1002
1003    // -- REQ-20 --
1004
1005    #[test]
1006    fn test_reshape() {
1007        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1008        let b = reshape(&a, &[3, 2]).unwrap();
1009        assert_eq!(b.shape(), &[3, 2]);
1010        let data: Vec<f64> = b.iter().copied().collect();
1011        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1012    }
1013
1014    #[test]
1015    fn test_reshape_size_mismatch() {
1016        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1017        assert!(reshape(&a, &[2, 4]).is_err());
1018    }
1019
1020    #[test]
1021    fn test_ravel() {
1022        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1023        let b = ravel(&a).unwrap();
1024        assert_eq!(b.shape(), &[6]);
1025        assert_eq!(b.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1026    }
1027
1028    #[test]
1029    fn test_flatten() {
1030        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1031        let b = flatten(&a).unwrap();
1032        assert_eq!(b.shape(), &[6]);
1033    }
1034
1035    #[test]
1036    fn test_squeeze() {
1037        let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1038        let b = squeeze(&a, None).unwrap();
1039        assert_eq!(b.shape(), &[3]);
1040    }
1041
1042    #[test]
1043    fn test_squeeze_specific_axis() {
1044        let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1045        let b = squeeze(&a, Some(0)).unwrap();
1046        assert_eq!(b.shape(), &[3, 1]);
1047    }
1048
1049    #[test]
1050    fn test_squeeze_not_size_1() {
1051        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1052        assert!(squeeze(&a, Some(0)).is_err());
1053    }
1054
1055    #[test]
1056    fn test_expand_dims() {
1057        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1058        let b = expand_dims(&a, 0).unwrap();
1059        assert_eq!(b.shape(), &[1, 3]);
1060        let c = expand_dims(&a, 1).unwrap();
1061        assert_eq!(c.shape(), &[3, 1]);
1062    }
1063
1064    #[test]
1065    fn test_expand_dims_oob() {
1066        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1067        assert!(expand_dims(&a, 3).is_err());
1068    }
1069
1070    #[test]
1071    fn test_broadcast_to() {
1072        let a = dyn_arr(&[1, 3], vec![1.0, 2.0, 3.0]);
1073        let b = broadcast_to(&a, &[3, 3]).unwrap();
1074        assert_eq!(b.shape(), &[3, 3]);
1075        let data: Vec<f64> = b.iter().copied().collect();
1076        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1077    }
1078
1079    #[test]
1080    fn test_broadcast_to_1d_to_2d() {
1081        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1082        let b = broadcast_to(&a, &[2, 3]).unwrap();
1083        assert_eq!(b.shape(), &[2, 3]);
1084    }
1085
1086    #[test]
1087    fn test_broadcast_to_incompatible() {
1088        let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1089        assert!(broadcast_to(&a, &[3]).is_err());
1090    }
1091
1092    #[test]
1093    fn test_broadcast_to_from_non_contiguous_source() {
1094        // Issue #133: broadcast_to's source may itself be non-contiguous
1095        // (e.g., transposed). Since broadcast_to now delegates to
1096        // ndarray's `broadcast` which handles any stride pattern, the
1097        // result should materialize correctly.
1098        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1099        // Transpose gives us a 3x2 logical view — then broadcast to (2, 3, 2).
1100        let t = transpose(&a, None).unwrap();
1101        let b = broadcast_to(&t, &[2, 3, 2]).unwrap();
1102        assert_eq!(b.shape(), &[2, 3, 2]);
1103        // Both outer slices should be identical.
1104        let data: Vec<f64> = b.iter().copied().collect();
1105        assert_eq!(&data[..6], &data[6..12]);
1106    }
1107
1108    // -- REQ-21 --
1109
1110    #[test]
1111    fn test_concatenate() {
1112        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1113        let b = dyn_arr(&[2, 3], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1114        let c = concatenate(&[a, b], 0).unwrap();
1115        assert_eq!(c.shape(), &[4, 3]);
1116    }
1117
1118    #[test]
1119    fn test_concatenate_axis1() {
1120        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1121        let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1122        let c = concatenate(&[a, b], 1).unwrap();
1123        assert_eq!(c.shape(), &[2, 5]);
1124    }
1125
1126    #[test]
1127    fn test_concatenate_shape_mismatch() {
1128        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1129        let b = dyn_arr(&[3, 3], vec![1.0; 9]);
1130        // Axis 0: different sizes on axis 1? No — axis 1 is same (3).
1131        // But axis 0 concat: shapes are [2,3] and [3,3], axis 0 can differ.
1132        // Non-concat axis (1) matches.
1133        let c = concatenate(&[a, b], 0).unwrap();
1134        assert_eq!(c.shape(), &[5, 3]);
1135    }
1136
1137    #[test]
1138    fn test_concatenate_empty() {
1139        let v: Vec<Array<f64, IxDyn>> = vec![];
1140        assert!(concatenate(&v, 0).is_err());
1141    }
1142
1143    #[test]
1144    fn test_stack() {
1145        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1146        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1147        let c = stack(&[a, b], 0).unwrap();
1148        assert_eq!(c.shape(), &[2, 3]);
1149        let data: Vec<f64> = c.iter().copied().collect();
1150        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1151    }
1152
1153    #[test]
1154    fn test_stack_axis1() {
1155        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1156        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1157        let c = stack(&[a, b], 1).unwrap();
1158        assert_eq!(c.shape(), &[3, 2]);
1159        let data: Vec<f64> = c.iter().copied().collect();
1160        assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1161    }
1162
1163    #[test]
1164    fn test_vstack() {
1165        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1166        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1167        let c = vstack(&[a, b]).unwrap();
1168        assert_eq!(c.shape(), &[2, 3]);
1169    }
1170
1171    #[test]
1172    fn test_hstack() {
1173        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1174        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1175        let c = hstack(&[a, b]).unwrap();
1176        assert_eq!(c.shape(), &[6]);
1177    }
1178
1179    #[test]
1180    fn test_hstack_2d() {
1181        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1182        let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1183        let c = hstack(&[a, b]).unwrap();
1184        assert_eq!(c.shape(), &[2, 5]);
1185    }
1186
1187    #[test]
1188    fn test_dstack() {
1189        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1190        let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1191        let c = dstack(&[a, b]).unwrap();
1192        assert_eq!(c.shape(), &[2, 2, 2]);
1193    }
1194
1195    #[test]
1196    fn test_block() {
1197        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1198        let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1199        let c = dyn_arr(&[2, 2], vec![9.0, 10.0, 11.0, 12.0]);
1200        let d = dyn_arr(&[2, 2], vec![13.0, 14.0, 15.0, 16.0]);
1201        let result = block(&[vec![a, b], vec![c, d]]).unwrap();
1202        assert_eq!(result.shape(), &[4, 4]);
1203    }
1204
1205    #[test]
1206    fn test_split() {
1207        let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1208        let parts = split(&a, 3, 0).unwrap();
1209        assert_eq!(parts.len(), 3);
1210        assert_eq!(parts[0].shape(), &[2]);
1211        assert_eq!(parts[1].shape(), &[2]);
1212        assert_eq!(parts[2].shape(), &[2]);
1213    }
1214
1215    #[test]
1216    fn test_split_uneven() {
1217        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1218        assert!(split(&a, 3, 0).is_err()); // 5 not divisible by 3
1219    }
1220
1221    #[test]
1222    fn test_array_split() {
1223        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1224        let parts = array_split(&a, &[2, 4], 0).unwrap();
1225        assert_eq!(parts.len(), 3);
1226        assert_eq!(parts[0].shape(), &[2]); // [1,2]
1227        assert_eq!(parts[1].shape(), &[2]); // [3,4]
1228        assert_eq!(parts[2].shape(), &[1]); // [5]
1229    }
1230
1231    #[test]
1232    fn test_vsplit() {
1233        let a = dyn_arr(&[4, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1234        let parts = vsplit(&a, 2).unwrap();
1235        assert_eq!(parts.len(), 2);
1236        assert_eq!(parts[0].shape(), &[2, 2]);
1237    }
1238
1239    #[test]
1240    fn test_hsplit() {
1241        let a = dyn_arr(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1242        let parts = hsplit(&a, 2).unwrap();
1243        assert_eq!(parts.len(), 2);
1244        assert_eq!(parts[0].shape(), &[2, 2]);
1245    }
1246
1247    // -- REQ-22 --
1248
1249    #[test]
1250    fn test_transpose_2d() {
1251        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1252        let b = transpose(&a, None).unwrap();
1253        assert_eq!(b.shape(), &[3, 2]);
1254        let data: Vec<f64> = b.iter().copied().collect();
1255        assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1256    }
1257
1258    #[test]
1259    fn test_transpose_explicit() {
1260        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1261        let b = transpose(&a, Some(&[1, 0])).unwrap();
1262        assert_eq!(b.shape(), &[3, 2]);
1263    }
1264
1265    #[test]
1266    fn test_transpose_bad_axes() {
1267        let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1268        assert!(transpose(&a, Some(&[0])).is_err()); // wrong length
1269    }
1270
1271    #[test]
1272    fn test_swapaxes() {
1273        let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1274        let b = swapaxes(&a, 0, 2).unwrap();
1275        assert_eq!(b.shape(), &[4, 3, 2]);
1276    }
1277
1278    #[test]
1279    fn test_moveaxis() {
1280        let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1281        let b = moveaxis(&a, 0, 2).unwrap();
1282        assert_eq!(b.shape(), &[3, 4, 2]);
1283    }
1284
1285    #[test]
1286    fn test_rollaxis() {
1287        let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1288        let b = rollaxis(&a, 2, 0).unwrap();
1289        assert_eq!(b.shape(), &[4, 2, 3]);
1290    }
1291
1292    #[test]
1293    fn test_flip() {
1294        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1295        let b = flip(&a, 0).unwrap();
1296        let data: Vec<f64> = b.iter().copied().collect();
1297        assert_eq!(data, vec![3.0, 2.0, 1.0]);
1298    }
1299
1300    #[test]
1301    fn test_flip_2d() {
1302        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1303        let b = flip(&a, 0).unwrap();
1304        let data: Vec<f64> = b.iter().copied().collect();
1305        assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1306
1307        let c = flip(&a, 1).unwrap();
1308        let data2: Vec<f64> = c.iter().copied().collect();
1309        assert_eq!(data2, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1310    }
1311
1312    #[test]
1313    fn test_fliplr() {
1314        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1315        let b = fliplr(&a).unwrap();
1316        let data: Vec<f64> = b.iter().copied().collect();
1317        assert_eq!(data, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1318    }
1319
1320    #[test]
1321    fn test_flipud() {
1322        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1323        let b = flipud(&a).unwrap();
1324        let data: Vec<f64> = b.iter().copied().collect();
1325        assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1326    }
1327
1328    #[test]
1329    fn test_fliplr_1d_err() {
1330        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1331        assert!(fliplr(&a).is_err());
1332    }
1333
1334    #[test]
1335    fn test_rot90_once() {
1336        // [[1, 2], [3, 4]] -> [[2, 4], [1, 3]]
1337        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1338        let b = rot90(&a, 1).unwrap();
1339        assert_eq!(b.shape(), &[2, 2]);
1340        let data: Vec<f64> = b.iter().copied().collect();
1341        assert_eq!(data, vec![2.0, 4.0, 1.0, 3.0]);
1342    }
1343
1344    #[test]
1345    fn test_rot90_twice() {
1346        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1347        let b = rot90(&a, 2).unwrap();
1348        let data: Vec<f64> = b.iter().copied().collect();
1349        assert_eq!(data, vec![4.0, 3.0, 2.0, 1.0]);
1350    }
1351
1352    #[test]
1353    fn test_rot90_four_is_identity() {
1354        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1355        let b = rot90(&a, 4).unwrap();
1356        let data_a: Vec<f64> = a.iter().copied().collect();
1357        let data_b: Vec<f64> = b.iter().copied().collect();
1358        assert_eq!(data_a, data_b);
1359        assert_eq!(a.shape(), b.shape());
1360    }
1361
1362    #[test]
1363    fn test_roll_flat() {
1364        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1365        let b = roll(&a, 2, None).unwrap();
1366        let data: Vec<f64> = b.iter().copied().collect();
1367        assert_eq!(data, vec![4.0, 5.0, 1.0, 2.0, 3.0]);
1368    }
1369
1370    #[test]
1371    fn test_roll_negative() {
1372        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1373        let b = roll(&a, -2, None).unwrap();
1374        let data: Vec<f64> = b.iter().copied().collect();
1375        assert_eq!(data, vec![3.0, 4.0, 5.0, 1.0, 2.0]);
1376    }
1377
1378    #[test]
1379    fn test_roll_axis() {
1380        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1381        let b = roll(&a, 1, Some(1)).unwrap();
1382        let data: Vec<f64> = b.iter().copied().collect();
1383        assert_eq!(data, vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1384    }
1385
1386    // -----------------------------------------------------------------------
1387    // column_stack / row_stack / array_split_n / to_dyn (issue #362)
1388    // -----------------------------------------------------------------------
1389
1390    #[test]
1391    fn test_column_stack_1d() {
1392        // 3 1-D arrays of length 4 -> (4, 3)
1393        let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1394        let b = dyn_arr(&[4], vec![10.0, 20.0, 30.0, 40.0]);
1395        let c = dyn_arr(&[4], vec![100.0, 200.0, 300.0, 400.0]);
1396        let result = column_stack(&[a, b, c]).unwrap();
1397        assert_eq!(result.shape(), &[4, 3]);
1398        assert_eq!(
1399            result.iter().copied().collect::<Vec<_>>(),
1400            vec![
1401                1.0, 10.0, 100.0, // row 0
1402                2.0, 20.0, 200.0, // row 1
1403                3.0, 30.0, 300.0, // row 2
1404                4.0, 40.0, 400.0, // row 3
1405            ]
1406        );
1407    }
1408
1409    #[test]
1410    fn test_column_stack_2d_same_as_hstack() {
1411        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1412        let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1413        let result = column_stack(&[a, b]).unwrap();
1414        assert_eq!(result.shape(), &[2, 4]);
1415        assert_eq!(
1416            result.iter().copied().collect::<Vec<_>>(),
1417            vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
1418        );
1419    }
1420
1421    #[test]
1422    fn test_column_stack_length_mismatch() {
1423        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1424        let b = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1425        assert!(column_stack(&[a, b]).is_err());
1426    }
1427
1428    #[test]
1429    fn test_row_stack_is_vstack() {
1430        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1431        let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1432        let row = row_stack(&[a.clone(), b.clone()]).unwrap();
1433        let v = vstack(&[a, b]).unwrap();
1434        assert_eq!(row.shape(), v.shape());
1435        assert_eq!(
1436            row.iter().copied().collect::<Vec<_>>(),
1437            v.iter().copied().collect::<Vec<_>>()
1438        );
1439    }
1440
1441    #[test]
1442    fn test_array_split_n_uneven() {
1443        // 7 elements split into 3 sections -> [3, 2, 2]
1444        let a = dyn_arr(&[7], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
1445        let parts = array_split_n(&a, 3, 0).unwrap();
1446        assert_eq!(parts.len(), 3);
1447        assert_eq!(
1448            parts[0].iter().copied().collect::<Vec<_>>(),
1449            vec![1.0, 2.0, 3.0]
1450        );
1451        assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0]);
1452        assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![6.0, 7.0]);
1453    }
1454
1455    #[test]
1456    fn test_array_split_n_even() {
1457        let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1458        let parts = array_split_n(&a, 3, 0).unwrap();
1459        assert_eq!(parts.len(), 3);
1460        for (i, expected) in [vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]
1461            .iter()
1462            .enumerate()
1463        {
1464            assert_eq!(&parts[i].iter().copied().collect::<Vec<_>>(), expected);
1465        }
1466    }
1467
1468    #[test]
1469    fn test_array_split_n_more_sections_than_elements() {
1470        // NumPy's behavior: 3 elements split into 5 sections gives 5 parts,
1471        // first 3 have 1 element each, last 2 are empty.
1472        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1473        let parts = array_split_n(&a, 5, 0).unwrap();
1474        assert_eq!(parts.len(), 5);
1475        assert_eq!(parts[0].iter().copied().collect::<Vec<_>>(), vec![1.0]);
1476        assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![2.0]);
1477        assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![3.0]);
1478        assert_eq!(
1479            parts[3].iter().copied().collect::<Vec<_>>(),
1480            Vec::<f64>::new()
1481        );
1482        assert_eq!(
1483            parts[4].iter().copied().collect::<Vec<_>>(),
1484            Vec::<f64>::new()
1485        );
1486    }
1487
1488    #[test]
1489    fn test_to_dyn_from_typed() {
1490        use crate::Array;
1491        use crate::dimension::Ix2;
1492        let typed =
1493            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1494                .unwrap();
1495        let dy = typed.to_dyn();
1496        assert_eq!(dy.shape(), &[2, 3]);
1497        assert_eq!(
1498            dy.iter().copied().collect::<Vec<_>>(),
1499            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
1500        );
1501    }
1502
1503    #[test]
1504    fn test_concatenate_typed_via_to_dyn() {
1505        // Demonstrates the typical end-user flow: have Array<T, Ix2>, want to
1506        // concatenate, route through to_dyn().
1507        use crate::Array;
1508        use crate::dimension::Ix2;
1509        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1510        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
1511        let result = concatenate(&[a.to_dyn(), b.to_dyn()], 0).unwrap();
1512        assert_eq!(result.shape(), &[4, 2]);
1513        assert_eq!(
1514            result.iter().copied().collect::<Vec<_>>(),
1515            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
1516        );
1517    }
1518}