use furiosa_mapping::*;
use furiosa_opt_macro::primitive;
use crate::cast::Cast;
use crate::context::*;
use crate::engine::{CanApplyCommitCast, CanApplyCommitTrim, CanApplyCommitValidCountPack};
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::tu::{Position, TuTensor};
pub(super) use furiosa_opt_lower::COMMIT_VALID_PACKET_SIZES;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum Activation {
#[default]
None,
Relu,
}
#[derive(Debug)]
pub struct PositionCommitTrim;
impl Position for PositionCommitTrim {}
pub type CommitTrimTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionCommitTrim, D, Chip, Cluster, Slice, Time, Packet, B>;
#[derive(Debug)]
pub struct PositionCommitCast;
impl Position for PositionCommitCast {}
pub type CommitCastTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionCommitCast, D, Chip, Cluster, Slice, Time, Packet, B>;
#[derive(Debug)]
pub struct PositionCommitValidCountPack;
impl Position for PositionCommitValidCountPack {}
pub type CommitValidCountPackTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionCommitValidCountPack, D, Chip, Cluster, Slice, Time, Packet, B>;
impl<'l, const T: Tu, P: CanApplyCommitTrim, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::commit_trim)]
pub fn commit_trim<OutPacket: M>(self) -> CommitTrimTensor<'l, T, D, Chip, Cluster, Slice, Time, OutPacket, B> {
verify_commit_trim::<D, Packet, OutPacket>();
CommitTrimTensor::new(self.ctx, self.inner.transpose(false))
}
}
impl<'l, const T: Tu, P: CanApplyCommitCast, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::commit_cast)]
pub fn commit_cast<OutD: Scalar>(
self,
_activation: Activation,
) -> CommitCastTensor<'l, T, OutD, Chip, Cluster, Slice, Time, Packet, B>
where
D: Cast<OutD>,
{
verify_commit_cast::<D, OutD>();
CommitCastTensor::new(self.ctx, self.inner.map(|v| v.map(|v| v.cast())))
}
}
impl<
'l,
const T: Tu,
P: CanApplyCommitValidCountPack,
D: Scalar,
Chip: M,
Cluster: M,
Slice: M,
Time: M,
Packet: M,
B: Backend,
> TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::commit_valid_count_pack)]
pub fn commit_valid_count_pack(
self,
_valid_count: usize,
) -> CommitValidCountPackTensor<'l, T, D, Chip, Cluster, Slice, Time, Packet, B> {
verify_commit_valid_count_pack::<D, Time, Packet>();
CommitValidCountPackTensor::new(self.ctx, self.inner.transpose(false))
}
}
pub(crate) fn verify_commit_trim<D: Scalar, Packet: M, OutPacket: M>() {
let out_packet_bytes = D::size_in_bytes_from_length(OutPacket::SIZE);
assert!(
COMMIT_VALID_PACKET_SIZES.contains(&out_packet_bytes),
"commit_trim output packet must be one of {COMMIT_VALID_PACKET_SIZES:?} bytes, got {out_packet_bytes}",
);
let out_packet = OutPacket::to_value();
let expected_packet = Packet::to_value();
assert!(
out_packet.is_resize_of(&expected_packet),
"commit_trim packet mismatch. Expected {expected_packet} or a trimming of it, got {out_packet}",
);
}
#[allow(clippy::extra_unused_type_parameters)]
fn verify_commit_cast<D: Scalar, OutD: Scalar>() {
todo!("commit_cast is not yet implemented")
}
#[allow(clippy::extra_unused_type_parameters)]
fn verify_commit_valid_count_pack<D: Scalar, Time: M, Packet: M>() {
todo!("commit_valid_count_pack is not yet implemented")
}
#[cfg(test)]
mod tests {
use furiosa_mapping::*;
use super::verify_commit_trim;
use crate::scalar::bf16;
mod valid {
use super::*;
axes![N = 8];
#[test]
fn full_trim() {
verify_commit_trim::<i8, m![N # 32], m![N]>();
}
#[test]
fn partial_trim() {
verify_commit_trim::<i8, m![N # 32], m![N # 16]>();
}
#[test]
fn no_trim() {
verify_commit_trim::<i8, m![N # 32], m![N # 32]>();
}
#[test]
fn bf16() {
verify_commit_trim::<bf16, m![N # 16], m![N]>();
}
#[test]
fn f32() {
verify_commit_trim::<f32, m![N # 8], m![N]>();
}
#[test]
fn single_time_step() {
verify_commit_trim::<i8, m![N # 32], m![N # 8]>();
}
#[test]
fn non_padding_resize() {
verify_commit_trim::<bf16, m![N # 16], m![N = 4]>();
}
}
mod invalid {
use super::*;
axes![N = 8, X = 8];
}
}