use crate::frame::mmm::*;
use std::hash::{Hash, Hasher};
use std::ops::Mul;
use tract_data::prelude::f16;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Scaler {
pub scale: f32,
pub mult: Option<i32>,
pub shift: isize,
pub policy: RoundingPolicy,
}
impl Eq for Scaler {}
#[allow(clippy::derive_hash_xor_eq)]
impl Hash for Scaler {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
Hash::hash(&self.scale.to_bits(), state)
}
}
impl Scaler {
pub fn new(scale: f32, policy: RoundingPolicy) -> Self {
let (mult, shift) = Self::convert_scale_to_mult_shift(scale);
Self { scale, mult, shift, policy }
}
pub fn as_fused_spec(&self) -> FusedSpec {
if let Some(multiplier) = self.mult {
FusedSpec::QScale(self.shift, self.policy, multiplier)
} else if self.shift > 0 {
FusedSpec::RoundingShiftRight(self.shift as usize, self.policy)
} else {
FusedSpec::ShiftLeft((-self.shift) as usize)
}
}
pub fn from_fuse_params(shift: isize, policy: RoundingPolicy, mult: i32) -> Self {
let scale = mult as f32 * 2f32.powi(-(31 + shift as i32));
Self { scale, mult: Some(mult), shift, policy }
}
#[inline]
fn convert_scale_to_mult_shift(scale: f32) -> (Option<i32>, isize) {
if scale == 0.0 {
return (None, 0);
}
let scale_bits = scale.to_bits();
let current_exponent = (scale_bits >> 23) & 0xff;
let partial_frac = scale_bits & 0x007fffff;
if partial_frac == 0 {
let shift = 127 - current_exponent as isize;
(None, shift)
} else {
let frac = partial_frac | 0x800000;
let half_frac = (frac << 7) as i32;
let shift = 127 - current_exponent as isize - 1;
(Some(half_frac), shift)
}
}
}
impl Mul<f16> for Scaler {
type Output = f16;
#[inline]
fn mul(self, rhs: f16) -> Self::Output {
f16::from_f32(self.scale) * rhs
}
}
impl Mul<f32> for Scaler {
type Output = f32;
#[inline]
fn mul(self, rhs: f32) -> Self::Output {
self.scale * rhs
}
}
impl Mul<f64> for Scaler {
type Output = f64;
#[inline]
fn mul(self, rhs: f64) -> Self::Output {
self.scale as f64 * rhs
}
}
impl Mul<Scaler> for f16 {
type Output = f16;
#[inline]
fn mul(self, rhs: Scaler) -> Self::Output {
rhs * self
}
}
impl Mul<Scaler> for f32 {
type Output = f32;
#[inline]
fn mul(self, rhs: Scaler) -> Self::Output {
rhs * self
}
}
impl Mul<Scaler> for f64 {
type Output = f64;
#[inline]
fn mul(self, rhs: Scaler) -> Self::Output {
rhs * self
}
}
impl Mul<i32> for Scaler {
type Output = i32;
#[inline]
fn mul(self, rhs: i32) -> Self::Output {
let (val, shift) = if let Some(multiplier) = self.mult {
(multiplier as i64 * rhs as i64, self.shift + 31)
} else {
(rhs as i64, self.shift)
};
use RoundingPolicy::*;
if shift > 0 {
let half: i64 = 1 << (shift - 1);
let nudge: i64 = match self.policy {
Zero => -1,
MinusInf => -((val >= 0) as i64),
PlusInf => -((val <= 0) as i64),
Away => 0,
Even => ((val.abs() >> shift) & 0x1) - 1,
Odd => -((val.abs() >> shift) & 0x1),
_ => panic!(),
};
(val.signum() * ((val.abs() + half + nudge) >> shift)) as i32
} else {
(val << -shift) as i32
}
}
}
impl Mul<Scaler> for i32 {
type Output = i32;
#[inline]
fn mul(self, rhs: Scaler) -> Self::Output {
rhs * self
}
}
pub trait ScaleShiftAndRound {
fn q_scale(self, scaler: Scaler) -> Self;
fn q_shl(self, shift: usize) -> Self;
fn q_shr(self, shift: usize, rp: RoundingPolicy) -> Self;
}
impl ScaleShiftAndRound for f64 {
fn q_scale(self, scaler: Scaler) -> Self {
self * scaler
}
fn q_shl(self, shift: usize) -> Self {
self * 2f64.powi(shift as i32)
}
fn q_shr(self, shift: usize, _rp: RoundingPolicy) -> Self {
self * 2f64.powi(-(shift as i32))
}
}
impl ScaleShiftAndRound for f32 {
fn q_scale(self, scaler: Scaler) -> Self {
self * scaler
}
fn q_shl(self, shift: usize) -> Self {
self * 2f32.powi(shift as i32)
}
fn q_shr(self, shift: usize, _rp: RoundingPolicy) -> Self {
self * 2f32.powi(-(shift as i32))
}
}
impl ScaleShiftAndRound for f16 {
fn q_scale(self, scaler: Scaler) -> Self {
self * scaler
}
fn q_shl(self, shift: usize) -> Self {
self * f16::from_f32(2f32.powi(shift as i32))
}
fn q_shr(self, shift: usize, _rp: RoundingPolicy) -> Self {
self * f16::from_f32(2f32.powi(-(shift as i32)))
}
}
impl ScaleShiftAndRound for i32 {
fn q_scale(self, scaler: Scaler) -> Self {
self * scaler
}
fn q_shr(self, shift: usize, rp: RoundingPolicy) -> Self {
use RoundingPolicy::*;
let half: i32 = 1 << (shift - 1);
let nudge: i32 = match rp {
Zero => -1,
MinusInf => -((self >= 0) as i32),
PlusInf => -((self <= 0) as i32),
Away => 0,
Even => ((self.abs() >> shift) & 0x1) - 1,
Odd => -((self.abs() >> shift) & 0x1),
_ => panic!(),
};
self.signum() * ((self.abs() + half + nudge) >> shift)
}
fn q_shl(self, shift: usize) -> Self {
self << shift
}
}
#[cfg(test)]
mod test {
use super::RoundingPolicy::*;
use super::*;
#[test]
fn test_scale_rounding_f32() {
assert_eq!(0f32.q_scale(Scaler::new(0.5, Zero)), 0.0);
assert_eq!(1f32.q_scale(Scaler::new(0.5, Zero)), 0.5);
assert_eq!(2f32.q_scale(Scaler::new(0.5, Zero)), 1.0);
assert_eq!(3f32.q_scale(Scaler::new(0.5, Zero)), 1.5);
assert_eq!((-1f32).q_scale(Scaler::new(0.5, Zero)), -0.5);
assert_eq!((-2f32).q_scale(Scaler::new(0.5, Zero)), -1.0);
assert_eq!((-3f32).q_scale(Scaler::new(0.5, Zero)), -1.5);
}
#[test]
fn test_shift_rounding_zero() {
assert_eq!(0i32.q_shr(1, Zero), 0);
assert_eq!(1i32.q_shr(1, Zero), 0);
assert_eq!(2i32.q_shr(1, Zero), 1);
assert_eq!(3i32.q_shr(1, Zero), 1);
assert_eq!(0i32.q_shr(2, Zero), 0);
assert_eq!(1i32.q_shr(2, Zero), 0);
assert_eq!(2i32.q_shr(2, Zero), 0);
assert_eq!(3i32.q_shr(2, Zero), 1);
assert_eq!(4i32.q_shr(2, Zero), 1);
assert_eq!(5i32.q_shr(2, Zero), 1);
assert_eq!(6i32.q_shr(2, Zero), 1);
assert_eq!((-1i32).q_shr(2, Zero), 0);
assert_eq!((-2i32).q_shr(2, Zero), 0);
assert_eq!((-3i32).q_shr(2, Zero), -1);
assert_eq!((-4i32).q_shr(2, Zero), -1);
assert_eq!((-5i32).q_shr(2, Zero), -1);
assert_eq!((-6i32).q_shr(2, Zero), -1);
}
#[test]
fn test_scale_rounding_zero() {
assert_eq!(0i32.q_scale(Scaler::new(0.5, Zero)), 0);
assert_eq!(1i32.q_scale(Scaler::new(0.5, Zero)), 0);
assert_eq!(2i32.q_scale(Scaler::new(0.5, Zero)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.5, Zero)), 1);
assert_eq!((-1i32).q_scale(Scaler::new(0.5, Zero)), 0);
assert_eq!((-2i32).q_scale(Scaler::new(0.5, Zero)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.5, Zero)), -1);
assert_eq!(2i32.q_scale(Scaler::new(0.25, Zero)), 0);
assert_eq!(3i32.q_scale(Scaler::new(0.25, Zero)), 1);
assert_eq!(4i32.q_scale(Scaler::new(0.25, Zero)), 1);
assert_eq!(5i32.q_scale(Scaler::new(0.25, Zero)), 1);
assert_eq!(6i32.q_scale(Scaler::new(0.25, Zero)), 1);
assert_eq!((-2i32).q_scale(Scaler::new(0.25, Zero)), 0);
assert_eq!((-3i32).q_scale(Scaler::new(0.25, Zero)), -1);
assert_eq!((-4i32).q_scale(Scaler::new(0.25, Zero)), -1);
assert_eq!((-5i32).q_scale(Scaler::new(0.25, Zero)), -1);
assert_eq!((-6i32).q_scale(Scaler::new(0.25, Zero)), -1);
}
#[test]
fn test_shift_rounding_away() {
assert_eq!(0i32.q_shr(1, Away), 0);
assert_eq!(1i32.q_shr(1, Away), 1);
assert_eq!(2i32.q_shr(1, Away), 1);
assert_eq!(3i32.q_shr(1, Away), 2);
assert_eq!(0i32.q_shr(2, Away), 0);
assert_eq!(1i32.q_shr(2, Away), 0);
assert_eq!(2i32.q_shr(2, Away), 1);
assert_eq!(3i32.q_shr(2, Away), 1);
assert_eq!(4i32.q_shr(2, Away), 1);
assert_eq!(5i32.q_shr(2, Away), 1);
assert_eq!(6i32.q_shr(2, Away), 2);
assert_eq!((-1i32).q_shr(2, Away), 0);
assert_eq!((-2i32).q_shr(2, Away), -1);
assert_eq!((-3i32).q_shr(2, Away), -1);
assert_eq!((-4i32).q_shr(2, Away), -1);
assert_eq!((-5i32).q_shr(2, Away), -1);
assert_eq!((-6i32).q_shr(2, Away), -2);
}
#[test]
fn test_scale_rounding_away() {
assert_eq!(0i32.q_scale(Scaler::new(0.5, Away)), 0);
assert_eq!(1i32.q_scale(Scaler::new(0.5, Away)), 1);
assert_eq!(2i32.q_scale(Scaler::new(0.5, Away)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.5, Away)), 2);
assert_eq!((-1i32).q_scale(Scaler::new(0.5, Away)), -1);
assert_eq!((-2i32).q_scale(Scaler::new(0.5, Away)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.5, Away)), -2);
assert_eq!(2i32.q_scale(Scaler::new(0.25, Away)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.25, Away)), 1);
assert_eq!(4i32.q_scale(Scaler::new(0.25, Away)), 1);
assert_eq!(5i32.q_scale(Scaler::new(0.25, Away)), 1);
assert_eq!(6i32.q_scale(Scaler::new(0.25, Away)), 2);
assert_eq!((-2i32).q_scale(Scaler::new(0.25, Away)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.25, Away)), -1);
assert_eq!((-4i32).q_scale(Scaler::new(0.25, Away)), -1);
assert_eq!((-5i32).q_scale(Scaler::new(0.25, Away)), -1);
assert_eq!((-6i32).q_scale(Scaler::new(0.25, Away)), -2);
}
#[test]
fn test_shift_rounding_plus_inf() {
assert_eq!(0i32.q_shr(1, PlusInf), 0);
assert_eq!(1i32.q_shr(1, PlusInf), 1);
assert_eq!(2i32.q_shr(1, PlusInf), 1);
assert_eq!(3i32.q_shr(1, PlusInf), 2);
assert_eq!(0i32.q_shr(2, PlusInf), 0);
assert_eq!(1i32.q_shr(2, PlusInf), 0);
assert_eq!(2i32.q_shr(2, PlusInf), 1);
assert_eq!(3i32.q_shr(2, PlusInf), 1);
assert_eq!(4i32.q_shr(2, PlusInf), 1);
assert_eq!(5i32.q_shr(2, PlusInf), 1);
assert_eq!(6i32.q_shr(2, PlusInf), 2);
assert_eq!((-1i32).q_shr(2, PlusInf), 0);
assert_eq!((-2i32).q_shr(2, PlusInf), 0);
assert_eq!((-3i32).q_shr(2, PlusInf), -1);
assert_eq!((-4i32).q_shr(2, PlusInf), -1);
assert_eq!((-5i32).q_shr(2, PlusInf), -1);
assert_eq!((-6i32).q_shr(2, PlusInf), -1);
}
#[test]
fn test_scale_rounding_plus_inf() {
assert_eq!(0i32.q_scale(Scaler::new(0.5, PlusInf)), 0);
assert_eq!(1i32.q_scale(Scaler::new(0.5, PlusInf)), 1);
assert_eq!(2i32.q_scale(Scaler::new(0.5, PlusInf)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.5, PlusInf)), 2);
assert_eq!((-1i32).q_scale(Scaler::new(0.5, PlusInf)), 0);
assert_eq!((-2i32).q_scale(Scaler::new(0.5, PlusInf)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.5, PlusInf)), -1);
assert_eq!(2i32.q_scale(Scaler::new(0.25, PlusInf)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.25, PlusInf)), 1);
assert_eq!(4i32.q_scale(Scaler::new(0.25, PlusInf)), 1);
assert_eq!(5i32.q_scale(Scaler::new(0.25, PlusInf)), 1);
assert_eq!(6i32.q_scale(Scaler::new(0.25, PlusInf)), 2);
assert_eq!((-2i32).q_scale(Scaler::new(0.25, PlusInf)), 0);
assert_eq!((-3i32).q_scale(Scaler::new(0.25, PlusInf)), -1);
assert_eq!((-4i32).q_scale(Scaler::new(0.25, PlusInf)), -1);
assert_eq!((-5i32).q_scale(Scaler::new(0.25, PlusInf)), -1);
assert_eq!((-6i32).q_scale(Scaler::new(0.25, PlusInf)), -1);
}
#[test]
fn test_shift_rounding_minus_inf() {
assert_eq!(0i32.q_shr(1, MinusInf), 0);
assert_eq!(1i32.q_shr(1, MinusInf), 0);
assert_eq!(2i32.q_shr(1, MinusInf), 1);
assert_eq!(3i32.q_shr(1, MinusInf), 1);
assert_eq!(0i32.q_shr(2, MinusInf), 0);
assert_eq!(1i32.q_shr(2, MinusInf), 0);
assert_eq!(2i32.q_shr(2, MinusInf), 0);
assert_eq!(3i32.q_shr(2, MinusInf), 1);
assert_eq!(4i32.q_shr(2, MinusInf), 1);
assert_eq!(5i32.q_shr(2, MinusInf), 1);
assert_eq!(6i32.q_shr(2, MinusInf), 1);
assert_eq!((-1i32).q_shr(2, MinusInf), 0);
assert_eq!((-2i32).q_shr(2, MinusInf), -1);
assert_eq!((-3i32).q_shr(2, MinusInf), -1);
assert_eq!((-4i32).q_shr(2, MinusInf), -1);
assert_eq!((-5i32).q_shr(2, MinusInf), -1);
assert_eq!((-6i32).q_shr(2, MinusInf), -2);
}
#[test]
fn test_scale_rounding_minus_inf() {
assert_eq!(0i32.q_scale(Scaler::new(0.5, MinusInf)), 0);
assert_eq!(1i32.q_scale(Scaler::new(0.5, MinusInf)), 0);
assert_eq!(2i32.q_scale(Scaler::new(0.5, MinusInf)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.5, MinusInf)), 1);
assert_eq!((-1i32).q_scale(Scaler::new(0.5, MinusInf)), -1);
assert_eq!((-2i32).q_scale(Scaler::new(0.5, MinusInf)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.5, MinusInf)), -2);
assert_eq!(2i32.q_scale(Scaler::new(0.25, MinusInf)), 0);
assert_eq!(3i32.q_scale(Scaler::new(0.25, MinusInf)), 1);
assert_eq!(4i32.q_scale(Scaler::new(0.25, MinusInf)), 1);
assert_eq!(5i32.q_scale(Scaler::new(0.25, MinusInf)), 1);
assert_eq!(6i32.q_scale(Scaler::new(0.25, MinusInf)), 1);
assert_eq!((-2i32).q_scale(Scaler::new(0.25, MinusInf)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.25, MinusInf)), -1);
assert_eq!((-4i32).q_scale(Scaler::new(0.25, MinusInf)), -1);
assert_eq!((-5i32).q_scale(Scaler::new(0.25, MinusInf)), -1);
assert_eq!((-6i32).q_scale(Scaler::new(0.25, MinusInf)), -2);
}
#[test]
fn test_shift_rounding_even() {
assert_eq!(0i32.q_shr(1, Even), 0);
assert_eq!(1i32.q_shr(1, Even), 0);
assert_eq!(2i32.q_shr(1, Even), 1);
assert_eq!(3i32.q_shr(1, Even), 2);
assert_eq!(0i32.q_shr(2, Even), 0);
assert_eq!(1i32.q_shr(2, Even), 0);
assert_eq!(2i32.q_shr(2, Even), 0);
assert_eq!(3i32.q_shr(2, Even), 1);
assert_eq!(4i32.q_shr(2, Even), 1);
assert_eq!(5i32.q_shr(2, Even), 1);
assert_eq!(6i32.q_shr(2, Even), 2);
assert_eq!((-1i32).q_shr(2, Even), 0);
assert_eq!((-2i32).q_shr(2, Even), 0);
assert_eq!((-3i32).q_shr(2, Even), -1);
assert_eq!((-4i32).q_shr(2, Even), -1);
assert_eq!((-5i32).q_shr(2, Even), -1);
assert_eq!((-6i32).q_shr(2, Even), -2);
}
#[test]
fn test_scale_rounding_even() {
assert_eq!(0i32.q_scale(Scaler::new(0.5, Even)), 0);
assert_eq!(1i32.q_scale(Scaler::new(0.5, Even)), 0);
assert_eq!(2i32.q_scale(Scaler::new(0.5, Even)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.5, Even)), 2);
assert_eq!((-1i32).q_scale(Scaler::new(0.5, Even)), 0);
assert_eq!((-2i32).q_scale(Scaler::new(0.5, Even)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.5, Even)), -2);
assert_eq!(2i32.q_scale(Scaler::new(0.25, Even)), 0);
assert_eq!(3i32.q_scale(Scaler::new(0.25, Even)), 1);
assert_eq!(4i32.q_scale(Scaler::new(0.25, Even)), 1);
assert_eq!(5i32.q_scale(Scaler::new(0.25, Even)), 1);
assert_eq!(6i32.q_scale(Scaler::new(0.25, Even)), 2);
assert_eq!((-2i32).q_scale(Scaler::new(0.25, Even)), 0);
assert_eq!((-3i32).q_scale(Scaler::new(0.25, Even)), -1);
assert_eq!((-4i32).q_scale(Scaler::new(0.25, Even)), -1);
assert_eq!((-5i32).q_scale(Scaler::new(0.25, Even)), -1);
assert_eq!((-6i32).q_scale(Scaler::new(0.25, Even)), -2);
}
#[test]
fn test_shift_rounding_odd() {
assert_eq!(0i32.q_shr(1, Odd), 0);
assert_eq!(1i32.q_shr(1, Odd), 1);
assert_eq!(2i32.q_shr(1, Odd), 1);
assert_eq!(3i32.q_shr(1, Odd), 1);
assert_eq!(0i32.q_shr(2, Odd), 0);
assert_eq!(1i32.q_shr(2, Odd), 0);
assert_eq!(2i32.q_shr(2, Odd), 1);
assert_eq!(3i32.q_shr(2, Odd), 1);
assert_eq!(4i32.q_shr(2, Odd), 1);
assert_eq!(5i32.q_shr(2, Odd), 1);
assert_eq!(6i32.q_shr(2, Odd), 1);
assert_eq!((-1i32).q_shr(2, Odd), 0);
assert_eq!((-2i32).q_shr(2, Odd), -1);
assert_eq!((-3i32).q_shr(2, Odd), -1);
assert_eq!((-4i32).q_shr(2, Odd), -1);
assert_eq!((-5i32).q_shr(2, Odd), -1);
assert_eq!((-6i32).q_shr(2, Odd), -1);
}
#[test]
fn test_scale_rounding_odd() {
assert_eq!(0i32.q_scale(Scaler::new(0.5, Odd)), 0);
assert_eq!(1i32.q_scale(Scaler::new(0.5, Odd)), 1);
assert_eq!(2i32.q_scale(Scaler::new(0.5, Odd)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.5, Odd)), 1);
assert_eq!((-1i32).q_scale(Scaler::new(0.5, Odd)), -1);
assert_eq!((-2i32).q_scale(Scaler::new(0.5, Odd)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.5, Odd)), -1);
assert_eq!(2i32.q_scale(Scaler::new(0.25, Odd)), 1);
assert_eq!(3i32.q_scale(Scaler::new(0.25, Odd)), 1);
assert_eq!(4i32.q_scale(Scaler::new(0.25, Odd)), 1);
assert_eq!(5i32.q_scale(Scaler::new(0.25, Odd)), 1);
assert_eq!(6i32.q_scale(Scaler::new(0.25, Odd)), 1);
assert_eq!((-2i32).q_scale(Scaler::new(0.25, Odd)), -1);
assert_eq!((-3i32).q_scale(Scaler::new(0.25, Odd)), -1);
assert_eq!((-4i32).q_scale(Scaler::new(0.25, Odd)), -1);
assert_eq!((-5i32).q_scale(Scaler::new(0.25, Odd)), -1);
assert_eq!((-6i32).q_scale(Scaler::new(0.25, Odd)), -1);
}
}