pub mod lane;
pub mod outer;
pub mod packet;
pub mod time;
pub use lane::LaneMode;
pub use outer::ContractOuterTensor;
use furiosa_mapping::*;
use crate::context::*;
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::Tensor;
use crate::tensor::tu::{Position, TuTensor};
pub(crate) const TEMPORAL_ACCUMULATOR_COLS: usize = 32;
pub(crate) const CONTRACT_LANE_OUT_PACKET_ELEMENTS: usize = 8;
#[derive(Debug)]
pub struct PositionContraction;
impl Position for PositionContraction {}
#[derive(Debug)]
pub struct ContractPacketTensor<
'l,
const T: Tu,
D: Scalar,
Chip: M,
Cluster: M,
Slice: M,
Lane: M,
Time: M,
Packet: M,
B: Backend = CurrentBackend,
> {
pub(crate) ctx: &'l mut TuContext<{ T }>,
pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Lane, Pair<Time, Packet>>>>>, B>,
}
#[derive(Debug)]
pub struct ContractTimeTensor<
'l,
const T: Tu,
D: Scalar,
Chip: M,
Cluster: M,
Slice: M,
Lane: M,
Time: M,
Packet: M,
B: Backend = CurrentBackend,
> {
pub(crate) ctx: &'l mut TuContext<{ T }>,
pub(crate) inner: Tensor<D, Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Lane, Pair<Time, Packet>>>>>, B>,
pub(crate) pre_reduce_time: Mapping,
}
pub(crate) fn padding_per_stride(m: &Mapping) -> std::collections::HashMap<usize, usize> {
let mut extents = Vec::new();
collect_axis_extents(&m.normalize(), &mut extents);
let mut map = std::collections::HashMap::new();
let mut stride = 1;
for extent in extents {
map.insert(stride, extent);
stride *= extent;
}
map
}
fn collect_axis_extents(m: &Mapping, out: &mut Vec<usize>) {
match m {
Mapping::Pair { left, right } => {
collect_axis_extents(right, out);
collect_axis_extents(left, out);
}
other => out.push(other.size()),
}
}
pub type ContractTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionContraction, D, Chip, Cluster, Slice, Time, Packet, B>;