use std::ffi::c_int;
use smol_str::format_smolstr;
use crate::error::{Error, OutOfRangePayload, Result};
mod sealed {
pub trait Sealed {}
impl Sealed for &[i32] {}
impl<const N: usize> Sealed for [i32; N] {}
impl Sealed for &[usize] {}
impl Sealed for Vec<i32> {}
impl Sealed for Vec<usize> {}
impl Sealed for (usize,) {}
impl Sealed for (usize, usize) {}
impl Sealed for (usize, usize, usize) {}
impl Sealed for (usize, usize, usize, usize) {}
impl Sealed for (usize, usize, usize, usize, usize) {}
impl Sealed for (usize, usize, usize, usize, usize, usize) {}
impl Sealed for (usize, usize, usize, usize, usize, usize, usize) {}
impl Sealed for (usize, usize, usize, usize, usize, usize, usize, usize) {}
}
pub trait IntoShape: sealed::Sealed {
fn with_shape<R>(&self, f: impl FnOnce(&[c_int]) -> Result<R>) -> Result<R>;
}
pub fn validate_dims(s: &[c_int]) -> Result<()> {
for (i, &d) in s.iter().enumerate() {
if d < 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"shape::validate_dims: dim",
"must be non-negative",
format_smolstr!("dim[{i}]={d}"),
)));
}
}
Ok(())
}
static EMPTY_DIM_SENTINEL: c_int = 0;
#[inline]
pub(crate) fn dim_ptr(s: &[c_int]) -> *const c_int {
if s.is_empty() {
&EMPTY_DIM_SENTINEL as *const c_int
} else {
s.as_ptr()
}
}
static EMPTY_STRIDE_SENTINEL: i64 = 0;
#[inline]
pub(crate) fn stride_ptr(s: &[i64]) -> *const i64 {
if s.is_empty() {
&EMPTY_STRIDE_SENTINEL as *const i64
} else {
s.as_ptr()
}
}
impl IntoShape for &[i32] {
fn with_shape<R>(&self, f: impl FnOnce(&[c_int]) -> Result<R>) -> Result<R> {
f(self)
}
}
impl<const N: usize> IntoShape for [i32; N] {
fn with_shape<R>(&self, f: impl FnOnce(&[c_int]) -> Result<R>) -> Result<R> {
f(&self[..])
}
}
fn convert_dim(d: usize) -> Result<c_int> {
i32::try_from(d).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"shape::convert_dim",
"must fit in i32",
format_smolstr!("{d}"),
))
})
}
impl IntoShape for &[usize] {
fn with_shape<R>(&self, f: impl FnOnce(&[c_int]) -> Result<R>) -> Result<R> {
if self.len() <= 8 {
let mut buf = [0i32; 8];
for (i, &d) in self.iter().enumerate() {
buf[i] = convert_dim(d)?;
}
f(&buf[..self.len()])
} else {
let v: Vec<c_int> = self
.iter()
.map(|&d| convert_dim(d))
.collect::<Result<_>>()?;
f(&v)
}
}
}
impl IntoShape for Vec<i32> {
fn with_shape<R>(&self, f: impl FnOnce(&[c_int]) -> Result<R>) -> Result<R> {
self.as_slice().with_shape(f)
}
}
impl IntoShape for Vec<usize> {
fn with_shape<R>(&self, f: impl FnOnce(&[c_int]) -> Result<R>) -> Result<R> {
self.as_slice().with_shape(f)
}
}
macro_rules! tuple_into_shape {
($($T:ty),+ => $($idx:tt),+) => {
impl IntoShape for ($($T,)+) {
fn with_shape<R>(&self, f: impl FnOnce(&[c_int]) -> Result<R>) -> Result<R> {
let s = [$(convert_dim(self.$idx)?),+];
f(&s)
}
}
};
}
tuple_into_shape!(usize => 0);
tuple_into_shape!(usize, usize => 0, 1);
tuple_into_shape!(usize, usize, usize => 0, 1, 2);
tuple_into_shape!(usize, usize, usize, usize => 0, 1, 2, 3);
tuple_into_shape!(usize, usize, usize, usize, usize => 0, 1, 2, 3, 4);
tuple_into_shape!(usize, usize, usize, usize, usize, usize => 0, 1, 2, 3, 4, 5);
tuple_into_shape!(usize, usize, usize, usize, usize, usize, usize => 0, 1, 2, 3, 4, 5, 6);
tuple_into_shape!(usize, usize, usize, usize, usize, usize, usize, usize => 0, 1, 2, 3, 4, 5, 6, 7);
#[cfg(test)]
mod tests {
use super::*;
fn collect(s: &impl IntoShape) -> Vec<c_int> {
s.with_shape(|dims| Ok(dims.to_vec())).expect("with_shape")
}
#[test]
fn tuple_ranks_1_through_8() {
assert_eq!(collect(&(1usize,)), vec![1]);
assert_eq!(collect(&(1usize, 2)), vec![1, 2]);
assert_eq!(collect(&(1usize, 2, 3)), vec![1, 2, 3]);
assert_eq!(collect(&(1usize, 2, 3, 4)), vec![1, 2, 3, 4]);
assert_eq!(collect(&(1usize, 2, 3, 4, 5)), vec![1, 2, 3, 4, 5]);
assert_eq!(collect(&(1usize, 2, 3, 4, 5, 6)), vec![1, 2, 3, 4, 5, 6]);
assert_eq!(
collect(&(1usize, 2, 3, 4, 5, 6, 7)),
vec![1, 2, 3, 4, 5, 6, 7]
);
assert_eq!(
collect(&(1usize, 2, 3, 4, 5, 6, 7, 8)),
vec![1, 2, 3, 4, 5, 6, 7, 8]
);
}
#[test]
fn vec_usize_and_i32_match_slices() {
let vu: Vec<usize> = vec![2, 3, 4];
assert_eq!(collect(&vu), vec![2, 3, 4]);
let vi: Vec<i32> = vec![5, 6];
assert_eq!(collect(&vi), vec![5, 6]);
assert_eq!(collect(&Vec::<usize>::new()), Vec::<c_int>::new());
assert_eq!(collect(&Vec::<i32>::new()), Vec::<c_int>::new());
}
#[test]
fn vec_usize_matches_equivalent_slice() {
let vu: Vec<usize> = vec![7, 8, 9, 10, 11];
let su: &[usize] = &[7, 8, 9, 10, 11];
assert_eq!(collect(&vu), collect(&su));
}
#[test]
fn slice_usize_rank_above_8_spills_to_heap_path() {
let dims: Vec<usize> = (1..=9).collect();
assert_eq!(collect(&dims), (1..=9).collect::<Vec<c_int>>());
}
#[test]
fn tuple_dim_overflowing_i32_is_rejected() {
let big = (i32::MAX as usize) + 1;
let err = (big,).with_shape(|_| Ok(())).unwrap_err();
assert!(matches!(err, Error::OutOfRange(_)), "got {err:?}");
}
#[test]
fn vec_dim_overflowing_i32_is_rejected() {
let big = (i32::MAX as usize) + 1;
let v: Vec<usize> = vec![1, big];
let err = v.with_shape(|_| Ok(())).unwrap_err();
assert!(matches!(err, Error::OutOfRange(_)), "got {err:?}");
}
#[test]
fn validate_dims_rejects_negative() {
assert!(validate_dims(&[1, 2, 3]).is_ok());
let err = validate_dims(&[1, -2, 3]).unwrap_err();
assert!(matches!(err, Error::OutOfRange(_)), "got {err:?}");
}
}