Skip to main content

nodedb_codec/vector_quant/ternary/
packing.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Trit packing/unpacking and FP32→ternary quantization.
4//!
5//! Cold (5 trits/byte, base-3) is the disk-friendly format.
6//! Hot (4 trits/byte, 2-bpw) is the SIMD-friendly format.
7
8/// Encode an `i8` trit `{-1,0,+1}` into the unsigned alphabet `{0,1,2}`.
9#[inline(always)]
10fn trit_to_u8(t: i8) -> u8 {
11    match t {
12        -1 => 0,
13        0 => 1,
14        1 => 2,
15        _ => 1,
16    }
17}
18
19/// Decode unsigned alphabet `{0,1,2}` back to `i8` trit `{-1,0,+1}`.
20#[inline(always)]
21fn u8_to_trit(v: u8) -> i8 {
22    match v {
23        0 => -1,
24        2 => 1,
25        _ => 0,
26    }
27}
28
29/// Pack trits (`i8 ∈ {-1, 0, +1}`) into cold 5-trits-per-byte format.
30pub fn pack_cold(trits: &[i8]) -> Vec<u8> {
31    let out_len = trits.len().div_ceil(5);
32    let mut out = vec![0u8; out_len];
33    for (chunk_idx, chunk) in trits.chunks(5).enumerate() {
34        let mut byte = 0u8;
35        let mut mul = 1u8;
36        for &t in chunk {
37            byte = byte.wrapping_add(trit_to_u8(t).wrapping_mul(mul));
38            mul = mul.wrapping_mul(3);
39        }
40        out[chunk_idx] = byte;
41    }
42    out
43}
44
45/// Unpack cold 5-trits-per-byte format back to `i8` trits.
46pub fn unpack_cold(cold: &[u8], dim: usize) -> Vec<i8> {
47    let mut out = Vec::with_capacity(dim);
48    'outer: for &byte in cold {
49        let mut v = byte;
50        for _ in 0..5 {
51            if out.len() >= dim {
52                break 'outer;
53            }
54            out.push(u8_to_trit(v % 3));
55            v /= 3;
56        }
57    }
58    out
59}
60
61/// Pack trits into hot 2-bpw format (4 trits per byte).
62///
63/// Bit encoding: `0b00=−1`, `0b01=0`, `0b10=+1`. LSB-first within each byte.
64pub fn pack_hot(trits: &[i8]) -> Vec<u8> {
65    let out_len = trits.len().div_ceil(4);
66    let mut out = vec![0u8; out_len];
67    for (i, &t) in trits.iter().enumerate() {
68        let byte_idx = i / 4;
69        let shift = (i % 4) * 2;
70        let bits: u8 = match t {
71            -1 => 0b00,
72            1 => 0b10,
73            _ => 0b01,
74        };
75        out[byte_idx] |= bits << shift;
76    }
77    out
78}
79
80/// Unpack hot 2-bpw format back to `i8` trits.
81pub fn unpack_hot(hot: &[u8], dim: usize) -> Vec<i8> {
82    let mut out = Vec::with_capacity(dim);
83    'outer: for &byte in hot {
84        for slot in 0..4 {
85            if out.len() >= dim {
86                break 'outer;
87            }
88            let bits = (byte >> (slot * 2)) & 0b11;
89            out.push(match bits {
90                0b00 => -1,
91                0b10 => 1,
92                _ => 0,
93            });
94        }
95    }
96    out
97}
98
99/// Convert cold-packed trits to hot-packed trits.
100pub fn cold_to_hot(cold: &[u8], dim: usize) -> Vec<u8> {
101    let trits = unpack_cold(cold, dim);
102    pack_hot(&trits)
103}
104
105/// Quantize a FP32 vector to ternary trits using BitNet absmean scaling.
106///
107/// Returns `(trits, scale)` where `scale = mean(|v_i|)`.
108pub fn quantize(v: &[f32]) -> (Vec<i8>, f32) {
109    if v.is_empty() {
110        return (Vec::new(), 0.0);
111    }
112    let scale: f32 = v.iter().map(|x| x.abs()).sum::<f32>() / v.len() as f32;
113    let trits = if scale == 0.0 {
114        vec![0i8; v.len()]
115    } else {
116        v.iter()
117            .map(|&x| (x / scale).round().clamp(-1.0, 1.0) as i8)
118            .collect()
119    };
120    (trits, scale)
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn cold_pack_roundtrip_simple() {
129        let trits: Vec<i8> = vec![-1, 0, 1, -1, 0, 1, -1];
130        let cold = pack_cold(&trits);
131        let out = unpack_cold(&cold, trits.len());
132        assert_eq!(out, trits);
133    }
134
135    #[test]
136    fn cold_pack_roundtrip_dim_7() {
137        let trits: Vec<i8> = vec![1, -1, 0, 1, -1, 0, 1];
138        let cold = pack_cold(&trits);
139        assert_eq!(cold.len(), 2);
140        assert_eq!(unpack_cold(&cold, 7), trits);
141    }
142
143    #[test]
144    fn cold_pack_roundtrip_dim_13() {
145        let trits: Vec<i8> = vec![1, 0, -1, 1, 0, -1, 1, 0, -1, 1, 0, -1, 1];
146        let cold = pack_cold(&trits);
147        assert_eq!(cold.len(), 3);
148        assert_eq!(unpack_cold(&cold, 13), trits);
149    }
150
151    #[test]
152    fn hot_pack_roundtrip() {
153        let trits: Vec<i8> = vec![-1, 0, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1];
154        let hot = pack_hot(&trits);
155        assert_eq!(unpack_hot(&hot, trits.len()), trits);
156    }
157
158    #[test]
159    fn cold_to_hot_preserves_trits() {
160        let trits: Vec<i8> = vec![1, -1, 0, 1, -1, 0, 1, -1, 0, 1, -1];
161        let cold = pack_cold(&trits);
162        let hot = cold_to_hot(&cold, trits.len());
163        assert_eq!(unpack_hot(&hot, trits.len()), trits);
164    }
165
166    #[test]
167    fn cold_to_hot_dim_not_multiple_of_5() {
168        for dim in [7usize, 13, 11, 3] {
169            let trits: Vec<i8> = (0..dim)
170                .map(|i| match i % 3 {
171                    0 => 1i8,
172                    1 => -1,
173                    _ => 0,
174                })
175                .collect();
176            let cold = pack_cold(&trits);
177            let hot = cold_to_hot(&cold, dim);
178            assert_eq!(unpack_hot(&hot, dim), trits, "mismatch for dim={dim}");
179        }
180    }
181
182    #[test]
183    fn quantize_zeros_gives_all_zero_trits() {
184        let v = vec![0.0f32; 16];
185        let (trits, scale) = quantize(&v);
186        assert_eq!(scale, 0.0);
187        assert!(trits.iter().all(|&t| t == 0));
188    }
189}