use core::fmt::Debug;
use sophus_autodiff::{
linalg::{
NumberCategory,
SMat,
SVec,
},
prelude::*,
};
pub trait IsStaticTensor<
Scalar: IsCoreScalar + 'static,
const SRANK: usize,
const ROWS: usize,
const COLS: usize,
>: Clone + Debug + num_traits::Zero
{
fn scalar(&self, idx: [usize; SRANK]) -> &Scalar;
fn rank(&self) -> usize {
SRANK
}
fn num_rows(&self) -> usize {
ROWS
}
fn num_cols(&self) -> usize {
COLS
}
fn sdims() -> [usize; SRANK];
fn num_scalars() -> usize {
ROWS * COLS
}
fn get_strides() -> [usize; SRANK];
fn from_slice(slice: &[Scalar]) -> Self;
}
impl<Scalar: IsCoreScalar + 'static> IsStaticTensor<Scalar, 0, 1, 1> for Scalar {
fn scalar(&self, _idx: [usize; 0]) -> &Scalar {
self
}
fn sdims() -> [usize; 0] {
[]
}
fn get_strides() -> [usize; 0] {
[]
}
fn from_slice(slice: &[Scalar]) -> Self {
slice[0].clone()
}
}
impl<Scalar: IsCoreScalar + 'static, const ROWS: usize> IsStaticTensor<Scalar, 1, ROWS, 1>
for SVec<Scalar, ROWS>
{
fn scalar(&self, idx: [usize; 1]) -> &Scalar {
&self[idx[0]]
}
fn sdims() -> [usize; 1] {
[ROWS]
}
fn get_strides() -> [usize; 1] {
[1]
}
fn from_slice(slice: &[Scalar]) -> Self {
SVec::from_iterator(slice.iter().cloned())
}
}
impl<Scalar: IsCoreScalar + 'static, const ROWS: usize, const COLS: usize>
IsStaticTensor<Scalar, 2, ROWS, COLS> for SMat<Scalar, ROWS, COLS>
{
fn scalar(&self, idx: [usize; 2]) -> &Scalar {
&self[(idx[0], idx[1])]
}
fn sdims() -> [usize; 2] {
[ROWS, COLS]
}
fn get_strides() -> [usize; 2] {
[1, ROWS]
}
fn from_slice(slice: &[Scalar]) -> Self {
SMat::from_iterator(slice.iter().cloned())
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct STensorFormat {
pub number_category: NumberCategory,
pub num_bytes_per_scalar: usize,
pub batch_size: usize,
pub num_rows: usize,
pub num_cols: usize,
}
impl STensorFormat {
pub fn new<
Scalar: IsCoreScalar + 'static,
const ROWS: usize,
const COLS: usize,
const BATCH: usize,
>() -> Self {
STensorFormat {
number_category: Scalar::number_category(),
num_rows: ROWS,
num_cols: COLS,
batch_size: BATCH,
num_bytes_per_scalar: core::mem::size_of::<Scalar>(),
}
}
pub fn num_bytes(&self) -> usize {
self.num_rows * self.num_cols * self.num_bytes_per_scalar
}
}
#[test]
fn test_elements() {
use approx::assert_abs_diff_eq;
#[cfg(feature = "simd")]
use sophus_autodiff::linalg::BatchScalar;
#[cfg(feature = "simd")]
use sophus_autodiff::linalg::BatchScalarF64;
#[cfg(feature = "simd")]
use sophus_autodiff::linalg::BatchVecF64;
#[cfg(feature = "simd")]
use sophus_autodiff::linalg::IsScalar;
use sophus_autodiff::linalg::{
NumberCategory,
VecF32,
};
assert_eq!(f32::number_category(), NumberCategory::Real);
assert_eq!(u32::number_category(), NumberCategory::Unsigned);
assert_eq!(i32::number_category(), NumberCategory::Signed);
#[cfg(feature = "simd")]
assert_eq!(
BatchScalar::<f64, 4>::number_category(),
NumberCategory::Real
);
let zeros_vec: VecF32<4> = IsStaticTensor::<f32, 1, 4, 1>::from_slice(&[0.0f32, 0.0, 0.0, 0.0]);
for elem in zeros_vec.iter() {
assert_eq!(*elem, 0.0);
}
let vec = SVec::<f32, 3>::new(1.0, 2.0, 3.0);
assert_abs_diff_eq!(vec, SVec::<f32, 3>::new(1.0, 2.0, 3.0));
let mat = SMat::<f32, 2, 2>::new(1.0, 2.0, 3.0, 4.0);
assert_eq!(mat.scalar([0, 0]), &1.0);
assert_eq!(mat.scalar([0, 1]), &2.0);
assert_eq!(mat.scalar([1, 0]), &3.0);
assert_eq!(mat.scalar([1, 1]), &4.0);
assert_abs_diff_eq!(mat, SMat::<f32, 2, 2>::new(1.0, 2.0, 3.0, 4.0));
#[cfg(feature = "simd")]
{
let batch_vec: BatchVecF64<2, 2> =
BatchVecF64::from_element(BatchScalarF64::from_real_array([1.0, 2.0]));
assert_eq!(batch_vec.scalar([0]).extract_single(0), 1.0);
assert_eq!(batch_vec.scalar([1]).extract_single(1), 2.0);
}
}