poulpy_core/
dist.rs

1use std::io::{Read, Result, Write};
2
3#[derive(Clone, Copy, Debug)]
4pub enum Distribution {
5    TernaryFixed(usize), // Ternary with fixed Hamming weight
6    TernaryProb(f64),    // Ternary with probabilistic Hamming weight
7    BinaryFixed(usize),  // Binary with fixed Hamming weight
8    BinaryProb(f64),     // Binary with probabilistic Hamming weight
9    BinaryBlock(usize),  // Binary split in block of size 2^k
10    ZERO,                // Debug mod
11    NONE,                // Unitialized
12}
13
14const TAG_TERNARY_FIXED: u8 = 0;
15const TAG_TERNARY_PROB: u8 = 1;
16const TAG_BINARY_FIXED: u8 = 2;
17const TAG_BINARY_PROB: u8 = 3;
18const TAG_BINARY_BLOCK: u8 = 4;
19const TAG_ZERO: u8 = 5;
20const TAG_NONE: u8 = 6;
21
22use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
23
24impl Distribution {
25    pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<()> {
26        let word: u64 = match self {
27            Distribution::TernaryFixed(v) => (TAG_TERNARY_FIXED as u64) << 56 | (*v as u64),
28            Distribution::TernaryProb(p) => {
29                let bits = p.to_bits(); // f64 -> u64 bit representation
30                (TAG_TERNARY_PROB as u64) << 56 | (bits & 0x00FF_FFFF_FFFF_FFFF)
31            }
32            Distribution::BinaryFixed(v) => (TAG_BINARY_FIXED as u64) << 56 | (*v as u64),
33            Distribution::BinaryProb(p) => {
34                let bits = p.to_bits();
35                (TAG_BINARY_PROB as u64) << 56 | (bits & 0x00FF_FFFF_FFFF_FFFF)
36            }
37            Distribution::BinaryBlock(v) => (TAG_BINARY_BLOCK as u64) << 56 | (*v as u64),
38            Distribution::ZERO => (TAG_ZERO as u64) << 56,
39            Distribution::NONE => (TAG_NONE as u64) << 56,
40        };
41        writer.write_u64::<LittleEndian>(word)
42    }
43
44    pub fn read_from<R: Read>(reader: &mut R) -> Result<Self> {
45        let word = reader.read_u64::<LittleEndian>()?;
46        let tag = (word >> 56) as u8;
47        let payload = word & 0x00FF_FFFF_FFFF_FFFF;
48
49        let dist = match tag {
50            TAG_TERNARY_FIXED => Distribution::TernaryFixed(payload as usize),
51            TAG_TERNARY_PROB => Distribution::TernaryProb(f64::from_bits(payload)),
52            TAG_BINARY_FIXED => Distribution::BinaryFixed(payload as usize),
53            TAG_BINARY_PROB => Distribution::BinaryProb(f64::from_bits(payload)),
54            TAG_BINARY_BLOCK => Distribution::BinaryBlock(payload as usize),
55            TAG_ZERO => Distribution::ZERO,
56            TAG_NONE => Distribution::NONE,
57            _ => {
58                return Err(std::io::Error::new(
59                    std::io::ErrorKind::InvalidData,
60                    "Invalid tag",
61                ));
62            }
63        };
64        Ok(dist)
65    }
66}
67
68impl PartialEq for Distribution {
69    fn eq(&self, other: &Self) -> bool {
70        use Distribution::*;
71        match (self, other) {
72            (TernaryFixed(a), TernaryFixed(b)) => a == b,
73            (TernaryProb(a), TernaryProb(b)) => a.to_bits() == b.to_bits(),
74            (BinaryFixed(a), BinaryFixed(b)) => a == b,
75            (BinaryProb(a), BinaryProb(b)) => a.to_bits() == b.to_bits(),
76            (BinaryBlock(a), BinaryBlock(b)) => a == b,
77            (ZERO, ZERO) => true,
78            (NONE, NONE) => true,
79            _ => false,
80        }
81    }
82}
83
84impl Eq for Distribution {}