use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use meshgrid_impl::Meshgrid;
#[allow(unused_imports)]
use std::compile_error;
use std::mem::{forget, size_of};
use std::ptr::NonNull;
use crate::{dimension, ArcArray1, ArcArray2};
use crate::{imp_prelude::*, ArrayPartsSized};
#[macro_export]
macro_rules! array {
($([$([$([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
compile_error!("Arrays of 7 dimensions or more (or ndarrays of Rust arrays) cannot be constructed with the array! macro.");
}};
($([$([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
$crate::Array6::from(vec![$([$([$([$([$([$($x,)*],)*],)*],)*],)*],)*])
}};
($([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
$crate::Array5::from(vec![$([$([$([$([$($x,)*],)*],)*],)*],)*])
}};
($([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{
$crate::Array4::from(vec![$([$([$([$($x,)*],)*],)*],)*])
}};
($([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*) => {{
$crate::Array3::from(vec![$([$([$($x,)*],)*],)*])
}};
($([$($x:expr),* $(,)*]),+ $(,)*) => {{
$crate::Array2::from(vec![$([$($x,)*],)*])
}};
($($x:expr),* $(,)*) => {{
$crate::Array::from(vec![$($x,)*])
}};
}
pub fn arr0<A>(x: A) -> Array0<A>
{
unsafe { ArrayBase::from_shape_vec_unchecked((), vec![x]) }
}
pub fn arr1<A: Clone>(xs: &[A]) -> Array1<A>
{
ArrayBase::from(xs.to_vec())
}
pub fn rcarr1<A: Clone>(xs: &[A]) -> ArcArray1<A>
{
arr1(xs).into_shared()
}
pub const fn aview0<A>(x: &A) -> ArrayView0<'_, A>
{
ArrayBase {
data: ViewRepr::new(),
parts: ArrayPartsSized::new(
unsafe { NonNull::new_unchecked(x as *const A as *mut A) },
Ix0(),
Ix0(),
),
}
}
pub const fn aview1<A>(xs: &[A]) -> ArrayView1<'_, A>
{
if size_of::<A>() == 0 {
assert!(
xs.len() <= isize::MAX as usize,
"Slice length must fit in `isize`.",
);
}
ArrayBase {
data: ViewRepr::new(),
parts: ArrayPartsSized::new(
unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) },
Ix1(xs.len()),
Ix1(1),
),
}
}
pub const fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A>
{
let cols = N;
let rows = xs.len();
if size_of::<A>() == 0 {
if let Some(n_elems) = rows.checked_mul(cols) {
assert!(
rows <= isize::MAX as usize
&& cols <= isize::MAX as usize
&& n_elems <= isize::MAX as usize,
"Product of non-zero axis lengths must not overflow isize.",
);
} else {
panic!("Overflow in number of elements.");
}
} else if N == 0 {
assert!(
rows <= isize::MAX as usize,
"Product of non-zero axis lengths must not overflow isize.",
);
}
let ptr = unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) };
let dim = Ix2(rows, cols);
let strides = if rows == 0 || cols == 0 {
Ix2(0, 0)
} else {
Ix2(cols, 1)
};
ArrayBase {
data: ViewRepr::new(),
parts: ArrayPartsSized::new(ptr, dim, strides),
}
}
pub fn aview_mut1<A>(xs: &mut [A]) -> ArrayViewMut1<'_, A>
{
ArrayViewMut::from(xs)
}
pub fn aview_mut2<A, const N: usize>(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A>
{
ArrayViewMut2::from(xs)
}
pub fn arr2<A: Clone, const N: usize>(xs: &[[A; N]]) -> Array2<A>
{
Array2::from(xs.to_vec())
}
macro_rules! impl_from_nested_vec {
($arr_type:ty, $ix_type:tt, $($n:ident),+) => {
impl<A, $(const $n: usize),+> From<Vec<$arr_type>> for Array<A, $ix_type>
{
fn from(mut xs: Vec<$arr_type>) -> Self
{
let dim = $ix_type(xs.len(), $($n),+);
let ptr = xs.as_mut_ptr();
let cap = xs.capacity();
let expand_len = dimension::size_of_shape_checked(&dim)
.expect("Product of non-zero axis lengths must not overflow isize.");
forget(xs);
unsafe {
let v = if size_of::<A>() == 0 {
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len)
} else if $($n == 0 ||)+ false {
Vec::new()
} else {
let expand_cap = cap $(* $n)+;
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap)
};
ArrayBase::from_shape_vec_unchecked(dim, v)
}
}
}
};
}
impl_from_nested_vec!([A; N], Ix2, N);
impl_from_nested_vec!([[A; M]; N], Ix3, N, M);
impl_from_nested_vec!([[[A; L]; M]; N], Ix4, N, M, L);
impl_from_nested_vec!([[[[A; K]; L]; M]; N], Ix5, N, M, L, K);
impl_from_nested_vec!([[[[[A; J]; K]; L]; M]; N], Ix6, N, M, L, K, J);
pub fn rcarr2<A: Clone, const N: usize>(xs: &[[A; N]]) -> ArcArray2<A>
{
arr2(xs).into_shared()
}
pub fn arr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> Array3<A>
{
Array3::from(xs.to_vec())
}
pub fn rcarr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> ArcArray<A, Ix3>
{
arr3(xs).into_shared()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MeshIndex
{
XY,
IJ,
}
mod meshgrid_impl
{
use super::MeshIndex;
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
use crate::{
ArrayBase,
ArrayRef1,
ArrayView,
ArrayView2,
ArrayView3,
ArrayView4,
ArrayView5,
ArrayView6,
Axis,
Data,
Dim,
IntoDimension,
Ix1,
LayoutRef1,
};
fn construct_strides<A, const N: usize>(
arr: &LayoutRef1<A>, idx: usize, indexing: MeshIndex,
) -> <[usize; N] as IntoDimension>::Dim
where [usize; N]: IntoDimension
{
let mut ret = [0; N];
if idx < 2 && indexing == MeshIndex::XY {
ret[1 - idx] = arr.stride_of(Axis(0)) as usize;
} else {
ret[idx] = arr.stride_of(Axis(0)) as usize;
}
Dim(ret)
}
fn construct_shape<A, const N: usize>(
arrays: [&LayoutRef1<A>; N], indexing: MeshIndex,
) -> <[usize; N] as IntoDimension>::Dim
where [usize; N]: IntoDimension
{
let mut ret = arrays.map(|a| a.len());
if indexing == MeshIndex::XY {
ret.swap(0, 1);
}
Dim(ret)
}
pub trait Meshgrid
{
type Output;
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output;
}
macro_rules! meshgrid_body {
($count:literal, $indexing:expr, $(($arr:expr, $idx:literal)),+) => {
{
let shape = construct_shape([$($arr),+], $indexing);
(
$({
let strides = construct_strides::<_, $count>($arr, $idx, $indexing);
unsafe { ArrayView::new(nonnull_debug_checked_from_ptr($arr.as_ptr() as *mut A), shape, strides) }
}),+
)
}
};
}
impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>)
{
type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
meshgrid_body!(2, indexing, (arrays.0, 0), (arrays.1, 1))
}
}
impl<'a, 'b, S1, S2, A: 'b + 'a> Meshgrid for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>)
where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
{
type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
Meshgrid::meshgrid((&**arrays.0, &**arrays.1), indexing)
}
}
impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>)
{
type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
meshgrid_body!(3, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2))
}
}
impl<'a, 'b, 'c, S1, S2, S3, A: 'b + 'a + 'c> Meshgrid
for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>)
where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
S3: Data<Elem = A>,
{
type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2), indexing)
}
}
impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>)
{
type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
meshgrid_body!(4, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3))
}
}
impl<'a, 'b, 'c, 'd, S1, S2, S3, S4, A: 'a + 'b + 'c + 'd> Meshgrid
for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>, &'d ArrayBase<S4, Ix1>)
where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
S3: Data<Elem = A>,
S4: Data<Elem = A>,
{
type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3), indexing)
}
}
impl<'a, 'b, 'c, 'd, 'e, A> Meshgrid
for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>, &'e ArrayRef1<A>)
{
type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
meshgrid_body!(5, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4))
}
}
impl<'a, 'b, 'c, 'd, 'e, S1, S2, S3, S4, S5, A: 'a + 'b + 'c + 'd + 'e> Meshgrid
for (
&'a ArrayBase<S1, Ix1>,
&'b ArrayBase<S2, Ix1>,
&'c ArrayBase<S3, Ix1>,
&'d ArrayBase<S4, Ix1>,
&'e ArrayBase<S5, Ix1>,
)
where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
S3: Data<Elem = A>,
S4: Data<Elem = A>,
S5: Data<Elem = A>,
{
type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4), indexing)
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, A> Meshgrid
for (
&'a ArrayRef1<A>,
&'b ArrayRef1<A>,
&'c ArrayRef1<A>,
&'d ArrayRef1<A>,
&'e ArrayRef1<A>,
&'f ArrayRef1<A>,
)
{
type Output = (
ArrayView6<'a, A>,
ArrayView6<'b, A>,
ArrayView6<'c, A>,
ArrayView6<'d, A>,
ArrayView6<'e, A>,
ArrayView6<'f, A>,
);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
meshgrid_body!(6, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4), (arrays.5, 5))
}
}
impl<'a, 'b, 'c, 'd, 'e, 'f, S1, S2, S3, S4, S5, S6, A: 'a + 'b + 'c + 'd + 'e + 'f> Meshgrid
for (
&'a ArrayBase<S1, Ix1>,
&'b ArrayBase<S2, Ix1>,
&'c ArrayBase<S3, Ix1>,
&'d ArrayBase<S4, Ix1>,
&'e ArrayBase<S5, Ix1>,
&'f ArrayBase<S6, Ix1>,
)
where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
S3: Data<Elem = A>,
S4: Data<Elem = A>,
S5: Data<Elem = A>,
S6: Data<Elem = A>,
{
type Output = (
ArrayView6<'a, A>,
ArrayView6<'b, A>,
ArrayView6<'c, A>,
ArrayView6<'d, A>,
ArrayView6<'e, A>,
ArrayView6<'f, A>,
);
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
{
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4, &**arrays.5), indexing)
}
}
}
pub fn meshgrid<T: Meshgrid>(arrays: T, indexing: MeshIndex) -> T::Output
{
Meshgrid::meshgrid(arrays, indexing)
}
#[cfg(test)]
mod tests
{
use super::s;
use crate::{meshgrid, Axis, MeshIndex};
#[cfg(not(feature = "std"))]
use alloc::vec;
#[test]
fn test_meshgrid2()
{
let x = array![1, 2, 3];
let y = array![4, 5, 6, 7];
let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
assert_eq!(xx, array![[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]);
assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]);
let (xx, yy) = meshgrid((&x, &y), MeshIndex::IJ);
assert_eq!(xx, array![[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]);
assert_eq!(yy, array![[4, 5, 6, 7], [4, 5, 6, 7], [4, 5, 6, 7]]);
}
#[test]
fn test_meshgrid3()
{
let x = array![1, 2, 3];
let y = array![4, 5, 6, 7];
let z = array![-1, -2];
let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::XY);
assert_eq!(xx, array![
[[1, 1], [2, 2], [3, 3]],
[[1, 1], [2, 2], [3, 3]],
[[1, 1], [2, 2], [3, 3]],
[[1, 1], [2, 2], [3, 3]],
]);
assert_eq!(yy, array![
[[4, 4], [4, 4], [4, 4]],
[[5, 5], [5, 5], [5, 5]],
[[6, 6], [6, 6], [6, 6]],
[[7, 7], [7, 7], [7, 7]],
]);
assert_eq!(zz, array![
[[-1, -2], [-1, -2], [-1, -2]],
[[-1, -2], [-1, -2], [-1, -2]],
[[-1, -2], [-1, -2], [-1, -2]],
[[-1, -2], [-1, -2], [-1, -2]],
]);
let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::IJ);
assert_eq!(xx, array![
[[1, 1], [1, 1], [1, 1], [1, 1]],
[[2, 2], [2, 2], [2, 2], [2, 2]],
[[3, 3], [3, 3], [3, 3], [3, 3]],
]);
assert_eq!(yy, array![
[[4, 4], [5, 5], [6, 6], [7, 7]],
[[4, 4], [5, 5], [6, 6], [7, 7]],
[[4, 4], [5, 5], [6, 6], [7, 7]],
]);
assert_eq!(zz, array![
[[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
[[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
[[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
]);
}
#[test]
fn test_meshgrid_from_offset()
{
let x = array![1, 2, 3];
let x = x.slice(s![1..]);
let y = array![4, 5, 6];
let y = y.slice(s![1..]);
let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
assert_eq!(xx, array![[2, 3], [2, 3]]);
assert_eq!(yy, array![[5, 5], [6, 6]]);
}
#[test]
fn test_meshgrid_neg_stride()
{
let x = array![1, 2, 3];
let x = x.slice(s![..;-1]);
assert!(x.stride_of(Axis(0)) < 0); let y = array![4, 5, 6];
let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
assert_eq!(xx, array![[3, 2, 1], [3, 2, 1], [3, 2, 1]]);
assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6]]);
}
}