1use crate::array::owned::Array;
6use crate::dimension::{Dimension, Ix1, IxDyn};
7use crate::dtype::Element;
8use crate::error::{FerrayError, FerrayResult};
9
10#[derive(Debug, Clone)]
16pub enum PadMode<T: Element> {
17 Constant(T),
19 Edge,
21 Reflect,
24 Symmetric,
26 Wrap,
28}
29
30pub fn pad_1d<T: Element>(
40 a: &Array<T, Ix1>,
41 pad_width: (usize, usize),
42 mode: &PadMode<T>,
43) -> FerrayResult<Array<T, Ix1>> {
44 let n = a.shape()[0];
45 let (before, after) = pad_width;
46 let new_len = before + n + after;
47 let src: Vec<T> = a.iter().cloned().collect();
48
49 if n == 0 && !matches!(mode, PadMode::Constant(_)) {
50 return Err(FerrayError::invalid_value(
51 "pad: cannot use Edge/Reflect/Symmetric/Wrap mode on empty array",
52 ));
53 }
54
55 let mut data = Vec::with_capacity(new_len);
56
57 for i in 0..before {
59 let val = match mode {
60 PadMode::Constant(c) => c.clone(),
61 PadMode::Edge => src[0].clone(),
62 PadMode::Reflect => {
63 let idx = reflect_index(before as isize - 1 - i as isize, n);
65 src[idx].clone()
66 }
67 PadMode::Symmetric => {
68 let idx = symmetric_index(before as isize - 1 - i as isize, n);
70 src[idx].clone()
71 }
72 PadMode::Wrap => {
73 let idx = ((n as isize - (before as isize - i as isize) % n as isize) % n as isize)
74 as usize;
75 src[idx].clone()
76 }
77 };
78 data.push(val);
79 }
80
81 data.extend_from_slice(&src);
83
84 for i in 0..after {
86 let val = match mode {
87 PadMode::Constant(c) => c.clone(),
88 PadMode::Edge => src[n - 1].clone(),
89 PadMode::Reflect => {
90 let idx = reflect_index(n as isize + i as isize, n);
91 src[idx].clone()
92 }
93 PadMode::Symmetric => {
94 let idx = symmetric_index(n as isize + i as isize, n);
95 src[idx].clone()
96 }
97 PadMode::Wrap => {
98 let idx = i % n;
99 src[idx].clone()
100 }
101 };
102 data.push(val);
103 }
104
105 Array::from_vec(Ix1::new([new_len]), data)
106}
107
108fn reflect_index(idx: isize, n: usize) -> usize {
111 if n <= 1 {
112 return 0;
113 }
114 let period = (n - 1) as isize * 2;
115 let mut i = idx % period;
116 if i < 0 {
117 i += period;
118 }
119 if i >= n as isize {
120 i = period - i;
121 }
122 i as usize
123}
124
125fn symmetric_index(idx: isize, n: usize) -> usize {
128 if n == 0 {
129 return 0;
130 }
131 if n == 1 {
132 return 0;
133 }
134 let period = n as isize * 2;
135 let mut i = idx % period;
136 if i < 0 {
137 i += period;
138 }
139 if i >= n as isize {
140 i = period - 1 - i;
141 }
142 i.max(0) as usize
143}
144
145pub fn pad<T: Element, D: Dimension>(
155 a: &Array<T, D>,
156 pad_width: &[(usize, usize)],
157 mode: &PadMode<T>,
158) -> FerrayResult<Array<T, IxDyn>> {
159 if pad_width.is_empty() {
160 return Err(FerrayError::invalid_value("pad: pad_width cannot be empty"));
161 }
162
163 let shape = a.shape();
164 let ndim = shape.len();
165
166 let pads: Vec<(usize, usize)> = (0..ndim)
168 .map(|i| {
169 if i < pad_width.len() {
170 pad_width[i]
171 } else {
172 *pad_width.last().unwrap_or_else(|| unreachable!())
174 }
175 })
176 .collect();
177
178 let mut current_data: Vec<T> = a.iter().cloned().collect();
181 let mut current_shape: Vec<usize> = shape.to_vec();
182
183 for ax in (0..ndim).rev() {
184 let (before, after) = pads[ax];
185 if before == 0 && after == 0 {
186 continue;
187 }
188 let axis_len = current_shape[ax];
189 let new_axis_len = before + axis_len + after;
190
191 let outer: usize = current_shape[..ax].iter().product();
193 let inner: usize = current_shape[ax + 1..].iter().product();
194
195 let new_total = outer * new_axis_len * inner;
196 let mut new_data = Vec::with_capacity(new_total);
197
198 for o in 0..outer {
199 for j in 0..new_axis_len {
200 for k in 0..inner {
201 let val = if j < before {
202 match mode {
204 PadMode::Constant(c) => c.clone(),
205 PadMode::Edge => {
206 let src_j = 0;
207 current_data[o * axis_len * inner + src_j * inner + k].clone()
208 }
209 PadMode::Reflect => {
210 let src_j =
211 reflect_index(before as isize - 1 - j as isize, axis_len);
212 current_data[o * axis_len * inner + src_j * inner + k].clone()
213 }
214 PadMode::Symmetric => {
215 let src_j =
216 symmetric_index(before as isize - 1 - j as isize, axis_len);
217 current_data[o * axis_len * inner + src_j * inner + k].clone()
218 }
219 PadMode::Wrap => {
220 let src_j = ((axis_len as isize
221 - (before as isize - j as isize) % axis_len as isize)
222 % axis_len as isize)
223 as usize;
224 current_data[o * axis_len * inner + src_j * inner + k].clone()
225 }
226 }
227 } else if j < before + axis_len {
228 let src_j = j - before;
230 current_data[o * axis_len * inner + src_j * inner + k].clone()
231 } else {
232 let after_idx = j - before - axis_len;
234 match mode {
235 PadMode::Constant(c) => c.clone(),
236 PadMode::Edge => {
237 let src_j = axis_len - 1;
238 current_data[o * axis_len * inner + src_j * inner + k].clone()
239 }
240 PadMode::Reflect => {
241 let src_j = reflect_index(
242 (axis_len as isize) + after_idx as isize,
243 axis_len,
244 );
245 current_data[o * axis_len * inner + src_j * inner + k].clone()
246 }
247 PadMode::Symmetric => {
248 let src_j = symmetric_index(
249 (axis_len as isize) + after_idx as isize,
250 axis_len,
251 );
252 current_data[o * axis_len * inner + src_j * inner + k].clone()
253 }
254 PadMode::Wrap => {
255 let src_j = after_idx % axis_len;
256 current_data[o * axis_len * inner + src_j * inner + k].clone()
257 }
258 }
259 };
260 new_data.push(val);
261 }
262 }
263 }
264
265 current_data = new_data;
266 current_shape[ax] = new_axis_len;
267 }
268
269 Array::from_vec(IxDyn::new(¤t_shape), current_data)
270}
271
272pub fn tile<T: Element, D: Dimension>(
282 a: &Array<T, D>,
283 reps: &[usize],
284) -> FerrayResult<Array<T, IxDyn>> {
285 if reps.is_empty() {
286 return Err(FerrayError::invalid_value("tile: reps cannot be empty"));
287 }
288
289 let src_shape = a.shape();
290 let src_ndim = src_shape.len();
291 let reps_ndim = reps.len();
292 let out_ndim = src_ndim.max(reps_ndim);
293
294 let mut padded_shape = vec![1usize; out_ndim];
296 for i in 0..src_ndim {
297 padded_shape[out_ndim - src_ndim + i] = src_shape[i];
298 }
299 let mut padded_reps = vec![1usize; out_ndim];
300 for i in 0..reps_ndim {
301 padded_reps[out_ndim - reps_ndim + i] = reps[i];
302 }
303
304 let out_shape: Vec<usize> = padded_shape
305 .iter()
306 .zip(padded_reps.iter())
307 .map(|(&s, &r)| s * r)
308 .collect();
309 let total: usize = out_shape.iter().product();
310
311 let src_data: Vec<T> = a.iter().cloned().collect();
312 let mut data = Vec::with_capacity(total);
313
314 let mut out_strides = vec![1usize; out_ndim];
316 for i in (0..out_ndim.saturating_sub(1)).rev() {
317 out_strides[i] = out_strides[i + 1] * out_shape[i + 1];
318 }
319
320 let mut src_strides = vec![1usize; out_ndim];
321 for i in (0..out_ndim.saturating_sub(1)).rev() {
322 src_strides[i] = src_strides[i + 1] * padded_shape[i + 1];
323 }
324
325 for flat in 0..total {
326 let mut rem = flat;
327 let mut src_flat = 0usize;
328 for i in 0..out_ndim {
329 let idx = rem / out_strides[i];
330 rem %= out_strides[i];
331 let src_idx = idx % padded_shape[i];
332 src_flat += src_idx * src_strides[i];
333 }
334 if src_flat < src_data.len() {
337 data.push(src_data[src_flat].clone());
338 } else {
339 data.push(T::zero());
341 }
342 }
343
344 Array::from_vec(IxDyn::new(&out_shape), data)
345}
346
347pub fn repeat<T: Element, D: Dimension>(
357 a: &Array<T, D>,
358 repeats: usize,
359 axis: Option<usize>,
360) -> FerrayResult<Array<T, IxDyn>> {
361 match axis {
362 None => {
363 let src: Vec<T> = a.iter().cloned().collect();
365 let mut data = Vec::with_capacity(src.len() * repeats);
366 for val in &src {
367 for _ in 0..repeats {
368 data.push(val.clone());
369 }
370 }
371 let n = data.len();
372 Array::from_vec(IxDyn::new(&[n]), data)
373 }
374 Some(ax) => {
375 let shape = a.shape();
376 let ndim = shape.len();
377 if ax >= ndim {
378 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
379 }
380
381 let mut new_shape = shape.to_vec();
382 new_shape[ax] *= repeats;
383 let total: usize = new_shape.iter().product();
384 let src_data: Vec<T> = a.iter().cloned().collect();
385
386 let mut src_strides = vec![1usize; ndim];
388 for i in (0..ndim.saturating_sub(1)).rev() {
389 src_strides[i] = src_strides[i + 1] * shape[i + 1];
390 }
391
392 let mut out_strides = vec![1usize; ndim];
394 for i in (0..ndim.saturating_sub(1)).rev() {
395 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
396 }
397
398 let mut data = Vec::with_capacity(total);
399 for flat in 0..total {
400 let mut rem = flat;
401 let mut src_flat = 0usize;
402 for i in 0..ndim {
403 let idx = rem / out_strides[i];
404 rem %= out_strides[i];
405 let src_idx = if i == ax { idx / repeats } else { idx };
406 src_flat += src_idx * src_strides[i];
407 }
408 data.push(src_data[src_flat].clone());
409 }
410
411 Array::from_vec(IxDyn::new(&new_shape), data)
412 }
413 }
414}
415
416pub fn delete<T: Element, D: Dimension>(
426 a: &Array<T, D>,
427 indices: &[usize],
428 axis: usize,
429) -> FerrayResult<Array<T, IxDyn>> {
430 let shape = a.shape();
431 let ndim = shape.len();
432 if axis >= ndim {
433 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
434 }
435 let axis_len = shape[axis];
436
437 for &idx in indices {
439 if idx >= axis_len {
440 return Err(FerrayError::IndexOutOfBounds {
441 index: idx as isize,
442 axis,
443 size: axis_len,
444 });
445 }
446 }
447
448 let to_remove: std::collections::HashSet<usize> = indices.iter().copied().collect();
449 let kept: Vec<usize> = (0..axis_len).filter(|i| !to_remove.contains(i)).collect();
450 let new_axis_len = kept.len();
451
452 let mut new_shape = shape.to_vec();
453 new_shape[axis] = new_axis_len;
454 let total: usize = new_shape.iter().product();
455 let src_data: Vec<T> = a.iter().cloned().collect();
456
457 let mut src_strides = vec![1usize; ndim];
459 for i in (0..ndim.saturating_sub(1)).rev() {
460 src_strides[i] = src_strides[i + 1] * shape[i + 1];
461 }
462
463 let mut out_strides = vec![1usize; ndim];
465 for i in (0..ndim.saturating_sub(1)).rev() {
466 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
467 }
468
469 let mut data = Vec::with_capacity(total);
470 for flat in 0..total {
471 let mut rem = flat;
472 let mut src_flat = 0usize;
473 for i in 0..ndim {
474 let idx = rem / out_strides[i];
475 rem %= out_strides[i];
476 let src_idx = if i == axis { kept[idx] } else { idx };
477 src_flat += src_idx * src_strides[i];
478 }
479 data.push(src_data[src_flat].clone());
480 }
481
482 Array::from_vec(IxDyn::new(&new_shape), data)
483}
484
485pub fn insert<T: Element, D: Dimension>(
496 a: &Array<T, D>,
497 index: usize,
498 values: &Array<T, IxDyn>,
499 axis: usize,
500) -> FerrayResult<Array<T, IxDyn>> {
501 let shape = a.shape();
502 let ndim = shape.len();
503 if axis >= ndim {
504 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
505 }
506 let axis_len = shape[axis];
507 if index > axis_len {
508 return Err(FerrayError::IndexOutOfBounds {
509 index: index as isize,
510 axis,
511 size: axis_len + 1,
512 });
513 }
514
515 let n_insert = values.size();
516 let vals: Vec<T> = values.iter().cloned().collect();
517
518 let mut new_shape = shape.to_vec();
519 new_shape[axis] = axis_len + n_insert;
520 let total: usize = new_shape.iter().product();
521 let src_data: Vec<T> = a.iter().cloned().collect();
522
523 let mut src_strides = vec![1usize; ndim];
525 for i in (0..ndim.saturating_sub(1)).rev() {
526 src_strides[i] = src_strides[i + 1] * shape[i + 1];
527 }
528
529 let mut out_strides = vec![1usize; ndim];
530 for i in (0..ndim.saturating_sub(1)).rev() {
531 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
532 }
533
534 let inner: usize = shape[axis + 1..].iter().product();
536
537 let mut data = Vec::with_capacity(total);
538 for flat in 0..total {
539 let mut rem = flat;
540 let mut nd_idx = vec![0usize; ndim];
541 for i in 0..ndim {
542 nd_idx[i] = rem / out_strides[i];
543 rem %= out_strides[i];
544 }
545
546 let ax_idx = nd_idx[axis];
547 if ax_idx >= index && ax_idx < index + n_insert {
548 let insert_idx = ax_idx - index;
550 let val_idx = (insert_idx * inner + nd_idx.get(axis + 1).copied().unwrap_or(0))
552 % vals.len().max(1);
553 data.push(vals[val_idx].clone());
554 } else {
555 let src_ax_idx = if ax_idx >= index + n_insert {
557 ax_idx - n_insert
558 } else {
559 ax_idx
560 };
561 let mut src_flat = 0usize;
562 for i in 0..ndim {
563 let idx = if i == axis { src_ax_idx } else { nd_idx[i] };
564 src_flat += idx * src_strides[i];
565 }
566 data.push(src_data[src_flat].clone());
567 }
568 }
569
570 Array::from_vec(IxDyn::new(&new_shape), data)
571}
572
573pub fn append<T: Element, D: Dimension>(
579 a: &Array<T, D>,
580 values: &Array<T, IxDyn>,
581 axis: Option<usize>,
582) -> FerrayResult<Array<T, IxDyn>> {
583 match axis {
584 None => {
585 let mut data: Vec<T> = a.iter().cloned().collect();
586 data.extend(values.iter().cloned());
587 let n = data.len();
588 Array::from_vec(IxDyn::new(&[n]), data)
589 }
590 Some(ax) => {
591 let a_dyn = {
592 let data: Vec<T> = a.iter().cloned().collect();
593 Array::from_vec(IxDyn::new(a.shape()), data)?
594 };
595 let vals_dyn = {
596 let data: Vec<T> = values.iter().cloned().collect();
597 Array::from_vec(IxDyn::new(values.shape()), data)?
598 };
599 super::concatenate(&[a_dyn, vals_dyn], ax)
600 }
601 }
602}
603
604pub fn resize<T: Element, D: Dimension>(
611 a: &Array<T, D>,
612 new_shape: &[usize],
613) -> FerrayResult<Array<T, IxDyn>> {
614 let src: Vec<T> = a.iter().cloned().collect();
615 let new_size: usize = new_shape.iter().product();
616
617 if src.is_empty() {
618 let data = vec![T::zero(); new_size];
620 return Array::from_vec(IxDyn::new(new_shape), data);
621 }
622
623 let mut data = Vec::with_capacity(new_size);
624 for i in 0..new_size {
625 data.push(src[i % src.len()].clone());
626 }
627 Array::from_vec(IxDyn::new(new_shape), data)
628}
629
630pub fn trim_zeros<T: Element + PartialEq>(
639 a: &Array<T, Ix1>,
640 trim: &str,
641) -> FerrayResult<Array<T, Ix1>> {
642 let data: Vec<T> = a.iter().cloned().collect();
643 let zero = T::zero();
644
645 let trim_front = trim.contains('f');
646 let trim_back = trim.contains('b');
647
648 if !trim.chars().all(|c| c == 'f' || c == 'b') {
649 return Err(FerrayError::invalid_value(
650 "trim_zeros: trim must contain only 'f' and/or 'b'",
651 ));
652 }
653
654 let start = if trim_front {
655 data.iter().position(|v| *v != zero).unwrap_or(data.len())
656 } else {
657 0
658 };
659
660 let end = if trim_back {
661 data.iter()
662 .rposition(|v| *v != zero)
663 .map(|i| i + 1)
664 .unwrap_or(start)
665 } else {
666 data.len()
667 };
668
669 let end = end.max(start);
670 let trimmed: Vec<T> = data[start..end].to_vec();
671 let n = trimmed.len();
672 Array::from_vec(Ix1::new([n]), trimmed)
673}
674
675#[cfg(test)]
680mod tests {
681 use super::*;
682
683 fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
684 Array::from_vec(IxDyn::new(shape), data).unwrap()
685 }
686
687 fn arr1d(data: Vec<f64>) -> Array<f64, Ix1> {
688 let n = data.len();
689 Array::from_vec(Ix1::new([n]), data).unwrap()
690 }
691
692 #[test]
695 fn test_pad_1d_constant() {
696 let a = arr1d(vec![1.0, 2.0, 3.0]);
697 let b = pad_1d(&a, (2, 3), &PadMode::Constant(0.0)).unwrap();
698 assert_eq!(b.shape(), &[8]);
699 let data: Vec<f64> = b.iter().copied().collect();
700 assert_eq!(data, vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
701 }
702
703 #[test]
704 fn test_pad_1d_edge() {
705 let a = arr1d(vec![1.0, 2.0, 3.0]);
706 let b = pad_1d(&a, (2, 2), &PadMode::Edge).unwrap();
707 let data: Vec<f64> = b.iter().copied().collect();
708 assert_eq!(data, vec![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
709 }
710
711 #[test]
712 fn test_pad_1d_wrap() {
713 let a = arr1d(vec![1.0, 2.0, 3.0]);
714 let b = pad_1d(&a, (2, 2), &PadMode::Wrap).unwrap();
715 let data: Vec<f64> = b.iter().copied().collect();
716 assert_eq!(data, vec![2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0]);
717 }
718
719 #[test]
720 fn test_pad_nd_constant() {
721 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
722 let b = pad(&a, &[(1, 1), (1, 1)], &PadMode::Constant(0.0)).unwrap();
723 assert_eq!(b.shape(), &[4, 4]);
724 let data: Vec<f64> = b.iter().copied().collect();
725 assert_eq!(
726 data,
727 vec![
728 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0
729 ]
730 );
731 }
732
733 #[test]
736 fn test_tile_1d() {
737 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
738 let b = tile(&a, &[3]).unwrap();
739 assert_eq!(b.shape(), &[9]);
740 let data: Vec<f64> = b.iter().copied().collect();
741 assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
742 }
743
744 #[test]
745 fn test_tile_2d() {
746 let a = dyn_arr(&[2], vec![1.0, 2.0]);
747 let b = tile(&a, &[2, 3]).unwrap();
748 assert_eq!(b.shape(), &[2, 6]);
749 }
750
751 #[test]
754 fn test_repeat_flat() {
755 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
756 let b = repeat(&a, 2, None).unwrap();
757 let data: Vec<f64> = b.iter().copied().collect();
758 assert_eq!(data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
759 }
760
761 #[test]
762 fn test_repeat_axis() {
763 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
764 let b = repeat(&a, 2, Some(0)).unwrap();
765 assert_eq!(b.shape(), &[4, 2]);
766 let data: Vec<f64> = b.iter().copied().collect();
767 assert_eq!(data, vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]);
768 }
769
770 #[test]
773 fn test_delete() {
774 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
775 let b = delete(&a, &[1, 3], 0).unwrap();
776 let data: Vec<f64> = b.iter().copied().collect();
777 assert_eq!(data, vec![1.0, 3.0, 5.0]);
778 }
779
780 #[test]
781 fn test_delete_2d() {
782 let a = dyn_arr(&[3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
783 let b = delete(&a, &[1], 0).unwrap();
784 assert_eq!(b.shape(), &[2, 2]);
785 let data: Vec<f64> = b.iter().copied().collect();
786 assert_eq!(data, vec![1.0, 2.0, 5.0, 6.0]);
787 }
788
789 #[test]
792 fn test_insert() {
793 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
794 let vals = dyn_arr(&[2], vec![10.0, 20.0]);
795 let b = insert(&a, 1, &vals, 0).unwrap();
796 let data: Vec<f64> = b.iter().copied().collect();
797 assert_eq!(data, vec![1.0, 10.0, 20.0, 2.0, 3.0]);
798 }
799
800 #[test]
803 fn test_append_flat() {
804 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
805 let vals = dyn_arr(&[2], vec![4.0, 5.0]);
806 let b = append(&a, &vals, None).unwrap();
807 let data: Vec<f64> = b.iter().copied().collect();
808 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
809 }
810
811 #[test]
812 fn test_append_axis() {
813 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
814 let vals = dyn_arr(&[2, 1], vec![5.0, 6.0]);
815 let b = append(&a, &vals, Some(1)).unwrap();
816 assert_eq!(b.shape(), &[2, 3]);
817 }
818
819 #[test]
822 fn test_resize_larger() {
823 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
824 let b = resize(&a, &[5]).unwrap();
825 let data: Vec<f64> = b.iter().copied().collect();
826 assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0]);
827 }
828
829 #[test]
830 fn test_resize_smaller() {
831 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
832 let b = resize(&a, &[3]).unwrap();
833 let data: Vec<f64> = b.iter().copied().collect();
834 assert_eq!(data, vec![1.0, 2.0, 3.0]);
835 }
836
837 #[test]
838 fn test_resize_2d() {
839 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
840 let b = resize(&a, &[3, 3]).unwrap();
841 assert_eq!(b.shape(), &[3, 3]);
842 }
843
844 #[test]
847 fn test_trim_zeros_both() {
848 let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
849 let b = trim_zeros(&a, "fb").unwrap();
850 let data: Vec<f64> = b.iter().copied().collect();
851 assert_eq!(data, vec![1.0, 2.0, 3.0]);
852 }
853
854 #[test]
855 fn test_trim_zeros_front() {
856 let a = arr1d(vec![0.0, 0.0, 1.0, 2.0, 0.0]);
857 let b = trim_zeros(&a, "f").unwrap();
858 let data: Vec<f64> = b.iter().copied().collect();
859 assert_eq!(data, vec![1.0, 2.0, 0.0]);
860 }
861
862 #[test]
863 fn test_trim_zeros_back() {
864 let a = arr1d(vec![0.0, 1.0, 2.0, 0.0, 0.0]);
865 let b = trim_zeros(&a, "b").unwrap();
866 let data: Vec<f64> = b.iter().copied().collect();
867 assert_eq!(data, vec![0.0, 1.0, 2.0]);
868 }
869
870 #[test]
871 fn test_trim_zeros_all_zeros() {
872 let a = arr1d(vec![0.0, 0.0, 0.0]);
873 let b = trim_zeros(&a, "fb").unwrap();
874 assert_eq!(b.shape(), &[0]);
875 }
876
877 #[test]
878 fn test_trim_zeros_bad_mode() {
879 let a = arr1d(vec![1.0, 2.0]);
880 assert!(trim_zeros(&a, "x").is_err());
881 }
882}