Skip to main content

embeddenator_vsa/
reversible_encoding.rs

1//! Reversible Position-Aware VSA Encoding
2//!
3//! This module implements true holographic storage where data can be
4//! reconstructed from the VSA vector. With chunked encoding, typical
5//! accuracy is 90-95% before correction layer application.
6//!
7//! Chunk size is configurable: smaller chunks (8-64 bytes) provide higher
8//! accuracy per chunk but more overhead; embeddenator-fs uses 64 bytes
9//! as a balance between accuracy and efficiency.
10//!
11//! # Architecture
12//!
13//! ```text
14//! For each byte at position i:
15//!   encoded[i] = bind(position_vector[i], byte_vector[data[i]])
16//!
17//! Memory = bundle(encoded[0], encoded[1], ..., encoded[n])
18//!
19//! To retrieve byte at position i:
20//!   query = bind(position_vector[i], Memory)
21//!   byte = argmax_b(cosine(query, byte_vector[b]))
22//! ```
23//!
24//! This uses the fundamental VSA operations:
25//! - **Bind**: Creates a composite that's dissimilar to both inputs but can be unbound
26//! - **Bundle**: Superimposes vectors while preserving retrievability
27//! - **Unbind**: Reverses bind to retrieve the bound content
28//!
29//! # Why This Works
30//!
31//! - Each (position, byte) pair has a unique representation
32//! - Bundle creates holographic superposition of all pairs
33//! - Unbind + similarity search retrieves the original byte
34//! - No information loss from collisions
35
36use crate::vsa::{SparseVec, DIM};
37use sha2::{Digest, Sha256};
38
39/// Maximum file size for position-aware encoding (4MB in 64-byte chunks)
40pub const MAX_POSITIONS: usize = 65536;
41
42/// Reversible encoder using position-aware VSA binding
43pub struct ReversibleVSAEncoder {
44    /// Basis vectors for each byte value (0-255)
45    byte_vectors: Vec<SparseVec>,
46    /// Basis vectors for each position (0 to MAX_POSITIONS-1)
47    position_vectors: Vec<SparseVec>,
48    /// Dimensionality
49    dim: usize,
50}
51
52impl ReversibleVSAEncoder {
53    /// Create a new reversible encoder
54    pub fn new() -> Self {
55        Self::with_dim(DIM)
56    }
57
58    /// Create a new reversible encoder with custom dimensionality
59    pub fn with_dim(dim: usize) -> Self {
60        let mut encoder = Self {
61            byte_vectors: Vec::with_capacity(256),
62            position_vectors: Vec::with_capacity(MAX_POSITIONS),
63            dim,
64        };
65        encoder.initialize_basis_vectors();
66        encoder
67    }
68
69    /// Initialize basis vectors for bytes and positions
70    fn initialize_basis_vectors(&mut self) {
71        // Create deterministic byte vectors
72        for byte_val in 0u8..=255 {
73            let seed = Self::hash_to_seed(b"byte", &[byte_val]);
74            self.byte_vectors
75                .push(SparseVec::from_seed(&seed, self.dim));
76        }
77
78        // Create deterministic position vectors (lazy - create on demand up to max)
79        // For now, pre-create a reasonable number
80        for pos in 0..4096 {
81            let seed = Self::hash_to_seed(b"position", &(pos as u64).to_le_bytes());
82            self.position_vectors
83                .push(SparseVec::from_seed(&seed, self.dim));
84        }
85    }
86
87    /// Hash a prefix and value to a 32-byte seed
88    fn hash_to_seed(prefix: &[u8], value: &[u8]) -> [u8; 32] {
89        let mut hasher = Sha256::new();
90        hasher.update(b"embeddenator:reversible:v1:");
91        hasher.update(prefix);
92        hasher.update(b":");
93        hasher.update(value);
94        hasher.finalize().into()
95    }
96
97    /// Get or create position vector for a given position
98    ///
99    /// # Panics
100    ///
101    /// Panics if `pos >= MAX_POSITIONS` (65536). Use `ensure_positions` with
102    /// proper bounds checking before calling this method with untrusted input.
103    fn get_position_vector(&mut self, pos: usize) -> &SparseVec {
104        assert!(
105            pos < MAX_POSITIONS,
106            "Position {} exceeds MAX_POSITIONS ({})",
107            pos,
108            MAX_POSITIONS
109        );
110
111        while pos >= self.position_vectors.len() {
112            let new_pos = self.position_vectors.len();
113            let seed = Self::hash_to_seed(b"position", &(new_pos as u64).to_le_bytes());
114            self.position_vectors
115                .push(SparseVec::from_seed(&seed, self.dim));
116        }
117        &self.position_vectors[pos]
118    }
119
120    /// Encode a byte at a specific position
121    ///
122    /// Returns bind(position_vector, byte_vector)
123    fn encode_byte_at_position(&self, byte: u8, position: usize) -> SparseVec {
124        let byte_vec = &self.byte_vectors[byte as usize];
125        let pos_vec = &self.position_vectors[position % self.position_vectors.len()];
126        byte_vec.bind(pos_vec)
127    }
128
129    /// Encode data into a holographic representation
130    ///
131    /// Returns a single SparseVec that contains all the data holographically.
132    pub fn encode(&mut self, data: &[u8]) -> SparseVec {
133        if data.is_empty() {
134            return SparseVec::new();
135        }
136
137        // Ensure we have enough position vectors
138        let _ = self.get_position_vector(data.len().saturating_sub(1));
139
140        // Encode each byte and bundle them together
141        let mut result = self.encode_byte_at_position(data[0], 0);
142
143        for (pos, &byte) in data.iter().enumerate().skip(1) {
144            let encoded_byte = self.encode_byte_at_position(byte, pos);
145            result = result.bundle(&encoded_byte);
146        }
147
148        result
149    }
150
151    /// Encode data in chunks for better accuracy with large files
152    ///
153    /// Returns a vector of SparseVecs, one per chunk.
154    ///
155    /// # Panics
156    /// Panics if `chunk_size` is 0.
157    pub fn encode_chunked(&mut self, data: &[u8], chunk_size: usize) -> Vec<SparseVec> {
158        assert!(chunk_size > 0, "chunk_size must be > 0");
159        data.chunks(chunk_size)
160            .enumerate()
161            .map(|(chunk_idx, chunk)| {
162                let offset = chunk_idx * chunk_size;
163                self.encode_with_offset(chunk, offset)
164            })
165            .collect()
166    }
167
168    /// Encode data with a position offset
169    fn encode_with_offset(&mut self, data: &[u8], offset: usize) -> SparseVec {
170        if data.is_empty() {
171            return SparseVec::new();
172        }
173
174        // Ensure we have enough position vectors
175        let _ = self.get_position_vector(offset + data.len().saturating_sub(1));
176
177        // Encode each byte and bundle
178        let mut result = self.encode_byte_at_position(data[0], offset);
179
180        for (i, &byte) in data.iter().enumerate().skip(1) {
181            let encoded_byte = self.encode_byte_at_position(byte, offset + i);
182            result = result.bundle(&encoded_byte);
183        }
184
185        result
186    }
187
188    /// Decode data from a holographic representation
189    ///
190    /// Uses bind + similarity search to retrieve each byte.
191    pub fn decode(&self, encoded: &SparseVec, length: usize) -> Vec<u8> {
192        let mut result = Vec::with_capacity(length);
193
194        for pos in 0..length {
195            let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
196
197            // Unbind position to get query for byte
198            let query = encoded.bind(pos_vec);
199
200            // Find best matching byte vector
201            let byte = self.find_best_byte_match(&query);
202            result.push(byte);
203        }
204
205        result
206    }
207
208    /// Decode chunked data
209    ///
210    /// # Panics
211    /// Panics if `chunk_size` is 0.
212    pub fn decode_chunked(
213        &self,
214        chunks: &[SparseVec],
215        chunk_size: usize,
216        total_length: usize,
217    ) -> Vec<u8> {
218        assert!(chunk_size > 0, "chunk_size must be > 0");
219        let mut result = Vec::with_capacity(total_length);
220
221        for (chunk_idx, chunk_vec) in chunks.iter().enumerate() {
222            let offset = chunk_idx * chunk_size;
223            let remaining = total_length.saturating_sub(offset);
224            let this_chunk_size = remaining.min(chunk_size);
225
226            for i in 0..this_chunk_size {
227                let pos = offset + i;
228                let pos_vec = &self.position_vectors[pos % self.position_vectors.len()];
229                let query = chunk_vec.bind(pos_vec);
230                let byte = self.find_best_byte_match(&query);
231                result.push(byte);
232            }
233        }
234
235        result
236    }
237
238    /// Find the byte value with highest similarity to query
239    fn find_best_byte_match(&self, query: &SparseVec) -> u8 {
240        let mut best_byte = 0u8;
241        let mut best_sim = f64::NEG_INFINITY;
242
243        for (byte_val, byte_vec) in self.byte_vectors.iter().enumerate() {
244            let sim = query.cosine(byte_vec);
245            if sim > best_sim {
246                best_sim = sim;
247                best_byte = byte_val as u8;
248            }
249        }
250
251        best_byte
252    }
253
254    /// Get reference to byte vectors (for GPU acceleration)
255    ///
256    /// Returns slice of 256 basis vectors, one per byte value.
257    pub fn get_byte_vectors(&self) -> &[SparseVec] {
258        &self.byte_vectors
259    }
260
261    /// Get position vector for a given position (for GPU acceleration)
262    ///
263    /// Returns reference to position basis vector. Uses modulo if position
264    /// exceeds pre-allocated vectors.
265    ///
266    /// Note: Call `ensure_positions(max_pos)` first if you need exact position
267    /// vectors without modulo wrapping.
268    pub fn get_position_vector_ref(&self, pos: usize) -> &SparseVec {
269        &self.position_vectors[pos % self.position_vectors.len()]
270    }
271
272    /// Ensure position vectors exist up to (and including) the given position
273    ///
274    /// # Panics
275    ///
276    /// Panics if `max_pos >= MAX_POSITIONS` (65536).
277    ///
278    /// # Example
279    ///
280    /// ```rust,ignore
281    /// let mut encoder = ReversibleVSAEncoder::new();
282    /// encoder.ensure_positions(1000); // OK
283    /// // encoder.ensure_positions(100000); // Would panic
284    /// ```
285    pub fn ensure_positions(&mut self, max_pos: usize) {
286        assert!(
287            max_pos < MAX_POSITIONS,
288            "max_pos {} exceeds MAX_POSITIONS ({})",
289            max_pos,
290            MAX_POSITIONS
291        );
292        let _ = self.get_position_vector(max_pos);
293    }
294
295    /// Ensure position vectors exist up to the given position, returning error on invalid input
296    ///
297    /// This is a non-panicking alternative to `ensure_positions`.
298    ///
299    /// # Returns
300    ///
301    /// Returns `Ok(())` if positions are allocated successfully, or `Err` if
302    /// `max_pos >= MAX_POSITIONS`.
303    pub fn try_ensure_positions(&mut self, max_pos: usize) -> Result<(), String> {
304        if max_pos >= MAX_POSITIONS {
305            return Err(format!(
306                "max_pos {} exceeds MAX_POSITIONS ({})",
307                max_pos, MAX_POSITIONS
308            ));
309        }
310        let _ = self.get_position_vector(max_pos);
311        Ok(())
312    }
313
314    /// Compute reconstruction accuracy (for testing)
315    pub fn test_accuracy(&mut self, data: &[u8]) -> f64 {
316        let encoded = self.encode(data);
317        let decoded = self.decode(&encoded, data.len());
318
319        let matches = data
320            .iter()
321            .zip(decoded.iter())
322            .filter(|(a, b)| a == b)
323            .count();
324
325        matches as f64 / data.len() as f64
326    }
327
328    /// Compute chunked reconstruction accuracy
329    pub fn test_accuracy_chunked(&mut self, data: &[u8], chunk_size: usize) -> f64 {
330        let chunks = self.encode_chunked(data, chunk_size);
331        let decoded = self.decode_chunked(&chunks, chunk_size, data.len());
332
333        let matches = data
334            .iter()
335            .zip(decoded.iter())
336            .filter(|(a, b)| a == b)
337            .count();
338
339        matches as f64 / data.len() as f64
340    }
341}
342
343impl Default for ReversibleVSAEncoder {
344    fn default() -> Self {
345        Self::new()
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_single_byte_roundtrip() {
355        let mut encoder = ReversibleVSAEncoder::new();
356
357        for byte in [0u8, 42, 127, 255] {
358            let encoded = encoder.encode(&[byte]);
359            let decoded = encoder.decode(&encoded, 1);
360            assert_eq!(decoded[0], byte, "Failed roundtrip for byte {}", byte);
361        }
362    }
363
364    #[test]
365    fn test_short_string_roundtrip() {
366        let mut encoder = ReversibleVSAEncoder::new();
367        let data = b"Hello";
368
369        let encoded = encoder.encode(data);
370        let decoded = encoder.decode(&encoded, data.len());
371
372        // Check accuracy
373        let accuracy = data
374            .iter()
375            .zip(decoded.iter())
376            .filter(|(a, b)| a == b)
377            .count() as f64
378            / data.len() as f64;
379
380        assert!(accuracy >= 0.5, "Accuracy too low: {}", accuracy);
381    }
382
383    #[test]
384    fn test_chunked_encoding() {
385        let mut encoder = ReversibleVSAEncoder::new();
386        let data = b"This is a test of chunked encoding for longer data.";
387
388        let accuracy = encoder.test_accuracy_chunked(data, 8);
389        println!("Chunked accuracy: {:.2}%", accuracy * 100.0);
390
391        // Chunked encoding typically achieves 90-95% accuracy, but there's natural
392        // variance depending on data content (some byte patterns are more distinct
393        // in the VSA representation). Use 80% threshold to accommodate variance
394        // while still catching regressions. The correction layer handles any
395        // remaining errors in production use.
396        assert!(
397            accuracy >= 0.80,
398            "Chunked accuracy {:.1}% is below expected threshold 80%",
399            accuracy * 100.0
400        );
401    }
402
403    #[test]
404    fn test_try_ensure_positions_within_limit() {
405        let mut encoder = ReversibleVSAEncoder::new();
406        // Should succeed within limit
407        assert!(encoder.try_ensure_positions(1000).is_ok());
408        assert!(encoder.try_ensure_positions(MAX_POSITIONS - 1).is_ok());
409    }
410
411    #[test]
412    fn test_try_ensure_positions_exceeds_limit() {
413        let mut encoder = ReversibleVSAEncoder::new();
414        // Should fail when exceeding limit
415        assert!(encoder.try_ensure_positions(MAX_POSITIONS).is_err());
416        assert!(encoder.try_ensure_positions(MAX_POSITIONS + 1).is_err());
417    }
418
419    #[test]
420    #[should_panic(expected = "MAX_POSITIONS")]
421    fn test_ensure_positions_panics_on_overflow() {
422        let mut encoder = ReversibleVSAEncoder::new();
423        encoder.ensure_positions(MAX_POSITIONS);
424    }
425}