use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(i8)]
pub enum TernaryWeight {
Neg = -1,
Zero = 0,
Pos = 1,
}
impl TernaryWeight {
#[inline]
pub fn to_f32(self) -> f32 {
self as i8 as f32
}
#[inline]
pub fn to_i8(self) -> i8 {
self as i8
}
#[inline]
pub fn from_f32(value: f32, threshold: f32) -> Self {
if value >= threshold {
TernaryWeight::Pos
} else if value <= -threshold {
TernaryWeight::Neg
} else {
TernaryWeight::Zero
}
}
#[inline]
pub fn strengthen(self) -> Self {
match self {
TernaryWeight::Neg => TernaryWeight::Zero,
TernaryWeight::Zero => TernaryWeight::Pos,
TernaryWeight::Pos => TernaryWeight::Pos, }
}
#[inline]
pub fn weaken(self) -> Self {
match self {
TernaryWeight::Pos => TernaryWeight::Zero,
TernaryWeight::Zero => TernaryWeight::Neg,
TernaryWeight::Neg => TernaryWeight::Neg, }
}
#[inline]
pub fn flip(self) -> Self {
match self {
TernaryWeight::Pos => TernaryWeight::Neg,
TernaryWeight::Neg => TernaryWeight::Pos,
TernaryWeight::Zero => TernaryWeight::Zero,
}
}
#[inline]
pub fn is_active(self) -> bool {
self != TernaryWeight::Zero
}
#[inline]
fn to_2bit(self) -> u8 {
match self {
TernaryWeight::Zero => 0b00,
TernaryWeight::Pos => 0b01,
TernaryWeight::Neg => 0b10,
}
}
#[inline]
fn from_2bit(bits: u8) -> Self {
match bits & 0b11 {
0b00 => TernaryWeight::Zero,
0b01 => TernaryWeight::Pos,
0b10 => TernaryWeight::Neg,
_ => TernaryWeight::Zero, }
}
}
impl Default for TernaryWeight {
fn default() -> Self {
TernaryWeight::Zero
}
}
impl fmt::Debug for TernaryWeight {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TernaryWeight::Neg => write!(f, "-"),
TernaryWeight::Zero => write!(f, "0"),
TernaryWeight::Pos => write!(f, "+"),
}
}
}
impl fmt::Display for TernaryWeight {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TernaryWeight::Neg => write!(f, "-1"),
TernaryWeight::Zero => write!(f, " 0"),
TernaryWeight::Pos => write!(f, "+1"),
}
}
}
impl From<i8> for TernaryWeight {
fn from(v: i8) -> Self {
match v.signum() {
-1 => TernaryWeight::Neg,
0 => TernaryWeight::Zero,
1 => TernaryWeight::Pos,
_ => TernaryWeight::Zero,
}
}
}
impl From<TernaryWeight> for i8 {
fn from(w: TernaryWeight) -> i8 {
w as i8
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ternary_weight_conversion() {
assert_eq!(TernaryWeight::Pos.to_f32(), 1.0);
assert_eq!(TernaryWeight::Zero.to_f32(), 0.0);
assert_eq!(TernaryWeight::Neg.to_f32(), -1.0);
}
#[test]
fn test_ternary_quantization() {
let threshold = 0.5;
assert_eq!(TernaryWeight::from_f32(0.8, threshold), TernaryWeight::Pos);
assert_eq!(TernaryWeight::from_f32(0.3, threshold), TernaryWeight::Zero);
assert_eq!(
TernaryWeight::from_f32(-0.3, threshold),
TernaryWeight::Zero
);
assert_eq!(TernaryWeight::from_f32(-0.8, threshold), TernaryWeight::Neg);
}
#[test]
fn test_ternary_strengthen_weaken() {
assert_eq!(TernaryWeight::Zero.strengthen(), TernaryWeight::Pos);
assert_eq!(TernaryWeight::Neg.strengthen(), TernaryWeight::Zero);
assert_eq!(TernaryWeight::Pos.strengthen(), TernaryWeight::Pos);
assert_eq!(TernaryWeight::Pos.weaken(), TernaryWeight::Zero);
assert_eq!(TernaryWeight::Zero.weaken(), TernaryWeight::Neg);
assert_eq!(TernaryWeight::Neg.weaken(), TernaryWeight::Neg);
}
}