1use crate::vsa::{SparseVec, DIM};
37use rayon::prelude::*;
38use sha2::{Digest, Sha256};
39
40pub const MAX_POSITIONS: usize = 65536;
42
43pub struct ReversibleVSAEncoder {
45 byte_vectors: Vec<SparseVec>,
47 position_vectors: Vec<SparseVec>,
49 dim: usize,
51}
52
53impl ReversibleVSAEncoder {
54 pub fn new() -> Self {
56 Self::with_dim(DIM)
57 }
58
59 pub fn with_dim(dim: usize) -> Self {
61 let mut encoder = Self {
62 byte_vectors: Vec::with_capacity(256),
63 position_vectors: Vec::with_capacity(MAX_POSITIONS),
64 dim,
65 };
66 encoder.initialize_basis_vectors();
67 encoder
68 }
69
70 fn initialize_basis_vectors(&mut self) {
72 for byte_val in 0u8..=255 {
74 let seed = Self::hash_to_seed(b"byte", &[byte_val]);
75 self.byte_vectors
76 .push(SparseVec::from_seed(&seed, self.dim));
77 }
78
79 for pos in 0..4096 {
82 let seed = Self::hash_to_seed(b"position", &(pos as u64).to_le_bytes());
83 self.position_vectors
84 .push(SparseVec::from_seed(&seed, self.dim));
85 }
86 }
87
88 fn hash_to_seed(prefix: &[u8], value: &[u8]) -> [u8; 32] {
90 let mut hasher = Sha256::new();
91 hasher.update(b"embeddenator:reversible:v1:");
92 hasher.update(prefix);
93 hasher.update(b":");
94 hasher.update(value);
95 hasher.finalize().into()
96 }
97
98 fn get_position_vector(&mut self, pos: usize) -> &SparseVec {
105 assert!(
106 pos < MAX_POSITIONS,
107 "Position {} exceeds MAX_POSITIONS ({})",
108 pos,
109 MAX_POSITIONS
110 );
111
112 while pos >= self.position_vectors.len() {
113 let new_pos = self.position_vectors.len();
114 let seed = Self::hash_to_seed(b"position", &(new_pos as u64).to_le_bytes());
115 self.position_vectors
116 .push(SparseVec::from_seed(&seed, self.dim));
117 }
118 &self.position_vectors[pos]
119 }
120
121 fn encode_byte_at_position(&self, byte: u8, position: usize) -> SparseVec {
125 let byte_vec = &self.byte_vectors[byte as usize];
126 let pos_vec = &self.position_vectors[position % self.position_vectors.len()];
127 byte_vec.bind(pos_vec)
128 }
129
130 pub fn encode(&mut self, data: &[u8]) -> SparseVec {
134 if data.is_empty() {
135 return SparseVec::new();
136 }
137
138 let _ = self.get_position_vector(data.len().saturating_sub(1));
140
141 let mut result = self.encode_byte_at_position(data[0], 0);
143
144 for (pos, &byte) in data.iter().enumerate().skip(1) {
145 let encoded_byte = self.encode_byte_at_position(byte, pos);
146 result = result.bundle(&encoded_byte);
147 }
148
149 result
150 }
151
152 pub fn encode_chunked(&mut self, data: &[u8], chunk_size: usize) -> Vec<SparseVec> {
159 assert!(chunk_size > 0, "chunk_size must be > 0");
160 data.chunks(chunk_size)
161 .enumerate()
162 .map(|(chunk_idx, chunk)| {
163 let offset = chunk_idx * chunk_size;
164 self.encode_with_offset(chunk, offset)
165 })
166 .collect()
167 }
168
169 fn encode_with_offset(&mut self, data: &[u8], offset: usize) -> SparseVec {
171 if data.is_empty() {
172 return SparseVec::new();
173 }
174
175 let _ = self.get_position_vector(offset + data.len().saturating_sub(1));
177
178 let mut result = self.encode_byte_at_position(data[0], offset);
180
181 for (i, &byte) in data.iter().enumerate().skip(1) {
182 let encoded_byte = self.encode_byte_at_position(byte, offset + i);
183 result = result.bundle(&encoded_byte);
184 }
185
186 result
187 }
188
189 pub fn decode(&self, encoded: &SparseVec, length: usize) -> Vec<u8> {
193 let mut result = Vec::with_capacity(length);
194
195 for pos in 0..length {
196 let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
197
198 let query = encoded.bind(pos_vec);
200
201 let byte = self.find_best_byte_match(&query);
203 result.push(byte);
204 }
205
206 result
207 }
208
209 pub fn decode_chunked(
214 &self,
215 chunks: &[SparseVec],
216 chunk_size: usize,
217 total_length: usize,
218 ) -> Vec<u8> {
219 assert!(chunk_size > 0, "chunk_size must be > 0");
220 let mut result = Vec::with_capacity(total_length);
221
222 for (chunk_idx, chunk_vec) in chunks.iter().enumerate() {
223 let offset = chunk_idx * chunk_size;
224 let remaining = total_length.saturating_sub(offset);
225 let this_chunk_size = remaining.min(chunk_size);
226
227 for i in 0..this_chunk_size {
228 let pos = offset + i;
229 let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
230 let query = chunk_vec.bind(pos_vec);
231 let byte = self.find_best_byte_match(&query);
232 result.push(byte);
233 }
234 }
235
236 result
237 }
238
239 fn find_best_byte_match(&self, query: &SparseVec) -> u8 {
241 let mut best_byte = 0u8;
242 let mut best_sim = f64::NEG_INFINITY;
243
244 for (byte_val, byte_vec) in self.byte_vectors.iter().enumerate() {
245 let sim = query.cosine(byte_vec);
246 if sim > best_sim {
247 best_sim = sim;
248 best_byte = byte_val as u8;
249 }
250 }
251
252 best_byte
253 }
254
255 pub fn get_byte_vectors(&self) -> &[SparseVec] {
259 &self.byte_vectors
260 }
261
262 pub fn get_position_vector_ref(&self, pos: usize) -> &SparseVec {
270 &self.position_vectors[pos % self.position_vectors.len()]
271 }
272
273 pub fn ensure_positions(&mut self, max_pos: usize) {
287 assert!(
288 max_pos < MAX_POSITIONS,
289 "max_pos {} exceeds MAX_POSITIONS ({})",
290 max_pos,
291 MAX_POSITIONS
292 );
293 let _ = self.get_position_vector(max_pos);
294 }
295
296 pub fn try_ensure_positions(&mut self, max_pos: usize) -> Result<(), String> {
305 if max_pos >= MAX_POSITIONS {
306 return Err(format!(
307 "max_pos {} exceeds MAX_POSITIONS ({})",
308 max_pos, MAX_POSITIONS
309 ));
310 }
311 let _ = self.get_position_vector(max_pos);
312 Ok(())
313 }
314
315 pub fn test_accuracy(&mut self, data: &[u8]) -> f64 {
317 let encoded = self.encode(data);
318 let decoded = self.decode(&encoded, data.len());
319
320 let matches = data
321 .iter()
322 .zip(decoded.iter())
323 .filter(|(a, b)| a == b)
324 .count();
325
326 matches as f64 / data.len() as f64
327 }
328
329 pub fn test_accuracy_chunked(&mut self, data: &[u8], chunk_size: usize) -> f64 {
331 let chunks = self.encode_chunked(data, chunk_size);
332 let decoded = self.decode_chunked(&chunks, chunk_size, data.len());
333
334 let matches = data
335 .iter()
336 .zip(decoded.iter())
337 .filter(|(a, b)| a == b)
338 .count();
339
340 matches as f64 / data.len() as f64
341 }
342
343 pub fn batch_encode(&mut self, data: &[u8], chunk_size: usize) -> Vec<SparseVec> {
364 assert!(chunk_size > 0, "chunk_size must be > 0");
365
366 if data.is_empty() {
367 return Vec::new();
368 }
369
370 let _ = self.get_position_vector(0);
372
373 let max_pos = data.len().saturating_sub(1);
375 if max_pos < MAX_POSITIONS {
376 let _ = self.get_position_vector(max_pos);
377 }
378
379 let byte_vectors = &self.byte_vectors;
382 let position_vectors = &self.position_vectors;
383
384 let chunks: Vec<(usize, &[u8])> = data.chunks(chunk_size).enumerate().collect();
386
387 chunks
388 .par_iter()
389 .map(|(chunk_idx, chunk)| {
390 let offset = chunk_idx * chunk_size;
391 Self::encode_chunk_parallel(chunk, offset, byte_vectors, position_vectors)
392 })
393 .collect()
394 }
395
396 fn encode_chunk_parallel(
398 data: &[u8],
399 offset: usize,
400 byte_vectors: &[SparseVec],
401 position_vectors: &[SparseVec],
402 ) -> SparseVec {
403 if data.is_empty() {
404 return SparseVec::new();
405 }
406
407 let byte_vec = &byte_vectors[data[0] as usize];
409 let pos_vec = &position_vectors[offset % position_vectors.len()];
410 let mut result = byte_vec.bind(pos_vec);
411
412 for (i, &byte) in data.iter().enumerate().skip(1) {
414 let byte_vec = &byte_vectors[byte as usize];
415 let pos_vec = &position_vectors[(offset + i) % position_vectors.len()];
416 let encoded_byte = byte_vec.bind(pos_vec);
417 result = result.bundle(&encoded_byte);
418 }
419
420 result
421 }
422
423 pub fn batch_decode(
438 &self,
439 chunks: &[SparseVec],
440 chunk_size: usize,
441 total_length: usize,
442 ) -> Vec<u8> {
443 assert!(chunk_size > 0, "chunk_size must be > 0");
444
445 if chunks.is_empty() || total_length == 0 {
446 return Vec::new();
447 }
448
449 let byte_vectors = &self.byte_vectors;
452 let position_vectors = &self.position_vectors;
453
454 let decoded_chunks: Vec<Vec<u8>> = chunks
456 .par_iter()
457 .enumerate()
458 .map(|(chunk_idx, chunk_vec)| {
459 let offset = chunk_idx * chunk_size;
460 let remaining = total_length.saturating_sub(offset);
461 let this_chunk_size = remaining.min(chunk_size);
462
463 Self::decode_chunk_parallel(
464 chunk_vec,
465 offset,
466 this_chunk_size,
467 byte_vectors,
468 position_vectors,
469 )
470 })
471 .collect();
472
473 decoded_chunks.into_iter().flatten().collect()
475 }
476
477 fn decode_chunk_parallel(
479 chunk_vec: &SparseVec,
480 offset: usize,
481 chunk_size: usize,
482 byte_vectors: &[SparseVec],
483 position_vectors: &[SparseVec],
484 ) -> Vec<u8> {
485 let mut result = Vec::with_capacity(chunk_size);
486
487 for i in 0..chunk_size {
488 let pos = offset + i;
489 let pos_vec = &position_vectors[pos % position_vectors.len()];
490 let query = chunk_vec.bind(pos_vec);
491
492 let mut best_byte = 0u8;
494 let mut best_sim = f64::NEG_INFINITY;
495
496 for (byte_val, byte_vec) in byte_vectors.iter().enumerate() {
497 let sim = query.cosine(byte_vec);
498 if sim > best_sim {
499 best_sim = sim;
500 best_byte = byte_val as u8;
501 }
502 }
503
504 result.push(best_byte);
505 }
506
507 result
508 }
509
510 pub fn estimate_throughput(data_size: usize, elapsed_secs: f64) -> f64 {
517 if elapsed_secs <= 0.0 {
518 return f64::INFINITY;
519 }
520 let mb = data_size as f64 / (1024.0 * 1024.0);
521 mb / elapsed_secs
522 }
523}
524
525impl Default for ReversibleVSAEncoder {
526 fn default() -> Self {
527 Self::new()
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_single_byte_roundtrip() {
537 let mut encoder = ReversibleVSAEncoder::new();
538
539 for byte in [0u8, 42, 127, 255] {
540 let encoded = encoder.encode(&[byte]);
541 let decoded = encoder.decode(&encoded, 1);
542 assert_eq!(decoded[0], byte, "Failed roundtrip for byte {}", byte);
543 }
544 }
545
546 #[test]
547 fn test_short_string_roundtrip() {
548 let mut encoder = ReversibleVSAEncoder::new();
549 let data = b"Hello";
550
551 let encoded = encoder.encode(data);
552 let decoded = encoder.decode(&encoded, data.len());
553
554 let accuracy = data
556 .iter()
557 .zip(decoded.iter())
558 .filter(|(a, b)| a == b)
559 .count() as f64
560 / data.len() as f64;
561
562 assert!(accuracy >= 0.5, "Accuracy too low: {}", accuracy);
563 }
564
565 #[test]
566 fn test_chunked_encoding() {
567 let mut encoder = ReversibleVSAEncoder::new();
568 let data = b"This is a test of chunked encoding for longer data.";
569
570 let accuracy = encoder.test_accuracy_chunked(data, 8);
571 println!("Chunked accuracy: {:.2}%", accuracy * 100.0);
572
573 assert!(
579 accuracy >= 0.80,
580 "Chunked accuracy {:.1}% is below expected threshold 80%",
581 accuracy * 100.0
582 );
583 }
584
585 #[test]
586 fn test_try_ensure_positions_within_limit() {
587 let mut encoder = ReversibleVSAEncoder::new();
588 assert!(encoder.try_ensure_positions(1000).is_ok());
590 assert!(encoder.try_ensure_positions(MAX_POSITIONS - 1).is_ok());
591 }
592
593 #[test]
594 fn test_try_ensure_positions_exceeds_limit() {
595 let mut encoder = ReversibleVSAEncoder::new();
596 assert!(encoder.try_ensure_positions(MAX_POSITIONS).is_err());
598 assert!(encoder.try_ensure_positions(MAX_POSITIONS + 1).is_err());
599 }
600
601 #[test]
602 #[should_panic(expected = "MAX_POSITIONS")]
603 fn test_ensure_positions_panics_on_overflow() {
604 let mut encoder = ReversibleVSAEncoder::new();
605 encoder.ensure_positions(MAX_POSITIONS);
606 }
607
608 #[test]
609 fn test_batch_encode_matches_sequential() {
610 let mut encoder = ReversibleVSAEncoder::new();
611 let data = b"This is a test of parallel batch encoding for higher throughput.";
612 let chunk_size = 16;
613
614 let sequential_chunks = encoder.encode_chunked(data, chunk_size);
616
617 let parallel_chunks = encoder.batch_encode(data, chunk_size);
619
620 assert_eq!(sequential_chunks.len(), parallel_chunks.len());
622
623 for (i, (seq_chunk, par_chunk)) in sequential_chunks
625 .iter()
626 .zip(parallel_chunks.iter())
627 .enumerate()
628 {
629 let seq_decoded =
630 encoder.decode(seq_chunk, chunk_size.min(data.len() - i * chunk_size));
631 let par_decoded =
632 encoder.decode(par_chunk, chunk_size.min(data.len() - i * chunk_size));
633 assert_eq!(seq_decoded, par_decoded, "Chunk {} decoded differently", i);
634 }
635 }
636
637 #[test]
638 fn test_batch_decode_matches_sequential() {
639 let mut encoder = ReversibleVSAEncoder::new();
640 let data = b"Testing parallel batch decode for higher throughput on multi-core systems.";
641 let chunk_size = 16;
642
643 let chunks = encoder.batch_encode(data, chunk_size);
645
646 let sequential_decoded = encoder.decode_chunked(&chunks, chunk_size, data.len());
648
649 let parallel_decoded = encoder.batch_decode(&chunks, chunk_size, data.len());
651
652 assert_eq!(sequential_decoded, parallel_decoded);
654 }
655
656 #[test]
657 fn test_batch_encode_empty() {
658 let mut encoder = ReversibleVSAEncoder::new();
659 let chunks = encoder.batch_encode(&[], 64);
660 assert!(chunks.is_empty());
661 }
662
663 #[test]
664 fn test_batch_decode_empty() {
665 let encoder = ReversibleVSAEncoder::new();
666 let decoded = encoder.batch_decode(&[], 64, 0);
667 assert!(decoded.is_empty());
668 }
669
670 #[test]
671 fn test_batch_encode_accuracy() {
672 let mut encoder = ReversibleVSAEncoder::new();
673 let data = b"The quick brown fox jumps over the lazy dog. 0123456789!";
674 let chunk_size = 16;
675
676 let chunks = encoder.batch_encode(data, chunk_size);
677 let decoded = encoder.batch_decode(&chunks, chunk_size, data.len());
678
679 let matches = data
680 .iter()
681 .zip(decoded.iter())
682 .filter(|(a, b)| a == b)
683 .count();
684 let accuracy = matches as f64 / data.len() as f64;
685
686 println!("Batch encode/decode accuracy: {:.2}%", accuracy * 100.0);
687 assert!(
688 accuracy >= 0.80,
689 "Batch accuracy {:.1}% below expected 80%",
690 accuracy * 100.0
691 );
692 }
693}