use std::fmt::Debug;
use abi_stable::std_types::RResult;
use furiosa_mapping::{Atom, Index, IndexExt, M, Mapping, MappingExt, Term};
use ndarray::IxDyn;
use crate::engine::vector::operand::OperandTag;
use crate::engine::vector::scalar::VeScalar;
use crate::scalar::{Opt, Scalar};
use crate::tensor::BufferConvertError;
mod buf;
mod math;
mod phantom;
pub use buf::BufRawTensor;
pub use math::MathRawTensor;
pub use phantom::PhantomRawTensor;
pub trait RawTensor<D: Scalar>: 'static + Clone + Debug {
fn axes(&self) -> &[Term];
fn uninit_from_axes(axes: Vec<Term>) -> Self;
fn read_index(&self, index: Index) -> Opt<D>;
fn write_index(&mut self, index: Index, value: Opt<D>);
fn from_buf<Mapping: M>(data: impl IntoIterator<Item = D>) -> Self;
fn try_from_buf<Mapping: M>(data: impl IntoIterator<Item = D>) -> Result<Self, BufferConvertError>
where
Self: Sized,
{
let data: Vec<D> = data.into_iter().collect();
if data.len() != Mapping::SIZE {
return Err(BufferConvertError::length_mismatch(Mapping::SIZE, data.len()));
}
Ok(Self::from_buf::<Mapping>(data))
}
fn to_buf<Mapping: M>(&self) -> Vec<D>;
fn to_buf_or_default<Mapping: M>(&self) -> Vec<D> {
self.to_buf::<Mapping>()
}
fn from_fn<F>(axes: Vec<Term>, mut f: F) -> Self
where
F: FnMut(&Vec<Term>, &IxDyn) -> Opt<D>,
{
let mut tensor = Self::uninit_from_axes(axes.clone());
for index in Index::new().gen_indexes(Mapping::from_terms(axes.iter().cloned())) {
let coords = finalize_coords(&axes, index.clone()).expect("generated index must be valid");
tensor.write_index(index, f(&axes, &IxDyn(&coords)));
}
tensor
}
fn map<D2: Scalar, Output: RawTensor<D2>, F>(&self, f: F) -> Output
where
F: FnMut(&Opt<D>) -> Opt<D2>;
fn reduce<Src: M, Dst: M, Reduce>(&self, reduce_fn: Reduce, identity: Opt<D>) -> Self
where
Reduce: Fn(Opt<D>, Opt<D>) -> Opt<D>;
fn reduce_then_broadcast<Src: M, Dst: M, Reduce>(&self, reduce_fn: Reduce, identity: Opt<D>) -> Self
where
Reduce: Fn(Opt<D>, Opt<D>) -> Opt<D>;
fn reshape<Mapping: M, Mapping2: M>(self) -> Self;
fn write_transpose<Src: M, Dst: M>(
&mut self,
src: &Self,
src_offset: &Index,
dst_offset: &Index,
allow_broadcast: bool,
);
fn zip_with<D2, D3, Other, Output, F>(&self, rhs: &Other, f: F) -> Output
where
D2: Scalar,
D3: Scalar,
Other: RawTensor<D2>,
Output: RawTensor<D3>,
F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>;
fn write_scatter<Src, Key, Dst, Idx, IdxRaw>(&self, dst: &mut Self, index: &IdxRaw, scaled: bool)
where
Src: M,
Key: M,
Dst: M,
Idx: M,
IdxRaw: RawTensor<i32>;
fn write_gather<Src, Dst, Idx, IdxRaw>(&self, dst: &mut Self, index: &IdxRaw, scaled: bool)
where
Src: M,
Dst: M,
Idx: M,
IdxRaw: RawTensor<i32>;
fn apply_branch_operands<Mapping, Operand, TagRaw, F>(&self, tag: &TagRaw, operands: &[Operand], update: F) -> Self
where
D: VeScalar,
Mapping: M,
TagRaw: RawTensor<u8>,
Operand: OperandTag<D, Mapping>,
F: FnMut(&Index, &Operand, &mut Self);
}
pub trait RawTensorOpt<D: Scalar>: RawTensor<D> {
fn from_opt_buf<Mapping: M>(data: impl IntoIterator<Item = Opt<D>>) -> Self;
fn try_from_opt_buf<Mapping: M>(data: impl IntoIterator<Item = Opt<D>>) -> Result<Self, BufferConvertError>
where
Self: Sized,
{
let data: Vec<Opt<D>> = data.into_iter().collect();
if data.len() != Mapping::SIZE {
return Err(BufferConvertError::length_mismatch(Mapping::SIZE, data.len()));
}
Ok(Self::from_opt_buf::<Mapping>(data))
}
fn to_opt_buf<Mapping: M>(&self) -> Vec<Opt<D>>;
fn to_buf_or_default_opt<Mapping: M>(&self) -> Vec<D> {
self.to_opt_buf::<Mapping>()
.into_iter()
.map(|x| match x {
Opt::Init(value) => value,
Opt::Uninit => D::zero(),
})
.collect()
}
}
pub(crate) fn gen_axes<Mapping: M>() -> Vec<Term> {
Mapping::to_value().axes()
}
pub(crate) fn finalize_coords(axes: &[Term], index: Index) -> Option<Vec<usize>> {
let RResult::ROk(coords) = index.finalize() else {
return None;
};
Some(
axes.iter()
.map(|axis| match axis.inner {
Atom::Symbol { symbol, .. } => (coords.get(&symbol).copied().unwrap_or(0) / axis.stride) % axis.modulo,
Atom::Composite(_) => panic!("tensor axis must be a resolved symbol, got {axis:?}"),
})
.collect(),
)
}
pub(crate) fn shape_from_axes(axes: &[Term]) -> Vec<usize> {
axes.iter().map(|term| term.modulo).collect()
}