use furiosa_mapping::*;
use furiosa_opt_macro::primitive;
use crate::context::*;
use crate::engine::vector::scalar::VeScalar;
use crate::engine::{CanApplyCollect, CanApplyToTrf, CanApplyToVrf, FLIT_BYTES, align_up, exact_div};
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::memory::{Address, TrfAddress, TrfTensor, VrfTensor};
use crate::tensor::tu::{Position, TuTensor};
#[derive(Debug)]
pub struct PositionCollect;
impl Position for PositionCollect {}
pub type CollectTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionCollect, D, Chip, Cluster, Slice, Time, Packet, B>;
impl<'l, const T: Tu, P: CanApplyCollect, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::collect)]
pub fn collect<Time2: M, Packet2: M>(self) -> CollectTensor<'l, T, D, Chip, Cluster, Slice, Time2, Packet2, B> {
verify_collect::<D, Time, Packet, Time2, Packet2>();
CollectTensor::new(self.ctx, self.inner.transpose(false))
}
}
impl<'l, const T: Tu, P: CanApplyToTrf, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::to_trf)]
pub fn to_trf<Lane: M, Element: M>(
self,
address: TrfAddress,
) -> TrfTensor<D, Chip, Cluster, Slice, Lane, Element, B> {
verify_to_trf::<D, Lane, Time, Packet, Element>(&address);
TrfTensor::new(self.inner.transpose(false), address)
}
}
impl<'l, const T: Tu, P: CanApplyToVrf, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::to_vrf)]
pub fn to_vrf<Element: M>(self, address: Address) -> VrfTensor<D, Chip, Cluster, Slice, Element, B> {
VrfTensor::new(self.inner.transpose(false), address)
}
}
pub(crate) fn verify_collect<D: Scalar, Time: M, Packet: M, Time2: M, Packet2: M>() {
let in_packet_bytes = D::size_in_bytes_from_length(Packet::SIZE);
let aligned_bytes = align_up(in_packet_bytes, FLIT_BYTES);
let flit_elements = D::length_from_bytes(FLIT_BYTES);
assert_eq!(
D::size_in_bytes_from_length(Packet2::SIZE),
FLIT_BYTES,
"Collect output packet must be exactly {FLIT_BYTES} bytes (one flit)."
);
let padded = Packet::to_value().replace_padding(D::length_from_bytes(aligned_bytes));
let (in_outer, in_flit) = padded.split_at(flit_elements);
let expected_packet = in_flit.normalize();
let out_packet = Packet2::to_value().normalize();
assert_eq!(
expected_packet, out_packet,
"Collect packet mismatch. Expected: {expected_packet}, got: {out_packet}"
);
let expected_time = Time::to_value().pair(in_outer).normalize();
let out_time = Time2::to_value().normalize();
assert_eq!(
expected_time, out_time,
"Collect time mismatch. Expected: {expected_time}, got: {out_time}"
);
}
pub(crate) fn verify_to_trf<D: Scalar, Lane: M, Time: M, Packet: M, Element: M>(address: &TrfAddress) {
assert!(
[1, 2, 4, 8].contains(&Lane::SIZE),
"Lane::SIZE must be 1, 2, 4, or 8, got {}",
Lane::SIZE
);
let capacity = address.capacity();
let total_trf_bytes = D::size_in_bytes_from_length(Lane::SIZE * Element::SIZE);
assert!(
total_trf_bytes <= capacity,
"TRF data ({} bytes = {} lanes x {} bytes) exceeds register file capacity ({} bytes for {})",
total_trf_bytes,
Lane::SIZE,
D::size_in_bytes_from_length(Element::SIZE),
capacity,
address,
);
let time = Time::to_value();
let (time_outer, time_inner) = time.split_at(exact_div(Time::SIZE, Lane::SIZE).unwrap_or_else(|| {
panic!(
"Lane::SIZE ({}) does not divide Time::SIZE ({})",
Lane::SIZE,
Time::SIZE
)
}));
let time_outer = time_outer.normalize();
let lane = Lane::to_value().normalize();
assert_eq!(
time_outer, lane,
"`to_trf` lane mismatch: time_outer != Lane: {time_outer} != {lane}",
);
let expected_element = time_inner.pair(Packet::to_value()).normalize();
let element = Element::to_value().normalize();
assert_eq!(
expected_element, element,
"`to_trf` element mismatch: [time_inner, Packet] != Element: {expected_element} != {element}",
);
}