use crate::ir::Opcode;
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);
impl core::fmt::Debug for Cost {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if *self == Cost::infinity() {
write!(f, "Cost::Infinite")
} else {
f.debug_struct("Cost::Finite")
.field("op_cost", &self.op_cost())
.field("depth", &self.depth())
.finish()
}
}
}
impl Ord for Cost {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl PartialOrd for Cost {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Cost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;
pub(crate) fn infinity() -> Cost {
Cost(u32::MAX)
}
pub(crate) fn zero() -> Cost {
Cost(0)
}
fn new(opcode_cost: u32, depth: u8) -> Cost {
if opcode_cost >= Self::MAX_OP_COST {
Self::infinity()
} else {
Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
}
}
fn depth(&self) -> u8 {
let depth = self.0 & Self::DEPTH_MASK;
u8::try_from(depth).unwrap()
}
fn op_cost(&self) -> u32 {
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
}
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
Cost::new(c.op_cost(), c.depth().saturating_add(1))
}
}
impl std::iter::Sum<Cost> for Cost {
fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
iter.fold(Self::zero(), |a, b| a + b)
}
}
impl std::default::Default for Cost {
fn default() -> Cost {
Cost::zero()
}
}
impl std::ops::Add<Cost> for Cost {
type Output = Cost;
fn add(self, other: Cost) -> Cost {
let op_cost = self.op_cost().saturating_add(other.op_cost());
let depth = std::cmp::max(self.depth(), other.depth());
Cost::new(op_cost, depth)
}
}
fn pure_op_cost(op: Opcode) -> Cost {
match op {
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost::new(2, 0)
}
Opcode::Iadd
| Opcode::Isub
| Opcode::Band
| Opcode::Bor
| Opcode::Bxor
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost::new(3, 0),
_ => Cost::new(4, 0),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_cost() {
let a = Cost::new(5, 2);
let b = Cost::new(37, 3);
assert_eq!(a + b, Cost::new(42, 3));
assert_eq!(b + a, Cost::new(42, 3));
}
#[test]
fn add_infinity() {
let a = Cost::new(5, 2);
let b = Cost::infinity();
assert_eq!(a + b, Cost::infinity());
assert_eq!(b + a, Cost::infinity());
}
#[test]
fn op_cost_saturates_to_infinity() {
let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
let b = Cost::new(11, 2);
assert_eq!(a + b, Cost::infinity());
assert_eq!(b + a, Cost::infinity());
}
#[test]
fn depth_saturates_to_max_depth() {
let a = Cost::new(10, u8::MAX);
let b = Cost::new(10, 1);
assert_eq!(
Cost::of_pure_op(Opcode::Iconst, [a, b]),
Cost::new(21, u8::MAX)
);
assert_eq!(
Cost::of_pure_op(Opcode::Iconst, [b, a]),
Cost::new(21, u8::MAX)
);
}
}