use furiosa_mapping::*;
use furiosa_opt_macro::primitive;
use crate::cast::Cast;
use crate::context::*;
use crate::engine::vector::scalar::VeScalar;
use crate::engine::{CanApplyCast, FLIT_BYTES};
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::tu::{Position, TuTensor};
#[derive(Debug)]
pub struct PositionCast;
impl Position for PositionCast {}
pub type CastTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionCast, D, Chip, Cluster, Slice, Time, Packet, B>;
impl<'l, const T: Tu, P: CanApplyCast, D: VeScalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::cast)]
pub fn cast<OutD: Scalar, OutPacket: M>(self) -> CastTensor<'l, T, OutD, Chip, Cluster, Slice, Time, OutPacket, B>
where
D: Cast<OutD>,
{
verify_cast::<D, OutD, Packet, OutPacket>();
CastTensor::new(self.ctx, self.inner.map(|v| v.map(|v| v.cast())).transpose(false))
}
}
fn verify_cast<D: Scalar, OutD: Scalar, InPacket: M, OutPacket: M>() {
assert_eq!(
D::size_in_bytes_from_length(InPacket::SIZE),
FLIT_BYTES,
"Cast input packet must be exactly {FLIT_BYTES} bytes (one flit): \
{} elements = {} bytes",
InPacket::SIZE,
D::size_in_bytes_from_length(InPacket::SIZE),
);
let out_flit_elements = OutD::length_from_bytes(FLIT_BYTES);
let expected_packet = InPacket::to_value().replace_padding(out_flit_elements).normalize();
let out_packet = OutPacket::to_value().normalize();
assert_eq!(
OutD::size_in_bytes_from_length(OutPacket::SIZE),
FLIT_BYTES,
"Cast output packet must be exactly {FLIT_BYTES} bytes (one flit). \
Expected: {expected_packet}, got: {out_packet}",
);
assert_eq!(
expected_packet, out_packet,
"Cast packet mismatch. Expected: {expected_packet}, got: {out_packet}",
);
}