Skip to main content

ferray_core/manipulation/
extended.rs

1// ferray-core: Extended manipulation functions (REQ-22a)
2//
3// pad, tile, repeat, delete, insert, append, resize, trim_zeros
4
5use crate::array::owned::Array;
6use crate::dimension::{Dimension, Ix1, IxDyn};
7use crate::dtype::Element;
8use crate::error::{FerrayError, FerrayResult};
9
10// ============================================================================
11// Pad modes
12// ============================================================================
13
14/// Padding mode for [`pad`].
15#[derive(Debug, Clone)]
16pub enum PadMode<T: Element> {
17    /// Pad with a constant value.
18    Constant(T),
19    /// Pad with the edge values of the array.
20    Edge,
21    /// Pad with the reflection of the array mirrored on the first and last
22    /// values (does not repeat the edge).
23    Reflect,
24    /// Pad with the reflection of the array mirrored along the edge.
25    Symmetric,
26    /// Pad by wrapping the array around.
27    Wrap,
28}
29
30/// Pad a 1-D array.
31///
32/// `pad_width` is `(before, after)` for the single axis.
33///
34/// Analogous to `numpy.pad()` for 1-D.
35///
36/// # Errors
37/// Returns `FerrayError::InvalidValue` if the array is empty and the mode
38/// requires elements (Edge, Reflect, Symmetric, Wrap).
39pub fn pad_1d<T: Element>(
40    a: &Array<T, Ix1>,
41    pad_width: (usize, usize),
42    mode: &PadMode<T>,
43) -> FerrayResult<Array<T, Ix1>> {
44    let n = a.shape()[0];
45    let (before, after) = pad_width;
46    let new_len = before + n + after;
47    let src: Vec<T> = a.iter().cloned().collect();
48
49    if n == 0 && !matches!(mode, PadMode::Constant(_)) {
50        return Err(FerrayError::invalid_value(
51            "pad: cannot use Edge/Reflect/Symmetric/Wrap mode on empty array",
52        ));
53    }
54
55    let mut data = Vec::with_capacity(new_len);
56
57    // Fill 'before' padding
58    for i in 0..before {
59        let val = match mode {
60            PadMode::Constant(c) => c.clone(),
61            PadMode::Edge => src[0].clone(),
62            PadMode::Reflect => {
63                // Reflect: index = before - 1 - i mapped into [1..n-1] reflected
64                let idx = reflect_index(before as isize - 1 - i as isize, n);
65                src[idx].clone()
66            }
67            PadMode::Symmetric => {
68                // Symmetric: similar but includes edge
69                let idx = symmetric_index(before as isize - 1 - i as isize, n);
70                src[idx].clone()
71            }
72            PadMode::Wrap => {
73                let idx = ((n as isize - (before as isize - i as isize) % n as isize) % n as isize)
74                    as usize;
75                src[idx].clone()
76            }
77        };
78        data.push(val);
79    }
80
81    // Copy original data
82    data.extend_from_slice(&src);
83
84    // Fill 'after' padding
85    for i in 0..after {
86        let val = match mode {
87            PadMode::Constant(c) => c.clone(),
88            PadMode::Edge => src[n - 1].clone(),
89            PadMode::Reflect => {
90                let idx = reflect_index(n as isize + i as isize, n);
91                src[idx].clone()
92            }
93            PadMode::Symmetric => {
94                let idx = symmetric_index(n as isize + i as isize, n);
95                src[idx].clone()
96            }
97            PadMode::Wrap => {
98                let idx = i % n;
99                src[idx].clone()
100            }
101        };
102        data.push(val);
103    }
104
105    Array::from_vec(Ix1::new([new_len]), data)
106}
107
108/// Reflect index: maps indices outside [0, n) by reflecting at boundaries 1 and n-2.
109/// This means the edge values are not repeated.
110fn reflect_index(idx: isize, n: usize) -> usize {
111    if n <= 1 {
112        return 0;
113    }
114    let period = (n - 1) as isize * 2;
115    let mut i = idx % period;
116    if i < 0 {
117        i += period;
118    }
119    if i >= n as isize {
120        i = period - i;
121    }
122    i as usize
123}
124
125/// Symmetric index: maps indices outside [0, n) by reflecting at boundaries 0 and n-1.
126/// The edge values are repeated.
127fn symmetric_index(idx: isize, n: usize) -> usize {
128    if n == 0 {
129        return 0;
130    }
131    if n == 1 {
132        return 0;
133    }
134    let period = n as isize * 2;
135    let mut i = idx % period;
136    if i < 0 {
137        i += period;
138    }
139    if i >= n as isize {
140        i = period - 1 - i;
141    }
142    i.max(0) as usize
143}
144
145/// Pad an N-D array.
146///
147/// `pad_width` is a slice of `(before, after)` pairs, one per axis.
148/// If it has fewer entries than the array's ndim, the last entry is repeated.
149///
150/// Analogous to `numpy.pad()`.
151///
152/// # Errors
153/// Returns `FerrayError::InvalidValue` if `pad_width` is empty.
154pub fn pad<T: Element, D: Dimension>(
155    a: &Array<T, D>,
156    pad_width: &[(usize, usize)],
157    mode: &PadMode<T>,
158) -> FerrayResult<Array<T, IxDyn>> {
159    if pad_width.is_empty() {
160        return Err(FerrayError::invalid_value("pad: pad_width cannot be empty"));
161    }
162
163    let shape = a.shape();
164    let ndim = shape.len();
165
166    // Expand pad_width to ndim entries
167    let pads: Vec<(usize, usize)> = (0..ndim)
168        .map(|i| {
169            if i < pad_width.len() {
170                pad_width[i]
171            } else {
172                // pad_width is non-empty (checked above), so last() is always Some
173                *pad_width.last().unwrap_or_else(|| unreachable!())
174            }
175        })
176        .collect();
177
178    // For multi-dimensional, we pad axis-by-axis starting from the last axis.
179    // Convert to IxDyn first.
180    let mut current_data: Vec<T> = a.iter().cloned().collect();
181    let mut current_shape: Vec<usize> = shape.to_vec();
182
183    for ax in (0..ndim).rev() {
184        let (before, after) = pads[ax];
185        if before == 0 && after == 0 {
186            continue;
187        }
188        let axis_len = current_shape[ax];
189        let new_axis_len = before + axis_len + after;
190
191        // Compute strides
192        let outer: usize = current_shape[..ax].iter().product();
193        let inner: usize = current_shape[ax + 1..].iter().product();
194
195        let new_total = outer * new_axis_len * inner;
196        let mut new_data = Vec::with_capacity(new_total);
197
198        for o in 0..outer {
199            for j in 0..new_axis_len {
200                for k in 0..inner {
201                    let val = if j < before {
202                        // Before padding
203                        match mode {
204                            PadMode::Constant(c) => c.clone(),
205                            PadMode::Edge => {
206                                let src_j = 0;
207                                current_data[o * axis_len * inner + src_j * inner + k].clone()
208                            }
209                            PadMode::Reflect => {
210                                let src_j =
211                                    reflect_index(before as isize - 1 - j as isize, axis_len);
212                                current_data[o * axis_len * inner + src_j * inner + k].clone()
213                            }
214                            PadMode::Symmetric => {
215                                let src_j =
216                                    symmetric_index(before as isize - 1 - j as isize, axis_len);
217                                current_data[o * axis_len * inner + src_j * inner + k].clone()
218                            }
219                            PadMode::Wrap => {
220                                let src_j = ((axis_len as isize
221                                    - (before as isize - j as isize) % axis_len as isize)
222                                    % axis_len as isize)
223                                    as usize;
224                                current_data[o * axis_len * inner + src_j * inner + k].clone()
225                            }
226                        }
227                    } else if j < before + axis_len {
228                        // Original data
229                        let src_j = j - before;
230                        current_data[o * axis_len * inner + src_j * inner + k].clone()
231                    } else {
232                        // After padding
233                        let after_idx = j - before - axis_len;
234                        match mode {
235                            PadMode::Constant(c) => c.clone(),
236                            PadMode::Edge => {
237                                let src_j = axis_len - 1;
238                                current_data[o * axis_len * inner + src_j * inner + k].clone()
239                            }
240                            PadMode::Reflect => {
241                                let src_j = reflect_index(
242                                    (axis_len as isize) + after_idx as isize,
243                                    axis_len,
244                                );
245                                current_data[o * axis_len * inner + src_j * inner + k].clone()
246                            }
247                            PadMode::Symmetric => {
248                                let src_j = symmetric_index(
249                                    (axis_len as isize) + after_idx as isize,
250                                    axis_len,
251                                );
252                                current_data[o * axis_len * inner + src_j * inner + k].clone()
253                            }
254                            PadMode::Wrap => {
255                                let src_j = after_idx % axis_len;
256                                current_data[o * axis_len * inner + src_j * inner + k].clone()
257                            }
258                        }
259                    };
260                    new_data.push(val);
261                }
262            }
263        }
264
265        current_data = new_data;
266        current_shape[ax] = new_axis_len;
267    }
268
269    Array::from_vec(IxDyn::new(&current_shape), current_data)
270}
271
272/// Construct an array by repeating `a` the number of times given by `reps`.
273///
274/// If `reps` has fewer entries than `a.ndim()`, it is prepended with 1s.
275/// If `reps` has more entries, `a`'s shape is prepended with 1s.
276///
277/// Analogous to `numpy.tile()`.
278///
279/// # Errors
280/// Returns `FerrayError::InvalidValue` if `reps` is empty.
281pub fn tile<T: Element, D: Dimension>(
282    a: &Array<T, D>,
283    reps: &[usize],
284) -> FerrayResult<Array<T, IxDyn>> {
285    if reps.is_empty() {
286        return Err(FerrayError::invalid_value("tile: reps cannot be empty"));
287    }
288
289    let src_shape = a.shape();
290    let src_ndim = src_shape.len();
291    let reps_ndim = reps.len();
292    let out_ndim = src_ndim.max(reps_ndim);
293
294    // Pad shapes to out_ndim
295    let mut padded_shape = vec![1usize; out_ndim];
296    for i in 0..src_ndim {
297        padded_shape[out_ndim - src_ndim + i] = src_shape[i];
298    }
299    let mut padded_reps = vec![1usize; out_ndim];
300    for i in 0..reps_ndim {
301        padded_reps[out_ndim - reps_ndim + i] = reps[i];
302    }
303
304    let out_shape: Vec<usize> = padded_shape
305        .iter()
306        .zip(padded_reps.iter())
307        .map(|(&s, &r)| s * r)
308        .collect();
309    let total: usize = out_shape.iter().product();
310
311    let src_data: Vec<T> = a.iter().cloned().collect();
312    let mut data = Vec::with_capacity(total);
313
314    // Compute strides for output and padded source
315    let mut out_strides = vec![1usize; out_ndim];
316    for i in (0..out_ndim.saturating_sub(1)).rev() {
317        out_strides[i] = out_strides[i + 1] * out_shape[i + 1];
318    }
319
320    let mut src_strides = vec![1usize; out_ndim];
321    for i in (0..out_ndim.saturating_sub(1)).rev() {
322        src_strides[i] = src_strides[i + 1] * padded_shape[i + 1];
323    }
324
325    for flat in 0..total {
326        let mut rem = flat;
327        let mut src_flat = 0usize;
328        for i in 0..out_ndim {
329            let idx = rem / out_strides[i];
330            rem %= out_strides[i];
331            let src_idx = idx % padded_shape[i];
332            src_flat += src_idx * src_strides[i];
333        }
334        // Map src_flat to the original (non-padded) source
335        // Since we padded with 1s, the strides handle it.
336        if src_flat < src_data.len() {
337            data.push(src_data[src_flat].clone());
338        } else {
339            // This shouldn't happen if the math is right
340            data.push(T::zero());
341        }
342    }
343
344    Array::from_vec(IxDyn::new(&out_shape), data)
345}
346
347/// Repeat elements of an array.
348///
349/// If `axis` is `None`, the array is flattened first, then each element
350/// is repeated `repeats` times.
351///
352/// Analogous to `numpy.repeat()`.
353///
354/// # Errors
355/// Returns `FerrayError::AxisOutOfBounds` if the axis is out of bounds.
356pub fn repeat<T: Element, D: Dimension>(
357    a: &Array<T, D>,
358    repeats: usize,
359    axis: Option<usize>,
360) -> FerrayResult<Array<T, IxDyn>> {
361    match axis {
362        None => {
363            // Flatten and repeat each element
364            let src: Vec<T> = a.iter().cloned().collect();
365            let mut data = Vec::with_capacity(src.len() * repeats);
366            for val in &src {
367                for _ in 0..repeats {
368                    data.push(val.clone());
369                }
370            }
371            let n = data.len();
372            Array::from_vec(IxDyn::new(&[n]), data)
373        }
374        Some(ax) => {
375            let shape = a.shape();
376            let ndim = shape.len();
377            if ax >= ndim {
378                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
379            }
380
381            let mut new_shape = shape.to_vec();
382            new_shape[ax] *= repeats;
383            let total: usize = new_shape.iter().product();
384            let src_data: Vec<T> = a.iter().cloned().collect();
385
386            // Compute source strides (C-order)
387            let mut src_strides = vec![1usize; ndim];
388            for i in (0..ndim.saturating_sub(1)).rev() {
389                src_strides[i] = src_strides[i + 1] * shape[i + 1];
390            }
391
392            // Compute output strides
393            let mut out_strides = vec![1usize; ndim];
394            for i in (0..ndim.saturating_sub(1)).rev() {
395                out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
396            }
397
398            let mut data = Vec::with_capacity(total);
399            for flat in 0..total {
400                let mut rem = flat;
401                let mut src_flat = 0usize;
402                for i in 0..ndim {
403                    let idx = rem / out_strides[i];
404                    rem %= out_strides[i];
405                    let src_idx = if i == ax { idx / repeats } else { idx };
406                    src_flat += src_idx * src_strides[i];
407                }
408                data.push(src_data[src_flat].clone());
409            }
410
411            Array::from_vec(IxDyn::new(&new_shape), data)
412        }
413    }
414}
415
416/// Delete sub-arrays along an axis.
417///
418/// `indices` specifies which indices along `axis` to remove.
419///
420/// Analogous to `numpy.delete()`.
421///
422/// # Errors
423/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
424/// Returns `FerrayError::IndexOutOfBounds` if any index is out of range.
425pub fn delete<T: Element, D: Dimension>(
426    a: &Array<T, D>,
427    indices: &[usize],
428    axis: usize,
429) -> FerrayResult<Array<T, IxDyn>> {
430    let shape = a.shape();
431    let ndim = shape.len();
432    if axis >= ndim {
433        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
434    }
435    let axis_len = shape[axis];
436
437    // Validate indices
438    for &idx in indices {
439        if idx >= axis_len {
440            return Err(FerrayError::IndexOutOfBounds {
441                index: idx as isize,
442                axis,
443                size: axis_len,
444            });
445        }
446    }
447
448    let to_remove: std::collections::HashSet<usize> = indices.iter().copied().collect();
449    let kept: Vec<usize> = (0..axis_len).filter(|i| !to_remove.contains(i)).collect();
450    let new_axis_len = kept.len();
451
452    let mut new_shape = shape.to_vec();
453    new_shape[axis] = new_axis_len;
454    let total: usize = new_shape.iter().product();
455    let src_data: Vec<T> = a.iter().cloned().collect();
456
457    // Compute source strides (C-order)
458    let mut src_strides = vec![1usize; ndim];
459    for i in (0..ndim.saturating_sub(1)).rev() {
460        src_strides[i] = src_strides[i + 1] * shape[i + 1];
461    }
462
463    // Compute output strides
464    let mut out_strides = vec![1usize; ndim];
465    for i in (0..ndim.saturating_sub(1)).rev() {
466        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
467    }
468
469    let mut data = Vec::with_capacity(total);
470    for flat in 0..total {
471        let mut rem = flat;
472        let mut src_flat = 0usize;
473        for i in 0..ndim {
474            let idx = rem / out_strides[i];
475            rem %= out_strides[i];
476            let src_idx = if i == axis { kept[idx] } else { idx };
477            src_flat += src_idx * src_strides[i];
478        }
479        data.push(src_data[src_flat].clone());
480    }
481
482    Array::from_vec(IxDyn::new(&new_shape), data)
483}
484
485/// Insert values along an axis before a given index.
486///
487/// `index` is the position before which to insert. `values` is a 1-D array
488/// of values to insert (its length determines how many slices are added).
489///
490/// Analogous to `numpy.insert()`.
491///
492/// # Errors
493/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
494/// Returns `FerrayError::IndexOutOfBounds` if `index > axis_len`.
495pub fn insert<T: Element, D: Dimension>(
496    a: &Array<T, D>,
497    index: usize,
498    values: &Array<T, IxDyn>,
499    axis: usize,
500) -> FerrayResult<Array<T, IxDyn>> {
501    let shape = a.shape();
502    let ndim = shape.len();
503    if axis >= ndim {
504        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
505    }
506    let axis_len = shape[axis];
507    if index > axis_len {
508        return Err(FerrayError::IndexOutOfBounds {
509            index: index as isize,
510            axis,
511            size: axis_len + 1,
512        });
513    }
514
515    let n_insert = values.size();
516    let vals: Vec<T> = values.iter().cloned().collect();
517
518    let mut new_shape = shape.to_vec();
519    new_shape[axis] = axis_len + n_insert;
520    let total: usize = new_shape.iter().product();
521    let src_data: Vec<T> = a.iter().cloned().collect();
522
523    // Compute strides
524    let mut src_strides = vec![1usize; ndim];
525    for i in (0..ndim.saturating_sub(1)).rev() {
526        src_strides[i] = src_strides[i + 1] * shape[i + 1];
527    }
528
529    let mut out_strides = vec![1usize; ndim];
530    for i in (0..ndim.saturating_sub(1)).rev() {
531        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
532    }
533
534    // Compute the "inner" size (product of dims after axis)
535    let inner: usize = shape[axis + 1..].iter().product();
536
537    let mut data = Vec::with_capacity(total);
538    for flat in 0..total {
539        let mut rem = flat;
540        let mut nd_idx = vec![0usize; ndim];
541        for i in 0..ndim {
542            nd_idx[i] = rem / out_strides[i];
543            rem %= out_strides[i];
544        }
545
546        let ax_idx = nd_idx[axis];
547        if ax_idx >= index && ax_idx < index + n_insert {
548            // This is an inserted value
549            let insert_idx = ax_idx - index;
550            // For multi-D insert, we tile the values along the inner dims
551            let val_idx = (insert_idx * inner + nd_idx.get(axis + 1).copied().unwrap_or(0))
552                % vals.len().max(1);
553            data.push(vals[val_idx].clone());
554        } else {
555            // Original data
556            let src_ax_idx = if ax_idx >= index + n_insert {
557                ax_idx - n_insert
558            } else {
559                ax_idx
560            };
561            let mut src_flat = 0usize;
562            for i in 0..ndim {
563                let idx = if i == axis { src_ax_idx } else { nd_idx[i] };
564                src_flat += idx * src_strides[i];
565            }
566            data.push(src_data[src_flat].clone());
567        }
568    }
569
570    Array::from_vec(IxDyn::new(&new_shape), data)
571}
572
573/// Append values to the end of an array along an axis.
574///
575/// If `axis` is `None`, both arrays are flattened first.
576///
577/// Analogous to `numpy.append()`.
578pub fn append<T: Element, D: Dimension>(
579    a: &Array<T, D>,
580    values: &Array<T, IxDyn>,
581    axis: Option<usize>,
582) -> FerrayResult<Array<T, IxDyn>> {
583    match axis {
584        None => {
585            let mut data: Vec<T> = a.iter().cloned().collect();
586            data.extend(values.iter().cloned());
587            let n = data.len();
588            Array::from_vec(IxDyn::new(&[n]), data)
589        }
590        Some(ax) => {
591            let a_dyn = {
592                let data: Vec<T> = a.iter().cloned().collect();
593                Array::from_vec(IxDyn::new(a.shape()), data)?
594            };
595            let vals_dyn = {
596                let data: Vec<T> = values.iter().cloned().collect();
597                Array::from_vec(IxDyn::new(values.shape()), data)?
598            };
599            super::concatenate(&[a_dyn, vals_dyn], ax)
600        }
601    }
602}
603
604/// Resize an array to a new shape.
605///
606/// If the new size is larger, the array is filled by repeating its elements.
607/// If smaller, the array is truncated.
608///
609/// Analogous to `numpy.resize()`.
610pub fn resize<T: Element, D: Dimension>(
611    a: &Array<T, D>,
612    new_shape: &[usize],
613) -> FerrayResult<Array<T, IxDyn>> {
614    let src: Vec<T> = a.iter().cloned().collect();
615    let new_size: usize = new_shape.iter().product();
616
617    if src.is_empty() {
618        // Fill with zeros
619        let data = vec![T::zero(); new_size];
620        return Array::from_vec(IxDyn::new(new_shape), data);
621    }
622
623    let mut data = Vec::with_capacity(new_size);
624    for i in 0..new_size {
625        data.push(src[i % src.len()].clone());
626    }
627    Array::from_vec(IxDyn::new(new_shape), data)
628}
629
630/// Trim leading and/or trailing zeros from a 1-D array.
631///
632/// `trim` can be `"f"` (front), `"b"` (back), or `"fb"` (both, default).
633///
634/// Analogous to `numpy.trim_zeros()`.
635///
636/// # Errors
637/// Returns `FerrayError::InvalidValue` if `trim` contains invalid characters.
638pub fn trim_zeros<T: Element + PartialEq>(
639    a: &Array<T, Ix1>,
640    trim: &str,
641) -> FerrayResult<Array<T, Ix1>> {
642    let data: Vec<T> = a.iter().cloned().collect();
643    let zero = T::zero();
644
645    let trim_front = trim.contains('f');
646    let trim_back = trim.contains('b');
647
648    if !trim.chars().all(|c| c == 'f' || c == 'b') {
649        return Err(FerrayError::invalid_value(
650            "trim_zeros: trim must contain only 'f' and/or 'b'",
651        ));
652    }
653
654    let start = if trim_front {
655        data.iter().position(|v| *v != zero).unwrap_or(data.len())
656    } else {
657        0
658    };
659
660    let end = if trim_back {
661        data.iter()
662            .rposition(|v| *v != zero)
663            .map(|i| i + 1)
664            .unwrap_or(start)
665    } else {
666        data.len()
667    };
668
669    let end = end.max(start);
670    let trimmed: Vec<T> = data[start..end].to_vec();
671    let n = trimmed.len();
672    Array::from_vec(Ix1::new([n]), trimmed)
673}
674
675// ============================================================================
676// Tests
677// ============================================================================
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682
683    fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
684        Array::from_vec(IxDyn::new(shape), data).unwrap()
685    }
686
687    fn arr1d(data: Vec<f64>) -> Array<f64, Ix1> {
688        let n = data.len();
689        Array::from_vec(Ix1::new([n]), data).unwrap()
690    }
691
692    // -- pad --
693
694    #[test]
695    fn test_pad_1d_constant() {
696        let a = arr1d(vec![1.0, 2.0, 3.0]);
697        let b = pad_1d(&a, (2, 3), &PadMode::Constant(0.0)).unwrap();
698        assert_eq!(b.shape(), &[8]);
699        let data: Vec<f64> = b.iter().copied().collect();
700        assert_eq!(data, vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
701    }
702
703    #[test]
704    fn test_pad_1d_edge() {
705        let a = arr1d(vec![1.0, 2.0, 3.0]);
706        let b = pad_1d(&a, (2, 2), &PadMode::Edge).unwrap();
707        let data: Vec<f64> = b.iter().copied().collect();
708        assert_eq!(data, vec![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
709    }
710
711    #[test]
712    fn test_pad_1d_wrap() {
713        let a = arr1d(vec![1.0, 2.0, 3.0]);
714        let b = pad_1d(&a, (2, 2), &PadMode::Wrap).unwrap();
715        let data: Vec<f64> = b.iter().copied().collect();
716        assert_eq!(data, vec![2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0]);
717    }
718
719    #[test]
720    fn test_pad_nd_constant() {
721        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
722        let b = pad(&a, &[(1, 1), (1, 1)], &PadMode::Constant(0.0)).unwrap();
723        assert_eq!(b.shape(), &[4, 4]);
724        let data: Vec<f64> = b.iter().copied().collect();
725        assert_eq!(
726            data,
727            vec![
728                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0
729            ]
730        );
731    }
732
733    // -- tile --
734
735    #[test]
736    fn test_tile_1d() {
737        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
738        let b = tile(&a, &[3]).unwrap();
739        assert_eq!(b.shape(), &[9]);
740        let data: Vec<f64> = b.iter().copied().collect();
741        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
742    }
743
744    #[test]
745    fn test_tile_2d() {
746        let a = dyn_arr(&[2], vec![1.0, 2.0]);
747        let b = tile(&a, &[2, 3]).unwrap();
748        assert_eq!(b.shape(), &[2, 6]);
749    }
750
751    // -- repeat --
752
753    #[test]
754    fn test_repeat_flat() {
755        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
756        let b = repeat(&a, 2, None).unwrap();
757        let data: Vec<f64> = b.iter().copied().collect();
758        assert_eq!(data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
759    }
760
761    #[test]
762    fn test_repeat_axis() {
763        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
764        let b = repeat(&a, 2, Some(0)).unwrap();
765        assert_eq!(b.shape(), &[4, 2]);
766        let data: Vec<f64> = b.iter().copied().collect();
767        assert_eq!(data, vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]);
768    }
769
770    // -- delete --
771
772    #[test]
773    fn test_delete() {
774        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
775        let b = delete(&a, &[1, 3], 0).unwrap();
776        let data: Vec<f64> = b.iter().copied().collect();
777        assert_eq!(data, vec![1.0, 3.0, 5.0]);
778    }
779
780    #[test]
781    fn test_delete_2d() {
782        let a = dyn_arr(&[3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
783        let b = delete(&a, &[1], 0).unwrap();
784        assert_eq!(b.shape(), &[2, 2]);
785        let data: Vec<f64> = b.iter().copied().collect();
786        assert_eq!(data, vec![1.0, 2.0, 5.0, 6.0]);
787    }
788
789    // -- insert --
790
791    #[test]
792    fn test_insert() {
793        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
794        let vals = dyn_arr(&[2], vec![10.0, 20.0]);
795        let b = insert(&a, 1, &vals, 0).unwrap();
796        let data: Vec<f64> = b.iter().copied().collect();
797        assert_eq!(data, vec![1.0, 10.0, 20.0, 2.0, 3.0]);
798    }
799
800    // -- append --
801
802    #[test]
803    fn test_append_flat() {
804        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
805        let vals = dyn_arr(&[2], vec![4.0, 5.0]);
806        let b = append(&a, &vals, None).unwrap();
807        let data: Vec<f64> = b.iter().copied().collect();
808        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
809    }
810
811    #[test]
812    fn test_append_axis() {
813        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
814        let vals = dyn_arr(&[2, 1], vec![5.0, 6.0]);
815        let b = append(&a, &vals, Some(1)).unwrap();
816        assert_eq!(b.shape(), &[2, 3]);
817    }
818
819    // -- resize --
820
821    #[test]
822    fn test_resize_larger() {
823        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
824        let b = resize(&a, &[5]).unwrap();
825        let data: Vec<f64> = b.iter().copied().collect();
826        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0]);
827    }
828
829    #[test]
830    fn test_resize_smaller() {
831        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
832        let b = resize(&a, &[3]).unwrap();
833        let data: Vec<f64> = b.iter().copied().collect();
834        assert_eq!(data, vec![1.0, 2.0, 3.0]);
835    }
836
837    #[test]
838    fn test_resize_2d() {
839        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
840        let b = resize(&a, &[3, 3]).unwrap();
841        assert_eq!(b.shape(), &[3, 3]);
842    }
843
844    // -- trim_zeros --
845
846    #[test]
847    fn test_trim_zeros_both() {
848        let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
849        let b = trim_zeros(&a, "fb").unwrap();
850        let data: Vec<f64> = b.iter().copied().collect();
851        assert_eq!(data, vec![1.0, 2.0, 3.0]);
852    }
853
854    #[test]
855    fn test_trim_zeros_front() {
856        let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 0.0]);
857        let b = trim_zeros(&a, "f").unwrap();
858        let data: Vec<f64> = b.iter().copied().collect();
859        assert_eq!(data, vec![1.0, 2.0, 0.0]);
860    }
861
862    #[test]
863    fn test_trim_zeros_back() {
864        let a = arr1d(vec![0.0, 1.0, 2.0, 0.0, 0.0]);
865        let b = trim_zeros(&a, "b").unwrap();
866        let data: Vec<f64> = b.iter().copied().collect();
867        assert_eq!(data, vec![0.0, 1.0, 2.0]);
868    }
869
870    #[test]
871    fn test_trim_zeros_all_zeros() {
872        let a = arr1d(vec![0.0, 0.0, 0.0]);
873        let b = trim_zeros(&a, "fb").unwrap();
874        assert_eq!(b.shape(), &[0]);
875    }
876
877    #[test]
878    fn test_trim_zeros_bad_mode() {
879        let a = arr1d(vec![1.0, 2.0]);
880        assert!(trim_zeros(&a, "x").is_err());
881    }
882}