use ndarray::{Array, IxDyn};
use crate::error::{Error, ErrorKind, Result};
use crate::tensor::BorrowedTensor;
impl<'a> BorrowedTensor<'a> {
pub fn from_ndarray_f32(array: &'a Array<f32, IxDyn>) -> Result<Self> {
let slice = array.as_slice().ok_or_else(|| {
Error::new(
ErrorKind::TensorCreate,
"ndarray is not contiguous (standard layout required); \
call .as_standard_layout().into_owned() first",
)
})?;
let shape: Vec<usize> = array.shape().to_vec();
Self::from_f32(slice, &shape)
}
pub fn from_ndarray_i32(array: &'a Array<i32, IxDyn>) -> Result<Self> {
let slice = array.as_slice().ok_or_else(|| {
Error::new(
ErrorKind::TensorCreate,
"ndarray is not contiguous (standard layout required); \
call .as_standard_layout().into_owned() first",
)
})?;
let shape: Vec<usize> = array.shape().to_vec();
Self::from_i32(slice, &shape)
}
pub fn from_ndarray_f64(array: &'a Array<f64, IxDyn>) -> Result<Self> {
let slice = array.as_slice().ok_or_else(|| {
Error::new(
ErrorKind::TensorCreate,
"ndarray is not contiguous (standard layout required); \
call .as_standard_layout().into_owned() first",
)
})?;
let shape: Vec<usize> = array.shape().to_vec();
Self::from_f64(slice, &shape)
}
}
pub trait PredictionNdarray {
fn get_ndarray_f32(&self, name: &str) -> Result<Array<f32, IxDyn>>;
fn get_ndarray_i32(&self, name: &str) -> Result<Array<i32, IxDyn>>;
fn get_ndarray_f64(&self, name: &str) -> Result<Array<f64, IxDyn>>;
}
impl PredictionNdarray for crate::Prediction {
fn get_ndarray_f32(&self, name: &str) -> Result<Array<f32, IxDyn>> {
let (data, shape) = self.get_f32(name)?;
Array::from_shape_vec(IxDyn(&shape), data).map_err(|e| {
Error::new(
ErrorKind::Prediction,
format!("ndarray shape reconstruction failed for '{name}': {e}"),
)
})
}
fn get_ndarray_i32(&self, name: &str) -> Result<Array<i32, IxDyn>> {
let (data, shape) = self.get_i32(name)?;
Array::from_shape_vec(IxDyn(&shape), data).map_err(|e| {
Error::new(
ErrorKind::Prediction,
format!("ndarray shape reconstruction failed for '{name}': {e}"),
)
})
}
fn get_ndarray_f64(&self, name: &str) -> Result<Array<f64, IxDyn>> {
let (data, shape) = self.get_f64(name)?;
Array::from_shape_vec(IxDyn(&shape), data).map_err(|e| {
Error::new(
ErrorKind::Prediction,
format!("ndarray shape reconstruction failed for '{name}': {e}"),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn ndarray_f32_shape_preserved() {
let arr = array![[1.0f32, 2.0], [3.0, 4.0]].into_dyn();
assert_eq!(arr.shape(), &[2, 2]);
let slice = arr.as_slice().expect("standard layout should be contiguous");
assert_eq!(slice, &[1.0f32, 2.0, 3.0, 4.0]);
}
#[test]
fn ndarray_i32_shape_preserved() {
let arr = array![[1i32, 2], [3, 4], [5, 6]].into_dyn();
assert_eq!(arr.shape(), &[3, 2]);
let slice = arr.as_slice().expect("standard layout should be contiguous");
assert_eq!(slice, &[1i32, 2, 3, 4, 5, 6]);
}
#[test]
fn ndarray_f64_shape_preserved() {
let arr = array![1.0f64, 2.0, 3.0].into_dyn();
assert_eq!(arr.shape(), &[3]);
let slice = arr.as_slice().expect("standard layout should be contiguous");
assert_eq!(slice, &[1.0f64, 2.0, 3.0]);
}
#[test]
fn standard_layout_is_contiguous() {
let arr = Array::<f32, _>::zeros(ndarray::IxDyn(&[4, 5, 6]));
assert!(arr.as_slice().is_some());
}
#[test]
fn transposed_then_owned_preserves_shape() {
let arr = array![[1.0f32, 2.0], [3.0, 4.0]].into_dyn();
let transposed_owned = arr.t().as_standard_layout().into_owned().into_dyn();
assert!(transposed_owned.as_slice().is_some());
assert_eq!(transposed_owned.shape(), &[2, 2]);
}
#[test]
fn raw_transposed_is_not_contiguous() {
let arr = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]].into_dyn();
let view = arr.t();
assert!(!view.is_standard_layout());
let contiguous = view.as_standard_layout().into_owned().into_dyn();
let slice = contiguous.as_slice().unwrap();
assert_eq!(slice, &[1.0f32, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn from_ndarray_f32_non_contiguous_returns_error() {
use ndarray::s;
let arr = Array::<f32, _>::from_iter((0..10).map(|x| x as f32))
.into_dyn();
let strided = arr.slice(s![..;2]).to_owned().into_dyn();
let strided_view = arr.slice(s![..;2]);
assert!(strided_view.as_slice().is_none());
assert!(strided.as_slice().is_some());
}
#[cfg(target_vendor = "apple")]
mod apple {
use super::*;
use crate::tensor::DataType;
use ndarray::array;
#[test]
fn borrowed_tensor_from_ndarray_f32_1d() {
let arr = array![1.0f32, 2.0, 3.0, 4.0].into_dyn();
let tensor = BorrowedTensor::from_ndarray_f32(&arr).unwrap();
assert_eq!(tensor.shape(), &[4]);
assert_eq!(tensor.data_type(), DataType::Float32);
assert_eq!(tensor.element_count(), 4);
}
#[test]
fn borrowed_tensor_from_ndarray_f32_2d() {
let arr = array![[1.0f32, 2.0], [3.0, 4.0]].into_dyn();
let tensor = BorrowedTensor::from_ndarray_f32(&arr).unwrap();
assert_eq!(tensor.shape(), &[2, 2]);
assert_eq!(tensor.data_type(), DataType::Float32);
assert_eq!(tensor.element_count(), 4);
}
#[test]
fn borrowed_tensor_from_ndarray_f32_3d() {
let arr = Array::<f32, _>::zeros(ndarray::IxDyn(&[2, 3, 4]));
let tensor = BorrowedTensor::from_ndarray_f32(&arr).unwrap();
assert_eq!(tensor.shape(), &[2, 3, 4]);
assert_eq!(tensor.element_count(), 24);
}
#[test]
fn borrowed_tensor_from_ndarray_i32() {
let arr = array![[0i32, 1], [2, 3]].into_dyn();
let tensor = BorrowedTensor::from_ndarray_i32(&arr).unwrap();
assert_eq!(tensor.shape(), &[2, 2]);
assert_eq!(tensor.data_type(), DataType::Int32);
}
#[test]
fn borrowed_tensor_from_ndarray_f64() {
let arr = array![0.5f64, 1.5, 2.5].into_dyn();
let tensor = BorrowedTensor::from_ndarray_f64(&arr).unwrap();
assert_eq!(tensor.shape(), &[3]);
assert_eq!(tensor.data_type(), DataType::Float64);
}
#[test]
fn borrowed_tensor_from_non_contiguous_f32_errors() {
use ndarray::s;
let base = Array::<f32, _>::from_iter((0..12).map(|x| x as f32))
.into_shape_with_order(ndarray::IxDyn(&[3, 4]))
.unwrap();
let strided = base.slice(s![.., ..;2]).to_owned().into_dyn();
assert!(BorrowedTensor::from_ndarray_f32(&strided).is_ok());
}
}
}