Skip to main content

ferray_core/manipulation/
extended.rs

1// ferray-core: Extended manipulation functions (REQ-22a)
2//
3// pad, tile, repeat, delete, insert, append, resize, trim_zeros
4
5use crate::array::owned::Array;
6use crate::dimension::{Dimension, Ix1, IxDyn};
7use crate::dtype::Element;
8use crate::error::{FerrayError, FerrayResult};
9
10// ============================================================================
11// Pad modes
12// ============================================================================
13
14/// Padding mode for [`pad`].
15#[derive(Debug, Clone)]
16pub enum PadMode<T: Element> {
17    /// Pad with a constant value.
18    Constant(T),
19    /// Pad with the edge values of the array.
20    Edge,
21    /// Pad with the reflection of the array mirrored on the first and last
22    /// values (does not repeat the edge).
23    Reflect,
24    /// Pad with the reflection of the array mirrored along the edge.
25    Symmetric,
26    /// Pad by wrapping the array around.
27    Wrap,
28}
29
30/// Pad a 1-D array.
31///
32/// `pad_width` is `(before, after)` for the single axis.
33///
34/// Analogous to `numpy.pad()` for 1-D.
35///
36/// # Errors
37/// Returns `FerrayError::InvalidValue` if the array is empty and the mode
38/// requires elements (Edge, Reflect, Symmetric, Wrap).
39pub fn pad_1d<T: Element>(
40    a: &Array<T, Ix1>,
41    pad_width: (usize, usize),
42    mode: &PadMode<T>,
43) -> FerrayResult<Array<T, Ix1>> {
44    let n = a.shape()[0];
45    let (before, after) = pad_width;
46    let new_len = before + n + after;
47    let src: Vec<T> = a.iter().cloned().collect();
48
49    if n == 0 && !matches!(mode, PadMode::Constant(_)) {
50        return Err(FerrayError::invalid_value(
51            "pad: cannot use Edge/Reflect/Symmetric/Wrap mode on empty array",
52        ));
53    }
54
55    let mut data = Vec::with_capacity(new_len);
56
57    // Fill 'before' padding
58    for i in 0..before {
59        let val = match mode {
60            PadMode::Constant(c) => c.clone(),
61            PadMode::Edge => src[0].clone(),
62            PadMode::Reflect => {
63                // Reflect: index = before - 1 - i mapped into [1..n-1] reflected
64                let idx = reflect_index(before as isize - 1 - i as isize, n);
65                src[idx].clone()
66            }
67            PadMode::Symmetric => {
68                // Symmetric: similar but includes edge
69                let idx = symmetric_index(before as isize - 1 - i as isize, n);
70                src[idx].clone()
71            }
72            PadMode::Wrap => {
73                let idx = ((n as isize - (before as isize - i as isize) % n as isize) % n as isize)
74                    as usize;
75                src[idx].clone()
76            }
77        };
78        data.push(val);
79    }
80
81    // Copy original data
82    data.extend_from_slice(&src);
83
84    // Fill 'after' padding
85    for i in 0..after {
86        let val = match mode {
87            PadMode::Constant(c) => c.clone(),
88            PadMode::Edge => src[n - 1].clone(),
89            PadMode::Reflect => {
90                let idx = reflect_index(n as isize + i as isize, n);
91                src[idx].clone()
92            }
93            PadMode::Symmetric => {
94                let idx = symmetric_index(n as isize + i as isize, n);
95                src[idx].clone()
96            }
97            PadMode::Wrap => {
98                let idx = i % n;
99                src[idx].clone()
100            }
101        };
102        data.push(val);
103    }
104
105    Array::from_vec(Ix1::new([new_len]), data)
106}
107
108/// Reflect index: maps indices outside [0, n) by reflecting at boundaries 1 and n-2.
109/// This means the edge values are not repeated.
110fn reflect_index(idx: isize, n: usize) -> usize {
111    if n <= 1 {
112        return 0;
113    }
114    let period = (n - 1) as isize * 2;
115    let mut i = idx % period;
116    if i < 0 {
117        i += period;
118    }
119    if i >= n as isize {
120        i = period - i;
121    }
122    i as usize
123}
124
125/// Symmetric index: maps indices outside [0, n) by reflecting at boundaries 0 and n-1.
126/// The edge values are repeated.
127fn symmetric_index(idx: isize, n: usize) -> usize {
128    if n == 0 {
129        return 0;
130    }
131    if n == 1 {
132        return 0;
133    }
134    let period = n as isize * 2;
135    let mut i = idx % period;
136    if i < 0 {
137        i += period;
138    }
139    if i >= n as isize {
140        i = period - 1 - i;
141    }
142    i.max(0) as usize
143}
144
145/// Pad an N-D array.
146///
147/// `pad_width` is a slice of `(before, after)` pairs, one per axis.
148/// If it has fewer entries than the array's ndim, the last entry is repeated.
149///
150/// Analogous to `numpy.pad()`.
151///
152/// # Errors
153/// Returns `FerrayError::InvalidValue` if `pad_width` is empty.
154pub fn pad<T: Element, D: Dimension>(
155    a: &Array<T, D>,
156    pad_width: &[(usize, usize)],
157    mode: &PadMode<T>,
158) -> FerrayResult<Array<T, IxDyn>> {
159    if pad_width.is_empty() {
160        return Err(FerrayError::invalid_value("pad: pad_width cannot be empty"));
161    }
162
163    let shape = a.shape();
164    let ndim = shape.len();
165
166    // Expand pad_width to ndim entries
167    let pads: Vec<(usize, usize)> = (0..ndim)
168        .map(|i| {
169            if i < pad_width.len() {
170                pad_width[i]
171            } else {
172                *pad_width.last().unwrap()
173            }
174        })
175        .collect();
176
177    // For multi-dimensional, we pad axis-by-axis starting from the last axis.
178    // Convert to IxDyn first.
179    let mut current_data: Vec<T> = a.iter().cloned().collect();
180    let mut current_shape: Vec<usize> = shape.to_vec();
181
182    for ax in (0..ndim).rev() {
183        let (before, after) = pads[ax];
184        if before == 0 && after == 0 {
185            continue;
186        }
187        let axis_len = current_shape[ax];
188        let new_axis_len = before + axis_len + after;
189
190        // Compute strides
191        let outer: usize = current_shape[..ax].iter().product();
192        let inner: usize = current_shape[ax + 1..].iter().product();
193
194        let new_total = outer * new_axis_len * inner;
195        let mut new_data = Vec::with_capacity(new_total);
196
197        for o in 0..outer {
198            for j in 0..new_axis_len {
199                for k in 0..inner {
200                    let val = if j < before {
201                        // Before padding
202                        match mode {
203                            PadMode::Constant(c) => c.clone(),
204                            PadMode::Edge => {
205                                let src_j = 0;
206                                current_data[o * axis_len * inner + src_j * inner + k].clone()
207                            }
208                            PadMode::Reflect => {
209                                let src_j =
210                                    reflect_index(before as isize - 1 - j as isize, axis_len);
211                                current_data[o * axis_len * inner + src_j * inner + k].clone()
212                            }
213                            PadMode::Symmetric => {
214                                let src_j =
215                                    symmetric_index(before as isize - 1 - j as isize, axis_len);
216                                current_data[o * axis_len * inner + src_j * inner + k].clone()
217                            }
218                            PadMode::Wrap => {
219                                let src_j = ((axis_len as isize
220                                    - (before as isize - j as isize) % axis_len as isize)
221                                    % axis_len as isize)
222                                    as usize;
223                                current_data[o * axis_len * inner + src_j * inner + k].clone()
224                            }
225                        }
226                    } else if j < before + axis_len {
227                        // Original data
228                        let src_j = j - before;
229                        current_data[o * axis_len * inner + src_j * inner + k].clone()
230                    } else {
231                        // After padding
232                        let after_idx = j - before - axis_len;
233                        match mode {
234                            PadMode::Constant(c) => c.clone(),
235                            PadMode::Edge => {
236                                let src_j = axis_len - 1;
237                                current_data[o * axis_len * inner + src_j * inner + k].clone()
238                            }
239                            PadMode::Reflect => {
240                                let src_j = reflect_index(
241                                    (axis_len as isize) + after_idx as isize,
242                                    axis_len,
243                                );
244                                current_data[o * axis_len * inner + src_j * inner + k].clone()
245                            }
246                            PadMode::Symmetric => {
247                                let src_j = symmetric_index(
248                                    (axis_len as isize) + after_idx as isize,
249                                    axis_len,
250                                );
251                                current_data[o * axis_len * inner + src_j * inner + k].clone()
252                            }
253                            PadMode::Wrap => {
254                                let src_j = after_idx % axis_len;
255                                current_data[o * axis_len * inner + src_j * inner + k].clone()
256                            }
257                        }
258                    };
259                    new_data.push(val);
260                }
261            }
262        }
263
264        current_data = new_data;
265        current_shape[ax] = new_axis_len;
266    }
267
268    Array::from_vec(IxDyn::new(&current_shape), current_data)
269}
270
271/// Construct an array by repeating `a` the number of times given by `reps`.
272///
273/// If `reps` has fewer entries than `a.ndim()`, it is prepended with 1s.
274/// If `reps` has more entries, `a`'s shape is prepended with 1s.
275///
276/// Analogous to `numpy.tile()`.
277///
278/// # Errors
279/// Returns `FerrayError::InvalidValue` if `reps` is empty.
280pub fn tile<T: Element, D: Dimension>(
281    a: &Array<T, D>,
282    reps: &[usize],
283) -> FerrayResult<Array<T, IxDyn>> {
284    if reps.is_empty() {
285        return Err(FerrayError::invalid_value("tile: reps cannot be empty"));
286    }
287
288    let src_shape = a.shape();
289    let src_ndim = src_shape.len();
290    let reps_ndim = reps.len();
291    let out_ndim = src_ndim.max(reps_ndim);
292
293    // Pad shapes to out_ndim
294    let mut padded_shape = vec![1usize; out_ndim];
295    for i in 0..src_ndim {
296        padded_shape[out_ndim - src_ndim + i] = src_shape[i];
297    }
298    let mut padded_reps = vec![1usize; out_ndim];
299    for i in 0..reps_ndim {
300        padded_reps[out_ndim - reps_ndim + i] = reps[i];
301    }
302
303    let out_shape: Vec<usize> = padded_shape
304        .iter()
305        .zip(padded_reps.iter())
306        .map(|(&s, &r)| s * r)
307        .collect();
308    let total: usize = out_shape.iter().product();
309
310    let src_data: Vec<T> = a.iter().cloned().collect();
311    let mut data = Vec::with_capacity(total);
312
313    // Compute strides for output and padded source
314    let mut out_strides = vec![1usize; out_ndim];
315    for i in (0..out_ndim.saturating_sub(1)).rev() {
316        out_strides[i] = out_strides[i + 1] * out_shape[i + 1];
317    }
318
319    let mut src_strides = vec![1usize; out_ndim];
320    for i in (0..out_ndim.saturating_sub(1)).rev() {
321        src_strides[i] = src_strides[i + 1] * padded_shape[i + 1];
322    }
323
324    for flat in 0..total {
325        let mut rem = flat;
326        let mut src_flat = 0usize;
327        for i in 0..out_ndim {
328            let idx = rem / out_strides[i];
329            rem %= out_strides[i];
330            let src_idx = idx % padded_shape[i];
331            src_flat += src_idx * src_strides[i];
332        }
333        // Map src_flat to the original (non-padded) source
334        // Since we padded with 1s, the strides handle it.
335        if src_flat < src_data.len() {
336            data.push(src_data[src_flat].clone());
337        } else {
338            // This shouldn't happen if the math is right
339            data.push(T::zero());
340        }
341    }
342
343    Array::from_vec(IxDyn::new(&out_shape), data)
344}
345
346/// Repeat elements of an array.
347///
348/// If `axis` is `None`, the array is flattened first, then each element
349/// is repeated `repeats` times.
350///
351/// Analogous to `numpy.repeat()`.
352///
353/// # Errors
354/// Returns `FerrayError::AxisOutOfBounds` if the axis is out of bounds.
355pub fn repeat<T: Element, D: Dimension>(
356    a: &Array<T, D>,
357    repeats: usize,
358    axis: Option<usize>,
359) -> FerrayResult<Array<T, IxDyn>> {
360    match axis {
361        None => {
362            // Flatten and repeat each element
363            let src: Vec<T> = a.iter().cloned().collect();
364            let mut data = Vec::with_capacity(src.len() * repeats);
365            for val in &src {
366                for _ in 0..repeats {
367                    data.push(val.clone());
368                }
369            }
370            let n = data.len();
371            Array::from_vec(IxDyn::new(&[n]), data)
372        }
373        Some(ax) => {
374            let shape = a.shape();
375            let ndim = shape.len();
376            if ax >= ndim {
377                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
378            }
379
380            let mut new_shape = shape.to_vec();
381            new_shape[ax] *= repeats;
382            let total: usize = new_shape.iter().product();
383            let src_data: Vec<T> = a.iter().cloned().collect();
384
385            // Compute source strides (C-order)
386            let mut src_strides = vec![1usize; ndim];
387            for i in (0..ndim.saturating_sub(1)).rev() {
388                src_strides[i] = src_strides[i + 1] * shape[i + 1];
389            }
390
391            // Compute output strides
392            let mut out_strides = vec![1usize; ndim];
393            for i in (0..ndim.saturating_sub(1)).rev() {
394                out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
395            }
396
397            let mut data = Vec::with_capacity(total);
398            for flat in 0..total {
399                let mut rem = flat;
400                let mut src_flat = 0usize;
401                for i in 0..ndim {
402                    let idx = rem / out_strides[i];
403                    rem %= out_strides[i];
404                    let src_idx = if i == ax { idx / repeats } else { idx };
405                    src_flat += src_idx * src_strides[i];
406                }
407                data.push(src_data[src_flat].clone());
408            }
409
410            Array::from_vec(IxDyn::new(&new_shape), data)
411        }
412    }
413}
414
415/// Delete sub-arrays along an axis.
416///
417/// `indices` specifies which indices along `axis` to remove.
418///
419/// Analogous to `numpy.delete()`.
420///
421/// # Errors
422/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
423/// Returns `FerrayError::IndexOutOfBounds` if any index is out of range.
424pub fn delete<T: Element, D: Dimension>(
425    a: &Array<T, D>,
426    indices: &[usize],
427    axis: usize,
428) -> FerrayResult<Array<T, IxDyn>> {
429    let shape = a.shape();
430    let ndim = shape.len();
431    if axis >= ndim {
432        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
433    }
434    let axis_len = shape[axis];
435
436    // Validate indices
437    for &idx in indices {
438        if idx >= axis_len {
439            return Err(FerrayError::IndexOutOfBounds {
440                index: idx as isize,
441                axis,
442                size: axis_len,
443            });
444        }
445    }
446
447    let to_remove: std::collections::HashSet<usize> = indices.iter().copied().collect();
448    let kept: Vec<usize> = (0..axis_len).filter(|i| !to_remove.contains(i)).collect();
449    let new_axis_len = kept.len();
450
451    let mut new_shape = shape.to_vec();
452    new_shape[axis] = new_axis_len;
453    let total: usize = new_shape.iter().product();
454    let src_data: Vec<T> = a.iter().cloned().collect();
455
456    // Compute source strides (C-order)
457    let mut src_strides = vec![1usize; ndim];
458    for i in (0..ndim.saturating_sub(1)).rev() {
459        src_strides[i] = src_strides[i + 1] * shape[i + 1];
460    }
461
462    // Compute output strides
463    let mut out_strides = vec![1usize; ndim];
464    for i in (0..ndim.saturating_sub(1)).rev() {
465        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
466    }
467
468    let mut data = Vec::with_capacity(total);
469    for flat in 0..total {
470        let mut rem = flat;
471        let mut src_flat = 0usize;
472        for i in 0..ndim {
473            let idx = rem / out_strides[i];
474            rem %= out_strides[i];
475            let src_idx = if i == axis { kept[idx] } else { idx };
476            src_flat += src_idx * src_strides[i];
477        }
478        data.push(src_data[src_flat].clone());
479    }
480
481    Array::from_vec(IxDyn::new(&new_shape), data)
482}
483
484/// Insert values along an axis before a given index.
485///
486/// `index` is the position before which to insert. `values` is a 1-D array
487/// of values to insert (its length determines how many slices are added).
488///
489/// Analogous to `numpy.insert()`.
490///
491/// # Errors
492/// Returns `FerrayError::AxisOutOfBounds` if axis is out of bounds.
493/// Returns `FerrayError::IndexOutOfBounds` if `index > axis_len`.
494pub fn insert<T: Element, D: Dimension>(
495    a: &Array<T, D>,
496    index: usize,
497    values: &Array<T, IxDyn>,
498    axis: usize,
499) -> FerrayResult<Array<T, IxDyn>> {
500    let shape = a.shape();
501    let ndim = shape.len();
502    if axis >= ndim {
503        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
504    }
505    let axis_len = shape[axis];
506    if index > axis_len {
507        return Err(FerrayError::IndexOutOfBounds {
508            index: index as isize,
509            axis,
510            size: axis_len + 1,
511        });
512    }
513
514    let n_insert = values.size();
515    let vals: Vec<T> = values.iter().cloned().collect();
516
517    let mut new_shape = shape.to_vec();
518    new_shape[axis] = axis_len + n_insert;
519    let total: usize = new_shape.iter().product();
520    let src_data: Vec<T> = a.iter().cloned().collect();
521
522    // Compute strides
523    let mut src_strides = vec![1usize; ndim];
524    for i in (0..ndim.saturating_sub(1)).rev() {
525        src_strides[i] = src_strides[i + 1] * shape[i + 1];
526    }
527
528    let mut out_strides = vec![1usize; ndim];
529    for i in (0..ndim.saturating_sub(1)).rev() {
530        out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
531    }
532
533    // Compute the "inner" size (product of dims after axis)
534    let inner: usize = shape[axis + 1..].iter().product();
535
536    let mut data = Vec::with_capacity(total);
537    for flat in 0..total {
538        let mut rem = flat;
539        let mut nd_idx = vec![0usize; ndim];
540        for i in 0..ndim {
541            nd_idx[i] = rem / out_strides[i];
542            rem %= out_strides[i];
543        }
544
545        let ax_idx = nd_idx[axis];
546        if ax_idx >= index && ax_idx < index + n_insert {
547            // This is an inserted value
548            let insert_idx = ax_idx - index;
549            // For multi-D insert, we tile the values along the inner dims
550            let val_idx = (insert_idx * inner + nd_idx.get(axis + 1).copied().unwrap_or(0))
551                % vals.len().max(1);
552            data.push(vals[val_idx].clone());
553        } else {
554            // Original data
555            let src_ax_idx = if ax_idx >= index + n_insert {
556                ax_idx - n_insert
557            } else {
558                ax_idx
559            };
560            let mut src_flat = 0usize;
561            for i in 0..ndim {
562                let idx = if i == axis { src_ax_idx } else { nd_idx[i] };
563                src_flat += idx * src_strides[i];
564            }
565            data.push(src_data[src_flat].clone());
566        }
567    }
568
569    Array::from_vec(IxDyn::new(&new_shape), data)
570}
571
572/// Append values to the end of an array along an axis.
573///
574/// If `axis` is `None`, both arrays are flattened first.
575///
576/// Analogous to `numpy.append()`.
577pub fn append<T: Element, D: Dimension>(
578    a: &Array<T, D>,
579    values: &Array<T, IxDyn>,
580    axis: Option<usize>,
581) -> FerrayResult<Array<T, IxDyn>> {
582    match axis {
583        None => {
584            let mut data: Vec<T> = a.iter().cloned().collect();
585            data.extend(values.iter().cloned());
586            let n = data.len();
587            Array::from_vec(IxDyn::new(&[n]), data)
588        }
589        Some(ax) => {
590            let a_dyn = {
591                let data: Vec<T> = a.iter().cloned().collect();
592                Array::from_vec(IxDyn::new(a.shape()), data)?
593            };
594            let vals_dyn = {
595                let data: Vec<T> = values.iter().cloned().collect();
596                Array::from_vec(IxDyn::new(values.shape()), data)?
597            };
598            super::concatenate(&[a_dyn, vals_dyn], ax)
599        }
600    }
601}
602
603/// Resize an array to a new shape.
604///
605/// If the new size is larger, the array is filled by repeating its elements.
606/// If smaller, the array is truncated.
607///
608/// Analogous to `numpy.resize()`.
609pub fn resize<T: Element, D: Dimension>(
610    a: &Array<T, D>,
611    new_shape: &[usize],
612) -> FerrayResult<Array<T, IxDyn>> {
613    let src: Vec<T> = a.iter().cloned().collect();
614    let new_size: usize = new_shape.iter().product();
615
616    if src.is_empty() {
617        // Fill with zeros
618        let data = vec![T::zero(); new_size];
619        return Array::from_vec(IxDyn::new(new_shape), data);
620    }
621
622    let mut data = Vec::with_capacity(new_size);
623    for i in 0..new_size {
624        data.push(src[i % src.len()].clone());
625    }
626    Array::from_vec(IxDyn::new(new_shape), data)
627}
628
629/// Trim leading and/or trailing zeros from a 1-D array.
630///
631/// `trim` can be `"f"` (front), `"b"` (back), or `"fb"` (both, default).
632///
633/// Analogous to `numpy.trim_zeros()`.
634///
635/// # Errors
636/// Returns `FerrayError::InvalidValue` if `trim` contains invalid characters.
637pub fn trim_zeros<T: Element + PartialEq>(
638    a: &Array<T, Ix1>,
639    trim: &str,
640) -> FerrayResult<Array<T, Ix1>> {
641    let data: Vec<T> = a.iter().cloned().collect();
642    let zero = T::zero();
643
644    let trim_front = trim.contains('f');
645    let trim_back = trim.contains('b');
646
647    if !trim.chars().all(|c| c == 'f' || c == 'b') {
648        return Err(FerrayError::invalid_value(
649            "trim_zeros: trim must contain only 'f' and/or 'b'",
650        ));
651    }
652
653    let start = if trim_front {
654        data.iter().position(|v| *v != zero).unwrap_or(data.len())
655    } else {
656        0
657    };
658
659    let end = if trim_back {
660        data.iter()
661            .rposition(|v| *v != zero)
662            .map(|i| i + 1)
663            .unwrap_or(start)
664    } else {
665        data.len()
666    };
667
668    let end = end.max(start);
669    let trimmed: Vec<T> = data[start..end].to_vec();
670    let n = trimmed.len();
671    Array::from_vec(Ix1::new([n]), trimmed)
672}
673
674// ============================================================================
675// Tests
676// ============================================================================
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681
682    fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
683        Array::from_vec(IxDyn::new(shape), data).unwrap()
684    }
685
686    fn arr1d(data: Vec<f64>) -> Array<f64, Ix1> {
687        let n = data.len();
688        Array::from_vec(Ix1::new([n]), data).unwrap()
689    }
690
691    // -- pad --
692
693    #[test]
694    fn test_pad_1d_constant() {
695        let a = arr1d(vec![1.0, 2.0, 3.0]);
696        let b = pad_1d(&a, (2, 3), &PadMode::Constant(0.0)).unwrap();
697        assert_eq!(b.shape(), &[8]);
698        let data: Vec<f64> = b.iter().copied().collect();
699        assert_eq!(data, vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
700    }
701
702    #[test]
703    fn test_pad_1d_edge() {
704        let a = arr1d(vec![1.0, 2.0, 3.0]);
705        let b = pad_1d(&a, (2, 2), &PadMode::Edge).unwrap();
706        let data: Vec<f64> = b.iter().copied().collect();
707        assert_eq!(data, vec![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
708    }
709
710    #[test]
711    fn test_pad_1d_wrap() {
712        let a = arr1d(vec![1.0, 2.0, 3.0]);
713        let b = pad_1d(&a, (2, 2), &PadMode::Wrap).unwrap();
714        let data: Vec<f64> = b.iter().copied().collect();
715        assert_eq!(data, vec![2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0]);
716    }
717
718    #[test]
719    fn test_pad_nd_constant() {
720        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
721        let b = pad(&a, &[(1, 1), (1, 1)], &PadMode::Constant(0.0)).unwrap();
722        assert_eq!(b.shape(), &[4, 4]);
723        let data: Vec<f64> = b.iter().copied().collect();
724        assert_eq!(
725            data,
726            vec![
727                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0
728            ]
729        );
730    }
731
732    // -- tile --
733
734    #[test]
735    fn test_tile_1d() {
736        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
737        let b = tile(&a, &[3]).unwrap();
738        assert_eq!(b.shape(), &[9]);
739        let data: Vec<f64> = b.iter().copied().collect();
740        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
741    }
742
743    #[test]
744    fn test_tile_2d() {
745        let a = dyn_arr(&[2], vec![1.0, 2.0]);
746        let b = tile(&a, &[2, 3]).unwrap();
747        assert_eq!(b.shape(), &[2, 6]);
748    }
749
750    // -- repeat --
751
752    #[test]
753    fn test_repeat_flat() {
754        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
755        let b = repeat(&a, 2, None).unwrap();
756        let data: Vec<f64> = b.iter().copied().collect();
757        assert_eq!(data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
758    }
759
760    #[test]
761    fn test_repeat_axis() {
762        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
763        let b = repeat(&a, 2, Some(0)).unwrap();
764        assert_eq!(b.shape(), &[4, 2]);
765        let data: Vec<f64> = b.iter().copied().collect();
766        assert_eq!(data, vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]);
767    }
768
769    // -- delete --
770
771    #[test]
772    fn test_delete() {
773        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
774        let b = delete(&a, &[1, 3], 0).unwrap();
775        let data: Vec<f64> = b.iter().copied().collect();
776        assert_eq!(data, vec![1.0, 3.0, 5.0]);
777    }
778
779    #[test]
780    fn test_delete_2d() {
781        let a = dyn_arr(&[3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
782        let b = delete(&a, &[1], 0).unwrap();
783        assert_eq!(b.shape(), &[2, 2]);
784        let data: Vec<f64> = b.iter().copied().collect();
785        assert_eq!(data, vec![1.0, 2.0, 5.0, 6.0]);
786    }
787
788    // -- insert --
789
790    #[test]
791    fn test_insert() {
792        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
793        let vals = dyn_arr(&[2], vec![10.0, 20.0]);
794        let b = insert(&a, 1, &vals, 0).unwrap();
795        let data: Vec<f64> = b.iter().copied().collect();
796        assert_eq!(data, vec![1.0, 10.0, 20.0, 2.0, 3.0]);
797    }
798
799    // -- append --
800
801    #[test]
802    fn test_append_flat() {
803        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
804        let vals = dyn_arr(&[2], vec![4.0, 5.0]);
805        let b = append(&a, &vals, None).unwrap();
806        let data: Vec<f64> = b.iter().copied().collect();
807        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
808    }
809
810    #[test]
811    fn test_append_axis() {
812        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
813        let vals = dyn_arr(&[2, 1], vec![5.0, 6.0]);
814        let b = append(&a, &vals, Some(1)).unwrap();
815        assert_eq!(b.shape(), &[2, 3]);
816    }
817
818    // -- resize --
819
820    #[test]
821    fn test_resize_larger() {
822        let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
823        let b = resize(&a, &[5]).unwrap();
824        let data: Vec<f64> = b.iter().copied().collect();
825        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0]);
826    }
827
828    #[test]
829    fn test_resize_smaller() {
830        let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
831        let b = resize(&a, &[3]).unwrap();
832        let data: Vec<f64> = b.iter().copied().collect();
833        assert_eq!(data, vec![1.0, 2.0, 3.0]);
834    }
835
836    #[test]
837    fn test_resize_2d() {
838        let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
839        let b = resize(&a, &[3, 3]).unwrap();
840        assert_eq!(b.shape(), &[3, 3]);
841    }
842
843    // -- trim_zeros --
844
845    #[test]
846    fn test_trim_zeros_both() {
847        let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
848        let b = trim_zeros(&a, "fb").unwrap();
849        let data: Vec<f64> = b.iter().copied().collect();
850        assert_eq!(data, vec![1.0, 2.0, 3.0]);
851    }
852
853    #[test]
854    fn test_trim_zeros_front() {
855        let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 0.0]);
856        let b = trim_zeros(&a, "f").unwrap();
857        let data: Vec<f64> = b.iter().copied().collect();
858        assert_eq!(data, vec![1.0, 2.0, 0.0]);
859    }
860
861    #[test]
862    fn test_trim_zeros_back() {
863        let a = arr1d(vec![0.0, 1.0, 2.0, 0.0, 0.0]);
864        let b = trim_zeros(&a, "b").unwrap();
865        let data: Vec<f64> = b.iter().copied().collect();
866        assert_eq!(data, vec![0.0, 1.0, 2.0]);
867    }
868
869    #[test]
870    fn test_trim_zeros_all_zeros() {
871        let a = arr1d(vec![0.0, 0.0, 0.0]);
872        let b = trim_zeros(&a, "fb").unwrap();
873        assert_eq!(b.shape(), &[0]);
874    }
875
876    #[test]
877    fn test_trim_zeros_bad_mode() {
878        let a = arr1d(vec![1.0, 2.0]);
879        assert!(trim_zeros(&a, "x").is_err());
880    }
881}