#[macro_export]
macro_rules! tensor {
( [ $( [ $( [ $( $x:expr ),* ] ),* ] ),+ ] ) => {
{
use ::num_traits::ToPrimitive;
let data: Vec<f32> = vec![
$(
$(
$( ToPrimitive::to_f32(&$x).expect("Failed to convert to f32") ),*
),*
),+
];
let dim0 = $crate::tensor!(@count [ $( [ $( [ $( $x ),* ] ),* ] ),+ ]);
let dim1 = {
let first = $crate::tensor!(@count_first_2d [ $( [ $( [ $( $x ),* ] ),* ] ),+ ]);
first
};
let dim2 = {
let first = $crate::tensor!(@count_first_1d [ $( [ $( [ $( $x ),* ] ),* ] ),+ ]);
first
};
let shape = vec![dim0, dim1, dim2];
$crate::tensor::Tensor::from_vec(data, shape)
}
};
( [ $( [ $( $x:expr ),* ] ),+ ] ) => {
{
use ::num_traits::ToPrimitive;
let data: Vec<f32> = vec![
$(
$( ToPrimitive::to_f32(&$x).expect("Failed to convert to f32") ),*
),+
];
let rows = $crate::tensor!(@count [ $( [ $( $x ),* ] ),+ ]);
let cols = $crate::tensor!(@count_first_1d [ $( [ $( $x ),* ] ),+ ]);
let shape = vec![rows, cols];
$crate::tensor::Tensor::from_vec(data, shape)
}
};
( [ $( $x:expr ),+ ] ) => {
{
use ::num_traits::ToPrimitive;
let data: Vec<f32> = vec![
$( ToPrimitive::to_f32(&$x).expect("Failed to convert to f32") ),+
];
let shape = vec![data.len()];
$crate::tensor::Tensor::from_vec(data, shape)
}
};
(@count [ $( $x:tt ),+ ]) => {
{
let count = 0 $( + { let _ = stringify!($x); 1 } )+;
count
}
};
(@count_first_2d [ [ $( [ $( $x:tt ),* ] ),+ ] $(, $_rest:tt )* ]) => {
{
let count = 0 $( + { let _ = stringify!([ $( $x ),* ]); 1 } )+;
count
}
};
(@count_first_1d [ [ $( $x:tt ),+ ] $(, $_rest:tt )* ]) => {
{
let count = 0 $( + { let _ = stringify!($x); 1 } )+;
count
}
};
}
#[cfg(test)]
mod tests {
use crate::tensor::Tensor;
#[test]
fn test_tensor_macro_1d() {
let t = tensor!([1, 2, 3, 4]);
assert_eq!(t.shape(), &[4]);
assert_eq!(t.as_slice(), Some(&[1.0, 2.0, 3.0, 4.0][..]));
}
#[test]
fn test_tensor_macro_2d() {
let t = tensor!([[1, 2, 3], [4, 5, 6]]);
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.as_slice(), Some(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0][..]));
}
#[test]
fn test_tensor_macro_3d() {
let t = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
assert_eq!(t.shape(), &[2, 2, 2]);
assert_eq!(
t.as_slice(),
Some(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0][..])
);
}
#[test]
fn test_tensor_macro_mixed_types() {
let t = tensor!([1, 2.5, 3, 4.2]);
assert_eq!(t.shape(), &[4]);
assert_eq!(t.as_slice(), Some(&[1.0, 2.5, 3.0, 4.2][..]));
}
}