use std::marker::PhantomData;
use furiosa_mapping::*;
use furiosa_opt_lower::config_tile;
use super::Tensor;
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::raw::RawTensor;
pub struct TensorViewMut<'l, D: Scalar, Mapping: M, B: Backend = CurrentBackend> {
inner: &'l mut B::RawTensor<D>,
offset: Index,
_marker: PhantomData<(Mapping, B)>,
}
impl<'l, D: Scalar, Mapping: M, B: Backend> std::fmt::Debug for TensorViewMut<'l, D, Mapping, B>
where
B::RawTensor<D>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TensorViewMut")
.field("inner", &self.inner)
.field("offset", &self.offset)
.finish()
}
}
pub struct TensorView<'l, D: Scalar, Mapping: M, B: Backend = CurrentBackend> {
inner: &'l B::RawTensor<D>,
offset: Index,
_marker: PhantomData<(Mapping, B)>,
}
impl<'l, D: Scalar, Mapping: M, B: Backend> std::fmt::Debug for TensorView<'l, D, Mapping, B>
where
B::RawTensor<D>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TensorView")
.field("inner", &self.inner)
.field("offset", &self.offset)
.finish()
}
}
impl<'l, D: Scalar, Mapping: M, B: Backend> Clone for TensorView<'l, D, Mapping, B> {
fn clone(&self) -> Self {
Self {
inner: self.inner,
offset: self.offset.clone(),
_marker: PhantomData,
}
}
}
impl<'l, D: Scalar, Mapping: M, B: Backend> From<TensorViewMut<'l, D, Mapping, B>> for TensorView<'l, D, Mapping, B> {
fn from(view: TensorViewMut<'l, D, Mapping, B>) -> Self {
Self {
inner: view.inner,
offset: view.offset,
_marker: PhantomData,
}
}
}
impl<'l, D: Scalar, E: M, B: Backend> TensorViewMut<'l, D, E, B> {
pub(crate) fn new(inner: &'l mut B::RawTensor<D>) -> Self {
Self {
inner,
offset: Index::new(),
_marker: PhantomData,
}
}
pub fn tile<I: M, E2: M, const LEN: usize>(self, start: usize) -> TensorViewMut<'l, D, E2, B> {
config_tile(&I::to_value(), &E::to_value(), &E2::to_value(), LEN).unwrap_or_else(|e| panic!("{e}"));
let mut offset = self.offset;
offset.add_mapping::<I>(start);
TensorViewMut {
inner: self.inner,
offset,
_marker: PhantomData,
}
}
pub fn write_transpose<'lsrc, Src: M>(&mut self, src: TensorView<'lsrc, D, Src, B>, allow_broadcast: bool) {
self.inner
.write_transpose::<Src, E>(src.inner, &src.offset, &self.offset, allow_broadcast);
}
}
impl<'l, D: Scalar, E: M, B: Backend> TensorView<'l, D, E, B> {
pub fn tile<I: M, E2: M, const LEN: usize>(&self, start: usize) -> TensorView<'l, D, E2, B> {
config_tile(&I::to_value(), &E::to_value(), &E2::to_value(), LEN).unwrap_or_else(|e| panic!("{e}"));
let mut offset = self.offset.clone();
offset.add_mapping::<I>(start);
TensorView {
inner: self.inner,
offset,
_marker: PhantomData,
}
}
}
impl<'l, D: Scalar, E: M, B: Backend> TensorView<'l, D, E, B> {
pub fn read(self) -> Tensor<D, E, B> {
let mut result = Tensor::uninit();
result.view_mut().write_transpose(self, false);
result
}
}
impl<'l, D: Scalar, E: M, B: Backend> TensorView<'l, D, E, B> {
pub(crate) fn new(inner: &'l B::RawTensor<D>) -> Self {
Self {
inner,
offset: Index::new(),
_marker: PhantomData,
}
}
}