1use alloc::vec;
10#[cfg(not(feature = "std"))]
11use alloc::vec::Vec;
12use meshgrid_impl::Meshgrid;
13#[allow(unused_imports)]
14use std::compile_error;
15use std::mem::{forget, size_of};
16use std::ptr::NonNull;
17
18use crate::{dimension, ArcArray1, ArcArray2};
19use crate::{imp_prelude::*, ArrayPartsSized};
20
21#[macro_export]
65macro_rules! array {
66 ($([$([$([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
67 compile_error!("Arrays of 7 dimensions or more (or ndarrays of Rust arrays) cannot be constructed with the array! macro.");
68 }};
69 ($([$([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
70 $crate::Array6::from(vec![$([$([$([$([$([$($x,)*],)*],)*],)*],)*],)*])
71 }};
72 ($([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
73 $crate::Array5::from(vec![$([$([$([$([$($x,)*],)*],)*],)*],)*])
74 }};
75 ($([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
76 $crate::Array4::from(vec![$([$([$([$($x,)*],)*],)*],)*])
77 }};
78 ($([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*) => {{
79 $crate::Array3::from(vec![$([$([$($x,)*],)*],)*])
80 }};
81 ($([$($x:expr),* $(,)*]),+ $(,)*) => {{
82 $crate::Array2::from(vec![$([$($x,)*],)*])
83 }};
84 ($($x:expr),* $(,)*) => {{
85 $crate::Array::from(vec![$($x,)*])
86 }};
87}
88
89pub fn arr0<A>(x: A) -> Array0<A>
91{
92 unsafe { ArrayBase::from_shape_vec_unchecked((), vec![x]) }
93}
94
95pub fn arr1<A: Clone>(xs: &[A]) -> Array1<A>
97{
98 ArrayBase::from(xs.to_vec())
99}
100
101pub fn rcarr1<A: Clone>(xs: &[A]) -> ArcArray1<A>
103{
104 arr1(xs).into_shared()
105}
106
107pub const fn aview0<A>(x: &A) -> ArrayView0<'_, A>
109{
110 ArrayBase {
111 data: ViewRepr::new(),
112 parts: ArrayPartsSized::new(
113 unsafe { NonNull::new_unchecked(x as *const A as *mut A) },
115 Ix0(),
116 Ix0(),
117 ),
118 }
119}
120
121pub const fn aview1<A>(xs: &[A]) -> ArrayView1<'_, A>
143{
144 if size_of::<A>() == 0 {
145 assert!(
146 xs.len() <= isize::MAX as usize,
147 "Slice length must fit in `isize`.",
148 );
149 }
150 ArrayBase {
151 data: ViewRepr::new(),
152 parts: ArrayPartsSized::new(
153 unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) },
155 Ix1(xs.len()),
156 Ix1(1),
157 ),
158 }
159}
160
161pub const fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A>
180{
181 let cols = N;
182 let rows = xs.len();
183 if size_of::<A>() == 0 {
184 if let Some(n_elems) = rows.checked_mul(cols) {
185 assert!(
186 rows <= isize::MAX as usize
187 && cols <= isize::MAX as usize
188 && n_elems <= isize::MAX as usize,
189 "Product of non-zero axis lengths must not overflow isize.",
190 );
191 } else {
192 panic!("Overflow in number of elements.");
193 }
194 } else if N == 0 {
195 assert!(
196 rows <= isize::MAX as usize,
197 "Product of non-zero axis lengths must not overflow isize.",
198 );
199 }
200 let ptr = unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) };
202 let dim = Ix2(rows, cols);
203 let strides = if rows == 0 || cols == 0 {
204 Ix2(0, 0)
205 } else {
206 Ix2(cols, 1)
207 };
208 ArrayBase {
209 data: ViewRepr::new(),
210 parts: ArrayPartsSized::new(ptr, dim, strides),
211 }
212}
213
214pub fn aview_mut1<A>(xs: &mut [A]) -> ArrayViewMut1<'_, A>
227{
228 ArrayViewMut::from(xs)
229}
230
231pub fn aview_mut2<A, const N: usize>(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A>
254{
255 ArrayViewMut2::from(xs)
256}
257
258pub fn arr2<A: Clone, const N: usize>(xs: &[[A; N]]) -> Array2<A>
270{
271 Array2::from(xs.to_vec())
272}
273
274macro_rules! impl_from_nested_vec {
275 ($arr_type:ty, $ix_type:tt, $($n:ident),+) => {
276 impl<A, $(const $n: usize),+> From<Vec<$arr_type>> for Array<A, $ix_type>
277 {
278 fn from(mut xs: Vec<$arr_type>) -> Self
279 {
280 let dim = $ix_type(xs.len(), $($n),+);
281 let ptr = xs.as_mut_ptr();
282 let cap = xs.capacity();
283 let expand_len = dimension::size_of_shape_checked(&dim)
284 .expect("Product of non-zero axis lengths must not overflow isize.");
285 forget(xs);
286 unsafe {
287 let v = if size_of::<A>() == 0 {
288 Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len)
289 } else if $($n == 0 ||)+ false {
290 Vec::new()
291 } else {
292 let expand_cap = cap $(* $n)+;
293 Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap)
294 };
295 ArrayBase::from_shape_vec_unchecked(dim, v)
296 }
297 }
298 }
299 };
300}
301
302impl_from_nested_vec!([A; N], Ix2, N);
303impl_from_nested_vec!([[A; M]; N], Ix3, N, M);
304impl_from_nested_vec!([[[A; L]; M]; N], Ix4, N, M, L);
305impl_from_nested_vec!([[[[A; K]; L]; M]; N], Ix5, N, M, L, K);
306impl_from_nested_vec!([[[[[A; J]; K]; L]; M]; N], Ix6, N, M, L, K, J);
307
308pub fn rcarr2<A: Clone, const N: usize>(xs: &[[A; N]]) -> ArcArray2<A>
311{
312 arr2(xs).into_shared()
313}
314
315pub fn arr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> Array3<A>
333{
334 Array3::from(xs.to_vec())
335}
336
337pub fn rcarr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> ArcArray<A, Ix3>
339{
340 arr3(xs).into_shared()
341}
342
343#[derive(Debug, Clone, Copy, PartialEq, Eq)]
347pub enum MeshIndex
348{
349 XY,
355 IJ,
359}
360
361mod meshgrid_impl
362{
363 use super::MeshIndex;
364 use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
365 use crate::{
366 ArrayBase,
367 ArrayRef1,
368 ArrayView,
369 ArrayView2,
370 ArrayView3,
371 ArrayView4,
372 ArrayView5,
373 ArrayView6,
374 Axis,
375 Data,
376 Dim,
377 IntoDimension,
378 Ix1,
379 LayoutRef1,
380 };
381
382 fn construct_strides<A, const N: usize>(
384 arr: &LayoutRef1<A>, idx: usize, indexing: MeshIndex,
385 ) -> <[usize; N] as IntoDimension>::Dim
386 where [usize; N]: IntoDimension
387 {
388 let mut ret = [0; N];
389 if idx < 2 && indexing == MeshIndex::XY {
390 ret[1 - idx] = arr.stride_of(Axis(0)) as usize;
391 } else {
392 ret[idx] = arr.stride_of(Axis(0)) as usize;
393 }
394 Dim(ret)
395 }
396
397 fn construct_shape<A, const N: usize>(
399 arrays: [&LayoutRef1<A>; N], indexing: MeshIndex,
400 ) -> <[usize; N] as IntoDimension>::Dim
401 where [usize; N]: IntoDimension
402 {
403 let mut ret = arrays.map(|a| a.len());
404 if indexing == MeshIndex::XY {
405 ret.swap(0, 1);
406 }
407 Dim(ret)
408 }
409
410 pub trait Meshgrid
417 {
418 type Output;
419
420 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output;
421 }
422
423 macro_rules! meshgrid_body {
424 ($count:literal, $indexing:expr, $(($arr:expr, $idx:literal)),+) => {
425 {
426 let shape = construct_shape([$($arr),+], $indexing);
427 (
428 $({
429 let strides = construct_strides::<_, $count>($arr, $idx, $indexing);
430 unsafe { ArrayView::new(nonnull_debug_checked_from_ptr($arr.as_ptr() as *mut A), shape, strides) }
431 }),+
432 )
433 }
434 };
435 }
436
437 impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>)
438 {
439 type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>);
440
441 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
442 {
443 meshgrid_body!(2, indexing, (arrays.0, 0), (arrays.1, 1))
444 }
445 }
446
447 impl<'a, 'b, S1, S2, A: 'b + 'a> Meshgrid for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>)
448 where
449 S1: Data<Elem = A>,
450 S2: Data<Elem = A>,
451 {
452 type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>);
453
454 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
455 {
456 Meshgrid::meshgrid((&**arrays.0, &**arrays.1), indexing)
457 }
458 }
459
460 impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>)
461 {
462 type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>);
463
464 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
465 {
466 meshgrid_body!(3, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2))
467 }
468 }
469
470 impl<'a, 'b, 'c, S1, S2, S3, A: 'b + 'a + 'c> Meshgrid
471 for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>)
472 where
473 S1: Data<Elem = A>,
474 S2: Data<Elem = A>,
475 S3: Data<Elem = A>,
476 {
477 type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>);
478
479 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
480 {
481 Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2), indexing)
482 }
483 }
484
485 impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>)
486 {
487 type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>);
488
489 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
490 {
491 meshgrid_body!(4, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3))
492 }
493 }
494
495 impl<'a, 'b, 'c, 'd, S1, S2, S3, S4, A: 'a + 'b + 'c + 'd> Meshgrid
496 for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>, &'d ArrayBase<S4, Ix1>)
497 where
498 S1: Data<Elem = A>,
499 S2: Data<Elem = A>,
500 S3: Data<Elem = A>,
501 S4: Data<Elem = A>,
502 {
503 type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>);
504
505 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
506 {
507 Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3), indexing)
508 }
509 }
510
511 impl<'a, 'b, 'c, 'd, 'e, A> Meshgrid
512 for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>, &'e ArrayRef1<A>)
513 {
514 type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>);
515
516 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
517 {
518 meshgrid_body!(5, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4))
519 }
520 }
521
522 impl<'a, 'b, 'c, 'd, 'e, S1, S2, S3, S4, S5, A: 'a + 'b + 'c + 'd + 'e> Meshgrid
523 for (
524 &'a ArrayBase<S1, Ix1>,
525 &'b ArrayBase<S2, Ix1>,
526 &'c ArrayBase<S3, Ix1>,
527 &'d ArrayBase<S4, Ix1>,
528 &'e ArrayBase<S5, Ix1>,
529 )
530 where
531 S1: Data<Elem = A>,
532 S2: Data<Elem = A>,
533 S3: Data<Elem = A>,
534 S4: Data<Elem = A>,
535 S5: Data<Elem = A>,
536 {
537 type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>);
538
539 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
540 {
541 Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4), indexing)
542 }
543 }
544
545 impl<'a, 'b, 'c, 'd, 'e, 'f, A> Meshgrid
546 for (
547 &'a ArrayRef1<A>,
548 &'b ArrayRef1<A>,
549 &'c ArrayRef1<A>,
550 &'d ArrayRef1<A>,
551 &'e ArrayRef1<A>,
552 &'f ArrayRef1<A>,
553 )
554 {
555 type Output = (
556 ArrayView6<'a, A>,
557 ArrayView6<'b, A>,
558 ArrayView6<'c, A>,
559 ArrayView6<'d, A>,
560 ArrayView6<'e, A>,
561 ArrayView6<'f, A>,
562 );
563
564 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
565 {
566 meshgrid_body!(6, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4), (arrays.5, 5))
567 }
568 }
569
570 impl<'a, 'b, 'c, 'd, 'e, 'f, S1, S2, S3, S4, S5, S6, A: 'a + 'b + 'c + 'd + 'e + 'f> Meshgrid
571 for (
572 &'a ArrayBase<S1, Ix1>,
573 &'b ArrayBase<S2, Ix1>,
574 &'c ArrayBase<S3, Ix1>,
575 &'d ArrayBase<S4, Ix1>,
576 &'e ArrayBase<S5, Ix1>,
577 &'f ArrayBase<S6, Ix1>,
578 )
579 where
580 S1: Data<Elem = A>,
581 S2: Data<Elem = A>,
582 S3: Data<Elem = A>,
583 S4: Data<Elem = A>,
584 S5: Data<Elem = A>,
585 S6: Data<Elem = A>,
586 {
587 type Output = (
588 ArrayView6<'a, A>,
589 ArrayView6<'b, A>,
590 ArrayView6<'c, A>,
591 ArrayView6<'d, A>,
592 ArrayView6<'e, A>,
593 ArrayView6<'f, A>,
594 );
595
596 fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
597 {
598 Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4, &**arrays.5), indexing)
599 }
600 }
601}
602
603pub fn meshgrid<T: Meshgrid>(arrays: T, indexing: MeshIndex) -> T::Output
653{
654 Meshgrid::meshgrid(arrays, indexing)
655}
656
657#[cfg(test)]
658mod tests
659{
660 use super::s;
661 use crate::{meshgrid, Axis, MeshIndex};
662 #[cfg(not(feature = "std"))]
663 use alloc::vec;
664
665 #[test]
666 fn test_meshgrid2()
667 {
668 let x = array![1, 2, 3];
669 let y = array![4, 5, 6, 7];
670 let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
671 assert_eq!(xx, array![[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]);
672 assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]);
673
674 let (xx, yy) = meshgrid((&x, &y), MeshIndex::IJ);
675 assert_eq!(xx, array![[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]);
676 assert_eq!(yy, array![[4, 5, 6, 7], [4, 5, 6, 7], [4, 5, 6, 7]]);
677 }
678
679 #[test]
680 fn test_meshgrid3()
681 {
682 let x = array![1, 2, 3];
683 let y = array![4, 5, 6, 7];
684 let z = array![-1, -2];
685 let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::XY);
686 assert_eq!(xx, array![
687 [[1, 1], [2, 2], [3, 3]],
688 [[1, 1], [2, 2], [3, 3]],
689 [[1, 1], [2, 2], [3, 3]],
690 [[1, 1], [2, 2], [3, 3]],
691 ]);
692 assert_eq!(yy, array![
693 [[4, 4], [4, 4], [4, 4]],
694 [[5, 5], [5, 5], [5, 5]],
695 [[6, 6], [6, 6], [6, 6]],
696 [[7, 7], [7, 7], [7, 7]],
697 ]);
698 assert_eq!(zz, array![
699 [[-1, -2], [-1, -2], [-1, -2]],
700 [[-1, -2], [-1, -2], [-1, -2]],
701 [[-1, -2], [-1, -2], [-1, -2]],
702 [[-1, -2], [-1, -2], [-1, -2]],
703 ]);
704
705 let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::IJ);
706 assert_eq!(xx, array![
707 [[1, 1], [1, 1], [1, 1], [1, 1]],
708 [[2, 2], [2, 2], [2, 2], [2, 2]],
709 [[3, 3], [3, 3], [3, 3], [3, 3]],
710 ]);
711 assert_eq!(yy, array![
712 [[4, 4], [5, 5], [6, 6], [7, 7]],
713 [[4, 4], [5, 5], [6, 6], [7, 7]],
714 [[4, 4], [5, 5], [6, 6], [7, 7]],
715 ]);
716 assert_eq!(zz, array![
717 [[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
718 [[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
719 [[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
720 ]);
721 }
722
723 #[test]
724 fn test_meshgrid_from_offset()
725 {
726 let x = array![1, 2, 3];
727 let x = x.slice(s![1..]);
728 let y = array![4, 5, 6];
729 let y = y.slice(s![1..]);
730 let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
731 assert_eq!(xx, array![[2, 3], [2, 3]]);
732 assert_eq!(yy, array![[5, 5], [6, 6]]);
733 }
734
735 #[test]
736 fn test_meshgrid_neg_stride()
737 {
738 let x = array![1, 2, 3];
739 let x = x.slice(s![..;-1]);
740 assert!(x.stride_of(Axis(0)) < 0); let y = array![4, 5, 6];
742 let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
743 assert_eq!(xx, array![[3, 2, 1], [3, 2, 1], [3, 2, 1]]);
744 assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6]]);
745 }
746}