use furiosa_mapping::*;
use furiosa_opt_lower::{DivideTerm, config_divide_exact};
use furiosa_opt_macro::primitive;
use crate::context::*;
use crate::engine::contraction::{ContractPacketTensor, ContractTimeTensor, padding_per_stride};
use crate::runtime::Backend;
use crate::scalar::*;
impl<'l, const T: Tu, D: Scalar, Chip: M, Cluster: M, Slice: M, Lane: M, Time: M, Packet: M, B: Backend>
ContractPacketTensor<'l, T, D, Chip, Cluster, Slice, Lane, Time, Packet, B>
{
#[primitive(ContractPacketTensor::contract_time)]
pub fn contract_time<OutTime: M>(
self,
) -> ContractTimeTensor<'l, T, D, Chip, Cluster, Slice, Lane, OutTime, Packet, B> {
verify_contract_time(Time::to_value(), OutTime::to_value());
ContractTimeTensor {
ctx: self.ctx,
inner: self.inner.reduce_add(),
pre_reduce_time: Time::to_value(),
}
}
}
pub(crate) fn verify_contract_time(time: Mapping, out_time: Mapping) {
let division_terms = config_divide_exact(&time, &out_time).unwrap_or_else(|_| {
panic!(
"contract_time: OutTime mismatch. Some axes present in Time are not present in OutTime: {time}, {out_time}"
)
});
assert!(
division_terms
.windows(2)
.all(|w| w[0].divisor_stride > w[1].divisor_stride),
"contract_time: OutTime axes must follow the same order as the Time axes"
);
let time_padding_per_stride = padding_per_stride(&time);
let mut sorted_divisions: Vec<&DivideTerm> = division_terms.iter().collect();
sorted_divisions.sort_by_key(|d| d.divisor_stride);
if let Some(first) = sorted_divisions.first() {
assert_eq!(
first.divisor_stride, 1,
"contract_time: Padding mismatch. \
OutTime {out_time} has unexpected leading padding not present in Time {time}"
);
}
for (pos, d) in sorted_divisions.iter().enumerate() {
let expected_end = d.divisor_stride
* time_padding_per_stride
.get(&d.dividend_stride)
.copied()
.unwrap_or(d.resize);
let end = sorted_divisions
.get(pos + 1)
.map_or(out_time.size(), |next| next.divisor_stride);
assert_eq!(
expected_end, end,
"contract_time: Padding mismatch. \
Non-reduced axes in OutTime {out_time} do not preserve padding from Time {time}"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
axes![A = 4, B = 2, C = 4, D = 32, K = 64, M = 4, N = 8, O = 2, P = 8];
mod contract_time_subset {
use super::*;
use furiosa_mapping::M as _;
#[test]
fn valid_identity() {
verify_contract_time(<m![A, B]>::to_value(), <m![A, B]>::to_value());
}
#[test]
fn valid_reduce_inner() {
verify_contract_time(<m![A, B]>::to_value(), <m![A]>::to_value());
}
#[test]
fn valid_reduce_outer() {
verify_contract_time(<m![A, B]>::to_value(), <m![B]>::to_value());
}
}
}