embeddenator_vsa/
reversible_encoding.rs1use crate::vsa::{SparseVec, DIM};
37use sha2::{Digest, Sha256};
38
39pub const MAX_POSITIONS: usize = 65536;
41
42pub struct ReversibleVSAEncoder {
44 byte_vectors: Vec<SparseVec>,
46 position_vectors: Vec<SparseVec>,
48 dim: usize,
50}
51
52impl ReversibleVSAEncoder {
53 pub fn new() -> Self {
55 Self::with_dim(DIM)
56 }
57
58 pub fn with_dim(dim: usize) -> Self {
60 let mut encoder = Self {
61 byte_vectors: Vec::with_capacity(256),
62 position_vectors: Vec::with_capacity(MAX_POSITIONS),
63 dim,
64 };
65 encoder.initialize_basis_vectors();
66 encoder
67 }
68
69 fn initialize_basis_vectors(&mut self) {
71 for byte_val in 0u8..=255 {
73 let seed = Self::hash_to_seed(b"byte", &[byte_val]);
74 self.byte_vectors
75 .push(SparseVec::from_seed(&seed, self.dim));
76 }
77
78 for pos in 0..4096 {
81 let seed = Self::hash_to_seed(b"position", &(pos as u64).to_le_bytes());
82 self.position_vectors
83 .push(SparseVec::from_seed(&seed, self.dim));
84 }
85 }
86
87 fn hash_to_seed(prefix: &[u8], value: &[u8]) -> [u8; 32] {
89 let mut hasher = Sha256::new();
90 hasher.update(b"embeddenator:reversible:v1:");
91 hasher.update(prefix);
92 hasher.update(b":");
93 hasher.update(value);
94 hasher.finalize().into()
95 }
96
97 fn get_position_vector(&mut self, pos: usize) -> &SparseVec {
104 assert!(
105 pos < MAX_POSITIONS,
106 "Position {} exceeds MAX_POSITIONS ({})",
107 pos,
108 MAX_POSITIONS
109 );
110
111 while pos >= self.position_vectors.len() {
112 let new_pos = self.position_vectors.len();
113 let seed = Self::hash_to_seed(b"position", &(new_pos as u64).to_le_bytes());
114 self.position_vectors
115 .push(SparseVec::from_seed(&seed, self.dim));
116 }
117 &self.position_vectors[pos]
118 }
119
120 fn encode_byte_at_position(&self, byte: u8, position: usize) -> SparseVec {
124 let byte_vec = &self.byte_vectors[byte as usize];
125 let pos_vec = &self.position_vectors[position % self.position_vectors.len()];
126 byte_vec.bind(pos_vec)
127 }
128
129 pub fn encode(&mut self, data: &[u8]) -> SparseVec {
133 if data.is_empty() {
134 return SparseVec::new();
135 }
136
137 let _ = self.get_position_vector(data.len().saturating_sub(1));
139
140 let mut result = self.encode_byte_at_position(data[0], 0);
142
143 for (pos, &byte) in data.iter().enumerate().skip(1) {
144 let encoded_byte = self.encode_byte_at_position(byte, pos);
145 result = result.bundle(&encoded_byte);
146 }
147
148 result
149 }
150
151 pub fn encode_chunked(&mut self, data: &[u8], chunk_size: usize) -> Vec<SparseVec> {
158 assert!(chunk_size > 0, "chunk_size must be > 0");
159 data.chunks(chunk_size)
160 .enumerate()
161 .map(|(chunk_idx, chunk)| {
162 let offset = chunk_idx * chunk_size;
163 self.encode_with_offset(chunk, offset)
164 })
165 .collect()
166 }
167
168 fn encode_with_offset(&mut self, data: &[u8], offset: usize) -> SparseVec {
170 if data.is_empty() {
171 return SparseVec::new();
172 }
173
174 let _ = self.get_position_vector(offset + data.len().saturating_sub(1));
176
177 let mut result = self.encode_byte_at_position(data[0], offset);
179
180 for (i, &byte) in data.iter().enumerate().skip(1) {
181 let encoded_byte = self.encode_byte_at_position(byte, offset + i);
182 result = result.bundle(&encoded_byte);
183 }
184
185 result
186 }
187
188 pub fn decode(&self, encoded: &SparseVec, length: usize) -> Vec<u8> {
192 let mut result = Vec::with_capacity(length);
193
194 for pos in 0..length {
195 let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
196
197 let query = encoded.bind(pos_vec);
199
200 let byte = self.find_best_byte_match(&query);
202 result.push(byte);
203 }
204
205 result
206 }
207
208 pub fn decode_chunked(
213 &self,
214 chunks: &[SparseVec],
215 chunk_size: usize,
216 total_length: usize,
217 ) -> Vec<u8> {
218 assert!(chunk_size > 0, "chunk_size must be > 0");
219 let mut result = Vec::with_capacity(total_length);
220
221 for (chunk_idx, chunk_vec) in chunks.iter().enumerate() {
222 let offset = chunk_idx * chunk_size;
223 let remaining = total_length.saturating_sub(offset);
224 let this_chunk_size = remaining.min(chunk_size);
225
226 for i in 0..this_chunk_size {
227 let pos = offset + i;
228 let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
229 let query = chunk_vec.bind(pos_vec);
230 let byte = self.find_best_byte_match(&query);
231 result.push(byte);
232 }
233 }
234
235 result
236 }
237
238 fn find_best_byte_match(&self, query: &SparseVec) -> u8 {
240 let mut best_byte = 0u8;
241 let mut best_sim = f64::NEG_INFINITY;
242
243 for (byte_val, byte_vec) in self.byte_vectors.iter().enumerate() {
244 let sim = query.cosine(byte_vec);
245 if sim > best_sim {
246 best_sim = sim;
247 best_byte = byte_val as u8;
248 }
249 }
250
251 best_byte
252 }
253
254 pub fn get_byte_vectors(&self) -> &[SparseVec] {
258 &self.byte_vectors
259 }
260
261 pub fn get_position_vector_ref(&self, pos: usize) -> &SparseVec {
269 &self.position_vectors[pos % self.position_vectors.len()]
270 }
271
272 pub fn ensure_positions(&mut self, max_pos: usize) {
286 assert!(
287 max_pos < MAX_POSITIONS,
288 "max_pos {} exceeds MAX_POSITIONS ({})",
289 max_pos,
290 MAX_POSITIONS
291 );
292 let _ = self.get_position_vector(max_pos);
293 }
294
295 pub fn try_ensure_positions(&mut self, max_pos: usize) -> Result<(), String> {
304 if max_pos >= MAX_POSITIONS {
305 return Err(format!(
306 "max_pos {} exceeds MAX_POSITIONS ({})",
307 max_pos, MAX_POSITIONS
308 ));
309 }
310 let _ = self.get_position_vector(max_pos);
311 Ok(())
312 }
313
314 pub fn test_accuracy(&mut self, data: &[u8]) -> f64 {
316 let encoded = self.encode(data);
317 let decoded = self.decode(&encoded, data.len());
318
319 let matches = data
320 .iter()
321 .zip(decoded.iter())
322 .filter(|(a, b)| a == b)
323 .count();
324
325 matches as f64 / data.len() as f64
326 }
327
328 pub fn test_accuracy_chunked(&mut self, data: &[u8], chunk_size: usize) -> f64 {
330 let chunks = self.encode_chunked(data, chunk_size);
331 let decoded = self.decode_chunked(&chunks, chunk_size, data.len());
332
333 let matches = data
334 .iter()
335 .zip(decoded.iter())
336 .filter(|(a, b)| a == b)
337 .count();
338
339 matches as f64 / data.len() as f64
340 }
341}
342
343impl Default for ReversibleVSAEncoder {
344 fn default() -> Self {
345 Self::new()
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_single_byte_roundtrip() {
355 let mut encoder = ReversibleVSAEncoder::new();
356
357 for byte in [0u8, 42, 127, 255] {
358 let encoded = encoder.encode(&[byte]);
359 let decoded = encoder.decode(&encoded, 1);
360 assert_eq!(decoded[0], byte, "Failed roundtrip for byte {}", byte);
361 }
362 }
363
364 #[test]
365 fn test_short_string_roundtrip() {
366 let mut encoder = ReversibleVSAEncoder::new();
367 let data = b"Hello";
368
369 let encoded = encoder.encode(data);
370 let decoded = encoder.decode(&encoded, data.len());
371
372 let accuracy = data
374 .iter()
375 .zip(decoded.iter())
376 .filter(|(a, b)| a == b)
377 .count() as f64
378 / data.len() as f64;
379
380 assert!(accuracy >= 0.5, "Accuracy too low: {}", accuracy);
381 }
382
383 #[test]
384 fn test_chunked_encoding() {
385 let mut encoder = ReversibleVSAEncoder::new();
386 let data = b"This is a test of chunked encoding for longer data.";
387
388 let accuracy = encoder.test_accuracy_chunked(data, 8);
389 println!("Chunked accuracy: {:.2}%", accuracy * 100.0);
390
391 assert!(
397 accuracy >= 0.80,
398 "Chunked accuracy {:.1}% is below expected threshold 80%",
399 accuracy * 100.0
400 );
401 }
402
403 #[test]
404 fn test_try_ensure_positions_within_limit() {
405 let mut encoder = ReversibleVSAEncoder::new();
406 assert!(encoder.try_ensure_positions(1000).is_ok());
408 assert!(encoder.try_ensure_positions(MAX_POSITIONS - 1).is_ok());
409 }
410
411 #[test]
412 fn test_try_ensure_positions_exceeds_limit() {
413 let mut encoder = ReversibleVSAEncoder::new();
414 assert!(encoder.try_ensure_positions(MAX_POSITIONS).is_err());
416 assert!(encoder.try_ensure_positions(MAX_POSITIONS + 1).is_err());
417 }
418
419 #[test]
420 #[should_panic(expected = "MAX_POSITIONS")]
421 fn test_ensure_positions_panics_on_overflow() {
422 let mut encoder = ReversibleVSAEncoder::new();
423 encoder.ensure_positions(MAX_POSITIONS);
424 }
425}