use crate::dtype::NumericDataType;
use crate::ndarray::flags::NdArrayFlags;
use crate::ndarray::NdArray;
use crate::util::flatten::Flatten;
use crate::util::nested::Nested;
use crate::util::shape::Shape;
use crate::util::to_vec::ToVec;
use crate::{FloatDataType, RawDataType};
use num::NumCast;
use std::mem::ManuallyDrop;
use std::ptr::NonNull;
pub(crate) fn stride_from_shape(shape: &[usize]) -> Vec<usize> {
let ndims = shape.len();
let mut stride = vec![0; ndims];
let mut p = 1;
for i in (0..ndims).rev() {
stride[i] = p;
p *= shape[i];
}
stride
}
impl<'a, T: RawDataType> NdArray<'a, T> {
pub(crate) unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self {
let flags = NdArrayFlags::Owned | NdArrayFlags::Contiguous | NdArrayFlags::UniformStride | NdArrayFlags::Writeable;
let mut data = ManuallyDrop::new(data);
let stride = stride_from_shape(&shape);
Self {
ptr: NonNull::new_unchecked(data.as_mut_ptr()),
len: data.len(),
capacity: data.capacity(),
shape,
stride,
flags,
_marker: Default::default(),
}
}
pub fn from<const D: usize>(data: impl Flatten<T> + Shape + Nested<{ D }>) -> Self {
assert!(data.check_homogenous(), "Tensor::from() failed, found inhomogeneous dimensions");
let shape = data.shape();
let data = data.flatten();
assert!(!data.is_empty(), "Tensor::from() failed, cannot create data buffer from empty data");
unsafe { NdArray::from_contiguous_owned_buffer(shape, data) }
}
pub fn full(n: T, shape: impl ToVec<usize>) -> Self {
let shape = shape.to_vec();
let data = vec![n; shape.iter().product()];
assert!(!data.is_empty(), "Cannot create an empty tensor!");
unsafe { NdArray::from_contiguous_owned_buffer(shape, data) }
}
pub fn zeros(shape: impl ToVec<usize>) -> Self
where
T: From<bool>,
{
Self::full(false.into(), shape)
}
pub fn ones(shape: impl ToVec<usize>) -> Self
where
T: From<bool>,
{
Self::full(true.into(), shape)
}
pub fn scalar(n: T) -> Self {
NdArray::full(n, [])
}
}
impl<T: NumericDataType> NdArray<'_, T> {
pub fn arange(start: T, stop: T) -> NdArray<'static, T> {
Self::arange_with_step(start, stop, T::one())
}
pub fn arange_with_step(start: T, stop: T, step: T) -> NdArray<'static, T> {
let n = ((stop - start).to_float() / step.to_float()).ceil();
let n = NumCast::from(n).unwrap();
let mut data: Vec<T> = vec![T::default(); n];
for (i, item) in data.iter_mut().enumerate() {
*item = <T as NumCast>::from(i).unwrap() * step + start;
}
unsafe { NdArray::from_contiguous_owned_buffer(vec![data.len()], data) }
}
}
impl<T: FloatDataType> NdArray<'_, T> {
pub fn linspace(start: T, stop: T, num: usize) -> NdArray<'static, T> {
assert!(num > 0);
if num == 1 {
return unsafe { NdArray::from_contiguous_owned_buffer(vec![1], vec![start]) };
}
let step = (stop - start) / (<T as NumCast>::from(num).unwrap() - T::one());
NdArray::arange_with_step(start, stop + step, step)
}
pub fn linspace_exclusive(start: T, stop: T, num: usize) -> NdArray<'static, T> {
assert!(num > 0);
if num == 1 {
return unsafe { NdArray::from_contiguous_owned_buffer(vec![1], vec![start]) };
}
let step = (stop - start) / <T as NumCast>::from(num).unwrap();
NdArray::arange_with_step(start, stop, step)
}
}
impl<T: RawDataType> Drop for NdArray<'_, T> {
fn drop(&mut self) {
if self.flags.contains(NdArrayFlags::Owned) {
unsafe { Vec::from_raw_parts(self.mut_ptr(), self.len, self.capacity) };
}
self.len = 0;
self.capacity = 0;
}
}