#[macro_export]
macro_rules! tensor {
[[$($elem:expr),*]] => {
{
use $crate::Tensor;
Tensor::from_vec(vec![$($elem),*])
}
};
[[$($elem:expr),*,]] => {
tensor!([$($elem),*])
};
(($($dim:expr),+); [$($elem:expr),*]) => {
{
use $crate::Tensor;
Tensor::from_data(&[$($dim),+], vec![$($elem),*])
}
};
(($($dim:expr),+); [$($elem:expr),*,]) => {
{
use $crate::Tensor;
Tensor::from_data(&[$($dim),+], vec![$($elem),*])
}
};
($elem:expr) => {
{
use $crate::Tensor;
Tensor::from_scalar($elem)
}
};
}
#[macro_export]
macro_rules! ndtensor {
[[$($elem:expr),*]] => {
{
use $crate::NdTensor;
let data = vec![$($elem),*];
let len = data.len();
NdTensor::from_data([len], data)
}
};
[[$($elem:expr),*,]] => {
ndtensor!([$($elem),*])
};
(($($dim:expr),+); [$($elem:expr),*]) => {
{
use $crate::NdTensor;
let data = vec![$($elem),*];
NdTensor::from_data([$($dim),+], data)
}
};
(($($dim:expr),+); [$($elem:expr),*,]) => {
{
use $crate::NdTensor;
let data = vec![$($elem),*];
NdTensor::from_data([$($dim),+], data)
}
};
($elem:expr) => {
{
use $crate::NdTensor;
NdTensor::from_data([], vec![$elem])
}
};
}
#[cfg(test)]
mod tests {
use crate::{ndtensor, tensor, NdTensor, Tensor};
#[test]
fn test_tensor_scalar() {
let x = tensor!(5.);
assert_eq!(x, Tensor::from_scalar(5.));
}
#[test]
fn test_tensor_vector() {
let x = tensor!([1, 2, 3]);
assert_eq!(x, Tensor::from_vec(vec![1, 2, 3]));
}
#[test]
fn test_tensor_nd() {
let x = tensor!((1, 2, 2); [1, 2, 3, 4]);
assert_eq!(x, Tensor::from_data(&[1, 2, 2], vec![1, 2, 3, 4]));
let x = tensor!((1, 2, 2); [1, 2, 3, 4,]);
assert_eq!(x, Tensor::from_data(&[1, 2, 2], vec![1, 2, 3, 4]));
}
#[test]
fn test_ndtensor_scalar() {
let x = ndtensor!(5.);
assert_eq!(x, NdTensor::from_data([], vec![5.]));
}
#[test]
fn test_ndtensor_vector() {
let x = ndtensor!([1, 2, 3]);
assert_eq!(x, NdTensor::from_data([3], vec![1, 2, 3]));
}
#[test]
fn test_ndtensor_nd() {
let x = ndtensor!((1, 2, 2); [1, 2, 3, 4]);
assert_eq!(x, NdTensor::from_data([1, 2, 2], vec![1, 2, 3, 4]));
let x = ndtensor!((1, 2, 2); [1, 2, 3, 4,]);
assert_eq!(x, NdTensor::from_data([1, 2, 2], vec![1, 2, 3, 4]));
}
}