use furiosa_mapping::*;
use furiosa_opt_lower::config_divide_exact;
use crate::engine::FLIT_BYTES;
use crate::runtime::Backend;
use crate::scalar::Scalar;
use crate::tensor::Tensor;
pub(super) fn contract_outer<
D: Scalar,
Chip: M,
Cluster: M,
Slice: M,
Lane: M,
Time: M,
Packet: M,
OutTime: M,
OutPacket: M,
B: Backend,
>(
inner: Tensor<D, m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }], B>,
) -> Tensor<D, m![{ Chip }, { Cluster }, { Slice }, { Lane }, { OutTime }, { OutPacket }], B> {
verify_stream_adapter::<D, Lane, Time, Packet, OutTime, OutPacket>();
inner.transpose(true)
}
pub(super) fn verify_stream_adapter<D: Scalar, Lane: M, Time: M, Packet: M, OutTime: M, OutPacket: M>() {
assert!(
[1, 2, 4, 8].contains(&Lane::SIZE),
"Lane::SIZE must be 1, 2, 4, or 8, got {}",
Lane::SIZE
);
let out_packet_size = D::size_in_bytes_from_length(OutPacket::SIZE);
assert!(
[32, 64].contains(&out_packet_size),
"OutPacket must be 32 or 64 bytes (matching PackSize ∈ {{1, 2}}), got {out_packet_size} bytes"
);
let flit_elements = D::length_from_bytes(FLIT_BYTES);
let (out_packet_outer, out_packet_inner) = OutPacket::to_value().split_at(flit_elements);
let out_packet_inner = out_packet_inner.normalize();
let expected_packet = Packet::to_value().normalize();
assert_eq!(
out_packet_inner, expected_packet,
"`contract_outer` packet mismatch: inner flit of OutPacket != Packet: {out_packet_inner} != {expected_packet}",
);
let expected_time = OutTime::to_value().pair(out_packet_outer.remove_padding()).normalize();
let input_time = Time::to_value();
let tiling_size = expected_time.size() / input_time.size();
let division_terms =
config_divide_exact(&expected_time, &input_time).expect("`contract_outer`: Time does not divide OutTime");
assert!(
division_terms
.windows(2)
.all(|w| w[0].divisor_stride > w[1].divisor_stride),
"`contract_outer`: Time axes are reordered in OutTime"
);
assert!(
division_terms.iter().all(|d| d.dividend_stride >= tiling_size),
"`contract_outer`: tiling axes must be innermost in OutTime"
);
}