use std::marker::PhantomData;
use furiosa_mapping::{M, Pair, m};
use furiosa_opt_macro::primitive;
use crate::{
array_vec::ArrayVec,
engine::vector::{MAX_TAGS, scalar::VeScalar},
prelude::{GroupId, TagFilter, VrfTensor},
tensor::Tensor,
};
#[primitive(op::VeRhs)]
#[derive(Debug, Clone)]
pub enum VeRhs<D: VeScalar, TargetMapping: M> {
Const {
v: D,
},
Vrf {
data: Tensor<D, TargetMapping>,
},
Stash,
}
impl<D: VeScalar, TargetMapping: M> VeRhs<D, TargetMapping> {
#[primitive(op::VeRhs::constant)]
pub fn constant(v: D) -> Self {
VeRhs::Const { v }
}
#[primitive(op::VeRhs::vrf)]
pub fn vrf<Chip: M, Cluster: M, Slice: M, Element: M>(vrf: &VrfTensor<D, Chip, Cluster, Slice, Element>) -> Self {
let transposed = vrf.inner.transpose::<TargetMapping>(true);
VeRhs::Vrf { data: transposed }
}
}
impl<TargetMapping: M> From<i32> for VeRhs<i32, TargetMapping> {
fn from(v: i32) -> Self {
VeRhs::Const { v }
}
}
impl<TargetMapping: M> From<f32> for VeRhs<f32, TargetMapping> {
fn from(v: f32) -> Self {
VeRhs::Const { v }
}
}
impl<D: VeScalar, TargetMapping: M> From<Stash> for VeRhs<D, TargetMapping> {
fn from(_: Stash) -> Self {
VeRhs::Stash
}
}
impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, Element: M, TargetMapping: M>
From<&VrfTensor<D, Chip, Cluster, Slice, Element>> for VeRhs<D, TargetMapping>
{
fn from(vrf: &VrfTensor<D, Chip, Cluster, Slice, Element>) -> Self {
VeRhs::vrf(vrf)
}
}
#[derive(Debug, Clone)]
pub struct StashOperand<D: VeScalar> {
pub(crate) tag_filter: TagFilter,
_phantom: PhantomData<D>,
}
impl<D: VeScalar> StashOperand<D> {
pub(crate) fn always() -> Self {
Self {
tag_filter: TagFilter::All,
_phantom: PhantomData,
}
}
#[expect(dead_code)]
pub(crate) fn group(id: GroupId) -> Self {
Self {
tag_filter: TagFilter::Group { id },
_phantom: PhantomData,
}
}
}
#[derive(Debug, Clone)]
pub struct OperandTagValue<D: VeScalar, TargetMapping: M, Operand1: Copy> {
pub operand0: VeRhs<D, TargetMapping>,
pub operand1: Operand1,
pub tag_filter: TagFilter,
}
#[primitive(op::BinaryOperandTag)]
pub type BinaryOperandTag<D, TargetMapping> = OperandTagValue<D, TargetMapping, ()>;
pub type TernaryOperandTag<Mapping> = OperandTagValue<f32, Mapping, f32>;
impl<D: VeScalar, TargetMapping: M> OperandTagValue<D, TargetMapping, ()> {
#[primitive(op::BinaryOperandTag::always)]
pub fn always(operand0: VeRhs<D, TargetMapping>) -> Self {
Self {
operand0,
operand1: (),
tag_filter: TagFilter::All,
}
}
pub fn group(operand0: VeRhs<D, TargetMapping>, id: GroupId) -> Self {
Self {
operand0,
operand1: (),
tag_filter: TagFilter::Group { id },
}
}
pub fn stash_always() -> Self {
Self {
operand0: VeRhs::Stash,
operand1: (),
tag_filter: TagFilter::All,
}
}
pub fn stash_group(id: GroupId) -> Self {
Self {
operand0: VeRhs::Stash,
operand1: (),
tag_filter: TagFilter::Group { id },
}
}
pub fn is_stash(&self) -> bool {
matches!(self.operand0, VeRhs::Stash)
}
}
impl<Mapping: M> OperandTagValue<f32, Mapping, f32> {
pub fn always(operand0: VeRhs<f32, Mapping>, operand1: f32) -> Self {
Self {
operand0,
operand1,
tag_filter: TagFilter::All,
}
}
pub fn group(operand0: VeRhs<f32, Mapping>, operand1: f32, id: GroupId) -> Self {
Self {
operand0,
operand1,
tag_filter: TagFilter::Group { id },
}
}
}
pub trait OperandTag<D: VeScalar, Mapping: M> {
type Operand1: Copy;
fn operand0(&self) -> &VeRhs<D, Mapping>;
fn operand1(&self) -> Self::Operand1;
fn tag_filter(&self) -> &TagFilter;
}
impl<D: VeScalar, Mapping: M, Operand1: Copy> OperandTag<D, Mapping> for OperandTagValue<D, Mapping, Operand1> {
type Operand1 = Operand1;
fn operand0(&self) -> &VeRhs<D, Mapping> {
&self.operand0
}
fn operand1(&self) -> Operand1 {
self.operand1
}
fn tag_filter(&self) -> &TagFilter {
&self.tag_filter
}
}
impl<R, Mapping: M> From<(R, f32)> for TernaryOperandTag<Mapping>
where
R: Into<VeRhs<f32, Mapping>>,
{
fn from((operand0, operand1): (R, f32)) -> Self {
TernaryOperandTag::always(operand0.into(), operand1)
}
}
impl<R, B, Mapping: M> From<((R, f32), B)> for TernaryOperandTag<Mapping>
where
R: Into<VeRhs<f32, Mapping>>,
B: Into<TagFilter>,
{
fn from(((operand0, operand1), branch): ((R, f32), B)) -> Self {
TernaryOperandTag {
operand0: operand0.into(),
operand1,
tag_filter: branch.into(),
}
}
}
pub trait IntoTernaryOperandTags<TargetMapping: M> {
fn into_ternary_operands(self) -> ArrayVec<TernaryOperandTag<TargetMapping>, MAX_TAGS>;
}
impl<T, TargetMapping: M> IntoTernaryOperandTags<TargetMapping> for T
where
T: Into<TernaryOperandTag<TargetMapping>>,
{
fn into_ternary_operands(self) -> ArrayVec<TernaryOperandTag<TargetMapping>, MAX_TAGS> {
ArrayVec::new([self.into()])
}
}
impl<TargetMapping: M, const N: usize> IntoTernaryOperandTags<TargetMapping> for [TernaryOperandTag<TargetMapping>; N] {
fn into_ternary_operands(self) -> ArrayVec<TernaryOperandTag<TargetMapping>, MAX_TAGS> {
let always_count = self.iter().filter(|op| matches!(op.tag_filter, TagFilter::All)).count();
assert!(
always_count <= 1,
"Multiple All operands are not allowed (found {always_count})"
);
ArrayVec::new(self)
}
}
impl<TargetMapping: M> IntoTernaryOperandTags<TargetMapping> for ArrayVec<TernaryOperandTag<TargetMapping>, MAX_TAGS> {
fn into_ternary_operands(self) -> ArrayVec<TernaryOperandTag<TargetMapping>, MAX_TAGS> {
self
}
}
impl<R, D: VeScalar, Mapping: M> From<R> for BinaryOperandTag<D, Mapping>
where
R: Into<VeRhs<D, Mapping>>,
{
fn from(rhs: R) -> Self {
BinaryOperandTag::always(rhs.into())
}
}
impl<R, B, D: VeScalar, Mapping: M> From<(R, B)> for BinaryOperandTag<D, Mapping>
where
R: Into<VeRhs<D, Mapping>>,
B: Into<TagFilter>,
{
fn from((rhs, branch): (R, B)) -> Self {
BinaryOperandTag {
operand0: rhs.into(),
operand1: (),
tag_filter: branch.into(),
}
}
}
pub trait IntoOperands<D: VeScalar, TargetMapping: M> {
fn into_operands(self) -> ArrayVec<BinaryOperandTag<D, TargetMapping>, MAX_TAGS>;
}
impl<T, D: VeScalar, TargetMapping: M> IntoOperands<D, TargetMapping> for T
where
T: Into<BinaryOperandTag<D, TargetMapping>>,
{
fn into_operands(self) -> ArrayVec<BinaryOperandTag<D, TargetMapping>, MAX_TAGS> {
ArrayVec::new([self.into()])
}
}
impl<D: VeScalar, TargetMapping: M> IntoOperands<D, TargetMapping>
for ArrayVec<BinaryOperandTag<D, TargetMapping>, MAX_TAGS>
{
fn into_operands(self) -> ArrayVec<BinaryOperandTag<D, TargetMapping>, MAX_TAGS> {
self
}
}
impl<D: VeScalar, TargetMapping: M, const N: usize> IntoOperands<D, TargetMapping>
for [BinaryOperandTag<D, TargetMapping>; N]
{
fn into_operands(self) -> ArrayVec<BinaryOperandTag<D, TargetMapping>, MAX_TAGS> {
let always_count = self
.iter()
.filter(|op| matches!(op.tag_filter(), TagFilter::All))
.count();
assert!(
always_count <= 1,
"Multiple All operands are not allowed (found {always_count})"
);
ArrayVec::new(self)
}
}
#[primitive(op::Stash)]
#[derive(Debug, Clone, Copy)]
pub struct Stash;
#[derive(Debug)]
pub enum VeOperand<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> {
Const(D),
Vrf(&'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>),
Stash(StashOperand<D>),
}
impl<Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<i32> for VeOperand<'_, i32, Chip, Cluster, Slice, VrfMapping> {
fn from(v: i32) -> Self {
VeOperand::Const(v)
}
}
impl<Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<f32> for VeOperand<'_, f32, Chip, Cluster, Slice, VrfMapping> {
fn from(v: f32) -> Self {
VeOperand::Const(v)
}
}
impl<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M>
From<&'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>> for VeOperand<'a, D, Chip, Cluster, Slice, VrfMapping>
{
fn from(vrf: &'a VrfTensor<D, Chip, Cluster, Slice, VrfMapping>) -> Self {
VeOperand::Vrf(vrf)
}
}
impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<StashOperand<D>>
for VeOperand<'_, D, Chip, Cluster, Slice, VrfMapping>
{
fn from(stash: StashOperand<D>) -> Self {
VeOperand::Stash(stash)
}
}
impl<D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> From<Stash>
for VeOperand<'_, D, Chip, Cluster, Slice, VrfMapping>
{
fn from(_: Stash) -> Self {
VeOperand::Stash(StashOperand::always())
}
}
impl<'a, D: VeScalar, Chip: M, Cluster: M, Slice: M, VrfMapping: M> VeOperand<'a, D, Chip, Cluster, Slice, VrfMapping> {
pub fn into_branch_operands<Time: M, Packet: M>(
self,
) -> ArrayVec<BinaryOperandTag<D, m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }]>, MAX_TAGS> {
type TargetShape<Chip, Cluster, Slice, Time, Packet> =
m![{ Chip }, { Cluster }, { Slice }, { Time }, { Packet }];
match self {
VeOperand::Const(v) => ArrayVec::new([BinaryOperandTag::always(VeRhs::Const { v })]),
VeOperand::Vrf(vrf) => {
let vrf_operand = VeRhs::<D, TargetShape<Chip, Cluster, Slice, Time, Packet>>::vrf(vrf);
ArrayVec::new([BinaryOperandTag::always(vrf_operand)])
}
VeOperand::Stash(stash) => ArrayVec::new([BinaryOperandTag {
operand0: VeRhs::Stash,
operand1: (),
tag_filter: stash.tag_filter,
}]),
}
}
}
pub type GroupOperand<D, Mapping> = Option<BinaryOperandTag<D, Mapping>>;
pub trait IntoGroupOperand<D: VeScalar, Mapping: M> {
fn into_group_operand(self) -> GroupOperand<D, Mapping>;
}
impl<D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for () {
fn into_group_operand(self) -> GroupOperand<D, Mapping> {
None
}
}
impl<D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for Option<BinaryOperandTag<D, Mapping>> {
fn into_group_operand(self) -> GroupOperand<D, Mapping> {
self
}
}
impl<T, D: VeScalar, Mapping: M> IntoGroupOperand<D, Mapping> for T
where
T: Into<BinaryOperandTag<D, Mapping>>,
{
fn into_group_operand(self) -> GroupOperand<D, Mapping> {
Some(self.into())
}
}
pub type GroupTernaryOperandTag<Mapping> = Option<TernaryOperandTag<Mapping>>;
pub trait IntoGroupTernaryOperandTag<Mapping: M> {
fn into_group_ternary_operand(self) -> GroupTernaryOperandTag<Mapping>;
}
impl<Mapping: M> IntoGroupTernaryOperandTag<Mapping> for () {
fn into_group_ternary_operand(self) -> GroupTernaryOperandTag<Mapping> {
None
}
}
impl<Mapping: M> IntoGroupTernaryOperandTag<Mapping> for Option<TernaryOperandTag<Mapping>> {
fn into_group_ternary_operand(self) -> GroupTernaryOperandTag<Mapping> {
self
}
}
impl<T, Mapping: M> IntoGroupTernaryOperandTag<Mapping> for T
where
T: Into<TernaryOperandTag<Mapping>>,
{
fn into_group_ternary_operand(self) -> GroupTernaryOperandTag<Mapping> {
Some(self.into())
}
}