use crate::prelude::*;
use crate::StaticVecUnion;
use paste::paste;
use std::{
marker::PhantomData,
mem::{transmute, transmute_copy},
ops::DerefMut,
};
macro_rules! impl_reshape_unchecked_ref {
($($mut: tt)?) => {
paste!{
unsafe fn [<reshape_unchecked_ref $(_$mut)?>]<
'a,
B: crate::backends::Backend<T>,
S: crate::tensor::Shape<NDIM>,
const NDIM: usize,
>(
&'a $($mut)? self,
shape: S,
backend: B,
) -> crate::tensor::Tensor<T, & $($mut)? [T; LEN], B, NDIM, LEN, S>
where
Self: Sized,
{
Tensor {
data: crate::backends::WithStaticBackend::from_static_vec(
transmute(self.[< as $(_$mut)? _ptr>]()),
backend,
),
shape,
}
}
}
};
}
macro_rules! impl_reshape {
($($pub: ident $name_suffix: ident $($t:tt)*)?) => {paste!{
$($pub)? fn [<matrix $($name_suffix)?>]<B: crate::backends::Backend<T>, const M: usize, const K: usize>(
$($($t)*)? self,
) -> crate::tensor::Matrix<T, $($($t)*)? Self, B, LEN, false, MatrixShape<M, K>>
where
Self: Sized,
{
assert_eq!(
M * K,
LEN,
"Cannot reshape vector of {} elements as matrix of {}",
LEN,
M * K
);
crate::tensor::Tensor {
data: crate::backends::WithStaticBackend::from_static_vec(self, B::default()),
shape: crate::tensor::MatrixShape::<M, K>,
}
.into()
}
$($pub)? fn [<reshape $($name_suffix)?>]<
B: crate::backends::Backend<T>,
S: crate::tensor::Shape<NDIM>,
const NDIM: usize,
>(
$($($t)*)? self,
shape: S,
backend: B,
) -> crate::tensor::Tensor<T, $($($t)*)? Self, B, NDIM, LEN, S>
where
Self: Sized,
{
assert_eq!(
shape.volume(),
LEN,
"Cannot reshape vector with lenght {} as {:?}",
LEN,
shape.slice()
);
Tensor {
data: crate::backends::WithStaticBackend::from_static_vec(self, backend),
shape,
}
}
}};
}
pub trait StaticVec<T, const LEN: usize> {
unsafe fn as_ptr(&self) -> *const T;
unsafe fn as_mut_ptr(&mut self) -> *mut T {
transmute(self.as_ptr())
}
fn moo_ref<'a>(&'a self) -> StaticVecRef<'a, T, LEN>
where
T: Copy,
{
unsafe { &*(self.as_ptr() as *const StaticVecUnion<T, LEN>) }
}
fn mut_moo_ref<'a>(&'a mut self) -> MutStaticVecRef<'a, T, LEN>
where
T: Copy,
{
unsafe { &mut *(self.as_mut_ptr() as *mut StaticVecUnion<T, LEN>) }
}
fn moo<'a>(&'a self) -> StaticCowVec<'a, T, LEN>
where
T: Copy,
{
unsafe { StaticCowVec::from_ptr(self.as_ptr()) }
}
unsafe fn get_unchecked<'a>(&'a self, i: usize) -> &'a T {
&*self.as_ptr().add(i)
}
unsafe fn get_unchecked_mut<'a>(&'a mut self, i: usize) -> &'a mut T
where
T: Copy,
{
&mut *self.as_mut_ptr().add(i)
}
unsafe fn static_slice_unchecked<'a, const SLEN: usize>(&'a self, i: usize) -> &'a [T; SLEN] {
&*(self.as_ptr().add(i) as *const [T; SLEN])
}
unsafe fn mut_static_slice_unchecked<'a, const SLEN: usize>(
&'a mut self,
i: usize,
) -> &'a mut [T; SLEN] {
&mut *(self.as_ptr().add(i) as *mut [T; SLEN])
}
fn moo_owned(&self) -> StaticVecUnion<'static, T, LEN>
where
T: Copy,
Self: Sized,
{
unsafe { transmute_copy(self) }
}
fn static_backend<B: Backend<T> + Default>(
self,
) -> crate::backends::WithStaticBackend<T, Self, B, LEN>
where
Self: Sized,
{
crate::backends::WithStaticBackend {
data: self,
backend: B::default(),
_pd: PhantomData::<T>,
}
}
impl_reshape!();
impl_reshape_unchecked_ref!(mut);
impl_reshape_unchecked_ref!();
}
impl<'a, T: Copy, const LEN: usize> StaticVecUnion<'a, T, LEN> {
impl_reshape!(pub _ref &'a);
impl_reshape!(pub _mut_ref &'a mut);
}
impl<'a, T: Copy, const LEN: usize> StaticVec<T, LEN> for StaticVecUnion<'a, T, LEN> {
unsafe fn as_ptr(&self) -> *const T {
self.owned.as_ptr()
}
}
impl<T, const LEN: usize> StaticVec<T, LEN> for [T; LEN] {
unsafe fn as_ptr(&self) -> *const T {
self as *const T
}
}
macro_rules! impl_vec_for_refs {
($($mut: tt)?) => {
impl<T, const LEN: usize> StaticVec<T, LEN> for & $($mut)? [T; LEN] {
unsafe fn as_ptr(&self) -> *const T {
(**self).as_ptr()
}
unsafe fn as_mut_ptr(&mut self) -> *mut T {
if stringify!($($mut)?) == "mut"{
(*self).as_mut_ptr()
}else{
panic!("Cannot get mutable pointer from &[T; LEN]. Maybe try &mut [T; LEN] instead.")
}
}
}
impl<'a, T: Copy, const LEN: usize> StaticVec<T, LEN> for paste!([<$($mut:camel)? StaticVecRef>]<'a, T, LEN>) {
unsafe fn as_ptr(&self) -> *const T {
(**self).as_ptr()
}
unsafe fn as_mut_ptr(&mut self) -> *mut T {
if stringify!($($mut)?) == "mut"{
(*self).as_mut_ptr()
}else{
panic!("Cannot get mutable pointer from StaticVecRef<'a, T, LEN>. Maybe try MutStaticVecRef<'a, T, LEN> instead.")
}
}
}
impl<'a, T: Copy, const LEN: usize> StaticVec<T, LEN> for & $($mut)? StaticCowVec<'a, T, LEN> {
unsafe fn as_ptr(&self) -> *const T {
(**self).as_ptr()
}
unsafe fn as_mut_ptr(&mut self) -> *mut T {
if stringify!($($mut)?) == "mut"{
(*self).as_mut_ptr()
}else{
panic!("Cannot get mutable pointer from &StaticCowVec<'a, T, LEN>. Maybe try &mut StaticCowVec<'a, T, LEN> instead.")
}
}
}
};
}
impl_vec_for_refs!();
impl_vec_for_refs!(mut);
impl<'a, T: Copy, const LEN: usize> StaticVec<T, LEN> for StaticCowVec<'a, T, LEN> {
unsafe fn as_ptr(&self) -> *const T {
if self.is_owned {
self.data.as_ptr()
} else {
self.data.borrowed as *const T
}
}
unsafe fn as_mut_ptr(&mut self) -> *mut T {
if self.is_owned {
self.data.as_mut_ptr()
} else {
transmute(self.mut_moo_ref())
}
}
fn mut_moo_ref<'b>(&'b mut self) -> MutStaticVecRef<'b, T, LEN>
where
T: Copy,
{
unsafe { transmute(self.deref_mut()) }
}
}