use crate::torch::ExclusiveTensor;
use ndarray::{Array1, Array2, ArrayView, ArrayViewMut, Ix1, Ix2};
use num_traits::{One, Zero};
use tch::{kind::Element, Tensor};
pub trait NumArray1D {
type Elem;
fn zeros(size: usize) -> Self;
fn ones(size: usize) -> Self;
fn as_slice(&self) -> &[Self::Elem];
fn as_slice_mut(&mut self) -> &mut [Self::Elem];
}
pub trait NumArray2D {
type Elem;
fn zeros(size: (usize, usize)) -> Self;
fn ones(size: (usize, usize)) -> Self;
fn view(&self) -> ArrayView<Self::Elem, Ix2>;
fn view_mut(&mut self) -> ArrayViewMut<Self::Elem, Ix2>;
}
pub trait BuildFromArray1D: From<Self::Array> {
type Array: NumArray1D;
}
impl<T: NumArray1D> BuildFromArray1D for T {
type Array = Self;
}
pub trait BuildFromArray2D: From<Self::Array> {
type Array: NumArray2D;
}
impl<T: NumArray2D> BuildFromArray2D for T {
type Array = Self;
}
impl<A: Clone + Zero + One> NumArray1D for Array1<A> {
type Elem = A;
#[inline]
fn zeros(size: usize) -> Self {
Self::zeros(size)
}
#[inline]
fn ones(size: usize) -> Self {
Self::ones(size)
}
#[inline]
fn as_slice(&self) -> &[Self::Elem] {
self.as_slice().unwrap()
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [Self::Elem] {
self.as_slice_mut().unwrap()
}
}
impl<A: Clone + Zero + One> NumArray2D for Array2<A> {
type Elem = A;
#[inline]
fn zeros(size: (usize, usize)) -> Self {
Self::zeros(size)
}
#[inline]
fn ones(size: (usize, usize)) -> Self {
Self::ones(size)
}
#[inline]
fn view(&self) -> ArrayView<Self::Elem, Ix2> {
self.view()
}
#[inline]
fn view_mut(&mut self) -> ArrayViewMut<Self::Elem, Ix2> {
self.view_mut()
}
}
impl<A: Element> NumArray1D for ExclusiveTensor<A, Ix1> {
type Elem = A;
#[inline]
fn zeros(size: usize) -> Self {
Self::zeros(size)
}
#[inline]
fn ones(size: usize) -> Self {
Self::ones(size)
}
#[inline]
fn as_slice(&self) -> &[Self::Elem] {
self.as_slice()
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [Self::Elem] {
self.as_slice_mut()
}
}
impl<A: Element> NumArray2D for ExclusiveTensor<A, Ix2> {
type Elem = A;
#[inline]
fn zeros(size: (usize, usize)) -> Self {
Self::zeros(size)
}
#[inline]
fn ones(size: (usize, usize)) -> Self {
Self::ones(size)
}
#[inline]
fn view(&self) -> ArrayView<Self::Elem, Ix2> {
self.array_view()
}
#[inline]
fn view_mut(&mut self) -> ArrayViewMut<Self::Elem, Ix2> {
self.array_view_mut()
}
}
impl BuildFromArray1D for Tensor {
type Array = ExclusiveTensor<f32, Ix1>;
}
impl BuildFromArray2D for Tensor {
type Array = ExclusiveTensor<f32, Ix2>;
}