thermogram 1.1.0

Plastic memory capsule with 4-temperature tensor states (hot/warm/cool/cold), bidirectional transitions, and hash-chained auditability
Documentation
//! Ternary Weight type for embedded SNN plasticity
//!
//! `TernaryWeight` is used by the embedded SNN for discrete STDP state transitions.
//! For bulk signal storage, use `PackedSignal` from `ternary-signal` instead.

use serde::{Deserialize, Serialize};
use std::fmt;

/// Ternary weight: +1, 0, or -1
///
/// This is the fundamental unit of our ternary neural network weights.
/// Uses `repr(i8)` for efficient memory layout and arithmetic.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(i8)]
pub enum TernaryWeight {
    /// Negative connection (-1)
    Neg = -1,
    /// No connection (0)
    Zero = 0,
    /// Positive connection (+1)
    Pos = 1,
}

impl TernaryWeight {
    /// Convert to f32 for computation
    #[inline]
    pub fn to_f32(self) -> f32 {
        self as i8 as f32
    }

    /// Convert to i8
    #[inline]
    pub fn to_i8(self) -> i8 {
        self as i8
    }

    /// Quantize from f32 using threshold
    ///
    /// - `|w| < threshold` → Zero
    /// - `w >= threshold` → Pos
    /// - `w <= -threshold` → Neg
    #[inline]
    pub fn from_f32(value: f32, threshold: f32) -> Self {
        if value >= threshold {
            TernaryWeight::Pos
        } else if value <= -threshold {
            TernaryWeight::Neg
        } else {
            TernaryWeight::Zero
        }
    }

    /// Apply STDP-like state transition
    ///
    /// Strengthening: Zero→Pos or Neg→Zero
    /// Weakening: Pos→Zero or Zero→Neg
    #[inline]
    pub fn strengthen(self) -> Self {
        match self {
            TernaryWeight::Neg => TernaryWeight::Zero,
            TernaryWeight::Zero => TernaryWeight::Pos,
            TernaryWeight::Pos => TernaryWeight::Pos, // Already max
        }
    }

    /// Weaken the weight (opposite of strengthen)
    #[inline]
    pub fn weaken(self) -> Self {
        match self {
            TernaryWeight::Pos => TernaryWeight::Zero,
            TernaryWeight::Zero => TernaryWeight::Neg,
            TernaryWeight::Neg => TernaryWeight::Neg, // Already min
        }
    }

    /// Flip the sign (Pos ↔ Neg, Zero stays Zero)
    #[inline]
    pub fn flip(self) -> Self {
        match self {
            TernaryWeight::Pos => TernaryWeight::Neg,
            TernaryWeight::Neg => TernaryWeight::Pos,
            TernaryWeight::Zero => TernaryWeight::Zero,
        }
    }

    /// Check if this weight is active (non-zero)
    #[inline]
    pub fn is_active(self) -> bool {
        self != TernaryWeight::Zero
    }

    /// 2-bit encoding for packing
    #[inline]
    fn to_2bit(self) -> u8 {
        match self {
            TernaryWeight::Zero => 0b00,
            TernaryWeight::Pos => 0b01,
            TernaryWeight::Neg => 0b10,
        }
    }

    /// Decode from 2-bit value
    #[inline]
    fn from_2bit(bits: u8) -> Self {
        match bits & 0b11 {
            0b00 => TernaryWeight::Zero,
            0b01 => TernaryWeight::Pos,
            0b10 => TernaryWeight::Neg,
            _ => TernaryWeight::Zero, // Reserved → Zero
        }
    }
}

impl Default for TernaryWeight {
    fn default() -> Self {
        TernaryWeight::Zero
    }
}

impl fmt::Debug for TernaryWeight {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            TernaryWeight::Neg => write!(f, "-"),
            TernaryWeight::Zero => write!(f, "0"),
            TernaryWeight::Pos => write!(f, "+"),
        }
    }
}

impl fmt::Display for TernaryWeight {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            TernaryWeight::Neg => write!(f, "-1"),
            TernaryWeight::Zero => write!(f, " 0"),
            TernaryWeight::Pos => write!(f, "+1"),
        }
    }
}

impl From<i8> for TernaryWeight {
    fn from(v: i8) -> Self {
        match v.signum() {
            -1 => TernaryWeight::Neg,
            0 => TernaryWeight::Zero,
            1 => TernaryWeight::Pos,
            _ => TernaryWeight::Zero,
        }
    }
}

impl From<TernaryWeight> for i8 {
    fn from(w: TernaryWeight) -> i8 {
        w as i8
    }
}

// =============================================================================
// TESTS
// =============================================================================

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ternary_weight_conversion() {
        assert_eq!(TernaryWeight::Pos.to_f32(), 1.0);
        assert_eq!(TernaryWeight::Zero.to_f32(), 0.0);
        assert_eq!(TernaryWeight::Neg.to_f32(), -1.0);
    }

    #[test]
    fn test_ternary_quantization() {
        let threshold = 0.5;

        assert_eq!(TernaryWeight::from_f32(0.8, threshold), TernaryWeight::Pos);
        assert_eq!(TernaryWeight::from_f32(0.3, threshold), TernaryWeight::Zero);
        assert_eq!(
            TernaryWeight::from_f32(-0.3, threshold),
            TernaryWeight::Zero
        );
        assert_eq!(TernaryWeight::from_f32(-0.8, threshold), TernaryWeight::Neg);
    }

    #[test]
    fn test_ternary_strengthen_weaken() {
        assert_eq!(TernaryWeight::Zero.strengthen(), TernaryWeight::Pos);
        assert_eq!(TernaryWeight::Neg.strengthen(), TernaryWeight::Zero);
        assert_eq!(TernaryWeight::Pos.strengthen(), TernaryWeight::Pos);

        assert_eq!(TernaryWeight::Pos.weaken(), TernaryWeight::Zero);
        assert_eq!(TernaryWeight::Zero.weaken(), TernaryWeight::Neg);
        assert_eq!(TernaryWeight::Neg.weaken(), TernaryWeight::Neg);
    }
}