embeddenator_fs/fs/
correction.rs

1//! Algebraic Correction Layer - Guaranteeing 100% Bitwise Reconstruction
2//!
3//! The fundamental challenge: VSA operations (bundle, bind) are inherently
4//! approximate when superposing multiple vectors. This module provides the
5//! mathematical machinery to guarantee bit-perfect reconstruction.
6//!
7//! # The Problem
8//!
9//! When you bundle N vectors: R = V₁ ⊕ V₂ ⊕ ... ⊕ Vₙ
10//!
11//! And then query: Q = R ⊙ Vᵢ⁻¹ (unbind to retrieve Vᵢ)
12//!
13//! You get: Q ≈ Vᵢ + noise (crosstalk from other vectors)
14//!
15//! The similarity cos(Q, Vᵢ) decreases as N increases (more crosstalk).
16//!
17//! # The Solution: Multi-Layer Correction
18//!
19//! 1. **Codebook Lookup** (not similarity): If pattern is in codebook,
20//!    retrieve EXACT original, not approximate match.
21//!
22//! 2. **Residual Storage**: For anything not in codebook, store exact
23//!    difference between approximation and original.
24//!
25//! 3. **Semantic Markers**: High-entropy regions that can't be approximated
26//!    well are stored verbatim with markers.
27//!
28//! 4. **Parity Verification**: Detect when approximation has errors,
29//!    triggering residual application.
30//!
31//! # Mathematical Guarantee
32//!
33//! Let D = original data, E = encoded approximation, R = residual
34//!
35//! Invariant: D = decode(E) + R (always, by construction)
36//!
37//! If decode(E) = D, then R = 0 (no storage needed)
38//! If decode(E) ≠ D, then R = D - decode(E) (exact correction stored)
39//!
40//! Either way: D is perfectly recoverable.
41
42use embeddenator_vsa::Trit;
43use serde::{Deserialize, Serialize};
44use sha2::{Digest, Sha256};
45use std::collections::HashMap;
46
47/// Correction type for different error scenarios
48#[derive(Clone, Debug, Serialize, Deserialize)]
49pub enum CorrectionType {
50    /// No correction needed - exact match
51    None,
52    /// Bit flip at specific positions
53    BitFlips(Vec<(u64, u8)>),
54    /// Trit flip at specific positions
55    TritFlips(Vec<(u64, Trit, Trit)>), // position, was, should_be
56    /// Block replacement
57    BlockReplace { offset: u64, original: Vec<u8> },
58    /// Full data (for high-entropy regions)
59    Verbatim(Vec<u8>),
60}
61
62/// A correction record for a data chunk
63#[derive(Clone, Debug, Serialize, Deserialize)]
64pub struct ChunkCorrection {
65    /// Chunk identifier
66    pub chunk_id: u64,
67    /// Type of correction needed
68    pub correction: CorrectionType,
69    /// Verification hash (first 8 bytes of SHA256)
70    pub hash: [u8; 8],
71    /// Parity trit for the chunk
72    pub parity: Trit,
73}
74
75impl ChunkCorrection {
76    /// Create a correction record
77    pub fn new(chunk_id: u64, original: &[u8], approximation: &[u8]) -> Self {
78        let hash = compute_hash(original);
79        let parity = compute_data_parity(original);
80
81        let correction = compute_correction(original, approximation);
82
83        ChunkCorrection {
84            chunk_id,
85            correction,
86            hash,
87            parity,
88        }
89    }
90
91    /// Check if correction is needed
92    pub fn needs_correction(&self) -> bool {
93        !matches!(self.correction, CorrectionType::None)
94    }
95
96    /// Apply correction to approximation to get original
97    pub fn apply(&self, approximation: &[u8]) -> Vec<u8> {
98        match &self.correction {
99            CorrectionType::None => approximation.to_vec(),
100
101            CorrectionType::BitFlips(flips) => {
102                let mut result = approximation.to_vec();
103                for &(pos, mask) in flips {
104                    if (pos as usize) < result.len() {
105                        result[pos as usize] ^= mask;
106                    }
107                }
108                result
109            }
110
111            CorrectionType::TritFlips(flips) => {
112                // Convert to bytes, apply trit corrections
113                let mut result = approximation.to_vec();
114                for &(pos, _was, should_be) in flips {
115                    // Trit position to byte position
116                    let byte_pos = (pos / 5) as usize; // 5 trits per byte
117                    if byte_pos < result.len() {
118                        // This is simplified - real impl would unpack/repack trits
119                        let trit_in_byte = (pos % 5) as u8;
120                        let shift = trit_in_byte * 2;
121                        let mask = !(0b11 << shift);
122                        let trit_bits = match should_be {
123                            Trit::N => 0b10,
124                            Trit::Z => 0b00,
125                            Trit::P => 0b01,
126                        };
127                        result[byte_pos] = (result[byte_pos] & mask) | (trit_bits << shift);
128                    }
129                }
130                result
131            }
132
133            CorrectionType::BlockReplace { offset, original } => {
134                let mut result = approximation.to_vec();
135                let start = *offset as usize;
136                let end = std::cmp::min(start + original.len(), result.len());
137                if start < result.len() {
138                    result[start..end].copy_from_slice(&original[..end - start]);
139                }
140                result
141            }
142
143            CorrectionType::Verbatim(data) => data.clone(),
144        }
145    }
146
147    /// Verify the correction produces the expected hash
148    pub fn verify(&self, result: &[u8]) -> bool {
149        compute_hash(result) == self.hash
150    }
151
152    /// Storage size of this correction
153    pub fn storage_size(&self) -> usize {
154        match &self.correction {
155            CorrectionType::None => 0,
156            CorrectionType::BitFlips(flips) => flips.len() * 9, // pos(8) + mask(1)
157            CorrectionType::TritFlips(flips) => flips.len() * 10, // pos(8) + 2 trits(2)
158            CorrectionType::BlockReplace { original, .. } => 8 + original.len(),
159            CorrectionType::Verbatim(data) => data.len(),
160        }
161    }
162}
163
164/// Compute verification hash (first 8 bytes of SHA256)
165fn compute_hash(data: &[u8]) -> [u8; 8] {
166    let mut hasher = Sha256::new();
167    hasher.update(data);
168    let result = hasher.finalize();
169    let mut hash = [0u8; 8];
170    hash.copy_from_slice(&result[..8]);
171    hash
172}
173
174/// Compute parity trit for data
175fn compute_data_parity(data: &[u8]) -> Trit {
176    let sum: i64 = data.iter().map(|&b| b as i64).sum();
177    match (sum % 3) as i8 {
178        0 => Trit::Z,
179        1 | -2 => Trit::P,
180        2 | -1 => Trit::N,
181        _ => Trit::Z,
182    }
183}
184
185/// Compute the minimal correction to transform approximation into original
186fn compute_correction(original: &[u8], approximation: &[u8]) -> CorrectionType {
187    // If identical, no correction
188    if original == approximation {
189        return CorrectionType::None;
190    }
191
192    // Count differences
193    let mut diff_positions: Vec<(u64, u8, u8)> = Vec::new();
194    let max_len = std::cmp::max(original.len(), approximation.len());
195
196    for i in 0..max_len {
197        let orig_byte = original.get(i).copied().unwrap_or(0);
198        let approx_byte = approximation.get(i).copied().unwrap_or(0);
199
200        if orig_byte != approx_byte {
201            diff_positions.push((i as u64, orig_byte, approx_byte));
202        }
203    }
204
205    // Choose correction strategy based on number of differences
206    let diff_count = diff_positions.len();
207
208    if diff_count == 0 {
209        return CorrectionType::None;
210    }
211
212    // If most bytes are different, store verbatim
213    if diff_count > original.len() / 2 {
214        return CorrectionType::Verbatim(original.to_vec());
215    }
216
217    // If differences are clustered, use block replace
218    if diff_count > 10 {
219        let first_diff = diff_positions.first().map(|p| p.0).unwrap_or(0);
220        let last_diff = diff_positions.last().map(|p| p.0).unwrap_or(0);
221        let span = (last_diff - first_diff + 1) as usize;
222
223        // If span is small compared to storing individual corrections
224        if span < diff_count * 9 {
225            let start = first_diff as usize;
226            let end = std::cmp::min(start + span, original.len());
227            return CorrectionType::BlockReplace {
228                offset: first_diff,
229                original: original[start..end].to_vec(),
230            };
231        }
232    }
233
234    // Use bit flips for sparse differences
235    let bit_flips: Vec<(u64, u8)> = diff_positions
236        .iter()
237        .map(|&(pos, orig, approx)| (pos, orig ^ approx))
238        .collect();
239
240    CorrectionType::BitFlips(bit_flips)
241}
242
243/// Correction store - manages all corrections for an engram
244#[derive(Clone, Debug, Default, Serialize, Deserialize)]
245pub struct CorrectionStore {
246    /// Corrections indexed by chunk ID
247    corrections: HashMap<u64, ChunkCorrection>,
248
249    /// Total storage used by corrections
250    total_correction_bytes: u64,
251
252    /// Total original data size
253    total_original_bytes: u64,
254
255    /// Chunks that needed no correction
256    perfect_chunks: u64,
257
258    /// Chunks that needed correction
259    corrected_chunks: u64,
260}
261
262impl CorrectionStore {
263    /// Create a new correction store
264    pub fn new() -> Self {
265        CorrectionStore::default()
266    }
267
268    /// Add a correction for a chunk
269    pub fn add(&mut self, chunk_id: u64, original: &[u8], approximation: &[u8]) {
270        let correction = ChunkCorrection::new(chunk_id, original, approximation);
271
272        self.total_original_bytes += original.len() as u64;
273
274        if correction.needs_correction() {
275            self.total_correction_bytes += correction.storage_size() as u64;
276            self.corrected_chunks += 1;
277        } else {
278            self.perfect_chunks += 1;
279        }
280
281        self.corrections.insert(chunk_id, correction);
282    }
283
284    /// Get correction for a chunk
285    pub fn get(&self, chunk_id: u64) -> Option<&ChunkCorrection> {
286        self.corrections.get(&chunk_id)
287    }
288
289    /// Apply correction to approximation
290    pub fn apply(&self, chunk_id: u64, approximation: &[u8]) -> Option<Vec<u8>> {
291        let correction = self.corrections.get(&chunk_id)?;
292        let result = correction.apply(approximation);
293
294        // Verify correction worked
295        if correction.verify(&result) {
296            Some(result)
297        } else {
298            None // Correction failed verification
299        }
300    }
301
302    /// Get correction statistics
303    pub fn stats(&self) -> CorrectionStats {
304        CorrectionStats {
305            total_chunks: self.perfect_chunks + self.corrected_chunks,
306            perfect_chunks: self.perfect_chunks,
307            corrected_chunks: self.corrected_chunks,
308            correction_bytes: self.total_correction_bytes,
309            original_bytes: self.total_original_bytes,
310            correction_ratio: if self.total_original_bytes > 0 {
311                self.total_correction_bytes as f64 / self.total_original_bytes as f64
312            } else {
313                0.0
314            },
315            perfect_ratio: if self.perfect_chunks + self.corrected_chunks > 0 {
316                self.perfect_chunks as f64 / (self.perfect_chunks + self.corrected_chunks) as f64
317            } else {
318                1.0
319            },
320        }
321    }
322
323    /// Serialize to bytes
324    pub fn to_bytes(&self) -> Vec<u8> {
325        bincode::serialize(self).unwrap_or_default()
326    }
327
328    /// Deserialize from bytes
329    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
330        bincode::deserialize(bytes).ok()
331    }
332}
333
334/// Statistics about corrections
335#[derive(Clone, Debug)]
336pub struct CorrectionStats {
337    pub total_chunks: u64,
338    pub perfect_chunks: u64,
339    pub corrected_chunks: u64,
340    pub correction_bytes: u64,
341    pub original_bytes: u64,
342    pub correction_ratio: f64,
343    pub perfect_ratio: f64,
344}
345
346impl std::fmt::Display for CorrectionStats {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        write!(
349            f,
350            "Corrections: {}/{} chunks perfect ({:.1}%), \
351                   {:.2}% overhead ({} bytes corrections / {} bytes original)",
352            self.perfect_chunks,
353            self.total_chunks,
354            self.perfect_ratio * 100.0,
355            self.correction_ratio * 100.0,
356            self.correction_bytes,
357            self.original_bytes,
358        )
359    }
360}
361
362/// Reconstruction verifier
363pub struct ReconstructionVerifier {
364    /// Expected hashes for all chunks
365    expected_hashes: HashMap<u64, [u8; 8]>,
366}
367
368impl ReconstructionVerifier {
369    /// Create a new verifier from original data
370    pub fn from_chunks(chunks: impl Iterator<Item = (u64, Vec<u8>)>) -> Self {
371        let expected_hashes: HashMap<u64, [u8; 8]> =
372            chunks.map(|(id, data)| (id, compute_hash(&data))).collect();
373
374        ReconstructionVerifier { expected_hashes }
375    }
376
377    /// Verify a reconstructed chunk
378    pub fn verify_chunk(&self, chunk_id: u64, data: &[u8]) -> bool {
379        match self.expected_hashes.get(&chunk_id) {
380            Some(expected) => compute_hash(data) == *expected,
381            None => false, // Unknown chunk
382        }
383    }
384
385    /// Verify all chunks
386    pub fn verify_all(&self, chunks: impl Iterator<Item = (u64, Vec<u8>)>) -> VerificationResult {
387        let mut verified = 0u64;
388        let mut failed = 0u64;
389        let mut failed_ids = Vec::new();
390
391        for (id, data) in chunks {
392            if self.verify_chunk(id, &data) {
393                verified += 1;
394            } else {
395                failed += 1;
396                failed_ids.push(id);
397            }
398        }
399
400        let missing = self.expected_hashes.len() as u64 - verified - failed;
401
402        VerificationResult {
403            verified,
404            failed,
405            missing,
406            failed_ids,
407            perfect: failed == 0 && missing == 0,
408        }
409    }
410}
411
412/// Result of verification
413#[derive(Clone, Debug)]
414pub struct VerificationResult {
415    pub verified: u64,
416    pub failed: u64,
417    pub missing: u64,
418    pub failed_ids: Vec<u64>,
419    pub perfect: bool,
420}
421
422impl std::fmt::Display for VerificationResult {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        if self.perfect {
425            write!(
426                f,
427                "✓ Perfect reconstruction: {} chunks verified",
428                self.verified
429            )
430        } else {
431            write!(
432                f,
433                "✗ Reconstruction issues: {} verified, {} failed, {} missing",
434                self.verified, self.failed, self.missing
435            )
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_no_correction_needed() {
446        let original = b"hello world";
447        let approximation = b"hello world";
448
449        let correction = ChunkCorrection::new(0, original, approximation);
450
451        assert!(!correction.needs_correction());
452        assert_eq!(correction.storage_size(), 0);
453    }
454
455    #[test]
456    fn test_bit_flip_correction() {
457        let original = b"hello world";
458        let mut approximation = original.to_vec();
459        approximation[0] ^= 0x01; // Flip one bit
460
461        let correction = ChunkCorrection::new(0, original, &approximation);
462
463        assert!(correction.needs_correction());
464
465        let recovered = correction.apply(&approximation);
466        assert_eq!(recovered, original);
467        assert!(correction.verify(&recovered));
468    }
469
470    #[test]
471    fn test_verbatim_correction() {
472        let original = b"completely different data here";
473        let approximation = b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
474
475        let correction = ChunkCorrection::new(0, original, approximation);
476
477        assert!(correction.needs_correction());
478
479        let recovered = correction.apply(approximation);
480        assert_eq!(recovered, original);
481    }
482
483    #[test]
484    fn test_correction_store() {
485        let mut store = CorrectionStore::new();
486
487        // Add some perfect chunks
488        store.add(0, b"chunk0", b"chunk0");
489        store.add(1, b"chunk1", b"chunk1");
490
491        // Add a chunk needing correction
492        store.add(2, b"chunk2", b"chunkX");
493
494        let stats = store.stats();
495        assert_eq!(stats.perfect_chunks, 2);
496        assert_eq!(stats.corrected_chunks, 1);
497
498        // Verify correction works
499        let recovered = store.apply(2, b"chunkX").unwrap();
500        assert_eq!(recovered, b"chunk2");
501    }
502
503    #[test]
504    fn test_reconstruction_verifier() {
505        let chunks = vec![
506            (0u64, b"chunk0".to_vec()),
507            (1u64, b"chunk1".to_vec()),
508            (2u64, b"chunk2".to_vec()),
509        ];
510
511        let verifier = ReconstructionVerifier::from_chunks(chunks.clone().into_iter());
512
513        // Verify correct chunks
514        assert!(verifier.verify_chunk(0, b"chunk0"));
515        assert!(verifier.verify_chunk(1, b"chunk1"));
516
517        // Verify incorrect chunk fails
518        assert!(!verifier.verify_chunk(0, b"wrong"));
519
520        // Verify all
521        let result = verifier.verify_all(chunks.into_iter());
522        assert!(result.perfect);
523        assert_eq!(result.verified, 3);
524    }
525
526    #[test]
527    fn test_hash_stability() {
528        // Ensure hash function is deterministic
529        let data = b"test data for hashing";
530        let hash1 = compute_hash(data);
531        let hash2 = compute_hash(data);
532        assert_eq!(hash1, hash2);
533
534        // Different data should produce different hash
535        let hash3 = compute_hash(b"different data");
536        assert_ne!(hash1, hash3);
537    }
538}