use std::fmt;
use std::marker::PhantomData;
use furiosa_mapping::*;
use ndarray::IxDyn;
use num_traits::Zero;
use rand::Rng;
use rand::distr::StandardUniform;
use self::view::*;
use crate::scalar::*;
use crate::tensor::raw::gen_axes;
use crate::engine::vector::operand::OperandTag;
use crate::engine::vector::scalar::VeScalar;
use crate::runtime::{Backend, CurrentBackend};
pub(crate) mod memory;
pub mod pseudo;
pub(crate) mod raw;
pub(crate) mod tu;
pub(crate) mod view;
pub use raw::{BufRawTensor, MathRawTensor, PhantomRawTensor, RawTensor, RawTensorOpt};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BufferConvertErrorKind {
LengthMismatch { expected: usize, actual: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BufferConvertError {
kind: BufferConvertErrorKind,
}
impl BufferConvertError {
pub(crate) fn length_mismatch(expected: usize, actual: usize) -> Self {
Self {
kind: BufferConvertErrorKind::LengthMismatch { expected, actual },
}
}
}
impl fmt::Display for BufferConvertError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
BufferConvertErrorKind::LengthMismatch { expected, actual } => {
write!(f, "buffer length mismatch: expected {expected} elements, got {actual}")
}
}
}
}
impl std::error::Error for BufferConvertError {}
pub struct Tensor<D: Scalar, Mapping: M, B: Backend = CurrentBackend> {
inner: B::RawTensor<D>,
_marker: PhantomData<(Mapping, B)>,
}
impl<D: Scalar, Mapping: M, B: Backend> std::fmt::Debug for Tensor<D, Mapping, B>
where
B::RawTensor<D>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tensor").field("inner", &self.inner).finish()
}
}
impl<D: Scalar, Mapping: M, B: Backend> Clone for Tensor<D, Mapping, B>
where
B::RawTensor<D>: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_marker: PhantomData,
}
}
}
impl<D: Scalar, Mapping: M, B: Backend> Tensor<D, Mapping, B> {
pub(crate) fn from_inner(inner: B::RawTensor<D>) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<D: Scalar, Mapping: M, B: Backend> Tensor<D, Mapping, B> {
pub fn into_raw(self) -> B::RawTensor<D> {
self.inner
}
pub(crate) fn read_index(&self, index: Index) -> Opt<D> {
self.inner.read_index(index)
}
}
impl<D: Scalar, Mapping: M, B: Backend> Tensor<D, Mapping, B> {
pub fn from_buf(data: impl IntoIterator<Item = D>) -> Self {
Self::try_from_buf(data).unwrap_or_else(|err| panic!("failed to convert buffer for backend storage: {err}"))
}
pub fn try_from_buf(data: impl IntoIterator<Item = D>) -> Result<Self, BufferConvertError> {
B::RawTensor::try_from_buf::<Mapping>(data).map(Self::from_inner)
}
pub fn to_buf(&self) -> Vec<D> {
self.inner.to_buf::<Mapping>()
}
pub fn to_buf_or_default(&self) -> Vec<D> {
self.inner.to_buf_or_default::<Mapping>()
}
}
impl<D: Scalar, Mapping: M, B: Backend> Tensor<D, Mapping, B> {
pub fn from_fn<F>(f: F) -> Self
where
F: FnMut(&Vec<Term>, &IxDyn) -> Opt<D>,
{
Self::from_inner(B::RawTensor::from_fn(gen_axes::<Mapping>(), f))
}
pub fn map<D2: Scalar, F>(&self, f: F) -> Tensor<D2, Mapping, B>
where
F: FnMut(&Opt<D>) -> Opt<D2>,
{
Tensor::from_inner(self.inner.map::<D2, _, F>(f))
}
pub fn zip_with<D2: Scalar, D3: Scalar, F>(&self, other: &Tensor<D2, Mapping, B>, f: F) -> Tensor<D3, Mapping, B>
where
F: Fn(Opt<D>, Opt<D2>) -> Opt<D3>,
{
Tensor::from_inner(self.inner.zip_with::<D2, D3, _, _, F>(&other.inner, f))
}
pub fn reduce_add<Dst: M>(&self) -> Tensor<D, Dst, B> {
self.reduce::<Dst>(|a, b| a + b, Opt::zero())
}
pub fn reduce<Dst: M>(&self, reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>, identity: Opt<D>) -> Tensor<D, Dst, B> {
Tensor::from_inner(self.inner.reduce::<Mapping, Dst, _>(reduce_fn, identity))
}
pub fn zero() -> Self
where
D: Zero,
{
Self::from_fn(|_, _| Opt::Init(D::zero()))
}
pub fn rand(rng: &mut impl Rng) -> Self
where
StandardUniform: rand::distr::Distribution<D>,
{
Self::from_fn(|_, _| Opt::Init(rng.random::<D>()))
}
pub fn reduce_then_broadcast<Dst: M>(&self) -> Tensor<D, Dst, B> {
self.reduce_then_broadcast_with::<Dst>(|a, b| a + b, Opt::zero())
}
pub fn reduce_then_broadcast_with<Dst: M>(
&self,
reduce_fn: impl Fn(Opt<D>, Opt<D>) -> Opt<D>,
identity: Opt<D>,
) -> Tensor<D, Dst, B> {
Tensor::from_inner(self.inner.reduce_then_broadcast::<Mapping, Dst, _>(reduce_fn, identity))
}
pub fn contraction<Union: M, Lhs: M, Rhs: M>(lhs: &Tensor<D, Lhs, B>, rhs: &Tensor<D, Rhs, B>) -> Self {
lhs.transpose::<Union>(true)
.zip_with(&rhs.transpose(true), |a, b| a * b)
.reduce_add()
}
}
impl<D: Scalar, Mapping: M, B: Backend> Tensor<D, Mapping, B> {
pub fn uninit() -> Self {
Self::from_inner(B::RawTensor::uninit_from_axes(gen_axes::<Mapping>()))
}
pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Mapping, B> {
TensorViewMut::new(&mut self.inner)
}
pub fn view<'l>(&'l self) -> TensorView<'l, D, Mapping, B> {
TensorView::new(&self.inner)
}
pub unsafe fn transmute<Mapping2: M>(self) -> Tensor<D, Mapping2, B> {
Tensor {
inner: self.inner,
_marker: PhantomData,
}
}
}
impl<D: Scalar, Mapping: M, B: Backend> Tensor<D, Mapping, B> {
pub unsafe fn reshape<Mapping2: M>(self) -> Tensor<D, Mapping2, B> {
assert_eq!(Mapping::SIZE, Mapping2::SIZE);
Tensor::from_inner(self.inner.reshape::<Mapping, Mapping2>())
}
pub fn transpose<Dst: M>(&self, allow_broadcast: bool) -> Tensor<D, Dst, B> {
let mut dst = Tensor::uninit();
dst.view_mut().write_transpose(self.view(), allow_broadcast);
dst
}
pub fn write_scatter<Key: M, Dst: M, Idx: M>(
&self,
dst: &mut Tensor<D, Dst, B>,
index: &Tensor<i32, Idx, B>,
scaled: bool,
) {
self.inner
.write_scatter::<Mapping, Key, Dst, Idx, _>(&mut dst.inner, &index.inner, scaled);
}
pub fn write_gather<Dst: M, Idx: M>(&self, dst: &mut Tensor<D, Dst, B>, index: &Tensor<i32, Idx, B>, scaled: bool) {
self.inner
.write_gather::<Mapping, Dst, Idx, _>(&mut dst.inner, &index.inner, scaled);
}
pub(crate) fn apply_branch_operands<Operand, F>(
&self,
tag: &Tensor<u8, Mapping, B>,
operands: &[Operand],
update: F,
) -> Self
where
D: VeScalar,
Operand: OperandTag<D, Mapping>,
F: FnMut(&Index, &Operand, &mut B::RawTensor<D>),
{
Self::from_inner(
self.inner
.apply_branch_operands::<Mapping, Operand, _, F>(&tag.inner, operands, update),
)
}
}
impl<D: Scalar, Mapping: M, B: Backend> Tensor<D, Mapping, B>
where
B::RawTensor<D>: RawTensorOpt<D>,
{
pub fn from_opt_buf(data: impl IntoIterator<Item = Opt<D>>) -> Self {
Self::try_from_opt_buf(data)
.unwrap_or_else(|err| panic!("failed to convert logical buffer for backend storage: {err}"))
}
pub fn try_from_opt_buf(data: impl IntoIterator<Item = Opt<D>>) -> Result<Self, BufferConvertError> {
B::RawTensor::try_from_opt_buf::<Mapping>(data).map(Self::from_inner)
}
pub fn to_buf_opt(&self) -> Vec<Opt<D>> {
self.inner.to_opt_buf::<Mapping>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::{Npu, Simulation, Typecheck};
#[test]
fn simulation_reduce_keeps_partial_axis_when_only_a_sub_factor_remains() {
axes![K = 8, M = 4];
let source_buf: Vec<i32> = (0..8).flat_map(|k| (0..4).map(move |m| k + 100 * m)).collect();
let source = Tensor::<i32, m![K, M], Simulation>::from_buf(source_buf);
let reduced: Tensor<i32, m![K / 2, M], Simulation> = source.reduce_add();
let expected: Vec<i32> = (0..4).flat_map(|j| (0..4).map(move |m| 4 * j + 1 + 200 * m)).collect();
assert_eq!(reduced.to_buf(), expected);
}
#[test]
fn simulation_reduce_then_broadcast_keeps_partial_axis_and_broadcasts_extra_dst_axis() {
axes![A = 4, B = 2];
let source = Tensor::<i32, m![A], Simulation>::from_buf(vec![0, 1, 2, 3]);
let result: Tensor<i32, m![A / 2, B], Simulation> = source.reduce_then_broadcast();
assert_eq!(result.to_buf(), vec![1, 1, 5, 5]);
}
#[test]
fn simulation_from_buf_round_trips_through_opt_storage() {
axes![A = 2];
let tensor = Tensor::<i32, m![A], Simulation>::from_buf(vec![1, 2]);
assert_eq!(tensor.to_buf(), vec![1, 2]);
assert_eq!(tensor.to_buf_opt(), vec![Opt::Init(1), Opt::Init(2)]);
}
#[test]
fn typecheck_try_from_buf_ignores_length_mismatch() {
axes![A = 2];
let tensor = Tensor::<i32, m![A], Typecheck>::try_from_buf(vec![1]).unwrap();
assert!(tensor.to_buf().is_empty());
assert!(tensor.to_buf_opt().is_empty());
}
#[test]
fn typecheck_to_buf_is_empty() {
axes![A = 2];
let tensor = Tensor::<i32, m![A], Typecheck>::empty();
assert!(tensor.to_buf().is_empty());
assert!(tensor.to_buf_opt().is_empty());
}
#[test]
fn simulation_write_gather_roundtrip_scaled() {
axes![W = 3, V = 2, K = 4];
let table = Tensor::<i32, m![W, V], Simulation>::from_buf(vec![10, 11, 20, 21, 30, 31]);
let index = Tensor::<i32, m![K], Simulation>::from_buf(vec![0, 16, 8, 0]);
let mut output = Tensor::<i32, m![K, V], Simulation>::uninit();
table.write_gather::<_, _>(&mut output, &index, true);
assert_eq!(output.to_buf(), vec![10, 11, 30, 31, 20, 21, 10, 11]);
}
#[test]
fn simulation_write_gather_roundtrip_unscaled() {
axes![W = 3, V = 2, K = 4];
let table = Tensor::<i32, m![W, V], Simulation>::from_buf(vec![10, 11, 20, 21, 30, 31]);
let index = Tensor::<i32, m![K], Simulation>::from_buf(vec![0, 2, 1, 0]);
let mut output = Tensor::<i32, m![K, V], Simulation>::uninit();
table.write_gather::<_, _>(&mut output, &index, false);
assert_eq!(output.to_buf(), vec![10, 11, 30, 31, 20, 21, 10, 11]);
}
#[test]
fn typecheck_write_gather_runs_assertion_only() {
axes![W = 3, V = 2, K = 4];
let table = Tensor::<i32, m![W, V], Typecheck>::empty();
let index = Tensor::<i32, m![K], Typecheck>::empty();
let mut output = Tensor::<i32, m![K, V], Typecheck>::empty();
table.write_gather::<_, _>(&mut output, &index, true);
}
#[test]
fn npu_to_buf_returns_plain_values() {
axes![A = 2];
let tensor = Tensor::<i32, m![A], Npu>::from_buf(vec![1, 2]);
assert_eq!(tensor.to_buf(), vec![1, 2]);
}
}