use furiosa_mapping::*;
use furiosa_opt_lower::{DivideTerm, config_divide_exact};
use furiosa_opt_macro::primitive;
use crate::context::*;
use crate::engine::align_up;
use crate::engine::contraction::{
CONTRACT_LANE_OUT_PACKET_ELEMENTS, ContractTensor, ContractTimeTensor, TEMPORAL_ACCUMULATOR_COLS,
padding_per_stride,
};
use crate::runtime::Backend;
use crate::scalar::*;
#[primitive(LaneMode)]
#[derive(Clone, Debug)]
pub enum LaneMode {
Interleaved,
Sequential,
}
impl std::fmt::Display for LaneMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LaneMode::Interleaved => write!(f, "Interleaved"),
LaneMode::Sequential => write!(f, "Sequential"),
}
}
}
impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Lane: M, Time: M, Packet: M, B: Backend>
ContractTimeTensor<'l, T, D, Chip, Cluster, Slice, Lane, Time, Packet, B>
{
#[primitive(ContractTimeTensor::contract_lane)]
pub fn contract_lane<OutTime: M, OutPacket: M>(
self,
mode: LaneMode,
) -> ContractTensor<'l, T, D, Chip, Cluster, Slice, OutTime, OutPacket, B> {
verify_contract_lane(
Lane::to_value(),
Time::to_value(),
Packet::to_value(),
OutTime::to_value(),
OutPacket::to_value(),
self.pre_reduce_time,
mode,
);
ContractTensor::new(self.ctx, self.inner.transpose(false))
}
}
pub(crate) fn verify_contract_lane(
lane: Mapping,
time: Mapping,
packet: Mapping,
out_time: Mapping,
out_packet: Mapping,
pre_reduce_time: Mapping,
kind: LaneMode,
) {
assert!(
packet.size() <= TEMPORAL_ACCUMULATOR_COLS,
"contract_lane: Packet::SIZE must be at most {TEMPORAL_ACCUMULATOR_COLS}, got {}",
packet.size()
);
assert_eq!(
out_packet.size(),
CONTRACT_LANE_OUT_PACKET_ELEMENTS,
"contract_lane: OutPacket::SIZE must be {CONTRACT_LANE_OUT_PACKET_ELEMENTS}, got {}",
out_packet.size()
);
let lane_size = lane.size();
let packet = packet.remove_padding();
let (outer_time, packet_outer_size) = match kind {
LaneMode::Interleaved => {
let expected_out_packet = lane.replace_padding(CONTRACT_LANE_OUT_PACKET_ELEMENTS).normalize();
let out_packet = out_packet.normalize();
assert_eq!(
out_packet, expected_out_packet,
"contract_lane ({kind}): OutPacket mismatch. Expected: {expected_out_packet}, got: {out_packet}"
);
let outer_time = (1..=out_time.size().min(packet.size()).min(TEMPORAL_ACCUMULATOR_COLS))
.filter(|&split| {
out_time.size().is_multiple_of(split)
&& (split > 1 || packet.size() == 1)
&& out_time.size() / split <= time.size()
})
.find_map(|split| {
let (outer_time, sliced_packet) = out_time.split_at(split);
if sliced_packet.normalize() != packet.clone().normalize() {
return None;
}
Some(outer_time)
})
.unwrap_or_else(|| {
panic!(
"contract_lane ({kind}): OutTime mismatch. \
Could not decompose OutTime {out_time} into [Time, Packet (truncated)] \
where Time is {time} and Packet is a truncation of {packet}"
)
});
(outer_time, 1)
}
LaneMode::Sequential => {
let padded = packet
.clone()
.replace_padding(align_up(packet.size(), CONTRACT_LANE_OUT_PACKET_ELEMENTS));
let (packet_outer, packet_inner) = padded.split_at(CONTRACT_LANE_OUT_PACKET_ELEMENTS);
let packet_outer_size = packet_outer.size();
assert_eq!(
packet_inner.clone().normalize(),
out_packet.normalize(),
"contract_lane ({kind}): OutPacket mismatch. Expected: {packet_inner}, got: {out_packet}"
);
let lane_packet = lane.pair(packet_outer);
let (outer_time, inner_time) = out_time.split_at(lane_packet.size());
assert_eq!(
inner_time.normalize(),
lane_packet.clone().normalize(),
"contract_lane ({kind}): OutTime mismatch. Expected {lane_packet}, got {inner_time}"
);
(outer_time, packet_outer_size)
}
};
assert_eq!(
outer_time.normalize(),
time.clone().normalize(),
"contract_lane ({kind}): OutTime mismatch. Outer portion of OutTime must equal Time: \
expected {time}, got {outer_time}"
);
let division_terms = config_divide_exact(&pre_reduce_time, &time).unwrap_or_else(|_| {
panic!("contract_lane ({kind}): inconsistent pre/post reduce Time: {pre_reduce_time}, {time}")
});
let time_padding_per_stride = padding_per_stride(&pre_reduce_time);
let padding_end = |d: &DivideTerm| {
d.dividend_stride
* time_padding_per_stride
.get(&d.dividend_stride)
.copied()
.unwrap_or(d.resize)
};
let inner_time = if division_terms.is_empty() {
1
} else if padding_end(&division_terms[0]) < pre_reduce_time.size() {
time.size()
} else {
division_terms
.windows(2)
.find(|w| padding_end(&w[1]) != w[0].dividend_stride)
.map_or(1, |w| w[0].divisor_stride)
};
match kind {
LaneMode::Interleaved => {
let buffer = inner_time * packet.size();
assert!(
buffer <= 1024 / CONTRACT_LANE_OUT_PACKET_ELEMENTS,
"contract_lane ({}): axes inner to reduce must be <= {} in size, got {}",
kind,
1024 / CONTRACT_LANE_OUT_PACKET_ELEMENTS,
buffer
);
}
LaneMode::Sequential => {
let buffer = inner_time * lane_size * packet_outer_size;
assert!(
buffer <= 1024 / TEMPORAL_ACCUMULATOR_COLS,
"contract_lane ({}): axes inner to reduce must be <= {} in size, got {}",
kind,
1024 / TEMPORAL_ACCUMULATOR_COLS,
buffer
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
axes![A = 4, B = 2, C = 4, D = 32, K = 64, M = 4, N = 8, O = 2, P = 8];
mod out_packet_size {
use super::*;
use furiosa_mapping::M as _;
#[test]
fn valid() {
verify_contract_lane(
<m![1]>::to_value(),
<m![A]>::to_value(),
<m![1]>::to_value(),
<m![A]>::to_value(),
<m![1 # 8]>::to_value(),
<m![A]>::to_value(),
LaneMode::Interleaved,
);
}
}
mod interleaved {
use super::*;
use furiosa_mapping::M as _;
#[test]
fn valid() {
verify_contract_lane(
<m![1]>::to_value(),
<m![B]>::to_value(),
<m![1]>::to_value(),
<m![B]>::to_value(),
<m![1 # 8]>::to_value(),
<m![B]>::to_value(),
LaneMode::Interleaved,
);
}
#[test]
fn valid_padding() {
verify_contract_lane(
<m![1]>::to_value(),
<m![B # 4]>::to_value(),
<m![1]>::to_value(),
<m![B # 4]>::to_value(),
<m![1 # 8]>::to_value(),
<m![B # 4]>::to_value(),
LaneMode::Interleaved,
);
}
#[test]
fn valid_no_reduction_with_padding() {
verify_contract_lane(
<m![1]>::to_value(),
<m![A # 8, B]>::to_value(),
<m![D]>::to_value(),
<m![A # 8, B, D]>::to_value(),
<m![1 # 8]>::to_value(),
<m![A # 8, B]>::to_value(),
LaneMode::Interleaved,
);
}
#[test]
fn valid_non_outermost() {
verify_contract_lane(
<m![N]>::to_value(),
<m![C, B]>::to_value(),
<m![1]>::to_value(),
<m![C, B]>::to_value(),
<m![N]>::to_value(),
<m![C, B]>::to_value(),
LaneMode::Interleaved,
);
}
#[test]
fn valid_four_rows() {
verify_contract_lane(
<m![M]>::to_value(),
<m![C, B]>::to_value(),
<m![1]>::to_value(),
<m![C, B]>::to_value(),
<m![M # 8]>::to_value(),
<m![C, B]>::to_value(),
LaneMode::Interleaved,
);
}
#[test]
fn valid_all_time_reduced() {
verify_contract_lane(
<m![N]>::to_value(),
<m![1]>::to_value(),
<m![1]>::to_value(),
<m![1]>::to_value(),
<m![N]>::to_value(),
<m![1]>::to_value(),
LaneMode::Interleaved,
);
}
}
mod sequential {
use super::*;
use furiosa_mapping::M as _;
#[test]
fn valid() {
verify_contract_lane(
<m![N]>::to_value(),
<m![B]>::to_value(),
<m![1]>::to_value(),
<m![B, N]>::to_value(),
<m![1 # 8]>::to_value(),
<m![B]>::to_value(),
LaneMode::Sequential,
);
}
#[test]
fn valid_padded_row() {
verify_contract_lane(
<m![N]>::to_value(),
<m![B]>::to_value(),
<m![1]>::to_value(),
<m![B, N # 8]>::to_value(),
<m![1 # 8]>::to_value(),
<m![B]>::to_value(),
LaneMode::Sequential,
);
}
#[test]
fn valid_all_time_reduced() {
verify_contract_lane(
<m![N]>::to_value(),
<m![1]>::to_value(),
<m![1]>::to_value(),
<m![N]>::to_value(),
<m![1 # 8]>::to_value(),
<m![1]>::to_value(),
LaneMode::Sequential,
);
}
#[test]
fn valid_no_reduction_with_padding() {
verify_contract_lane(
<m![N]>::to_value(),
<m![A # 8, B]>::to_value(),
<m![1]>::to_value(),
<m![A # 8, B, N]>::to_value(),
<m![1 # 8]>::to_value(),
<m![A # 8, B]>::to_value(),
LaneMode::Sequential,
);
}
#[test]
fn valid_padded_packet() {
verify_contract_lane(
<m![N]>::to_value(),
<m![M]>::to_value(),
<m![B]>::to_value(),
<m![M, N]>::to_value(),
<m![B # 8]>::to_value(),
<m![M]>::to_value(),
LaneMode::Sequential,
);
}
#[test]
fn valid_full_temporal_reduction() {
verify_contract_lane(
<m![N]>::to_value(),
<m![1]>::to_value(),
<m![D]>::to_value(),
<m![N, D / 8]>::to_value(),
<m![D % 8]>::to_value(),
<m![1]>::to_value(),
LaneMode::Sequential,
);
}
#[test]
fn valid_multi_axis_reduction() {
verify_contract_lane(
<m![N]>::to_value(),
<m![B]>::to_value(),
<m![1]>::to_value(),
<m![B, N]>::to_value(),
<m![1 # 8]>::to_value(),
<m![B]>::to_value(),
LaneMode::Sequential,
);
}
}
}