Skip to main content

embeddenator_vsa/
reversible_encoding.rs

1//! Reversible Position-Aware VSA Encoding
2//!
3//! This module implements true holographic storage where data can be
4//! reconstructed from the VSA vector. With chunked encoding, typical
5//! accuracy is 90-95% before correction layer application.
6//!
7//! Chunk size is configurable: smaller chunks (8-64 bytes) provide higher
8//! accuracy per chunk but more overhead; embeddenator-fs uses 64 bytes
9//! as a balance between accuracy and efficiency.
10//!
11//! # Architecture
12//!
13//! ```text
14//! For each byte at position i:
15//!   encoded[i] = bind(position_vector[i], byte_vector[data[i]])
16//!
17//! Memory = bundle(encoded[0], encoded[1], ..., encoded[n])
18//!
19//! To retrieve byte at position i:
20//!   query = bind(position_vector[i], Memory)
21//!   byte = argmax_b(cosine(query, byte_vector[b]))
22//! ```
23//!
24//! This uses the fundamental VSA operations:
25//! - **Bind**: Creates a composite that's dissimilar to both inputs but can be unbound
26//! - **Bundle**: Superimposes vectors while preserving retrievability
27//! - **Unbind**: Reverses bind to retrieve the bound content
28//!
29//! # Why This Works
30//!
31//! - Each (position, byte) pair has a unique representation
32//! - Bundle creates holographic superposition of all pairs
33//! - Unbind + similarity search retrieves the original byte
34//! - No information loss from collisions
35
36use crate::vsa::{SparseVec, DIM};
37use rayon::prelude::*;
38use sha2::{Digest, Sha256};
39
40/// Maximum file size for position-aware encoding (4MB in 64-byte chunks)
41pub const MAX_POSITIONS: usize = 65536;
42
43/// Reversible encoder using position-aware VSA binding
44pub struct ReversibleVSAEncoder {
45    /// Basis vectors for each byte value (0-255)
46    byte_vectors: Vec<SparseVec>,
47    /// Basis vectors for each position (0 to MAX_POSITIONS-1)
48    position_vectors: Vec<SparseVec>,
49    /// Dimensionality
50    dim: usize,
51}
52
53impl ReversibleVSAEncoder {
54    /// Create a new reversible encoder
55    pub fn new() -> Self {
56        Self::with_dim(DIM)
57    }
58
59    /// Create a new reversible encoder with custom dimensionality
60    pub fn with_dim(dim: usize) -> Self {
61        let mut encoder = Self {
62            byte_vectors: Vec::with_capacity(256),
63            position_vectors: Vec::with_capacity(MAX_POSITIONS),
64            dim,
65        };
66        encoder.initialize_basis_vectors();
67        encoder
68    }
69
70    /// Initialize basis vectors for bytes and positions
71    fn initialize_basis_vectors(&mut self) {
72        // Create deterministic byte vectors
73        for byte_val in 0u8..=255 {
74            let seed = Self::hash_to_seed(b"byte", &[byte_val]);
75            self.byte_vectors
76                .push(SparseVec::from_seed(&seed, self.dim));
77        }
78
79        // Create deterministic position vectors (lazy - create on demand up to max)
80        // For now, pre-create a reasonable number
81        for pos in 0..4096 {
82            let seed = Self::hash_to_seed(b"position", &(pos as u64).to_le_bytes());
83            self.position_vectors
84                .push(SparseVec::from_seed(&seed, self.dim));
85        }
86    }
87
88    /// Hash a prefix and value to a 32-byte seed
89    fn hash_to_seed(prefix: &[u8], value: &[u8]) -> [u8; 32] {
90        let mut hasher = Sha256::new();
91        hasher.update(b"embeddenator:reversible:v1:");
92        hasher.update(prefix);
93        hasher.update(b":");
94        hasher.update(value);
95        hasher.finalize().into()
96    }
97
98    /// Get or create position vector for a given position
99    ///
100    /// # Panics
101    ///
102    /// Panics if `pos >= MAX_POSITIONS` (65536). Use `ensure_positions` with
103    /// proper bounds checking before calling this method with untrusted input.
104    fn get_position_vector(&mut self, pos: usize) -> &SparseVec {
105        assert!(
106            pos < MAX_POSITIONS,
107            "Position {} exceeds MAX_POSITIONS ({})",
108            pos,
109            MAX_POSITIONS
110        );
111
112        while pos >= self.position_vectors.len() {
113            let new_pos = self.position_vectors.len();
114            let seed = Self::hash_to_seed(b"position", &(new_pos as u64).to_le_bytes());
115            self.position_vectors
116                .push(SparseVec::from_seed(&seed, self.dim));
117        }
118        &self.position_vectors[pos]
119    }
120
121    /// Encode a byte at a specific position
122    ///
123    /// Returns bind(position_vector, byte_vector)
124    fn encode_byte_at_position(&self, byte: u8, position: usize) -> SparseVec {
125        let byte_vec = &self.byte_vectors[byte as usize];
126        let pos_vec = &self.position_vectors[position % self.position_vectors.len()];
127        byte_vec.bind(pos_vec)
128    }
129
130    /// Encode data into a holographic representation
131    ///
132    /// Returns a single SparseVec that contains all the data holographically.
133    pub fn encode(&mut self, data: &[u8]) -> SparseVec {
134        if data.is_empty() {
135            return SparseVec::new();
136        }
137
138        // Ensure we have enough position vectors
139        let _ = self.get_position_vector(data.len().saturating_sub(1));
140
141        // Encode each byte and bundle them together
142        let mut result = self.encode_byte_at_position(data[0], 0);
143
144        for (pos, &byte) in data.iter().enumerate().skip(1) {
145            let encoded_byte = self.encode_byte_at_position(byte, pos);
146            result = result.bundle(&encoded_byte);
147        }
148
149        result
150    }
151
152    /// Encode data in chunks for better accuracy with large files
153    ///
154    /// Returns a vector of SparseVecs, one per chunk.
155    ///
156    /// # Panics
157    /// Panics if `chunk_size` is 0.
158    pub fn encode_chunked(&mut self, data: &[u8], chunk_size: usize) -> Vec<SparseVec> {
159        assert!(chunk_size > 0, "chunk_size must be > 0");
160        data.chunks(chunk_size)
161            .enumerate()
162            .map(|(chunk_idx, chunk)| {
163                let offset = chunk_idx * chunk_size;
164                self.encode_with_offset(chunk, offset)
165            })
166            .collect()
167    }
168
169    /// Encode data with a position offset
170    fn encode_with_offset(&mut self, data: &[u8], offset: usize) -> SparseVec {
171        if data.is_empty() {
172            return SparseVec::new();
173        }
174
175        // Ensure we have enough position vectors
176        let _ = self.get_position_vector(offset + data.len().saturating_sub(1));
177
178        // Encode each byte and bundle
179        let mut result = self.encode_byte_at_position(data[0], offset);
180
181        for (i, &byte) in data.iter().enumerate().skip(1) {
182            let encoded_byte = self.encode_byte_at_position(byte, offset + i);
183            result = result.bundle(&encoded_byte);
184        }
185
186        result
187    }
188
189    /// Decode data from a holographic representation
190    ///
191    /// Uses bind + similarity search to retrieve each byte.
192    pub fn decode(&self, encoded: &SparseVec, length: usize) -> Vec<u8> {
193        let mut result = Vec::with_capacity(length);
194
195        for pos in 0..length {
196            let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
197
198            // Unbind position to get query for byte
199            let query = encoded.bind(pos_vec);
200
201            // Find best matching byte vector
202            let byte = self.find_best_byte_match(&query);
203            result.push(byte);
204        }
205
206        result
207    }
208
209    /// Decode chunked data
210    ///
211    /// # Panics
212    /// Panics if `chunk_size` is 0.
213    pub fn decode_chunked(
214        &self,
215        chunks: &[SparseVec],
216        chunk_size: usize,
217        total_length: usize,
218    ) -> Vec<u8> {
219        assert!(chunk_size > 0, "chunk_size must be > 0");
220        let mut result = Vec::with_capacity(total_length);
221
222        for (chunk_idx, chunk_vec) in chunks.iter().enumerate() {
223            let offset = chunk_idx * chunk_size;
224            let remaining = total_length.saturating_sub(offset);
225            let this_chunk_size = remaining.min(chunk_size);
226
227            for i in 0..this_chunk_size {
228                let pos = offset + i;
229                let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
230                let query = chunk_vec.bind(pos_vec);
231                let byte = self.find_best_byte_match(&query);
232                result.push(byte);
233            }
234        }
235
236        result
237    }
238
239    /// Find the byte value with highest similarity to query
240    fn find_best_byte_match(&self, query: &SparseVec) -> u8 {
241        let mut best_byte = 0u8;
242        let mut best_sim = f64::NEG_INFINITY;
243
244        for (byte_val, byte_vec) in self.byte_vectors.iter().enumerate() {
245            let sim = query.cosine(byte_vec);
246            if sim > best_sim {
247                best_sim = sim;
248                best_byte = byte_val as u8;
249            }
250        }
251
252        best_byte
253    }
254
255    /// Get reference to byte vectors (for GPU acceleration)
256    ///
257    /// Returns slice of 256 basis vectors, one per byte value.
258    pub fn get_byte_vectors(&self) -> &[SparseVec] {
259        &self.byte_vectors
260    }
261
262    /// Get position vector for a given position (for GPU acceleration)
263    ///
264    /// Returns reference to position basis vector. Uses modulo if position
265    /// exceeds pre-allocated vectors.
266    ///
267    /// Note: Call `ensure_positions(max_pos)` first if you need exact position
268    /// vectors without modulo wrapping.
269    pub fn get_position_vector_ref(&self, pos: usize) -> &SparseVec {
270        &self.position_vectors[pos % self.position_vectors.len()]
271    }
272
273    /// Ensure position vectors exist up to (and including) the given position
274    ///
275    /// # Panics
276    ///
277    /// Panics if `max_pos >= MAX_POSITIONS` (65536).
278    ///
279    /// # Example
280    ///
281    /// ```rust,ignore
282    /// let mut encoder = ReversibleVSAEncoder::new();
283    /// encoder.ensure_positions(1000); // OK
284    /// // encoder.ensure_positions(100000); // Would panic
285    /// ```
286    pub fn ensure_positions(&mut self, max_pos: usize) {
287        assert!(
288            max_pos < MAX_POSITIONS,
289            "max_pos {} exceeds MAX_POSITIONS ({})",
290            max_pos,
291            MAX_POSITIONS
292        );
293        let _ = self.get_position_vector(max_pos);
294    }
295
296    /// Ensure position vectors exist up to the given position, returning error on invalid input
297    ///
298    /// This is a non-panicking alternative to `ensure_positions`.
299    ///
300    /// # Returns
301    ///
302    /// Returns `Ok(())` if positions are allocated successfully, or `Err` if
303    /// `max_pos >= MAX_POSITIONS`.
304    pub fn try_ensure_positions(&mut self, max_pos: usize) -> Result<(), String> {
305        if max_pos >= MAX_POSITIONS {
306            return Err(format!(
307                "max_pos {} exceeds MAX_POSITIONS ({})",
308                max_pos, MAX_POSITIONS
309            ));
310        }
311        let _ = self.get_position_vector(max_pos);
312        Ok(())
313    }
314
315    /// Compute reconstruction accuracy (for testing)
316    pub fn test_accuracy(&mut self, data: &[u8]) -> f64 {
317        let encoded = self.encode(data);
318        let decoded = self.decode(&encoded, data.len());
319
320        let matches = data
321            .iter()
322            .zip(decoded.iter())
323            .filter(|(a, b)| a == b)
324            .count();
325
326        matches as f64 / data.len() as f64
327    }
328
329    /// Compute chunked reconstruction accuracy
330    pub fn test_accuracy_chunked(&mut self, data: &[u8], chunk_size: usize) -> f64 {
331        let chunks = self.encode_chunked(data, chunk_size);
332        let decoded = self.decode_chunked(&chunks, chunk_size, data.len());
333
334        let matches = data
335            .iter()
336            .zip(decoded.iter())
337            .filter(|(a, b)| a == b)
338            .count();
339
340        matches as f64 / data.len() as f64
341    }
342
343    /// Batch encode data in parallel using rayon
344    ///
345    /// This method achieves higher throughput by processing chunks in parallel.
346    /// Position vectors are pre-allocated before parallel processing.
347    ///
348    /// # Arguments
349    /// * `data` - Raw bytes to encode
350    /// * `chunk_size` - Size of each chunk (64 bytes recommended)
351    ///
352    /// # Returns
353    /// Vector of encoded chunks, one `SparseVec` per chunk
354    ///
355    /// # Panics
356    /// Panics if `chunk_size` is 0.
357    ///
358    /// # Example
359    /// ```rust,ignore
360    /// let mut encoder = ReversibleVSAEncoder::new();
361    /// let chunks = encoder.batch_encode(&large_data, 64);
362    /// ```
363    pub fn batch_encode(&mut self, data: &[u8], chunk_size: usize) -> Vec<SparseVec> {
364        assert!(chunk_size > 0, "chunk_size must be > 0");
365
366        if data.is_empty() {
367            return Vec::new();
368        }
369
370        // Ensure at least one position vector exists to avoid modulo by zero
371        let _ = self.get_position_vector(0);
372
373        // Pre-ensure all position vectors exist (sequential, but fast)
374        let max_pos = data.len().saturating_sub(1);
375        if max_pos < MAX_POSITIONS {
376            let _ = self.get_position_vector(max_pos);
377        }
378
379        // Use shared references for parallel access (no cloning/allocation)
380        // SparseVec is Sync+Send for read access, so shared refs work with Rayon
381        let byte_vectors = &self.byte_vectors;
382        let position_vectors = &self.position_vectors;
383
384        // Process chunks in parallel
385        let chunks: Vec<(usize, &[u8])> = data.chunks(chunk_size).enumerate().collect();
386
387        chunks
388            .par_iter()
389            .map(|(chunk_idx, chunk)| {
390                let offset = chunk_idx * chunk_size;
391                Self::encode_chunk_parallel(chunk, offset, byte_vectors, position_vectors)
392            })
393            .collect()
394    }
395
396    /// Encode a single chunk (used by parallel encoding)
397    fn encode_chunk_parallel(
398        data: &[u8],
399        offset: usize,
400        byte_vectors: &[SparseVec],
401        position_vectors: &[SparseVec],
402    ) -> SparseVec {
403        if data.is_empty() {
404            return SparseVec::new();
405        }
406
407        // Encode first byte
408        let byte_vec = &byte_vectors[data[0] as usize];
409        let pos_vec = &position_vectors[offset % position_vectors.len()];
410        let mut result = byte_vec.bind(pos_vec);
411
412        // Encode remaining bytes and bundle
413        for (i, &byte) in data.iter().enumerate().skip(1) {
414            let byte_vec = &byte_vectors[byte as usize];
415            let pos_vec = &position_vectors[(offset + i) % position_vectors.len()];
416            let encoded_byte = byte_vec.bind(pos_vec);
417            result = result.bundle(&encoded_byte);
418        }
419
420        result
421    }
422
423    /// Batch decode chunks in parallel using rayon
424    ///
425    /// This method achieves higher throughput by decoding chunks in parallel.
426    ///
427    /// # Arguments
428    /// * `chunks` - Encoded chunks to decode
429    /// * `chunk_size` - Size of each chunk (must match encoding chunk_size)
430    /// * `total_length` - Total expected output length
431    ///
432    /// # Returns
433    /// Reconstructed bytes
434    ///
435    /// # Panics
436    /// Panics if `chunk_size` is 0.
437    pub fn batch_decode(
438        &self,
439        chunks: &[SparseVec],
440        chunk_size: usize,
441        total_length: usize,
442    ) -> Vec<u8> {
443        assert!(chunk_size > 0, "chunk_size must be > 0");
444
445        if chunks.is_empty() || total_length == 0 {
446            return Vec::new();
447        }
448
449        // Use shared references for parallel access (no cloning/allocation)
450        // SparseVec is Sync+Send for read access, so shared refs work with Rayon
451        let byte_vectors = &self.byte_vectors;
452        let position_vectors = &self.position_vectors;
453
454        // Decode chunks in parallel
455        let decoded_chunks: Vec<Vec<u8>> = chunks
456            .par_iter()
457            .enumerate()
458            .map(|(chunk_idx, chunk_vec)| {
459                let offset = chunk_idx * chunk_size;
460                let remaining = total_length.saturating_sub(offset);
461                let this_chunk_size = remaining.min(chunk_size);
462
463                Self::decode_chunk_parallel(
464                    chunk_vec,
465                    offset,
466                    this_chunk_size,
467                    byte_vectors,
468                    position_vectors,
469                )
470            })
471            .collect();
472
473        // Flatten results
474        decoded_chunks.into_iter().flatten().collect()
475    }
476
477    /// Decode a single chunk (used by parallel decoding)
478    fn decode_chunk_parallel(
479        chunk_vec: &SparseVec,
480        offset: usize,
481        chunk_size: usize,
482        byte_vectors: &[SparseVec],
483        position_vectors: &[SparseVec],
484    ) -> Vec<u8> {
485        let mut result = Vec::with_capacity(chunk_size);
486
487        for i in 0..chunk_size {
488            let pos = offset + i;
489            let pos_vec = &position_vectors[pos % position_vectors.len()];
490            let query = chunk_vec.bind(pos_vec);
491
492            // Find best matching byte
493            let mut best_byte = 0u8;
494            let mut best_sim = f64::NEG_INFINITY;
495
496            for (byte_val, byte_vec) in byte_vectors.iter().enumerate() {
497                let sim = query.cosine(byte_vec);
498                if sim > best_sim {
499                    best_sim = sim;
500                    best_byte = byte_val as u8;
501                }
502            }
503
504            result.push(best_byte);
505        }
506
507        result
508    }
509
510    /// Get throughput statistics for batch encoding
511    ///
512    /// Returns approximate throughput in MB/s based on data size and elapsed time.
513    /// Useful for benchmarking and optimization.
514    ///
515    /// Returns `f64::INFINITY` if elapsed_secs is zero or negative.
516    pub fn estimate_throughput(data_size: usize, elapsed_secs: f64) -> f64 {
517        if elapsed_secs <= 0.0 {
518            return f64::INFINITY;
519        }
520        let mb = data_size as f64 / (1024.0 * 1024.0);
521        mb / elapsed_secs
522    }
523}
524
525impl Default for ReversibleVSAEncoder {
526    fn default() -> Self {
527        Self::new()
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_single_byte_roundtrip() {
537        let mut encoder = ReversibleVSAEncoder::new();
538
539        for byte in [0u8, 42, 127, 255] {
540            let encoded = encoder.encode(&[byte]);
541            let decoded = encoder.decode(&encoded, 1);
542            assert_eq!(decoded[0], byte, "Failed roundtrip for byte {}", byte);
543        }
544    }
545
546    #[test]
547    fn test_short_string_roundtrip() {
548        let mut encoder = ReversibleVSAEncoder::new();
549        let data = b"Hello";
550
551        let encoded = encoder.encode(data);
552        let decoded = encoder.decode(&encoded, data.len());
553
554        // Check accuracy
555        let accuracy = data
556            .iter()
557            .zip(decoded.iter())
558            .filter(|(a, b)| a == b)
559            .count() as f64
560            / data.len() as f64;
561
562        assert!(accuracy >= 0.5, "Accuracy too low: {}", accuracy);
563    }
564
565    #[test]
566    fn test_chunked_encoding() {
567        let mut encoder = ReversibleVSAEncoder::new();
568        let data = b"This is a test of chunked encoding for longer data.";
569
570        let accuracy = encoder.test_accuracy_chunked(data, 8);
571        println!("Chunked accuracy: {:.2}%", accuracy * 100.0);
572
573        // Chunked encoding typically achieves 90-95% accuracy, but there's natural
574        // variance depending on data content (some byte patterns are more distinct
575        // in the VSA representation). Use 80% threshold to accommodate variance
576        // while still catching regressions. The correction layer handles any
577        // remaining errors in production use.
578        assert!(
579            accuracy >= 0.80,
580            "Chunked accuracy {:.1}% is below expected threshold 80%",
581            accuracy * 100.0
582        );
583    }
584
585    #[test]
586    fn test_try_ensure_positions_within_limit() {
587        let mut encoder = ReversibleVSAEncoder::new();
588        // Should succeed within limit
589        assert!(encoder.try_ensure_positions(1000).is_ok());
590        assert!(encoder.try_ensure_positions(MAX_POSITIONS - 1).is_ok());
591    }
592
593    #[test]
594    fn test_try_ensure_positions_exceeds_limit() {
595        let mut encoder = ReversibleVSAEncoder::new();
596        // Should fail when exceeding limit
597        assert!(encoder.try_ensure_positions(MAX_POSITIONS).is_err());
598        assert!(encoder.try_ensure_positions(MAX_POSITIONS + 1).is_err());
599    }
600
601    #[test]
602    #[should_panic(expected = "MAX_POSITIONS")]
603    fn test_ensure_positions_panics_on_overflow() {
604        let mut encoder = ReversibleVSAEncoder::new();
605        encoder.ensure_positions(MAX_POSITIONS);
606    }
607
608    #[test]
609    fn test_batch_encode_matches_sequential() {
610        let mut encoder = ReversibleVSAEncoder::new();
611        let data = b"This is a test of parallel batch encoding for higher throughput.";
612        let chunk_size = 16;
613
614        // Sequential encoding
615        let sequential_chunks = encoder.encode_chunked(data, chunk_size);
616
617        // Parallel batch encoding
618        let parallel_chunks = encoder.batch_encode(data, chunk_size);
619
620        // Should produce same number of chunks
621        assert_eq!(sequential_chunks.len(), parallel_chunks.len());
622
623        // Each chunk should produce same decoded output
624        for (i, (seq_chunk, par_chunk)) in sequential_chunks
625            .iter()
626            .zip(parallel_chunks.iter())
627            .enumerate()
628        {
629            let seq_decoded =
630                encoder.decode(seq_chunk, chunk_size.min(data.len() - i * chunk_size));
631            let par_decoded =
632                encoder.decode(par_chunk, chunk_size.min(data.len() - i * chunk_size));
633            assert_eq!(seq_decoded, par_decoded, "Chunk {} decoded differently", i);
634        }
635    }
636
637    #[test]
638    fn test_batch_decode_matches_sequential() {
639        let mut encoder = ReversibleVSAEncoder::new();
640        let data = b"Testing parallel batch decode for higher throughput on multi-core systems.";
641        let chunk_size = 16;
642
643        // Encode
644        let chunks = encoder.batch_encode(data, chunk_size);
645
646        // Sequential decode
647        let sequential_decoded = encoder.decode_chunked(&chunks, chunk_size, data.len());
648
649        // Parallel batch decode
650        let parallel_decoded = encoder.batch_decode(&chunks, chunk_size, data.len());
651
652        // Should match exactly
653        assert_eq!(sequential_decoded, parallel_decoded);
654    }
655
656    #[test]
657    fn test_batch_encode_empty() {
658        let mut encoder = ReversibleVSAEncoder::new();
659        let chunks = encoder.batch_encode(&[], 64);
660        assert!(chunks.is_empty());
661    }
662
663    #[test]
664    fn test_batch_decode_empty() {
665        let encoder = ReversibleVSAEncoder::new();
666        let decoded = encoder.batch_decode(&[], 64, 0);
667        assert!(decoded.is_empty());
668    }
669
670    #[test]
671    fn test_batch_encode_accuracy() {
672        let mut encoder = ReversibleVSAEncoder::new();
673        let data = b"The quick brown fox jumps over the lazy dog. 0123456789!";
674        let chunk_size = 16;
675
676        let chunks = encoder.batch_encode(data, chunk_size);
677        let decoded = encoder.batch_decode(&chunks, chunk_size, data.len());
678
679        let matches = data
680            .iter()
681            .zip(decoded.iter())
682            .filter(|(a, b)| a == b)
683            .count();
684        let accuracy = matches as f64 / data.len() as f64;
685
686        println!("Batch encode/decode accuracy: {:.2}%", accuracy * 100.0);
687        assert!(
688            accuracy >= 0.80,
689            "Batch accuracy {:.1}% below expected 80%",
690            accuracy * 100.0
691        );
692    }
693}