use crate::tensors::Dimension;
use crate::tensors::views::{DataLayout, TensorRef};
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub(crate) struct TensorMap<T, U, S, F, const D: usize> {
source: S,
f: F,
_from: PhantomData<T>,
_to: PhantomData<U>,
}
impl<T, U, S, F, const D: usize> TensorMap<T, U, S, F, D>
where
S: TensorRef<T, D>,
F: Fn(&T) -> &U,
{
#[track_caller]
pub fn from(source: S, f: F) -> TensorMap<T, U, S, F, D> {
TensorMap {
source,
f,
_from: PhantomData,
_to: PhantomData,
}
}
#[allow(dead_code)]
pub fn source(self) -> S {
self.source
}
#[allow(dead_code)]
pub fn source_ref(&self) -> &S {
&self.source
}
#[allow(dead_code)]
pub fn source_ref_mut(&mut self) -> &mut S {
&mut self.source
}
}
unsafe impl<T, U, S, F, const D: usize> TensorRef<U, D> for TensorMap<T, U, S, F, D>
where
S: TensorRef<T, D>,
F: Fn(&T) -> &U,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&U> {
Some((self.f)(self.source.get_reference(indexes)?))
}
fn view_shape(&self) -> [(Dimension, usize); D] {
self.source.view_shape()
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &U {
unsafe { (self.f)(self.source.get_reference_unchecked(indexes)) }
}
fn data_layout(&self) -> DataLayout<D> {
self.source.data_layout()
}
}