1use embeddenator_vsa::Trit;
43use serde::{Deserialize, Serialize};
44use sha2::{Digest, Sha256};
45use std::collections::HashMap;
46
47#[derive(Clone, Debug, Serialize, Deserialize)]
49pub enum CorrectionType {
50 None,
52 BitFlips(Vec<(u64, u8)>),
54 TritFlips(Vec<(u64, Trit, Trit)>), BlockReplace { offset: u64, original: Vec<u8> },
58 Verbatim(Vec<u8>),
60}
61
62#[derive(Clone, Debug, Serialize, Deserialize)]
64pub struct ChunkCorrection {
65 pub chunk_id: u64,
67 pub correction: CorrectionType,
69 pub hash: [u8; 8],
71 pub parity: Trit,
73}
74
75impl ChunkCorrection {
76 pub fn new(chunk_id: u64, original: &[u8], approximation: &[u8]) -> Self {
78 let hash = compute_hash(original);
79 let parity = compute_data_parity(original);
80
81 let correction = compute_correction(original, approximation);
82
83 ChunkCorrection {
84 chunk_id,
85 correction,
86 hash,
87 parity,
88 }
89 }
90
91 pub fn needs_correction(&self) -> bool {
93 !matches!(self.correction, CorrectionType::None)
94 }
95
96 pub fn apply(&self, approximation: &[u8]) -> Vec<u8> {
98 match &self.correction {
99 CorrectionType::None => approximation.to_vec(),
100
101 CorrectionType::BitFlips(flips) => {
102 let mut result = approximation.to_vec();
103 for &(pos, mask) in flips {
104 if (pos as usize) < result.len() {
105 result[pos as usize] ^= mask;
106 }
107 }
108 result
109 }
110
111 CorrectionType::TritFlips(flips) => {
112 let mut result = approximation.to_vec();
114 for &(pos, _was, should_be) in flips {
115 let byte_pos = (pos / 5) as usize; if byte_pos < result.len() {
118 let trit_in_byte = (pos % 5) as u8;
120 let shift = trit_in_byte * 2;
121 let mask = !(0b11 << shift);
122 let trit_bits = match should_be {
123 Trit::N => 0b10,
124 Trit::Z => 0b00,
125 Trit::P => 0b01,
126 };
127 result[byte_pos] = (result[byte_pos] & mask) | (trit_bits << shift);
128 }
129 }
130 result
131 }
132
133 CorrectionType::BlockReplace { offset, original } => {
134 let mut result = approximation.to_vec();
135 let start = *offset as usize;
136 let end = std::cmp::min(start + original.len(), result.len());
137 if start < result.len() {
138 result[start..end].copy_from_slice(&original[..end - start]);
139 }
140 result
141 }
142
143 CorrectionType::Verbatim(data) => data.clone(),
144 }
145 }
146
147 pub fn verify(&self, result: &[u8]) -> bool {
149 compute_hash(result) == self.hash
150 }
151
152 pub fn storage_size(&self) -> usize {
154 match &self.correction {
155 CorrectionType::None => 0,
156 CorrectionType::BitFlips(flips) => flips.len() * 9, CorrectionType::TritFlips(flips) => flips.len() * 10, CorrectionType::BlockReplace { original, .. } => 8 + original.len(),
159 CorrectionType::Verbatim(data) => data.len(),
160 }
161 }
162}
163
164fn compute_hash(data: &[u8]) -> [u8; 8] {
166 let mut hasher = Sha256::new();
167 hasher.update(data);
168 let result = hasher.finalize();
169 let mut hash = [0u8; 8];
170 hash.copy_from_slice(&result[..8]);
171 hash
172}
173
174fn compute_data_parity(data: &[u8]) -> Trit {
176 let sum: i64 = data.iter().map(|&b| b as i64).sum();
177 match (sum % 3) as i8 {
178 0 => Trit::Z,
179 1 | -2 => Trit::P,
180 2 | -1 => Trit::N,
181 _ => Trit::Z,
182 }
183}
184
185fn compute_correction(original: &[u8], approximation: &[u8]) -> CorrectionType {
187 if original == approximation {
189 return CorrectionType::None;
190 }
191
192 let mut diff_positions: Vec<(u64, u8, u8)> = Vec::new();
194 let max_len = std::cmp::max(original.len(), approximation.len());
195
196 for i in 0..max_len {
197 let orig_byte = original.get(i).copied().unwrap_or(0);
198 let approx_byte = approximation.get(i).copied().unwrap_or(0);
199
200 if orig_byte != approx_byte {
201 diff_positions.push((i as u64, orig_byte, approx_byte));
202 }
203 }
204
205 let diff_count = diff_positions.len();
207
208 if diff_count == 0 {
209 return CorrectionType::None;
210 }
211
212 if diff_count > original.len() / 2 {
214 return CorrectionType::Verbatim(original.to_vec());
215 }
216
217 if diff_count > 10 {
219 let first_diff = diff_positions.first().map(|p| p.0).unwrap_or(0);
220 let last_diff = diff_positions.last().map(|p| p.0).unwrap_or(0);
221 let span = (last_diff - first_diff + 1) as usize;
222
223 if span < diff_count * 9 {
225 let start = first_diff as usize;
226 let end = std::cmp::min(start + span, original.len());
227 return CorrectionType::BlockReplace {
228 offset: first_diff,
229 original: original[start..end].to_vec(),
230 };
231 }
232 }
233
234 let bit_flips: Vec<(u64, u8)> = diff_positions
236 .iter()
237 .map(|&(pos, orig, approx)| (pos, orig ^ approx))
238 .collect();
239
240 CorrectionType::BitFlips(bit_flips)
241}
242
243#[derive(Clone, Debug, Default, Serialize, Deserialize)]
245pub struct CorrectionStore {
246 corrections: HashMap<u64, ChunkCorrection>,
248
249 total_correction_bytes: u64,
251
252 total_original_bytes: u64,
254
255 perfect_chunks: u64,
257
258 corrected_chunks: u64,
260}
261
262impl CorrectionStore {
263 pub fn new() -> Self {
265 CorrectionStore::default()
266 }
267
268 pub fn add(&mut self, chunk_id: u64, original: &[u8], approximation: &[u8]) {
270 let correction = ChunkCorrection::new(chunk_id, original, approximation);
271
272 self.total_original_bytes += original.len() as u64;
273
274 if correction.needs_correction() {
275 self.total_correction_bytes += correction.storage_size() as u64;
276 self.corrected_chunks += 1;
277 } else {
278 self.perfect_chunks += 1;
279 }
280
281 self.corrections.insert(chunk_id, correction);
282 }
283
284 pub fn get(&self, chunk_id: u64) -> Option<&ChunkCorrection> {
286 self.corrections.get(&chunk_id)
287 }
288
289 pub fn apply(&self, chunk_id: u64, approximation: &[u8]) -> Option<Vec<u8>> {
291 let correction = self.corrections.get(&chunk_id)?;
292 let result = correction.apply(approximation);
293
294 if correction.verify(&result) {
296 Some(result)
297 } else {
298 None }
300 }
301
302 pub fn stats(&self) -> CorrectionStats {
304 CorrectionStats {
305 total_chunks: self.perfect_chunks + self.corrected_chunks,
306 perfect_chunks: self.perfect_chunks,
307 corrected_chunks: self.corrected_chunks,
308 correction_bytes: self.total_correction_bytes,
309 original_bytes: self.total_original_bytes,
310 correction_ratio: if self.total_original_bytes > 0 {
311 self.total_correction_bytes as f64 / self.total_original_bytes as f64
312 } else {
313 0.0
314 },
315 perfect_ratio: if self.perfect_chunks + self.corrected_chunks > 0 {
316 self.perfect_chunks as f64 / (self.perfect_chunks + self.corrected_chunks) as f64
317 } else {
318 1.0
319 },
320 }
321 }
322
323 pub fn to_bytes(&self) -> Vec<u8> {
325 bincode::serialize(self).unwrap_or_default()
326 }
327
328 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
330 bincode::deserialize(bytes).ok()
331 }
332}
333
334#[derive(Clone, Debug)]
336pub struct CorrectionStats {
337 pub total_chunks: u64,
338 pub perfect_chunks: u64,
339 pub corrected_chunks: u64,
340 pub correction_bytes: u64,
341 pub original_bytes: u64,
342 pub correction_ratio: f64,
343 pub perfect_ratio: f64,
344}
345
346impl std::fmt::Display for CorrectionStats {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 write!(
349 f,
350 "Corrections: {}/{} chunks perfect ({:.1}%), \
351 {:.2}% overhead ({} bytes corrections / {} bytes original)",
352 self.perfect_chunks,
353 self.total_chunks,
354 self.perfect_ratio * 100.0,
355 self.correction_ratio * 100.0,
356 self.correction_bytes,
357 self.original_bytes,
358 )
359 }
360}
361
362pub struct ReconstructionVerifier {
364 expected_hashes: HashMap<u64, [u8; 8]>,
366}
367
368impl ReconstructionVerifier {
369 pub fn from_chunks(chunks: impl Iterator<Item = (u64, Vec<u8>)>) -> Self {
371 let expected_hashes: HashMap<u64, [u8; 8]> =
372 chunks.map(|(id, data)| (id, compute_hash(&data))).collect();
373
374 ReconstructionVerifier { expected_hashes }
375 }
376
377 pub fn verify_chunk(&self, chunk_id: u64, data: &[u8]) -> bool {
379 match self.expected_hashes.get(&chunk_id) {
380 Some(expected) => compute_hash(data) == *expected,
381 None => false, }
383 }
384
385 pub fn verify_all(&self, chunks: impl Iterator<Item = (u64, Vec<u8>)>) -> VerificationResult {
387 let mut verified = 0u64;
388 let mut failed = 0u64;
389 let mut failed_ids = Vec::new();
390
391 for (id, data) in chunks {
392 if self.verify_chunk(id, &data) {
393 verified += 1;
394 } else {
395 failed += 1;
396 failed_ids.push(id);
397 }
398 }
399
400 let missing = self.expected_hashes.len() as u64 - verified - failed;
401
402 VerificationResult {
403 verified,
404 failed,
405 missing,
406 failed_ids,
407 perfect: failed == 0 && missing == 0,
408 }
409 }
410}
411
412#[derive(Clone, Debug)]
414pub struct VerificationResult {
415 pub verified: u64,
416 pub failed: u64,
417 pub missing: u64,
418 pub failed_ids: Vec<u64>,
419 pub perfect: bool,
420}
421
422impl std::fmt::Display for VerificationResult {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 if self.perfect {
425 write!(
426 f,
427 "✓ Perfect reconstruction: {} chunks verified",
428 self.verified
429 )
430 } else {
431 write!(
432 f,
433 "✗ Reconstruction issues: {} verified, {} failed, {} missing",
434 self.verified, self.failed, self.missing
435 )
436 }
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_no_correction_needed() {
446 let original = b"hello world";
447 let approximation = b"hello world";
448
449 let correction = ChunkCorrection::new(0, original, approximation);
450
451 assert!(!correction.needs_correction());
452 assert_eq!(correction.storage_size(), 0);
453 }
454
455 #[test]
456 fn test_bit_flip_correction() {
457 let original = b"hello world";
458 let mut approximation = original.to_vec();
459 approximation[0] ^= 0x01; let correction = ChunkCorrection::new(0, original, &approximation);
462
463 assert!(correction.needs_correction());
464
465 let recovered = correction.apply(&approximation);
466 assert_eq!(recovered, original);
467 assert!(correction.verify(&recovered));
468 }
469
470 #[test]
471 fn test_verbatim_correction() {
472 let original = b"completely different data here";
473 let approximation = b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
474
475 let correction = ChunkCorrection::new(0, original, approximation);
476
477 assert!(correction.needs_correction());
478
479 let recovered = correction.apply(approximation);
480 assert_eq!(recovered, original);
481 }
482
483 #[test]
484 fn test_correction_store() {
485 let mut store = CorrectionStore::new();
486
487 store.add(0, b"chunk0", b"chunk0");
489 store.add(1, b"chunk1", b"chunk1");
490
491 store.add(2, b"chunk2", b"chunkX");
493
494 let stats = store.stats();
495 assert_eq!(stats.perfect_chunks, 2);
496 assert_eq!(stats.corrected_chunks, 1);
497
498 let recovered = store.apply(2, b"chunkX").unwrap();
500 assert_eq!(recovered, b"chunk2");
501 }
502
503 #[test]
504 fn test_reconstruction_verifier() {
505 let chunks = vec![
506 (0u64, b"chunk0".to_vec()),
507 (1u64, b"chunk1".to_vec()),
508 (2u64, b"chunk2".to_vec()),
509 ];
510
511 let verifier = ReconstructionVerifier::from_chunks(chunks.clone().into_iter());
512
513 assert!(verifier.verify_chunk(0, b"chunk0"));
515 assert!(verifier.verify_chunk(1, b"chunk1"));
516
517 assert!(!verifier.verify_chunk(0, b"wrong"));
519
520 let result = verifier.verify_all(chunks.into_iter());
522 assert!(result.perfect);
523 assert_eq!(result.verified, 3);
524 }
525
526 #[test]
527 fn test_hash_stability() {
528 let data = b"test data for hashing";
530 let hash1 = compute_hash(data);
531 let hash2 = compute_hash(data);
532 assert_eq!(hash1, hash2);
533
534 let hash3 = compute_hash(b"different data");
536 assert_ne!(hash1, hash3);
537 }
538}