nodedb_codec/vector_quant/ternary/
packing.rs1#[inline(always)]
10fn trit_to_u8(t: i8) -> u8 {
11 match t {
12 -1 => 0,
13 0 => 1,
14 1 => 2,
15 _ => 1,
16 }
17}
18
19#[inline(always)]
21fn u8_to_trit(v: u8) -> i8 {
22 match v {
23 0 => -1,
24 2 => 1,
25 _ => 0,
26 }
27}
28
29pub fn pack_cold(trits: &[i8]) -> Vec<u8> {
31 let out_len = trits.len().div_ceil(5);
32 let mut out = vec![0u8; out_len];
33 for (chunk_idx, chunk) in trits.chunks(5).enumerate() {
34 let mut byte = 0u8;
35 let mut mul = 1u8;
36 for &t in chunk {
37 byte = byte.wrapping_add(trit_to_u8(t).wrapping_mul(mul));
38 mul = mul.wrapping_mul(3);
39 }
40 out[chunk_idx] = byte;
41 }
42 out
43}
44
45pub fn unpack_cold(cold: &[u8], dim: usize) -> Vec<i8> {
47 let mut out = Vec::with_capacity(dim);
48 'outer: for &byte in cold {
49 let mut v = byte;
50 for _ in 0..5 {
51 if out.len() >= dim {
52 break 'outer;
53 }
54 out.push(u8_to_trit(v % 3));
55 v /= 3;
56 }
57 }
58 out
59}
60
61pub fn pack_hot(trits: &[i8]) -> Vec<u8> {
65 let out_len = trits.len().div_ceil(4);
66 let mut out = vec![0u8; out_len];
67 for (i, &t) in trits.iter().enumerate() {
68 let byte_idx = i / 4;
69 let shift = (i % 4) * 2;
70 let bits: u8 = match t {
71 -1 => 0b00,
72 1 => 0b10,
73 _ => 0b01,
74 };
75 out[byte_idx] |= bits << shift;
76 }
77 out
78}
79
80pub fn unpack_hot(hot: &[u8], dim: usize) -> Vec<i8> {
82 let mut out = Vec::with_capacity(dim);
83 'outer: for &byte in hot {
84 for slot in 0..4 {
85 if out.len() >= dim {
86 break 'outer;
87 }
88 let bits = (byte >> (slot * 2)) & 0b11;
89 out.push(match bits {
90 0b00 => -1,
91 0b10 => 1,
92 _ => 0,
93 });
94 }
95 }
96 out
97}
98
99pub fn cold_to_hot(cold: &[u8], dim: usize) -> Vec<u8> {
101 let trits = unpack_cold(cold, dim);
102 pack_hot(&trits)
103}
104
105pub fn quantize(v: &[f32]) -> (Vec<i8>, f32) {
109 if v.is_empty() {
110 return (Vec::new(), 0.0);
111 }
112 let scale: f32 = v.iter().map(|x| x.abs()).sum::<f32>() / v.len() as f32;
113 let trits = if scale == 0.0 {
114 vec![0i8; v.len()]
115 } else {
116 v.iter()
117 .map(|&x| (x / scale).round().clamp(-1.0, 1.0) as i8)
118 .collect()
119 };
120 (trits, scale)
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn cold_pack_roundtrip_simple() {
129 let trits: Vec<i8> = vec![-1, 0, 1, -1, 0, 1, -1];
130 let cold = pack_cold(&trits);
131 let out = unpack_cold(&cold, trits.len());
132 assert_eq!(out, trits);
133 }
134
135 #[test]
136 fn cold_pack_roundtrip_dim_7() {
137 let trits: Vec<i8> = vec![1, -1, 0, 1, -1, 0, 1];
138 let cold = pack_cold(&trits);
139 assert_eq!(cold.len(), 2);
140 assert_eq!(unpack_cold(&cold, 7), trits);
141 }
142
143 #[test]
144 fn cold_pack_roundtrip_dim_13() {
145 let trits: Vec<i8> = vec![1, 0, -1, 1, 0, -1, 1, 0, -1, 1, 0, -1, 1];
146 let cold = pack_cold(&trits);
147 assert_eq!(cold.len(), 3);
148 assert_eq!(unpack_cold(&cold, 13), trits);
149 }
150
151 #[test]
152 fn hot_pack_roundtrip() {
153 let trits: Vec<i8> = vec![-1, 0, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1];
154 let hot = pack_hot(&trits);
155 assert_eq!(unpack_hot(&hot, trits.len()), trits);
156 }
157
158 #[test]
159 fn cold_to_hot_preserves_trits() {
160 let trits: Vec<i8> = vec![1, -1, 0, 1, -1, 0, 1, -1, 0, 1, -1];
161 let cold = pack_cold(&trits);
162 let hot = cold_to_hot(&cold, trits.len());
163 assert_eq!(unpack_hot(&hot, trits.len()), trits);
164 }
165
166 #[test]
167 fn cold_to_hot_dim_not_multiple_of_5() {
168 for dim in [7usize, 13, 11, 3] {
169 let trits: Vec<i8> = (0..dim)
170 .map(|i| match i % 3 {
171 0 => 1i8,
172 1 => -1,
173 _ => 0,
174 })
175 .collect();
176 let cold = pack_cold(&trits);
177 let hot = cold_to_hot(&cold, dim);
178 assert_eq!(unpack_hot(&hot, dim), trits, "mismatch for dim={dim}");
179 }
180 }
181
182 #[test]
183 fn quantize_zeros_gives_all_zero_trits() {
184 let v = vec![0.0f32; 16];
185 let (trits, scale) = quantize(&v);
186 assert_eq!(scale, 0.0);
187 assert!(trits.iter().all(|&t| t == 0));
188 }
189}