#![feature(try_from)]
pub trait Tensor {
const SIZE: usize;
const NDIM: usize;
fn dims(&self) -> Vec<usize>;
}
pub trait Matrix {
const ROWS: usize;
const COLS: usize;
}
pub trait Vector {
const COLS: usize;
}
pub trait RowVector {
const ROWS: usize;
}
#[derive(Debug, PartialEq)]
pub enum TensorError {
Size,
}
#[macro_export]
macro_rules! make_tensor {
($name:ident $($dim:literal)x+ ) => {
pub struct $name<T> (
[T; 1 $( * $dim )*]
);
impl<T> Tensor for $name<T> {
const SIZE: usize = 1 $( * $dim )*;
const NDIM: usize = 0 $( + $dim/$dim )*;
fn dims(&self) -> Vec<usize> {
vec!($($dim),*)
}
}
impl<T: PartialEq> PartialEq for $name<T> {
fn eq(&self, other: &Self) -> bool {
for (p, q) in self.0.iter().zip(other.0.iter()) {
if p != q {
return false;
}
}
true
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for $name<T> where $name<T>: Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in self.0.iter() {
write!(f, "{:?} ", i)?
}
Ok(())
}
}
impl<T: Default + Copy> Default for $name<T> {
fn default() -> Self {
$name::<T>([Default::default(); 1 $( * $dim )*])
}
}
impl<T: Default + Copy> std::convert::TryFrom<&[T]> for $name<T> {
type Error = TensorError;
fn try_from(v: &[T]) -> Result<Self, Self::Error> {
if v.len() < 1 $( * $dim )* {
Err(TensorError::Size)
} else {
let mut a: [T; 1 $( * $dim )*] = [Default::default(); 1 $( * $dim )*];
a.copy_from_slice(&v[..1 $( * $dim )*]);
Ok($name::<T>(a))
}
}
}
impl<T: Default + Copy> std::convert::TryFrom<Vec<T>> for $name<T> {
type Error = TensorError;
fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
if v.len() < 1 $( * $dim )* {
Err(TensorError::Size)
} else {
let mut a: [T; 1 $( * $dim )*] = [Default::default(); 1 $( * $dim )*];
a.copy_from_slice(&v[..1 $( * $dim )*]);
Ok($name::<T>(a))
}
}
}
};
}
#[macro_export]
macro_rules! tensor {
($name:ident $dim:literal) => {
make_tensor!($name $dim);
impl<T> Vector for $name<T> {
const COLS: usize = $dim;
}
};
($name:ident row $dim:literal) => {
make_tensor!($name $dim);
impl<T> RowVector for $name<T> {
const ROWS: usize = $dim;
}
};
($name:ident $dim1:literal x $dim2:literal) => {
make_tensor!($name $dim1 x $dim2);
impl<T> Matrix for $name<T> {
const ROWS: usize = $dim1;
const COLS: usize = $dim2;
}
};
($name:ident $($dim:literal)x+ ) => (
make_tensor!($name $($dim) x *);
)
}
#[cfg(test)]
mod tests {
use super::*;
tensor!(T2345 2 x 3 x 4 x 5);
#[test]
fn tensor_dims() {
assert_eq!(T2345::<u8>::SIZE, 2 * 3 * 4 * 5);
assert_eq!(T2345::<u8>::NDIM, 4);
}
tensor!(M23 2 x 3);
#[test]
fn matrix_dims() {
assert_eq!(M23::<u8>::ROWS, 2);
assert_eq!(M23::<u8>::COLS, 3);
}
tensor!(V4 4);
#[test]
fn col_vector_size() {
assert_eq!(V4::<u8>::COLS, 4);
}
tensor!(V2 row 2);
#[test]
fn row_vector_size() {
assert_eq!(V2::<u8>::ROWS, 2);
}
}