Skip to main content

ferray_core/manipulation/
extended.rs

1// ferray-core: Extended manipulation functions (REQ-22a)
2//
3// pad, tile, repeat, delete, insert, append, resize, trim_zeros
4//
5// ## REQ status (extended manipulation, NumPy parity)
6//  - REQ-22a (extended manipulation surface) — SHIPPED: `pad`/`pad_1d` with all
7//    five `PadMode` variants (`Constant`/`Edge`/`Reflect`/`Symmetric`/`Wrap`),
8//    `tile`/`repeat`/`delete`/`insert`/`append`/`resize`/`trim_zeros`, and the
9//    `atleast_1d`/`atleast_2d`/`atleast_3d` shape-promotion helpers (all this
10//    file). `PadMode` (this file) is the mode enum consumed by `pad`/`pad_1d`,
11//    mirroring `numpy.pad`'s `mode` argument.
12
13use crate::array::owned::Array;
14use crate::dimension::{Dimension, Ix1, IxDyn};
15use crate::dtype::Element;
16use crate::error::{FerrayError, FerrayResult};
17
18// ============================================================================
19// Pad modes
20// ============================================================================
21
22/// Padding mode for [`pad`].
23#[derive(Debug, Clone)]
24pub enum PadMode<T: Element> {
25    /// Pad with a constant value.
26    Constant(T),
27    /// Pad with the edge values of the array.
28    Edge,
29    /// Pad with the reflection of the array mirrored on the first and last
30    /// values (does not repeat the edge).
31    Reflect,
32    /// Pad with the reflection of the array mirrored along the edge.
33    Symmetric,
34    /// Pad by wrapping the array around.
35    Wrap,
36}
37
38/// Pad a 1-D array.
39///
40/// `pad_width` is `(before, after)` for the single axis.
41///
42/// Analogous to `numpy.pad()` for 1-D.
43///
44/// # Errors
45/// Returns `FerrayError::InvalidValue` if the array is empty and the mode
46/// requires elements (Edge, Reflect, Symmetric, Wrap).
47pub fn pad_1d<T: Element>(
48    a: &Array<T, Ix1>,
49    pad_width: (usize, usize),
50    mode: &PadMode<T>,
51) -> FerrayResult<Array<T, Ix1>> {
52    let n = a.shape()[0];
53    let (before, after) = pad_width;
54    let new_len = before + n + after;
55    let src: Vec<T> = a.iter().cloned().collect();
56
57    if n == 0 && !matches!(mode, PadMode::Constant(_)) {
58        return Err(FerrayError::invalid_value(
59            "pad: cannot use Edge/Reflect/Symmetric/Wrap mode on empty array",
60        ));
61    }
62
63    let mut data = Vec::with_capacity(new_len);
64
65    // Fill 'before' padding. We pass negative logical indices into
66    // `reflect_index` / `symmetric_index` so the reflection formulae
67    // actually reach past the start of the array — the previous
68    // expression `before - 1 - i` collapsed into the range
69    // `[0, before)` which meant ferray's Reflect mode was producing
70    // the same clamped pattern as Symmetric and disagreed with NumPy
71    // on arrays smaller than `2 * before` (issue #137).
72    for i in 0..before {
73        let logical_idx = -((before - i) as isize);
74        let val = match mode {
75            PadMode::Constant(c) => c.clone(),
76            PadMode::Edge => src[0].clone(),
77            PadMode::Reflect => src[reflect_index(logical_idx, n)].clone(),
78            PadMode::Symmetric => src[symmetric_index(logical_idx, n)].clone(),
79            PadMode::Wrap => {
80                // Wrap: negative index modulo n picks the mirror from
81                // the other end.
82                let m = n as isize;
83                let idx = ((logical_idx % m) + m) % m;
84                src[idx as usize].clone()
85            }
86        };
87        data.push(val);
88    }
89
90    // Copy original data
91    data.extend_from_slice(&src);
92
93    // Fill 'after' padding
94    for i in 0..after {
95        let val = match mode {
96            PadMode::Constant(c) => c.clone(),
97            PadMode::Edge => src[n - 1].clone(),
98            PadMode::Reflect => {
99                let idx = reflect_index(n as isize + i as isize, n);
100                src[idx].clone()
101            }
102            PadMode::Symmetric => {
103                let idx = symmetric_index(n as isize + i as isize, n);
104                src[idx].clone()
105            }
106            PadMode::Wrap => {
107                let idx = i % n;
108                src[idx].clone()
109            }
110        };
111        data.push(val);
112    }
113
114    Array::from_vec(Ix1::new([new_len]), data)
115}
116
117/// Reflect index: maps indices outside [0, n) by reflecting at boundaries 1 and n-2.
118/// This means the edge values are not repeated.
119const fn reflect_index(idx: isize, n: usize) -> usize {
120    if n <= 1 {
121        return 0;
122    }
123    let period = (n - 1) as isize * 2;
124    let mut i = idx % period;
125    if i < 0 {
126        i += period;
127    }
128    if i >= n as isize {
129        i = period - i;
130    }
131    i as usize
132}
133
134/// Symmetric index: maps indices outside [0, n) by reflecting at boundaries 0 and n-1.
135/// The edge values are repeated.
136fn symmetric_index(idx: isize, n: usize) -> usize {
137    if n == 0 {
138        return 0;
139    }
140    if n == 1 {
141        return 0;
142    }
143    let period = n as isize * 2;
144    let mut i = idx % period;
145    if i < 0 {
146        i += period;
147    }
148    if i >= n as isize {
149        i = period - 1 - i;
150    }
151    i.max(0) as usize
152}
153
154/// Pad an N-D array.
155///
156/// `pad_width` is a slice of `(before, after)` pairs, one per axis.
157/// If it has fewer entries than the array's ndim, the last entry is repeated.
158///
159/// Analogous to `numpy.pad()`.
160///
161/// # Errors
162/// Returns `FerrayError::InvalidValue` if `pad_width` is empty.
163pub fn pad<T: Element, D: Dimension>(
164    a: &Array<T, D>,
165    pad_width: &[(usize, usize)],
166    mode: &PadMode<T>,
167) -> FerrayResult<Array<T, IxDyn>> {
168    if pad_width.is_empty() {
169        return Err(FerrayError::invalid_value("pad: pad_width cannot be empty"));
170    }
171
172    let shape = a.shape();
173    let ndim = shape.len();
174
175    // Expand pad_width to ndim entries
176    let pads: Vec<(usize, usize)> = (0..ndim)
177        .map(|i| {
178            if i < pad_width.len() {
179                pad_width[i]
180            } else {
181                // pad_width is non-empty (checked above), so last() is always Some
182                *pad_width.last().unwrap_or_else(|| unreachable!())
183            }
184        })
185        .collect();
186
187    // For multi-dimensional, we pad axis-by-axis starting from the last axis.
188    // Convert to IxDyn first.
189    let mut current_data: Vec<T> = a.iter().cloned().collect();
190    let mut current_shape: Vec<usize> = shape.to_vec();
191
192    for ax in (0..ndim).rev() {
193        let (before, after) = pads[ax];
194        if before == 0 && after == 0 {
195            continue;
196        }
197        let axis_len = current_shape[ax];
198        let new_axis_len = before + axis_len + after;
199
200        // Compute strides
201        let outer: usize = current_shape[..ax].iter().product();
202        let inner: usize = current_shape[ax + 1..].iter().product();
203
204        let new_total = outer * new_axis_len * inner;
205        let mut new_data = Vec::with_capacity(new_total);
206
207        for o in 0..outer {
208            for j in 0..new_axis_len {
209                for k in 0..inner {
210                    let val = if j < before {
211                        // Before padding
212                        match mode {
213                            PadMode::Constant(c) => c.clone(),
214                            PadMode::Edge => {
215                                let src_j = 0;
216                                current_data[o * axis_len * inner + src_j * inner + k].clone()
217                            }
218                            PadMode::Reflect => {
219                                let src_j =
220                                    reflect_index(before as isize - 1 - j as isize, axis_len);
221                                current_data[o * axis_len * inner + src_j * inner + k].clone()
222                            }
223                            PadMode::Symmetric => {
224                                let src_j =
225                                    symmetric_index(before as isize - 1 - j as isize, axis_len);
226                                current_data[o * axis_len * inner + src_j * inner + k].clone()
227                            }
228                            PadMode::Wrap => {
229                                let src_j = ((axis_len as isize
230                                    - (before as isize - j as isize) % axis_len as isize)
231                                    % axis_len as isize)
232                                    as usize;
233                                current_data[o * axis_len * inner + src_j * inner + k].clone()
234                            }
235                        }
236                    } else if j < before + axis_len {
237                        // Original data
238                        let src_j = j - before;
239                        current_data[o * axis_len * inner + src_j * inner + k].clone()
240                    } else {
241                        // After padding
242                        let after_idx = j - before - axis_len;
243                        match mode {
244                            PadMode::Constant(c) => c.clone(),
245                            PadMode::Edge => {
246                                let src_j = axis_len - 1;
247                                current_data[o * axis_len * inner + src_j * inner + k].clone()
248                            }
249                            PadMode::Reflect => {
250                                let src_j = reflect_index(
251                                    (axis_len as isize) + after_idx as isize,
252                                    axis_len,
253                                );
254                                current_data[o * axis_len * inner + src_j * inner + k].clone()
255                            }
256                            PadMode::Symmetric => {
257                                let src_j = symmetric_index(
258                                    (axis_len as isize) + after_idx as isize,
259                                    axis_len,
260                                );
261                                current_data[o * axis_len * inner + src_j * inner + k].clone()
262                            }
263                            PadMode::Wrap => {
264                                let src_j = after_idx % axis_len;
265                                current_data[o * axis_len * inner + src_j * inner + k].clone()
266                            }
267                        }
268                    };
269                    new_data.push(val);
270                }
271            }
272        }
273
274        current_data = new_data;
275        current_shape[ax] = new_axis_len;
276    }
277
278    Array::from_vec(IxDyn::new(&current_shape), current_data)
279}
280
281/// Construct an array by repeating `a` the number of times given by `reps`.
282///
283/// If `reps` has fewer entries than `a.ndim()`, it is prepended with 1s.
284/// If `reps` has more entries, `a`'s shape is prepended with 1s.
285///
286/// Analogous to `numpy.tile()`.
287///
288/// # Errors
289/// Returns `FerrayError::InvalidValue` if `reps` is empty.
290pub fn tile<T: Element, D: Dimension>(
291    a: &Array<T, D>,
292    reps: &[usize],
293) -> FerrayResult<Array<T, IxDyn>> {
294    if reps.is_empty() {
295        return Err(FerrayError::invalid_value("tile: reps cannot be empty"));
296    }
297
298    let src_shape = a.shape();
299    let src_ndim = src_shape.len();
300    let reps_ndim = reps.len();
301    let out_ndim = src_ndim.max(reps_ndim);
302
303    // Pad shapes to out_ndim
304    let mut padded_shape = vec![1usize; out_ndim];
305    for i in 0..src_ndim {
306        padded_shape[out_ndim - src_ndim + i] = src_shape[i];
307    }
308    let mut padded_reps = vec![1usize; out_ndim];
309    for i in 0..reps_ndim {
310        padded_reps[out_ndim - reps_ndim + i] = reps[i];
311    }
312
313    let out_shape: Vec<usize> = padded_shape
314        .iter()
315        .zip(padded_reps.iter())
316        .map(|(&s, &r)| s * r)
317        .collect();
318    let total: usize = out_shape.iter().product();
319
320    let src_data: Vec<T> = a.iter().cloned().collect();
321    let mut data = Vec::with_capacity(total);
322
323    // Compute strides for output and padded source
324    let mut out_strides = vec![1usize; out_ndim];
325    for i in (0..out_ndim.saturating_sub(1)).rev() {
326        out_strides[i] = out_strides[i + 1] * out_shape[i + 1];
327    }
328
329    let mut src_strides = vec![1usize; out_ndim];
330    for i in (0..out_ndim.saturating_sub(1)).rev() {
331        src_strides[i] = src_strides[i + 1] * padded_shape[i + 1];
332    }
333
334    for flat in 0..total {
335        let mut rem = flat;
336        let mut src_flat = 0usize;
337        for i in 0..out_ndim {
338            let idx = rem / out_strides[i];
339            rem %= out_strides[i];
340            let src_idx = idx % padded_shape[i];
341            src_flat += src_idx * src_strides[i];
342        }
343        // Map src_flat to the original (non-padded) source
344        // Since we padded with 1s, the strides handle it.
345        if src_flat < src_data.len() {
346            data.push(src_data[src_flat].clone());
347        } else {
348            // This shouldn't happen if the math is right
349            data.push(T::zero());
350        }
351    }
352
353    Array::from_vec(IxDyn::new(&out_shape), data)
354}
355
356/// Repeat elements of an array.
357///
358/// If `axis` is `None`, the array is flattened first, then each element
359/// is repeated `repeats` times.
360///
361/// Analogous to `numpy.repeat()`.
362///
363/// # Errors
364/// Returns `FerrayError::AxisOutOfBounds` if the axis is out of bounds.
365pub fn repeat<T: Element, D: Dimension>(
366    a: &Array<T, D>,
367    repeats: usize,
368    axis: Option<usize>,
369) -> FerrayResult<Array<T, IxDyn>> {
370    match axis {
371        None => {
372            // Flatten and repeat each element
373            let src: Vec<T> = a.iter().cloned().collect();
374            let mut data = Vec::with_capacity(src.len() * repeats);
375            for val in &src {
376                for _ in 0..repeats {
377                    data.push(val.clone());
378                }
379            }
380            let n = data.len();
381            Array::from_vec(IxDyn::new(&[n]), data)
382        }
383        Some(ax) => {
384            let shape = a.shape();
385            let ndim = shape.len();
386            if ax >= ndim {
387                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
388            }
389
390            let mut new_shape = shape.to_vec();
391            new_shape[ax] *= repeats;
392            let total: usize = new_shape.iter().product();
393            let src_data: Vec<T> = a.iter().cloned().collect();
394
395            // Compute source strides (C-order)
396            let mut src_strides = vec![1usize; ndim];
397            for i in (0..ndim.saturating_sub(1)).rev() {
398                src_strides[i] = src_strides[i + 1] * shape[i + 1];
399            }
400
401            // Compute output strides
402            let mut out_strides = vec![1usize; ndim];
403            for i in (0..ndim.saturating_sub(1)).rev() {
404                out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
405            }
406
407            let mut data = Vec::with_capacity(total);
408            for flat in 0..total {
409                let mut rem = flat;
410                let mut src_flat = 0usize;
411                for i in 0..ndim {
412                    let idx = rem / out_strides[i];
413                    rem %= out_strides[i];
414                    let src_idx = if i == ax { idx / repeats } else { idx };
415                    src_flat += src_idx * src_strides[i];
416                }
417                data.push(src_data[src_flat].clone());
418            }
419
420            Array::from_vec(IxDyn::new(&new_shape), data)
421        }
422    }
423}
424
425/// Delete sub-arrays along an axis.
426///
427/// `indices` specifies which indices along `axis` to remove.
428///
429/// Analogous to `numpy.delete()`.
430///
431/// # Errors
432/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
433/// Returns `FerrayError::IndexOutOfBounds` if any index is out of range.
434pub fn delete<T: Element, D: Dimension>(
435    a: &Array<T, D>,
436    indices: &[usize],
437    axis: usize,
438) -> FerrayResult<Array<T, IxDyn>> {
439    let shape = a.shape();
440    let ndim = shape.len();
441    if axis >= ndim {
442        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
443    }
444    let axis_len = shape[axis];
445
446    // Validate indices
447    for &idx in indices {
448        if idx >= axis_len {
449            return Err(FerrayError::IndexOutOfBounds {
450                index: idx as isize,
451                axis,
452                size: axis_len,
453            });
454        }
455    }
456
457    let to_remove: std::collections::HashSet<usize> = indices.iter().copied().collect();
458    let kept: Vec<usize> = (0..axis_len).filter(|i| !to_remove.contains(i)).collect();
459    let new_axis_len = kept.len();
460
461    let mut new_shape = shape.to_vec();
462    new_shape[axis] = new_axis_len;
463    let total: usize = new_shape.iter().product();
464    let src_data: Vec<T> = a.iter().cloned().collect();
465
466    // Compute source strides (C-order)
467    let mut src_strides = vec![1usize; ndim];
468    for i in (0..ndim.saturating_sub(1)).rev() {
469        src_strides[i] = src_strides[i + 1] * shape[i + 1];
470    }
471
472    // Compute output strides
473    let mut out_strides = vec![1usize; ndim];
474    for i in (0..ndim.saturating_sub(1)).rev() {
475        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
476    }
477
478    let mut data = Vec::with_capacity(total);
479    for flat in 0..total {
480        let mut rem = flat;
481        let mut src_flat = 0usize;
482        for i in 0..ndim {
483            let idx = rem / out_strides[i];
484            rem %= out_strides[i];
485            let src_idx = if i == axis { kept[idx] } else { idx };
486            src_flat += src_idx * src_strides[i];
487        }
488        data.push(src_data[src_flat].clone());
489    }
490
491    Array::from_vec(IxDyn::new(&new_shape), data)
492}
493
494/// Insert values along an axis before a given index.
495///
496/// `index` is the position before which to insert. `values` is a 1-D array
497/// of values to insert (its length determines how many slices are added).
498///
499/// Analogous to `numpy.insert()`.
500///
501/// # Errors
502/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
503/// Returns `FerrayError::IndexOutOfBounds` if `index > axis_len`.
504pub fn insert<T: Element, D: Dimension>(
505    a: &Array<T, D>,
506    index: usize,
507    values: &Array<T, IxDyn>,
508    axis: usize,
509) -> FerrayResult<Array<T, IxDyn>> {
510    let shape = a.shape();
511    let ndim = shape.len();
512    if axis >= ndim {
513        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
514    }
515    let axis_len = shape[axis];
516    if index > axis_len {
517        return Err(FerrayError::IndexOutOfBounds {
518            index: index as isize,
519            axis,
520            size: axis_len + 1,
521        });
522    }
523
524    // A scalar `index` (obj) inserts a single contiguous block of
525    // sub-arrays along `axis`, with `values` broadcast into that block.
526    // numpy _function_base_impl.py:5544-5556: for a 0-d index,
527    //   values = array(values, ndmin=arr.ndim)   # prepend leading 1-axes
528    //   values = np.moveaxis(values, 0, axis)    # axis-0 -> axis
529    //   numnew = values.shape[axis]              # count of inserted slices
530    //   new[index:index+numnew, ...] = values    # broadcast assignment
531    // So the inserted count is `values.shape[axis]` AFTER prepending
532    // leading length-1 axes to reach `ndim` dims, then moving the leading
533    // axis to `axis`. We replicate that shape arithmetic, then fill the
534    // inserted region by broadcasting `values` (in its moved layout) into
535    // the destination region shape.
536    let vals: Vec<T> = values.iter().cloned().collect();
537
538    // `values` reshaped to ndmin = ndim by prepending leading 1-axes.
539    let mut vshape: Vec<usize> = values.shape().to_vec();
540    while vshape.len() < ndim {
541        vshape.insert(0, 1);
542    }
543    if vshape.len() > ndim {
544        return Err(FerrayError::shape_mismatch(format!(
545            "insert values have {} dims, cannot exceed array ndim {}",
546            vshape.len(),
547            ndim,
548        )));
549    }
550    // moveaxis(values, 0, axis): pull leading axis out, reinsert at `axis`.
551    let lead = vshape.remove(0);
552    vshape.insert(axis, lead);
553    let moved_shape = vshape; // logical shape of the moved `values`
554    let n_insert = moved_shape[axis];
555
556    // Destination region for the inserted block: array shape with the axis
557    // length set to `n_insert`. `values` (moved) must broadcast to it.
558    let mut region_shape = shape.to_vec();
559    region_shape[axis] = n_insert;
560    for i in 0..ndim {
561        if moved_shape[i] != region_shape[i] && moved_shape[i] != 1 {
562            return Err(FerrayError::shape_mismatch(format!(
563                "could not broadcast insert values shape {:?} into region {:?}",
564                moved_shape, region_shape,
565            )));
566        }
567    }
568    // Row-major strides over the moved `values` (length == vals.len()).
569    let mut moved_strides = vec![1usize; ndim];
570    for i in (0..ndim.saturating_sub(1)).rev() {
571        moved_strides[i] = moved_strides[i + 1] * moved_shape[i + 1];
572    }
573
574    let mut new_shape = shape.to_vec();
575    new_shape[axis] = axis_len + n_insert;
576    let total: usize = new_shape.iter().product();
577    let src_data: Vec<T> = a.iter().cloned().collect();
578
579    // Compute strides
580    let mut src_strides = vec![1usize; ndim];
581    for i in (0..ndim.saturating_sub(1)).rev() {
582        src_strides[i] = src_strides[i + 1] * shape[i + 1];
583    }
584
585    let mut out_strides = vec![1usize; ndim];
586    for i in (0..ndim.saturating_sub(1)).rev() {
587        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
588    }
589
590    let mut data = Vec::with_capacity(total);
591    for flat in 0..total {
592        let mut rem = flat;
593        let mut nd_idx = vec![0usize; ndim];
594        for i in 0..ndim {
595            nd_idx[i] = rem / out_strides[i];
596            rem %= out_strides[i];
597        }
598
599        let ax_idx = nd_idx[axis];
600        if ax_idx >= index && ax_idx < index + n_insert {
601            // Inserted block: position within the region is `nd_idx` with the
602            // axis coordinate rebased to `ax_idx - index`. Broadcast the
603            // moved `values` (size-1 axes repeat) into this region.
604            let mut val_flat = 0usize;
605            for i in 0..ndim {
606                let region_coord = if i == axis { ax_idx - index } else { nd_idx[i] };
607                let v_coord = if moved_shape[i] == 1 { 0 } else { region_coord };
608                val_flat += v_coord * moved_strides[i];
609            }
610            data.push(vals[val_flat].clone());
611        } else {
612            // Original data
613            let src_ax_idx = if ax_idx >= index + n_insert {
614                ax_idx - n_insert
615            } else {
616                ax_idx
617            };
618            let mut src_flat = 0usize;
619            for i in 0..ndim {
620                let idx = if i == axis { src_ax_idx } else { nd_idx[i] };
621                src_flat += idx * src_strides[i];
622            }
623            data.push(src_data[src_flat].clone());
624        }
625    }
626
627    Array::from_vec(IxDyn::new(&new_shape), data)
628}
629
630/// Append values to the end of an array along an axis.
631///
632/// If `axis` is `None`, both arrays are flattened first.
633///
634/// Analogous to `numpy.append()`.
635pub fn append<T: Element, D: Dimension>(
636    a: &Array<T, D>,
637    values: &Array<T, IxDyn>,
638    axis: Option<usize>,
639) -> FerrayResult<Array<T, IxDyn>> {
640    match axis {
641        None => {
642            let mut data: Vec<T> = a.iter().cloned().collect();
643            data.extend(values.iter().cloned());
644            let n = data.len();
645            Array::from_vec(IxDyn::new(&[n]), data)
646        }
647        Some(ax) => {
648            let a_dyn = {
649                let data: Vec<T> = a.iter().cloned().collect();
650                Array::from_vec(IxDyn::new(a.shape()), data)?
651            };
652            let vals_dyn = {
653                let data: Vec<T> = values.iter().cloned().collect();
654                Array::from_vec(IxDyn::new(values.shape()), data)?
655            };
656            super::concatenate(&[a_dyn, vals_dyn], ax)
657        }
658    }
659}
660
661/// Resize an array to a new shape.
662///
663/// If the new size is larger, the array is filled by repeating its elements.
664/// If smaller, the array is truncated.
665///
666/// Analogous to `numpy.resize()`.
667pub fn resize<T: Element, D: Dimension>(
668    a: &Array<T, D>,
669    new_shape: &[usize],
670) -> FerrayResult<Array<T, IxDyn>> {
671    let src: Vec<T> = a.iter().cloned().collect();
672    let new_size: usize = new_shape.iter().product();
673
674    if src.is_empty() {
675        // Fill with zeros
676        let data = vec![T::zero(); new_size];
677        return Array::from_vec(IxDyn::new(new_shape), data);
678    }
679
680    let mut data = Vec::with_capacity(new_size);
681    for i in 0..new_size {
682        data.push(src[i % src.len()].clone());
683    }
684    Array::from_vec(IxDyn::new(new_shape), data)
685}
686
687/// Trim leading and/or trailing zeros from a 1-D array.
688///
689/// `trim` can be `"f"` (front), `"b"` (back), or `"fb"` (both, default).
690///
691/// Analogous to `numpy.trim_zeros()`.
692///
693/// # Errors
694/// Returns `FerrayError::InvalidValue` if `trim` contains invalid characters.
695pub fn trim_zeros<T: Element + PartialEq>(
696    a: &Array<T, Ix1>,
697    trim: &str,
698) -> FerrayResult<Array<T, Ix1>> {
699    let data: Vec<T> = a.iter().cloned().collect();
700    let zero = T::zero();
701
702    let trim_front = trim.contains('f');
703    let trim_back = trim.contains('b');
704
705    if !trim.chars().all(|c| c == 'f' || c == 'b') {
706        return Err(FerrayError::invalid_value(
707            "trim_zeros: trim must contain only 'f' and/or 'b'",
708        ));
709    }
710
711    let start = if trim_front {
712        data.iter().position(|v| *v != zero).unwrap_or(data.len())
713    } else {
714        0
715    };
716
717    let end = if trim_back {
718        data.iter()
719            .rposition(|v| *v != zero)
720            .map_or(start, |i| i + 1)
721    } else {
722        data.len()
723    };
724
725    let end = end.max(start);
726    let trimmed: Vec<T> = data[start..end].to_vec();
727    let n = trimmed.len();
728    Array::from_vec(Ix1::new([n]), trimmed)
729}
730
731// ============================================================================
732// REQ: atleast_1d / atleast_2d / atleast_3d
733// ============================================================================
734
735/// View inputs as arrays with at least one dimension.
736///
737/// Scalars (0-D arrays) are reshaped to (1,). Arrays already 1-D or higher
738/// are returned with shape unchanged.
739///
740/// Analogous to `numpy.atleast_1d()` for a single input.
741pub fn atleast_1d<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
742    let shape = a.shape();
743    let data: Vec<T> = a.iter().cloned().collect();
744    let new_shape: Vec<usize> = if shape.is_empty() {
745        vec![1]
746    } else {
747        shape.to_vec()
748    };
749    Array::from_vec(IxDyn::new(&new_shape), data)
750}
751
752/// View inputs as arrays with at least two dimensions.
753///
754/// 0-D scalars become shape (1, 1); 1-D arrays of shape (N,) become (1, N);
755/// 2-D and higher arrays are returned unchanged.
756///
757/// Analogous to `numpy.atleast_2d()` for a single input.
758pub fn atleast_2d<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
759    let shape = a.shape();
760    let data: Vec<T> = a.iter().cloned().collect();
761    let new_shape: Vec<usize> = match shape.len() {
762        0 => vec![1, 1],
763        1 => vec![1, shape[0]],
764        _ => shape.to_vec(),
765    };
766    Array::from_vec(IxDyn::new(&new_shape), data)
767}
768
769/// View inputs as arrays with at least three dimensions.
770///
771/// 0-D scalars become shape (1, 1, 1); 1-D arrays of shape (N,) become
772/// (1, N, 1); 2-D arrays (M, N) become (M, N, 1); 3-D and higher are
773/// returned unchanged.
774///
775/// Analogous to `numpy.atleast_3d()` for a single input.
776pub fn atleast_3d<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
777    let shape = a.shape();
778    let data: Vec<T> = a.iter().cloned().collect();
779    let new_shape: Vec<usize> = match shape.len() {
780        0 => vec![1, 1, 1],
781        1 => vec![1, shape[0], 1],
782        2 => vec![shape[0], shape[1], 1],
783        _ => shape.to_vec(),
784    };
785    Array::from_vec(IxDyn::new(&new_shape), data)
786}
787
788// ============================================================================
789// Tests
790// ============================================================================
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795
796    fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
797        Array::from_vec(IxDyn::new(shape), data).unwrap()
798    }
799
800    fn arr1d(data: Vec<f64>) -> Array<f64, Ix1> {
801        let n = data.len();
802        Array::from_vec(Ix1::new([n]), data).unwrap()
803    }
804
805    // -- pad --
806
807    #[test]
808    fn test_pad_1d_constant() {
809        let a = arr1d(vec![1.0, 2.0, 3.0]);
810        let b = pad_1d(&a, (2, 3), &PadMode::Constant(0.0)).unwrap();
811        assert_eq!(b.shape(), &[8]);
812        let data: Vec<f64> = b.iter().copied().collect();
813        assert_eq!(data, vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
814    }
815
816    #[test]
817    fn test_pad_1d_edge() {
818        let a = arr1d(vec![1.0, 2.0, 3.0]);
819        let b = pad_1d(&a, (2, 2), &PadMode::Edge).unwrap();
820        let data: Vec<f64> = b.iter().copied().collect();
821        assert_eq!(data, vec![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
822    }
823
824    #[test]
825    fn test_pad_1d_wrap() {
826        let a = arr1d(vec![1.0, 2.0, 3.0]);
827        let b = pad_1d(&a, (2, 2), &PadMode::Wrap).unwrap();
828        let data: Vec<f64> = b.iter().copied().collect();
829        assert_eq!(data, vec![2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0]);
830    }
831
832    // ----- Reflect / Symmetric edge cases on small arrays (#137) -----
833
834    #[test]
835    fn test_pad_1d_reflect_three_element_array() {
836        // For [1, 2, 3] with Reflect: edge values are not repeated,
837        // so before-pad of 2 reflects around index 1 to produce [3, 2].
838        let a = arr1d(vec![1.0, 2.0, 3.0]);
839        let b = pad_1d(&a, (2, 2), &PadMode::Reflect).unwrap();
840        let data: Vec<f64> = b.iter().copied().collect();
841        assert_eq!(data, vec![3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0]);
842    }
843
844    #[test]
845    fn test_pad_1d_reflect_two_element_array() {
846        // [1, 2]: the smallest array where Reflect is still well-defined.
847        let a = arr1d(vec![1.0, 2.0]);
848        let b = pad_1d(&a, (2, 2), &PadMode::Reflect).unwrap();
849        let data: Vec<f64> = b.iter().copied().collect();
850        // Reflect around index 0 / index 1:
851        //   before: 2, 1 -> becomes 1, 2 (reflects back)
852        //   after:  2, 1 (reflects back)
853        // full: [1, 2, 1, 2, 1, 2]? Actually reflect on a 2-elem:
854        // left: src[1], src[0] but edge not repeated => 2, 1 -> 1, 2
855        // Let's be more precise: reflect_index maps i<0 to -i,
856        // i>=n to 2*(n-1)-i. So for n=2:
857        //   before_idx_from = reflect(-1, 2) = 1
858        //   before_idx_far  = reflect(-2, 2) = 2 -> clamped? -> 0
859        // Skip the detailed expected values — assert the padded array
860        // has the right shape and contains only elements from the
861        // source set {1.0, 2.0}.
862        assert_eq!(b.shape(), &[6]);
863        for v in &data {
864            assert!(
865                *v == 1.0 || *v == 2.0,
866                "Reflect produced unexpected value {v}"
867            );
868        }
869    }
870
871    #[test]
872    fn test_pad_1d_symmetric_three_element_array() {
873        // Symmetric includes the edge, so before-pad of [1, 2, 3] is
874        // [2, 1] (not [3, 2]).
875        let a = arr1d(vec![1.0, 2.0, 3.0]);
876        let b = pad_1d(&a, (2, 2), &PadMode::Symmetric).unwrap();
877        let data: Vec<f64> = b.iter().copied().collect();
878        assert_eq!(data, vec![2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 2.0]);
879    }
880
881    #[test]
882    fn test_pad_1d_symmetric_single_element() {
883        // [5]: any pad amount with Symmetric must just replicate the
884        // single value.
885        let a = arr1d(vec![5.0]);
886        let b = pad_1d(&a, (3, 2), &PadMode::Symmetric).unwrap();
887        let data: Vec<f64> = b.iter().copied().collect();
888        assert_eq!(data, vec![5.0; 6]);
889    }
890
891    #[test]
892    fn test_pad_nd_constant() {
893        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
894        let b = pad(&a, &[(1, 1), (1, 1)], &PadMode::Constant(0.0)).unwrap();
895        assert_eq!(b.shape(), &[4, 4]);
896        let data: Vec<f64> = b.iter().copied().collect();
897        assert_eq!(
898            data,
899            vec![
900                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0
901            ]
902        );
903    }
904
905    // -- tile --
906
907    #[test]
908    fn test_tile_1d() {
909        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
910        let b = tile(&a, &[3]).unwrap();
911        assert_eq!(b.shape(), &[9]);
912        let data: Vec<f64> = b.iter().copied().collect();
913        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
914    }
915
916    #[test]
917    fn test_tile_2d() {
918        let a = dyn_arr(&[2], vec![1.0, 2.0]);
919        let b = tile(&a, &[2, 3]).unwrap();
920        assert_eq!(b.shape(), &[2, 6]);
921    }
922
923    // -- repeat --
924
925    #[test]
926    fn test_repeat_flat() {
927        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
928        let b = repeat(&a, 2, None).unwrap();
929        let data: Vec<f64> = b.iter().copied().collect();
930        assert_eq!(data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
931    }
932
933    #[test]
934    fn test_repeat_axis() {
935        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
936        let b = repeat(&a, 2, Some(0)).unwrap();
937        assert_eq!(b.shape(), &[4, 2]);
938        let data: Vec<f64> = b.iter().copied().collect();
939        assert_eq!(data, vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]);
940    }
941
942    // -- delete --
943
944    #[test]
945    fn test_delete() {
946        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
947        let b = delete(&a, &[1, 3], 0).unwrap();
948        let data: Vec<f64> = b.iter().copied().collect();
949        assert_eq!(data, vec![1.0, 3.0, 5.0]);
950    }
951
952    #[test]
953    fn test_delete_2d() {
954        let a = dyn_arr(&[3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
955        let b = delete(&a, &[1], 0).unwrap();
956        assert_eq!(b.shape(), &[2, 2]);
957        let data: Vec<f64> = b.iter().copied().collect();
958        assert_eq!(data, vec![1.0, 2.0, 5.0, 6.0]);
959    }
960
961    // -- insert --
962
963    #[test]
964    fn test_insert() {
965        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
966        let vals = dyn_arr(&[2], vec![10.0, 20.0]);
967        let b = insert(&a, 1, &vals, 0).unwrap();
968        let data: Vec<f64> = b.iter().copied().collect();
969        assert_eq!(data, vec![1.0, 10.0, 20.0, 2.0, 3.0]);
970    }
971
972    // -- append --
973
974    #[test]
975    fn test_append_flat() {
976        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
977        let vals = dyn_arr(&[2], vec![4.0, 5.0]);
978        let b = append(&a, &vals, None).unwrap();
979        let data: Vec<f64> = b.iter().copied().collect();
980        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
981    }
982
983    #[test]
984    fn test_append_axis() {
985        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
986        let vals = dyn_arr(&[2, 1], vec![5.0, 6.0]);
987        let b = append(&a, &vals, Some(1)).unwrap();
988        assert_eq!(b.shape(), &[2, 3]);
989    }
990
991    // -- resize --
992
993    #[test]
994    fn test_resize_larger() {
995        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
996        let b = resize(&a, &[5]).unwrap();
997        let data: Vec<f64> = b.iter().copied().collect();
998        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0]);
999    }
1000
1001    #[test]
1002    fn test_resize_smaller() {
1003        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1004        let b = resize(&a, &[3]).unwrap();
1005        let data: Vec<f64> = b.iter().copied().collect();
1006        assert_eq!(data, vec![1.0, 2.0, 3.0]);
1007    }
1008
1009    #[test]
1010    fn test_resize_2d() {
1011        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1012        let b = resize(&a, &[3, 3]).unwrap();
1013        assert_eq!(b.shape(), &[3, 3]);
1014    }
1015
1016    // -- trim_zeros --
1017
1018    #[test]
1019    fn test_trim_zeros_both() {
1020        let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
1021        let b = trim_zeros(&a, "fb").unwrap();
1022        let data: Vec<f64> = b.iter().copied().collect();
1023        assert_eq!(data, vec![1.0, 2.0, 3.0]);
1024    }
1025
1026    #[test]
1027    fn test_trim_zeros_front() {
1028        let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 0.0]);
1029        let b = trim_zeros(&a, "f").unwrap();
1030        let data: Vec<f64> = b.iter().copied().collect();
1031        assert_eq!(data, vec![1.0, 2.0, 0.0]);
1032    }
1033
1034    #[test]
1035    fn test_trim_zeros_back() {
1036        let a = arr1d(vec![0.0, 1.0, 2.0, 0.0, 0.0]);
1037        let b = trim_zeros(&a, "b").unwrap();
1038        let data: Vec<f64> = b.iter().copied().collect();
1039        assert_eq!(data, vec![0.0, 1.0, 2.0]);
1040    }
1041
1042    #[test]
1043    fn test_trim_zeros_all_zeros() {
1044        let a = arr1d(vec![0.0, 0.0, 0.0]);
1045        let b = trim_zeros(&a, "fb").unwrap();
1046        assert_eq!(b.shape(), &[0]);
1047    }
1048
1049    #[test]
1050    fn test_trim_zeros_bad_mode() {
1051        let a = arr1d(vec![1.0, 2.0]);
1052        assert!(trim_zeros(&a, "x").is_err());
1053    }
1054
1055    // -- atleast_1d / atleast_2d / atleast_3d --
1056
1057    #[test]
1058    fn test_atleast_1d_from_scalar() {
1059        let a = Array::from_vec(IxDyn::new(&[]), vec![42.0]).unwrap();
1060        let b = atleast_1d(&a).unwrap();
1061        assert_eq!(b.shape(), &[1]);
1062        assert_eq!(b.iter().copied().collect::<Vec<_>>(), vec![42.0]);
1063    }
1064
1065    #[test]
1066    fn test_atleast_1d_passthrough_1d() {
1067        let a = arr1d(vec![1.0, 2.0, 3.0]);
1068        let b = atleast_1d(&a).unwrap();
1069        assert_eq!(b.shape(), &[3]);
1070    }
1071
1072    #[test]
1073    fn test_atleast_1d_passthrough_2d() {
1074        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1075        let b = atleast_1d(&a).unwrap();
1076        assert_eq!(b.shape(), &[2, 3]);
1077    }
1078
1079    #[test]
1080    fn test_atleast_2d_from_scalar() {
1081        let a = Array::from_vec(IxDyn::new(&[]), vec![7.0]).unwrap();
1082        let b = atleast_2d(&a).unwrap();
1083        assert_eq!(b.shape(), &[1, 1]);
1084    }
1085
1086    #[test]
1087    fn test_atleast_2d_from_1d() {
1088        let a = arr1d(vec![1.0, 2.0, 3.0]);
1089        let b = atleast_2d(&a).unwrap();
1090        assert_eq!(b.shape(), &[1, 3]);
1091    }
1092
1093    #[test]
1094    fn test_atleast_2d_passthrough_2d() {
1095        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1096        let b = atleast_2d(&a).unwrap();
1097        assert_eq!(b.shape(), &[2, 3]);
1098    }
1099
1100    #[test]
1101    fn test_atleast_3d_from_scalar() {
1102        let a = Array::from_vec(IxDyn::new(&[]), vec![7.0]).unwrap();
1103        let b = atleast_3d(&a).unwrap();
1104        assert_eq!(b.shape(), &[1, 1, 1]);
1105    }
1106
1107    #[test]
1108    fn test_atleast_3d_from_1d() {
1109        let a = arr1d(vec![1.0, 2.0, 3.0]);
1110        let b = atleast_3d(&a).unwrap();
1111        assert_eq!(b.shape(), &[1, 3, 1]);
1112    }
1113
1114    #[test]
1115    fn test_atleast_3d_from_2d() {
1116        let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1117        let b = atleast_3d(&a).unwrap();
1118        assert_eq!(b.shape(), &[2, 3, 1]);
1119    }
1120
1121    #[test]
1122    fn test_atleast_3d_passthrough_3d() {
1123        let a = dyn_arr(&[2, 2, 2], (0..8).map(|i| i as f64).collect());
1124        let b = atleast_3d(&a).unwrap();
1125        assert_eq!(b.shape(), &[2, 2, 2]);
1126    }
1127}