use std::marker::PhantomData;
use furiosa_mapping::{Index, M, MappingExt, Term};
use ndarray::IxDyn;
use crate::engine::vector::operand::OperandTag;
use crate::engine::vector::scalar::VeScalar;
use crate::runtime::op_prep::{assert_zip, gather_params, reduce_broadcast, scatter_params, transpose_broadcast};
use crate::scalar::{Opt, Scalar};
use crate::tensor::BufferConvertError;
use crate::tensor::raw::{RawTensor, RawTensorOpt, gen_axes};
#[derive(Debug, Clone, PartialEq, Eq)]
#[doc(hidden)]
pub struct PhantomRawTensor<D: Scalar> {
axes: Vec<Term>,
_phantom: PhantomData<D>,
}
impl<D: Scalar> RawTensor<D> for PhantomRawTensor<D> {
fn axes(&self) -> &[Term] {
&self.axes
}
fn uninit_from_axes(axes: Vec<Term>) -> Self {
Self {
axes,
_phantom: PhantomData,
}
}
fn read_index(&self, _index: Index) -> Opt<D> {
Opt::Uninit
}
fn write_index(&mut self, _index: Index, _value: Opt<D>) {}
fn from_buf<Mapping: M>(_data: impl IntoIterator<Item = D>) -> Self {
Self::uninit_from_axes(gen_axes::<Mapping>())
}
fn try_from_buf<Mapping: M>(data: impl IntoIterator<Item = D>) -> Result<Self, BufferConvertError> {
Ok(Self::from_buf::<Mapping>(data))
}
fn to_buf<Mapping: M>(&self) -> Vec<D> {
Vec::new()
}
fn from_fn<F>(axes: Vec<Term>, _f: F) -> Self
where
F: FnMut(&Vec<Term>, &IxDyn) -> Opt<D>,
{
Self::uninit_from_axes(axes)
}
fn map<D2: Scalar, Output: RawTensor<D2>, F>(&self, _f: F) -> Output
where
F: FnMut(&Opt<D>) -> Opt<D2>,
{
Output::uninit_from_axes(self.axes.clone())
}
fn reduce<Src: M, Dst: M, Reduce>(&self, _reduce_fn: Reduce, _identity: Opt<D>) -> Self
where
Reduce: Fn(Opt<D>, Opt<D>) -> Opt<D>,
{
let _ = Src::to_value().carve(&Dst::to_value());
Self::uninit_from_axes(gen_axes::<Dst>())
}
fn reduce_then_broadcast<Src: M, Dst: M, Reduce>(&self, _reduce_fn: Reduce, _identity: Opt<D>) -> Self
where
Reduce: Fn(Opt<D>, Opt<D>) -> Opt<D>,
{
let dst_axes = gen_axes::<Dst>();
let reduced_axes: Vec<Term> = self
.axes
.iter()
.filter(|axis| dst_axes.contains(axis))
.cloned()
.collect();
let _ = reduce_broadcast(&reduced_axes, &dst_axes);
Self::uninit_from_axes(dst_axes)
}
fn reshape<Mapping: M, Mapping2: M>(self) -> Self {
assert_eq!(Mapping::SIZE, Mapping2::SIZE);
if Mapping::to_value() == Mapping2::to_value() {
self
} else {
Self::uninit_from_axes(gen_axes::<Mapping2>())
}
}
fn write_transpose<Src: M, Dst: M>(
&mut self,
_src: &Self,
_src_offset: &Index,
_dst_offset: &Index,
allow_broadcast: bool,
) {
let _ = transpose_broadcast::<Src, Dst>(allow_broadcast);
}
fn zip_with<D2, D3, Other, Output, F>(&self, rhs: &Other, _f: F) -> Output
where
D2: Scalar,
D3: Scalar,
Other: RawTensor<D2>,
Output: RawTensor<D3>,
F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>,
{
assert_zip(self.axes(), rhs.axes());
Output::uninit_from_axes(self.axes.clone())
}
fn write_scatter<Src, Key, Dst, Idx, IdxRaw>(&self, _dst: &mut Self, _index: &IdxRaw, _scaled: bool)
where
Src: M,
Key: M,
Dst: M,
Idx: M,
IdxRaw: RawTensor<i32>,
{
let _ = scatter_params(&Src::to_value(), &Dst::to_value(), &Key::to_value());
}
fn write_gather<Src, Dst, Idx, IdxRaw>(&self, _dst: &mut Self, _index: &IdxRaw, _scaled: bool)
where
Src: M,
Dst: M,
Idx: M,
IdxRaw: RawTensor<i32>,
{
let _ = gather_params(&Src::to_value(), &Dst::to_value(), &Idx::to_value());
}
fn apply_branch_operands<Mapping, Operand, TagRaw, F>(
&self,
_tag: &TagRaw,
_operands: &[Operand],
_update: F,
) -> Self
where
D: VeScalar,
Mapping: M,
TagRaw: RawTensor<u8>,
Operand: OperandTag<D, Mapping>,
F: FnMut(&Index, &Operand, &mut Self),
{
self.clone()
}
}
impl<D: Scalar> RawTensorOpt<D> for PhantomRawTensor<D> {
fn from_opt_buf<Mapping: M>(_data: impl IntoIterator<Item = Opt<D>>) -> Self {
Self::uninit_from_axes(gen_axes::<Mapping>())
}
fn try_from_opt_buf<Mapping: M>(data: impl IntoIterator<Item = Opt<D>>) -> Result<Self, BufferConvertError> {
Ok(Self::from_opt_buf::<Mapping>(data))
}
fn to_opt_buf<Mapping: M>(&self) -> Vec<Opt<D>> {
Vec::new()
}
}