#![allow(dead_code)]
use std::marker::PhantomData;
use furiosa_mapping::*;
use crate::scalar::Scalar;
use crate::tensor::Tensor;
#[derive(Debug)]
pub struct BufTensor<D: Scalar, Buf: M> {
inner: Tensor<D, Buf>,
}
#[derive(Debug)]
pub struct StreamTensor<'l, D: Scalar, Time: M, Packet: M> {
inner: Tensor<D, Pair<Time, Packet>>,
_marker: PhantomData<&'l ()>,
}
impl<D: Scalar, Buf: M> BufTensor<D, Buf> {
pub fn read<'l, Time: M, Packet: M>(&'l self) -> StreamTensor<'l, D, Time, Packet> {
StreamTensor {
inner: self.inner.transpose(true),
_marker: PhantomData,
}
}
pub fn write<'l, Time: M, Packet: M>(&mut self, stream: StreamTensor<'l, D, Time, Packet>) {
self.inner = stream.inner.transpose(false);
}
}
impl<D: Scalar, Buf: M> BufTensor<D, Buf> {
pub fn from_buf(data: impl IntoIterator<Item = D>) -> Self {
Self {
inner: Tensor::from_buf(data),
}
}
pub fn to_buf(&self) -> Vec<D> {
self.inner.to_buf()
}
}
impl<'l, D: Scalar, Time: M, Packet: M> StreamTensor<'l, D, Time, Packet> {
pub fn to_buf(&self) -> Vec<D> {
self.inner.to_buf()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_identity_preserves_order() {
axes![A = 2, B = 3];
let buf = BufTensor::<i32, m![A, B]>::from_buf(vec![10, 11, 12, 20, 21, 22]);
let stream = buf.read::<m![A], m![B]>();
assert_eq!(
stream.to_buf(),
Tensor::<i32, m![A, B]>::from_buf(vec![10, 11, 12, 20, 21, 22]).to_buf()
);
}
#[test]
fn read_reorders_axes() {
axes![A = 2, B = 3];
let buf = BufTensor::<i32, m![A, B]>::from_buf(vec![10, 11, 12, 20, 21, 22]);
let stream = buf.read::<m![B], m![A]>();
assert_eq!(
stream.to_buf(),
Tensor::<i32, m![A, B]>::from_buf(vec![10, 20, 11, 21, 12, 22]).to_buf()
);
}
#[test]
fn write_inverts_read() {
axes![A = 2, B = 3];
let original = vec![10, 11, 12, 20, 21, 22];
let buf = BufTensor::<i32, m![A, B]>::from_buf(original.clone());
let stream = buf.read::<m![B], m![A]>();
let mut sink = BufTensor::<i32, m![A, B]>::from_buf(vec![0; 6]);
sink.write(stream);
assert_eq!(sink.to_buf(), Tensor::<i32, m![A, B]>::from_buf(original).to_buf());
}
#[test]
fn read_splits_axis() {
axes![A = 4];
let buf = BufTensor::<i32, m![A]>::from_buf(vec![0, 1, 2, 3]);
let stream = buf.read::<m![A % 2], m![A / 2]>();
assert_eq!(
stream.to_buf(),
Tensor::<i32, m![A]>::from_buf(vec![0, 2, 1, 3]).to_buf()
);
}
}