use furiosa_mapping::{Index, M, Term};
use crate::engine::vector::operand::OperandTag;
use crate::engine::vector::scalar::VeScalar;
use crate::scalar::{Opt, Scalar};
use crate::tensor::raw::{RawTensor, finalize_coords, gen_axes, shape_from_axes};
#[derive(Debug, Clone, PartialEq, Eq)]
#[doc(hidden)]
pub struct BufRawTensor<D: Scalar> {
pub(crate) axes: Vec<Term>,
pub(crate) data: Vec<D>,
}
impl<D: Scalar> RawTensor<D> for BufRawTensor<D> {
fn axes(&self) -> &[Term] {
&self.axes
}
fn uninit_from_axes(axes: Vec<Term>) -> Self {
let len = shape_from_axes(&axes).iter().product::<usize>();
let mut data: Vec<D> = Vec::with_capacity(len);
data.resize_with(len, || unsafe { std::mem::zeroed() });
Self { axes, data }
}
fn read_index(&self, index: Index) -> Opt<D> {
match self.linear_index(index) {
Some(linear) => Opt::Init(self.data[linear]),
None => Opt::Uninit,
}
}
fn write_index(&mut self, index: Index, value: Opt<D>) {
let Opt::Init(v) = value else { return };
let Some(linear) = self.linear_index(index) else {
return;
};
self.data[linear] = v;
}
fn from_buf<Mapping: M>(data: impl IntoIterator<Item = D>) -> Self {
let axes = gen_axes::<Mapping>();
let data: Vec<D> = data.into_iter().collect();
let expected = shape_from_axes(&axes).iter().product::<usize>();
assert_eq!(expected, data.len(), "shape mismatch");
Self { axes, data }
}
fn to_buf<Mapping: M>(&self) -> Vec<D> {
self.data.clone()
}
fn map<D2: Scalar, Output: RawTensor<D2>, F>(&self, _f: F) -> Output
where
F: FnMut(&Opt<D>) -> Opt<D2>,
{
todo!("BufRawTensor::map: buffer semantics not implemented yet")
}
fn reduce<Src: M, Dst: M, Reduce>(&self, _reduce_fn: Reduce, _identity: Opt<D>) -> Self
where
Reduce: Fn(Opt<D>, Opt<D>) -> Opt<D>,
{
todo!("BufRawTensor::reduce: buffer semantics not implemented yet")
}
fn reduce_then_broadcast<Src: M, Dst: M, Reduce>(&self, _reduce_fn: Reduce, _identity: Opt<D>) -> Self
where
Reduce: Fn(Opt<D>, Opt<D>) -> Opt<D>,
{
todo!("BufRawTensor::reduce_then_broadcast: buffer semantics not implemented yet")
}
fn reshape<Mapping: M, Mapping2: M>(self) -> Self {
assert_eq!(Mapping::SIZE, Mapping2::SIZE);
Self {
axes: gen_axes::<Mapping2>(),
data: self.data,
}
}
fn write_transpose<Src: M, Dst: M>(
&mut self,
_src: &Self,
_src_offset: &Index,
_dst_offset: &Index,
_allow_broadcast: bool,
) {
todo!("BufRawTensor::write_transpose: buffer semantics not implemented yet")
}
fn zip_with<D2, D3, Other, Output, F>(&self, _rhs: &Other, _f: F) -> Output
where
D2: Scalar,
D3: Scalar,
Other: RawTensor<D2>,
Output: RawTensor<D3>,
F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>,
{
todo!("BufRawTensor::zip_with: buffer semantics not implemented yet")
}
fn write_scatter<Src, Key, Dst, Idx, IdxRaw>(&self, _dst: &mut Self, _index: &IdxRaw, _scaled: bool)
where
Src: M,
Key: M,
Dst: M,
Idx: M,
IdxRaw: RawTensor<i32>,
{
todo!("BufRawTensor::write_scatter: buffer semantics not implemented yet")
}
fn write_gather<Src, Dst, Idx, IdxRaw>(&self, _dst: &mut Self, _index: &IdxRaw, _scaled: bool)
where
Src: M,
Dst: M,
Idx: M,
IdxRaw: RawTensor<i32>,
{
todo!("BufRawTensor::write_gather: buffer semantics not implemented yet")
}
fn apply_branch_operands<Mapping, Operand, TagRaw, F>(
&self,
_tag: &TagRaw,
_operands: &[Operand],
_update: F,
) -> Self
where
D: VeScalar,
Mapping: M,
TagRaw: RawTensor<u8>,
Operand: OperandTag<D, Mapping>,
F: FnMut(&Index, &Operand, &mut Self),
{
todo!("BufRawTensor::apply_branch_operands: buffer semantics not implemented yet")
}
}
impl<D: Scalar> BufRawTensor<D> {
fn linear_index(&self, index: Index) -> Option<usize> {
let coords = finalize_coords(&self.axes, index)?;
let shape = shape_from_axes(&self.axes);
Some(coords.iter().zip(shape.iter()).fold(0usize, |acc, (c, &s)| acc * s + c))
}
}