use std::fmt::{self, Display, Formatter};
use furiosa_mapping::{Atom, Ident, M};
use furiosa_opt_macro::primitive;
use smart_default::SmartDefault;
use crate::scalar::Opt;
use crate::tensor::Tensor;
use super::scalar::VeScalar;
#[primitive(ve::TagMode)]
#[derive(Debug, Clone, SmartDefault)]
pub enum TagMode {
#[default]
Zero,
AxisToggle {
axis: Ident,
},
ValidCount,
Comparison([InputCmp; 4]),
Vrf,
}
impl Display for TagMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Zero => write!(f, "TagMode::Zero"),
Self::AxisToggle { axis } => write!(f, "TagMode::AxisToggle {{ axis: {axis} }}"),
Self::ValidCount => write!(f, "TagMode::ValidCount"),
Self::Comparison(input_cmps) => {
write!(f, "TagMode::Comparison(")?;
for (i, cmp) in input_cmps.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{cmp}")?;
}
write!(f, ")")
}
Self::Vrf => write!(f, "TagMode::Vrf"),
}
}
}
#[derive(Debug, Clone)]
pub enum InputCmp {
I32(InputCmpI32),
F32(InputCmpF32),
}
impl Display for InputCmp {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::I32(input_cmp_i32) => write!(f, "{input_cmp_i32}"),
Self::F32(input_cmp_f32) => write!(f, "{input_cmp_f32}"),
}
}
}
#[derive(Debug, Clone)]
pub enum InputCmpI32 {
Equal {
boundary: i32,
},
Less {
boundary: i32,
},
Greater {
boundary: i32,
},
LessUnsigned {
boundary: i32,
},
GreaterUnsigned {
boundary: i32,
},
True,
False,
}
impl Display for InputCmpI32 {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Equal { boundary } => write!(f, "={boundary}"),
Self::Less { boundary } => write!(f, "<{boundary}"),
Self::Greater { boundary } => write!(f, ">{boundary}"),
Self::LessUnsigned { boundary } => write!(f, "<u{boundary}"),
Self::GreaterUnsigned { boundary } => write!(f, ">u{boundary}"),
Self::True => write!(f, "true"),
Self::False => write!(f, "false"),
}
}
}
#[derive(Debug, Clone)]
pub enum InputCmpF32 {
Equal {
boundary: f32,
},
Less {
boundary: f32,
},
Greater {
boundary: f32,
},
LessUnsigned {
boundary: f32,
},
GreaterUnsigned {
boundary: f32,
},
True,
False,
}
impl Display for InputCmpF32 {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Equal { boundary } => write!(f, "={boundary}"),
Self::Less { boundary } => write!(f, "<{boundary}"),
Self::Greater { boundary } => write!(f, ">{boundary}"),
Self::LessUnsigned { boundary } => write!(f, "<u{boundary}"),
Self::GreaterUnsigned { boundary } => write!(f, ">u{boundary}"),
Self::True => write!(f, "true"),
Self::False => write!(f, "false"),
}
}
}
impl InputCmpI32 {
pub fn matches(&self, x: i32) -> bool {
match self {
InputCmpI32::Equal { boundary } => x == *boundary,
InputCmpI32::Less { boundary } => x < *boundary,
InputCmpI32::Greater { boundary } => x > *boundary,
InputCmpI32::LessUnsigned { boundary } => (x as u32) < (*boundary as u32),
InputCmpI32::GreaterUnsigned { boundary } => (x as u32) > (*boundary as u32),
InputCmpI32::True => true,
InputCmpI32::False => false,
}
}
}
impl InputCmpF32 {
pub fn matches(&self, x: f32) -> bool {
match self {
InputCmpF32::Equal { boundary } => x == *boundary,
InputCmpF32::Less { boundary } => x < *boundary,
InputCmpF32::Greater { boundary } => x > *boundary,
InputCmpF32::LessUnsigned { boundary } => {
let x_bits = x.to_bits();
let boundary_bits = boundary.to_bits();
x_bits < boundary_bits
}
InputCmpF32::GreaterUnsigned { boundary } => {
let x_bits = x.to_bits();
let boundary_bits = boundary.to_bits();
x_bits > boundary_bits
}
InputCmpF32::True => true,
InputCmpF32::False => false,
}
}
}
impl InputCmp {
pub fn matches<D: VeScalar>(&self, x: D) -> bool {
use std::any::TypeId;
match self {
InputCmp::I32(cmp) => {
if TypeId::of::<D>() == TypeId::of::<i32>() {
unsafe {
let x_i32 = std::mem::transmute_copy::<D, i32>(&x);
cmp.matches(x_i32)
}
} else {
panic!("Type mismatch: InputCmp::I32 used with f32 data")
}
}
InputCmp::F32(cmp) => {
if TypeId::of::<D>() == TypeId::of::<f32>() {
unsafe {
let x_f32 = std::mem::transmute_copy::<D, f32>(&x);
cmp.matches(x_f32)
}
} else {
panic!("Type mismatch: InputCmp::F32 used with i32 data")
}
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GroupId {
Zero,
One,
}
impl GroupId {
pub fn bit_value(&self) -> u8 {
match self {
GroupId::Zero => 0,
GroupId::One => 1,
}
}
}
#[primitive(ve::TagFilter)]
#[derive(Debug, Clone, Default)]
pub enum TagFilter {
Group {
id: GroupId,
},
#[default]
All,
}
impl TagFilter {
pub fn matches(&self, exec_id: Opt<u8>) -> bool {
match (self, exec_id) {
(_, Opt::Uninit) => false,
(TagFilter::All, Opt::Init(_)) => true,
(TagFilter::Group { id }, Opt::Init(eid_val)) => ((eid_val >> 3) & 1) == id.bit_value(),
}
}
}
impl From<GroupId> for TagFilter {
fn from(id: GroupId) -> Self {
TagFilter::Group { id }
}
}
pub fn apply_branch_config<D: VeScalar, Mapping: M>(
data: &Tensor<D, Mapping>,
config: &TagMode,
) -> Tensor<u8, Mapping> {
match config {
TagMode::Zero => data.map(|_| Opt::Init(0u8)),
TagMode::AxisToggle { axis } => Tensor::from_fn(|axes, idx| {
let axis_pos = axes.iter().position(|term| {
if let Atom::Symbol { symbol, .. } = &term.inner {
symbol == axis
} else {
false
}
});
if let Some(pos) = axis_pos {
let axis_idx = idx[pos];
let group_id = (axis_idx % 2) as u8;
let exec_id = group_id << 3;
Opt::Init(exec_id)
} else {
Opt::Init(0u8)
}
}),
TagMode::ValidCount => todo!(),
TagMode::Vrf => todo!("TagMode::Vrf: load execution IDs from VRF (GenBranch::WithLog)"),
TagMode::Comparison(cmps) => data.map(|x| match x {
Opt::Init(x) => {
let mut exec_id: u8 = 0;
for (bit_pos, cmp) in cmps.iter().enumerate() {
let bit = if cmp.matches(*x) { 0x1 } else { 0x0 };
exec_id |= bit << bit_pos;
}
Opt::Init(exec_id)
}
Opt::Uninit => Opt::Uninit,
}),
}
}