use rand::Rng;
use rand::distr::StandardUniform;
use std::fmt::{self, Display, Formatter};
use std::marker::PhantomData;
use furiosa_mapping::*;
use furiosa_opt_lower::{RelaxedDivision, config_divide_relaxed};
use furiosa_opt_macro::primitive;
use crate::context::*;
use crate::engine::vector::scalar::VeScalar;
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::raw::gen_axes;
use crate::tensor::{BufferConvertError, *};
pub type Address = u64;
const DMA_SRAM_WRITE_WIDTH: usize = 8;
pub(crate) fn assert_dma_layout<D: Scalar, Src: M, Dst: M>(min_align: usize) {
assert!(min_align > 0, "min_align must be positive");
let src = Src::to_value();
let dst = Dst::to_value();
let division = config_divide_relaxed(&src, &dst);
let packet_end = check_dma_tail::<D>(&division, &src, &dst, min_align);
check_dma_address_stride::<D>(&division, &src, &dst, min_align, packet_end);
}
fn check_dma_tail<D: Scalar>(division: &RelaxedDivision, src: &Mapping, dst: &Mapping, min_align: usize) -> usize {
let reachable_end = division.contiguous_tail;
let reachable_end_bytes = D::size_in_bytes_from_length(reachable_end);
assert!(
reachable_end_bytes.is_multiple_of(min_align),
"DMA tail alignment violation: reachable destination tail \
end is not aligned to {min_align} bytes.\n \
reachable destination tail end (elements) = {reachable_end}\n \
reachable destination tail end (bytes) = {reachable_end_bytes}\n \
src mapping = {src:?}\n \
dst mapping = {dst:?}",
);
reachable_end
}
fn check_dma_address_stride<D: Scalar>(
division: &RelaxedDivision,
src: &Mapping,
dst: &Mapping,
min_align: usize,
packet_end: usize,
) {
for bound in division.matched.iter().filter(|b| b.divisor_stride >= packet_end) {
let dst_stride = bound.divisor_stride;
let dst_bytes = D::size_in_bytes_from_length(dst_stride);
assert!(
dst_bytes.is_multiple_of(min_align),
"DMA address stride alignment violation: matched term {term:?} beginning at \
or past the contiguous tail has dst stride {dst_bytes} bytes, not aligned to \
{min_align}-byte granularity.\n \
reachable packet end (elements) = {packet_end}\n \
src mapping = {src:?}\n \
dst mapping = {dst:?}",
term = bound,
);
}
}
#[primitive(TrfAddress)]
#[derive(Copy, Clone, Debug)]
pub enum TrfAddress {
FirstHalf,
SecondHalf,
Full,
}
impl TrfAddress {
pub fn capacity(&self) -> usize {
match self {
Self::Full => 65_536,
Self::FirstHalf | Self::SecondHalf => 32_768,
}
}
}
impl Display for TrfAddress {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::FirstHalf => write!(f, "TrfAddress::FirstHalf"),
Self::SecondHalf => write!(f, "TrfAddress::SecondHalf"),
Self::Full => write!(f, "TrfAddress::Full"),
}
}
}
#[primitive(HostTensor)]
#[derive(Debug, Clone)]
pub struct HostTensor<D: Scalar, Element: M, B: Backend = CurrentBackend> {
inner: Tensor<D, Element, B>,
}
impl<D: Scalar, Element: M, B: Backend> From<Tensor<D, Element, B>> for HostTensor<D, Element, B> {
fn from(inner: Tensor<D, Element, B>) -> Self {
Self { inner }
}
}
impl<D: Scalar, Element: M, B: Backend> HostTensor<D, Element, B> {
pub type Mapping = Element;
pub(crate) fn inner_tensor(&self) -> &Tensor<D, Element, B> {
&self.inner
}
pub fn from_buf(data: impl IntoIterator<Item = D>) -> Self {
Tensor::from_buf(data).into()
}
pub fn try_from_buf(data: impl IntoIterator<Item = D>) -> Result<Self, BufferConvertError> {
Tensor::try_from_buf(data).map(Into::into)
}
pub async fn to_hbm<Chip: M, Element2: M>(
&self,
_dma: &mut DmaContext<{ Dma::Pcie }>,
address: Address,
) -> HbmTensor<D, Chip, Element2, B> {
B::to_hbm(self, address).await
}
pub fn into_inner(self) -> Tensor<D, Self::Mapping, B> {
self.inner
}
pub fn into_raw(self) -> B::RawTensor<D> {
self.inner.into_raw()
}
pub fn to_buf(&self) -> Vec<D> {
self.inner.to_buf()
}
}
impl<D: Scalar, Element: M, B: Backend> HostTensor<D, Element, B> {
pub fn zero() -> Self
where
D: num_traits::Zero,
{
Tensor::zero().into()
}
#[primitive(HostTensor::rand)]
pub fn rand(rng: &mut impl Rng) -> Self
where
StandardUniform: rand::distr::Distribution<D>,
{
Tensor::rand(rng).into()
}
pub fn uninit() -> Self {
Tensor::uninit().into()
}
pub fn from_safetensors(view: &safetensors::tensor::TensorView<'_>) -> Result<Self, safetensors::SafeTensorError>
where
D: ScalarBytes,
{
fn flat_shape(mapping: &Mapping, out: &mut Vec<usize>) {
match mapping {
Mapping::Pair { left, right } => {
flat_shape(left, out);
flat_shape(right, out);
}
_ => out.push(mapping.size()),
}
}
let mut expected_shape = Vec::new();
flat_shape(&Element::to_value(), &mut expected_shape);
if view.shape() != expected_shape.as_slice() {
return Err(safetensors::SafeTensorError::TensorInvalidInfo);
}
let stride = D::BITS / 8;
if view.data().len() != Element::SIZE * stride {
return Err(safetensors::SafeTensorError::TensorInvalidInfo);
}
Ok(Tensor::from_buf(view.data().chunks_exact(stride).map(D::from_le_bytes)).into())
}
}
impl<D: Scalar, Element: M, B: Backend> HostTensor<D, Element, B>
where
B::RawTensor<D>: RawTensorOpt<D>,
{
pub fn from_opt_buf(data: impl IntoIterator<Item = Opt<D>>) -> Self {
Tensor::from_opt_buf(data).into()
}
pub fn try_from_opt_buf(data: impl IntoIterator<Item = Opt<D>>) -> Result<Self, BufferConvertError> {
Tensor::try_from_opt_buf(data).map(Into::into)
}
pub fn to_buf_opt(&self) -> Vec<Opt<D>> {
self.inner.to_buf_opt()
}
}
#[primitive(HbmTensor)]
#[derive(Debug)]
pub struct HbmTensor<D: Scalar, Chip: M, Element: M, B: Backend = CurrentBackend> {
inner: Tensor<D, Pair<Chip, Element>, B>,
address: Address,
}
impl<D: Scalar, Chip: M, Element: M, B: Backend> crate::runtime::DeviceSend for HbmTensor<D, Chip, Element, B> {}
impl<D: Scalar, Chip: M, Element: M, B: Backend> crate::runtime::DeviceSend for &HbmTensor<D, Chip, Element, B> {}
impl<D: Scalar, Chip: M, Element: M, B: Backend> crate::runtime::DeviceSend for &mut HbmTensor<D, Chip, Element, B> {}
impl<D: Scalar, Chip: M, Element: M, B: Backend> crate::runtime::DeviceSend for HbmTensorView<'_, D, Chip, Element, B> {}
impl<D: Scalar, Chip: M, Element: M, B: Backend> crate::runtime::DeviceSend
for HbmTensorViewMut<'_, D, Chip, Element, B>
{
}
impl<D: Scalar, Chip: M, Element: M, B: Backend> HbmTensor<D, Chip, Element, B> {
pub type Mapping = m![{ Chip }, { Element }];
pub(crate) fn new(inner: Tensor<D, Self::Mapping, B>, address: Address) -> Self {
Self { inner, address }
}
pub(crate) fn inner_tensor(&self) -> &Tensor<D, Self::Mapping, B> {
&self.inner
}
pub fn address(&self) -> Address {
self.address
}
pub fn size() -> usize {
Pair::<Chip, Element>::SIZE * std::mem::size_of::<D>()
}
pub async fn to_host<Element2: M>(&self, _dma: &mut DmaContext<{ Dma::Pcie }>) -> HostTensor<D, Element2, B> {
B::from_hbm(self).await
}
pub fn to_buf(&self) -> Vec<D> {
self.inner.to_buf()
}
}
impl<D: Scalar, Chip: M, Element: M, B: Backend> HbmTensor<D, Chip, Element, B> {
pub fn to_buf_or_default(&self) -> Vec<D> {
self.inner.to_buf_or_default()
}
}
impl<D: Scalar, Chip: M, Element: M, B: Backend> HbmTensor<D, Chip, Element, B> {
#[primitive(HbmTensor::from_addr)]
pub unsafe fn from_addr(address: Address) -> Self {
let axes = gen_axes::<Pair<Chip, Element>>();
Self::new(Tensor::from_inner(B::RawTensor::uninit_from_axes(axes)), address)
}
}
impl<D: Scalar, Chip: M, Element: M, B: Backend> HbmTensor<D, Chip, Element, B> {
#[primitive(HbmTensor::view)]
pub fn view<'l>(&'l self) -> HbmTensorView<'l, D, Chip, Element, B> {
HbmTensorView {
inner: self.inner.view(),
address: self.address,
}
}
#[primitive(HbmTensor::view_mut)]
pub fn view_mut<'l>(&'l mut self) -> HbmTensorViewMut<'l, D, Chip, Element, B> {
HbmTensorViewMut {
inner: self.inner.view_mut(),
address: self.address,
}
}
#[primitive(HbmTensor::to_hbm)]
pub fn to_hbm<const DMA: Dma, Element2: M>(
&self,
_dma: &mut DmaContext<{ DMA }>,
address: Address,
) -> HbmTensor<D, Chip, Element2, B> {
HbmTensor::new(self.inner.transpose(true), address)
}
#[primitive(HbmTensor::dma_gather)]
pub fn dma_gather<Cluster2: M, Slice2: M, Element2: M, Element3: M>(
&self,
index: &HbmTensor<i32, Chip, Element3, B>,
address: Address,
scaled: bool,
) -> DmTensor<D, Chip, Cluster2, Slice2, Element2, B> {
let mut output: DmTensor<D, Chip, Cluster2, Slice2, Element2, B> = unsafe { DmTensor::from_addr(address) };
self.inner.write_gather::<_, _>(&mut output.inner, &index.inner, scaled);
output
}
}
impl<D: Scalar, Chip: M, Element: M, B: Backend> HbmTensor<D, Chip, Element, B> {
#[primitive(HbmTensor::to_dm)]
pub fn to_dm<Cluster: M, Slice: M, Element2: M>(
&self,
_dma: &mut DmaContext<{ Dma::Tensor }>,
address: Address,
) -> DmTensor<D, Chip, Cluster, Slice, Element2, B> {
assert_dma_layout::<D, m![{ Chip }, { Element }], Element2>(DMA_SRAM_WRITE_WIDTH);
DmTensor::new(self.inner.transpose(true), address)
}
}
impl<D: Scalar, Chip: M, Element: M, B: Backend> HbmTensor<D, Chip, Element, B> {
pub fn hbm_cluster_shuffle<const DMA: Dma>(
&self,
_dma: &mut DmaContext<{ DMA }>,
_shuffle_pattern: &[usize],
) -> HbmTensor<D, Chip, Element, B> {
todo!(
"hbm_cluster_shuffle is Under Construction. HbmTensor has no Cluster axis \
(only Chip + Element); Cluster distribution is decided at .to_dm() time. \
No current callers. Either the Element axis is meant to encode a Cluster \
sub-axis (API needs to take that axis explicitly) or the operation belongs \
on DmTensorView::dm_cluster_shuffle. Pending design review; see the doc \
comment on hbm_cluster_shuffle."
)
}
}
#[primitive(HbmTensorView)]
#[derive(Debug, Clone)]
pub struct HbmTensorView<'l, D: Scalar, Chip: M, Element: M, B: Backend = CurrentBackend> {
inner: TensorView<'l, D, Pair<Chip, Element>, B>,
address: Address,
}
impl<'l, D: Scalar, Chip: M, Element: M, B: Backend> HbmTensorView<'l, D, Chip, Element, B> {
pub type Mapping = m![{ Chip }, { Element }];
pub fn address(&self) -> Address {
self.address
}
#[primitive(HbmTensorView::to_hbm_view)]
pub fn to_hbm_view<const DMA: Dma, Element2: M>(
self,
_dma: &mut DmaContext<{ DMA }>,
mut dst: HbmTensorViewMut<'l, D, Chip, Element2, B>,
) {
dst.inner.write_transpose(self.inner, true);
}
pub fn to_dm_view<Cluster: M, Slice: M, Element2: M>(
self,
_dma: &mut DmaContext<{ Dma::Tensor }>,
mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element2, B>,
) {
assert_dma_layout::<D, m![{ Chip }, { Element }], Element2>(DMA_SRAM_WRITE_WIDTH);
dst.inner.write_transpose(self.inner, true);
}
pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
&self,
start: usize,
) -> HbmTensorView<'l, D, Chip2, Element, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
HbmTensorView {
inner,
address: self.address,
}
}
#[primitive(HbmTensorView::tile)]
pub fn tile<Index: M, const LEN: usize, Element2: M>(
&self,
start: usize,
) -> HbmTensorView<'l, D, Chip, Element2, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
HbmTensorView {
inner,
address: self.address,
}
}
pub fn to_buf(&self) -> Vec<D> {
self.inner.clone().read().to_buf()
}
}
impl<'l, D: Scalar, Chip: M, Element: M, B: Backend> HbmTensorView<'l, D, Chip, Element, B> {
pub fn to_buf_or_default(&self) -> Vec<D> {
self.inner.clone().read().to_buf_or_default()
}
}
impl<'l, D: Scalar, Chip: M, Element: M, B: Backend> HbmTensorView<'l, D, Chip, Element, B> {
#[primitive(HbmTensorView::to_dm)]
pub fn to_dm<Cluster: M, Slice: M, Element2: M>(
self,
_dma: &mut DmaContext<{ Dma::Tensor }>,
address: Address,
) -> DmTensor<D, Chip, Cluster, Slice, Element2, B> {
assert_dma_layout::<D, m![{ Chip }, { Element }], Element2>(DMA_SRAM_WRITE_WIDTH);
DmTensor::new(self.inner.read().transpose(true), address)
}
pub fn hbm_chip_shuffle<const CHIP_DIM: usize, const DMA: Dma>(
self,
dma: &mut DmaContext<{ DMA }>,
shuffle_pattern: &[usize; CHIP_DIM],
) -> HbmTensor<D, Chip, Element, B> {
let mut shuffled: HbmTensor<D, Chip, Element, B> = unsafe { HbmTensor::from_addr(0) };
for (target_chip_idx, source_chip_idx) in shuffle_pattern.iter().enumerate() {
self.chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(*source_chip_idx)
.to_hbm_view(
dma,
shuffled
.view_mut()
.chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(target_chip_idx),
);
}
shuffled
}
}
#[primitive(HbmTensorViewMut)]
#[derive(Debug)]
pub struct HbmTensorViewMut<'l, D: Scalar, Chip: M, Element: M, B: Backend = CurrentBackend> {
inner: TensorViewMut<'l, D, Pair<Chip, Element>, B>,
address: Address,
}
impl<'l, D: Scalar, Chip: M, Element: M, B: Backend> HbmTensorViewMut<'l, D, Chip, Element, B> {
pub fn address(&self) -> Address {
self.address
}
pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
self,
start: usize,
) -> HbmTensorViewMut<'l, D, Chip2, Element, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
HbmTensorViewMut {
inner,
address: self.address,
}
}
#[primitive(HbmTensorViewMut::tile)]
pub fn tile<Index: M, const LEN: usize, Element2: M>(
self,
start: usize,
) -> HbmTensorViewMut<'l, D, Chip, Element2, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
HbmTensorViewMut {
inner,
address: self.address,
}
}
}
#[primitive(DmTensor)]
#[derive(Debug)]
pub struct DmTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend = CurrentBackend> {
inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>, B>,
address: Address,
_marker: PhantomData<(D, Chip, Cluster, Slice, Element)>,
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend> DmTensor<D, Chip, Cluster, Slice, Element, B> {
pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
pub(crate) fn new(inner: Tensor<D, Self::Mapping, B>, address: Address) -> Self {
Self {
inner,
address,
_marker: PhantomData,
}
}
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend> DmTensor<D, Chip, Cluster, Slice, Element, B> {
#[primitive(DmTensor::from_addr)]
pub unsafe fn from_addr(address: Address) -> Self {
let axes = gen_axes::<Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>();
Self::new(Tensor::from_inner(B::RawTensor::uninit_from_axes(axes)), address)
}
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend> DmTensor<D, Chip, Cluster, Slice, Element, B> {
#[primitive(DmTensor::view)]
pub fn view<'l>(&'l self) -> DmTensorView<'l, D, Chip, Cluster, Slice, Element, B> {
DmTensorView {
inner: self.inner.view(),
}
}
#[primitive(DmTensor::view_mut)]
pub fn view_mut<'l>(&'l mut self) -> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element, B> {
DmTensorViewMut {
inner: self.inner.view_mut(),
}
}
#[primitive(DmTensor::to_hbm)]
pub fn to_hbm<Element2: M>(
&self,
_dma: &mut DmaContext<{ Dma::Tensor }>,
address: Address,
) -> HbmTensor<D, Chip, Element2, B> {
HbmTensor::new(self.inner.transpose(true), address)
}
#[primitive(DmTensor::dma_scatter)]
pub fn dma_scatter<Key: M, Element2: M, Element3: M>(
&self,
index: &HbmTensor<i32, Chip, Element3, B>,
output: &mut HbmTensor<D, Chip, Element2, B>,
scaled: bool,
) {
let src = Pair::<Slice, Element>::to_value();
let key = Key::to_value();
assert!(
sequence(&[&key], &[&src], SequencerMode::Read).is_ok(),
"scatter key `{key}` must be fully contained in source `{src}`. \
If the key axis is split across Chip and Element, indirect DMA cannot address it.",
);
self.inner
.write_scatter::<Key, _, _>(&mut output.inner, &index.inner, scaled);
}
pub fn to_dm<Slice2: M, Element2: M>(
&self,
_dma: &mut DmaContext<{ Dma::Tensor }>,
address: Address,
) -> DmTensor<D, Chip, Cluster, Slice2, Element2, B> {
assert_dma_layout::<D, m![{ Cluster }, { Slice }, { Element }], Element2>(DMA_SRAM_WRITE_WIDTH);
DmTensor::new(self.inner.transpose(true), address)
}
pub fn to_dm_pcopy<Slice2: M, Element2: M>(
&self,
sub: &mut TuContext<{ Tu::Sub }>,
dst: &mut DmTensor<D, Chip, Cluster, Slice2, Element2, B>,
) {
self.view().to_dm_view_pcopy(sub, dst.view_mut());
}
#[primitive(DmTensor::reshape)]
pub unsafe fn reshape<Chip2: M, Cluster2: M, Slice2: M, Element2: M>(
self,
) -> DmTensor<D, Chip2, Cluster2, Slice2, Element2, B> {
assert_eq!(Chip::SIZE, Chip2::SIZE);
assert_eq!(Cluster::SIZE, Cluster2::SIZE);
assert_eq!(Slice::SIZE, Slice2::SIZE);
assert_eq!(Element::SIZE, Element2::SIZE);
let reshaped = unsafe {
self.inner
.reshape::<m![{ Chip2 }, { Cluster2 }, { Slice2 }, { Element2 }]>()
};
DmTensor::new(reshaped, self.address)
}
}
#[primitive(DmTensorViewMut)]
#[derive(Debug)]
pub struct DmTensorViewMut<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend = CurrentBackend> {
pub(crate) inner: TensorViewMut<'l, D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>, B>,
}
#[primitive(DmTensorView)]
#[derive(Debug, Clone)]
pub struct DmTensorView<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend = CurrentBackend> {
pub(crate) inner: TensorView<'l, D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>, B>,
}
impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend>
From<DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element, B>>
for DmTensorView<'l, D, Chip, Cluster, Slice, Element, B>
{
fn from(view: DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element, B>) -> Self {
Self {
inner: view.inner.into(),
}
}
}
impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend>
DmTensorView<'l, D, Chip, Cluster, Slice, Element, B>
{
pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
#[primitive(DmTensorView::to_hbm_view)]
pub fn to_hbm_view<Element2: M>(
self,
_dma: &mut DmaContext<{ Dma::Tensor }>,
mut dst: HbmTensorViewMut<'l, D, Chip, Element2, B>,
) {
dst.inner.write_transpose(self.inner, true);
}
#[primitive(DmTensorView::to_dm_view)]
pub fn to_dm_view<Slice2: M, Element2: M>(
self,
_dma: &mut DmaContext<{ Dma::Tensor }>,
mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice2, Element2, B>,
) {
assert_dma_layout::<D, m![{ Cluster }, { Slice }, { Element }], Element2>(DMA_SRAM_WRITE_WIDTH);
dst.inner.write_transpose(self.inner, true);
}
pub fn to_dm_view_pcopy<Slice2: M, Element2: M>(
self,
_sub: &mut TuContext<{ Tu::Sub }>,
mut dst: DmTensorViewMut<'l, D, Chip, Cluster, Slice2, Element2, B>,
) {
dst.inner.write_transpose(self.inner, true);
}
pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
&self,
start: usize,
) -> DmTensorView<'l, D, Chip2, Cluster, Slice, Element, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
DmTensorView { inner }
}
pub fn cluster_tile<Index: M, const LEN: usize, Cluster2: M>(
&self,
start: usize,
) -> DmTensorView<'l, D, Chip, Cluster2, Slice, Element, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
DmTensorView { inner }
}
#[primitive(DmTensorView::slice_tile)]
pub fn slice_tile<Index: M, const LEN: usize, Slice2: M>(
&self,
start: usize,
) -> DmTensorView<'l, D, Chip, Cluster, Slice2, Element, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
DmTensorView { inner }
}
#[primitive(DmTensorView::tile)]
pub fn tile<Index: M, const LEN: usize, Element2: M>(
&self,
start: usize,
) -> DmTensorView<'l, D, Chip, Cluster, Slice, Element2, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
DmTensorView { inner }
}
#[primitive(DmTensorView::dm_cluster_shuffle)]
pub fn dm_cluster_shuffle<const CLUSTER_DIM: usize>(
self,
dma: &mut DmaContext<{ Dma::Tensor }>,
shuffle_pattern: &[usize],
) -> DmTensor<D, Chip, Cluster, Slice, Element, B> {
let mut shuffled: DmTensor<D, Chip, Cluster, Slice, Element, B> = unsafe { DmTensor::from_addr(0) };
for (target_cluster_idx, source_cluster_idx) in shuffle_pattern.iter().enumerate() {
self.cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(*source_cluster_idx)
.to_dm_view(
dma,
shuffled
.view_mut()
.cluster_tile::<Cluster, 1, Padding<Identity, CLUSTER_DIM>>(target_cluster_idx),
);
}
shuffled
}
#[primitive(DmTensorView::dm_chip_shuffle)]
pub fn dm_chip_shuffle<const CHIP_DIM: usize>(
self,
dma: &mut DmaContext<{ Dma::Tensor }>,
shuffle_pattern: &[usize; CHIP_DIM],
) -> DmTensor<D, Chip, Cluster, Slice, Element, B> {
let mut shuffled: DmTensor<D, Chip, Cluster, Slice, Element, B> = unsafe { DmTensor::from_addr(0) };
for (target_chip_idx, source_chip_idx) in shuffle_pattern.iter().enumerate() {
self.chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(*source_chip_idx)
.to_dm_view(
dma,
shuffled
.view_mut()
.chip_tile::<Chip, 1, Padding<Identity, CHIP_DIM>>(target_chip_idx),
);
}
shuffled
}
}
impl<'l, D: Scalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend>
DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element, B>
{
pub fn chip_tile<Index: M, const LEN: usize, Chip2: M>(
self,
start: usize,
) -> DmTensorViewMut<'l, D, Chip2, Cluster, Slice, Element, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
DmTensorViewMut { inner }
}
pub fn cluster_tile<Index: M, const LEN: usize, Cluster2: M>(
self,
start: usize,
) -> DmTensorViewMut<'l, D, Chip, Cluster2, Slice, Element, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
DmTensorViewMut { inner }
}
#[primitive(DmTensorViewMut::tile)]
pub fn tile<Index: M, const LEN: usize, Element2: M>(
self,
start: usize,
) -> DmTensorViewMut<'l, D, Chip, Cluster, Slice, Element2, B> {
let inner = self.inner.tile::<Index, _, LEN>(start);
DmTensorViewMut { inner }
}
}
#[primitive(TrfTensor)]
#[derive(Debug)]
pub struct TrfTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Lane: M, Element: M, B: Backend = CurrentBackend> {
pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Lane, Element>>>>, B>,
#[expect(dead_code)]
address: TrfAddress,
_marker: PhantomData<(D, Chip, Cluster, Slice, Lane, Element)>,
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Lane: M, Element: M, B: Backend>
TrfTensor<D, Chip, Cluster, Slice, Lane, Element, B>
{
pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Lane }, { Element }];
pub(crate) fn new(inner: Tensor<D, Self::Mapping, B>, address: TrfAddress) -> Self {
Self {
inner,
address,
_marker: PhantomData,
}
}
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Lane: M, Element: M, B: Backend>
TrfTensor<D, Chip, Cluster, Slice, Lane, Element, B>
{
pub unsafe fn from_addr(address: TrfAddress) -> Self {
let axes = gen_axes::<Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Lane, Element>>>>>();
Self::new(Tensor::from_inner(B::RawTensor::uninit_from_axes(axes)), address)
}
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Lane: M, Element: M, B: Backend>
TrfTensor<D, Chip, Cluster, Slice, Lane, Element, B>
{
pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping, B> {
self.inner.view_mut()
}
pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping, B> {
self.inner.view()
}
}
#[primitive(VrfTensor)]
#[derive(Debug, Clone)]
pub struct VrfTensor<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend = CurrentBackend> {
pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>, B>,
#[expect(dead_code)]
address: Address,
_marker: PhantomData<(D, Chip, Cluster, Slice, Element)>,
}
impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend>
VrfTensor<D, Chip, Cluster, Slice, Element, B>
{
pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Element }];
pub(crate) fn new(inner: Tensor<D, Self::Mapping, B>, address: Address) -> Self {
Self {
inner,
address,
_marker: PhantomData,
}
}
}
impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend>
VrfTensor<D, Chip, Cluster, Slice, Element, B>
{
pub unsafe fn from_addr(address: Address) -> Self {
let axes = gen_axes::<Pair<Chip, Pair<Cluster, Pair<Slice, Element>>>>();
Self::new(Tensor::from_inner(B::RawTensor::uninit_from_axes(axes)), address)
}
}
impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M, B: Backend>
VrfTensor<D, Chip, Cluster, Slice, Element, B>
{
pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping, B> {
self.inner.view_mut()
}
pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping, B> {
self.inner.view()
}
}
#[derive(Debug)]
pub struct DpeTensor<D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Lane: M, Packet: M, B: Backend = CurrentBackend>
{
inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Time, Pair<Lane, Packet>>>>>, B>,
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Lane: M, Packet: M, B: Backend>
DpeTensor<D, Chip, Cluster, Slice, Time, Lane, Packet, B>
{
pub type Mapping = m![{ Chip }, { Cluster }, { Slice }, { Time }, { Lane }, { Packet }];
}
impl<D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Lane: M, Packet: M, B: Backend>
DpeTensor<D, Chip, Cluster, Slice, Time, Lane, Packet, B>
{
pub fn view_mut<'l>(&'l mut self) -> TensorViewMut<'l, D, Self::Mapping, B> {
self.inner.view_mut()
}
pub fn view<'l>(&'l self) -> TensorView<'l, D, Self::Mapping, B> {
self.inner.view()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scalar::Scalar;
fn reachable_end<Src: M, Dst: M>() -> usize {
config_divide_relaxed(&Src::to_value(), &Dst::to_value()).contiguous_tail
}
#[test]
fn unittest_extents_reachable_end_with_dst_padding_absorb() {
axes![A = 8, B = 3];
assert_eq!(reachable_end::<m![A, B], m![A, B # 8]>(), 8);
}
#[test]
fn unittest_extents_reachable_end_invariant_under_outer_cluster_slice() {
axes![Cl = 2, Sl = 4, A = 3];
assert_eq!(
reachable_end::<m![A], m![A # 16]>(),
reachable_end::<m![Cl, Sl, A], m![A # 16]>(),
);
}
#[test]
fn unittest_extents_reachable_end_single_element_underflows_alignment() {
axes![A = 1];
let end = reachable_end::<m![A], m![A]>();
assert_eq!(end, 1);
assert_eq!(<i32 as Scalar>::size_in_bytes_from_length(end), 4);
assert!(!<i32 as Scalar>::size_in_bytes_from_length(end).is_multiple_of(DMA_SRAM_WRITE_WIDTH));
}
#[test]
fn unittest_assert_dma_layout_canonical_cluster_slice_passes() {
axes![Cl = 2, Sl = 4, A = 8, B = 4];
assert_dma_layout::<i32, m![Cl, Sl, A, B], m![A, B]>(DMA_SRAM_WRITE_WIDTH);
}
#[test]
fn unittest_assert_dma_layout_dst_padding_absorbed() {
axes![A = 8, B = 3];
assert_dma_layout::<i32, m![A, B], m![A, B # 8]>(DMA_SRAM_WRITE_WIDTH);
}
#[test]
fn unittest_assert_dma_layout_min_align_one_is_noop() {
axes![A = 1];
assert_dma_layout::<i32, m![A], m![A]>(1);
axes![Cl = 2, Sl = 4, B = 3];
assert_dma_layout::<i32, m![Cl, Sl, B], m![B # 7]>(1);
}
}