Skip to main content

clay_codes/
decode.rs

1//! Decoding and erasure recovery for Clay codes
2//!
3//! This module implements the layered decoding algorithm from the FAST'18 paper.
4//! It handles both full decoding (all chunks available) and erasure recovery
5//! (up to m chunks missing).
6
7use std::collections::{BTreeSet, HashMap};
8
9use reed_solomon_erasure::galois_8::{self, add as gf_add, mul as gf_mul};
10
11use crate::coords::get_plane_vector;
12use crate::encode::EncodeParams;
13use crate::error::ClayError;
14use crate::transforms::{
15    compute_c_from_u_and_cstar, compute_u_from_c_and_ustar, pft_compute_both, prt_compute_both,
16    GAMMA,
17};
18
19/// Parameters needed for decoding (same as encode for now)
20pub type DecodeParams = EncodeParams;
21
22/// Recover original data from available chunks
23///
24/// # Parameters
25/// - `params`: Code parameters
26/// - `available`: Map from chunk index to chunk data
27/// - `erasures`: Set of erased chunk indices
28///
29/// # Returns
30/// Recovered original data, or error if decoding fails
31pub fn decode(
32    params: &DecodeParams,
33    available: &HashMap<usize, Vec<u8>>,
34    erasures: &[usize],
35) -> Result<Vec<u8>, ClayError> {
36    if available.is_empty() {
37        return Ok(Vec::new());
38    }
39
40    // Validate erasure count
41    if erasures.len() > params.m {
42        return Err(ClayError::TooManyErasures {
43            max: params.m,
44            actual: erasures.len(),
45        });
46    }
47
48    // Get chunk size from first available chunk and validate all chunks match
49    let mut iter = available.iter();
50    let (_, first_chunk) = iter.next().unwrap();
51    let chunk_size = first_chunk.len();
52
53    // Validate chunk_size is divisible by sub_chunk_no
54    if chunk_size == 0 || chunk_size % params.sub_chunk_no != 0 {
55        return Err(ClayError::InvalidChunkSize {
56            expected: params.sub_chunk_no,
57            actual: chunk_size,
58        });
59    }
60
61    // Validate all chunks have same size
62    for (&idx, chunk) in iter {
63        if chunk.len() != chunk_size {
64            return Err(ClayError::InconsistentChunkSizes {
65                first_size: chunk_size,
66                mismatched_idx: idx,
67                mismatched_size: chunk.len(),
68            });
69        }
70    }
71
72    // Validate chunk indices are in valid range
73    for &idx in available.keys() {
74        if idx >= params.n {
75            return Err(ClayError::InvalidParameters(format!(
76                "Chunk index {} out of range [0, {})",
77                idx, params.n
78            )));
79        }
80    }
81    for &e in erasures {
82        if e >= params.n {
83            return Err(ClayError::InvalidParameters(format!(
84                "Erasure index {} out of range [0, {})",
85                e, params.n
86            )));
87        }
88    }
89
90    // Validate consistency between available and erasures
91    // 1. Erasures and available keys must be disjoint
92    for &e in erasures {
93        if available.contains_key(&e) {
94            return Err(ClayError::InvalidParameters(format!(
95                "Node {} is both in available chunks and marked as erased",
96                e
97            )));
98        }
99    }
100
101    // 2. We need exactly n - erasures.len() available chunks
102    let expected_available = params.n - erasures.len();
103    if available.len() != expected_available {
104        return Err(ClayError::InvalidParameters(format!(
105            "Expected {} available chunks (n={} - erasures={}), but got {}",
106            expected_available,
107            params.n,
108            erasures.len(),
109            available.len()
110        )));
111    }
112
113    // 3. All non-erased nodes must be present in available
114    for node in 0..params.n {
115        if !erasures.contains(&node) && !available.contains_key(&node) {
116            return Err(ClayError::InvalidParameters(format!(
117                "Node {} is neither erased nor provided in available chunks",
118                node
119            )));
120        }
121    }
122
123    let sub_chunk_size = chunk_size / params.sub_chunk_no;
124    let total_nodes = params.q * params.t;
125
126    // Build full chunks array with proper node indices
127    let mut chunks: Vec<Vec<u8>> = vec![vec![0u8; chunk_size]; total_nodes];
128
129    // Copy available chunks, mapping from external (k data + m parity) to internal indices
130    for (&idx, data) in available.iter() {
131        let internal_idx = if idx < params.k { idx } else { idx + params.nu };
132        chunks[internal_idx] = data.clone();
133    }
134
135    // Build erasure set with internal indices
136    // Note: shortened nodes are NOT erasures - they are known zeros
137    let mut erased_set: BTreeSet<usize> = BTreeSet::new();
138    for &e in erasures {
139        let internal_idx = if e < params.k { e } else { e + params.nu };
140        erased_set.insert(internal_idx);
141    }
142
143    // Shortened nodes are KNOWN zeros, already set in chunks array
144    // They should NOT be added to erased_set
145
146    // Decode
147    decode_layered(params, &erased_set, &mut chunks, sub_chunk_size)?;
148
149    // Extract original data from first k chunks
150    let mut result = Vec::with_capacity(params.k * chunk_size);
151    for i in 0..params.k {
152        result.extend_from_slice(&chunks[i]);
153    }
154
155    Ok(result)
156}
157
158/// Main layered decoding algorithm
159///
160/// Processes layers in order of increasing intersection score, applying
161/// PRT/PFT transforms and RS decoding as needed.
162pub fn decode_layered(
163    params: &DecodeParams,
164    erased_chunks: &BTreeSet<usize>,
165    chunks: &mut Vec<Vec<u8>>,
166    sub_chunk_size: usize,
167) -> Result<(), ClayError> {
168    let total_nodes = params.q * params.t;
169
170    // Initialize U buffers
171    let chunk_size = chunks[0].len();
172    let mut u_buf: Vec<Vec<u8>> = vec![vec![0u8; chunk_size]; total_nodes];
173
174    // Track which U values have been computed (for using across iterations)
175    let mut u_computed: Vec<Vec<bool>> = vec![vec![false; params.sub_chunk_no]; total_nodes];
176
177    // Compute layer order by intersection score
178    let mut order: Vec<usize> = vec![0; params.sub_chunk_no];
179    set_planes_sequential_decoding_order(params, &mut order, erased_chunks);
180
181    let max_iscore = get_max_iscore(params, erased_chunks);
182
183    // Process layers in order of increasing intersection score
184    for iscore in 0..=max_iscore {
185        // First pass: decode erasures for layers with this iscore
186        for z in 0..params.sub_chunk_no {
187            if order[z] == iscore {
188                decode_layered_with_tracking(
189                    params,
190                    erased_chunks,
191                    z,
192                    chunks,
193                    &mut u_buf,
194                    &mut u_computed,
195                    sub_chunk_size,
196                )?;
197            }
198        }
199
200        // Second pass: recover C values from U values
201        for z in 0..params.sub_chunk_no {
202            if order[z] == iscore {
203                let z_vec = get_plane_vector(z, params.t, params.q);
204
205                for &node_xy in erased_chunks {
206                    let x = node_xy % params.q;
207                    let y = node_xy / params.q;
208                    let z_y = z_vec[y];
209                    let node_sw = y * params.q + z_y;
210                    let z_sw = get_companion_layer(params, z, x, y, z_y);
211
212                    if z_y != x {
213                        if !erased_chunks.contains(&node_sw) {
214                            // Type 1: companion is not erased
215                            recover_type1_erasure(
216                                params,
217                                chunks,
218                                &u_buf,
219                                x,
220                                y,
221                                z,
222                                z_y,
223                                z_sw,
224                                sub_chunk_size,
225                            );
226                        } else if z_y < x {
227                            // Both erased, process once (when z_y < x)
228                            get_coupled_from_uncoupled(
229                                params, chunks, &u_buf, x, y, z, z_y, z_sw, sub_chunk_size,
230                            );
231                        }
232                    } else {
233                        // Red vertex: C = U
234                        let offset = z * sub_chunk_size;
235                        chunks[node_xy][offset..offset + sub_chunk_size]
236                            .copy_from_slice(&u_buf[node_xy][offset..offset + sub_chunk_size]);
237                    }
238                }
239            }
240        }
241    }
242
243    Ok(())
244}
245
246/// Decode erasures for a single layer with U tracking
247fn decode_layered_with_tracking(
248    params: &DecodeParams,
249    erased_chunks: &BTreeSet<usize>,
250    z: usize,
251    chunks: &[Vec<u8>],
252    u_buf: &mut [Vec<u8>],
253    u_computed: &mut [Vec<bool>],
254    sub_chunk_size: usize,
255) -> Result<(), ClayError> {
256    let z_vec = get_plane_vector(z, params.t, params.q);
257
258    // Track nodes that need MDS recovery for this layer
259    let mut needs_mds: BTreeSet<usize> = erased_chunks.clone();
260
261    // Compute U values for non-erased nodes
262    for x in 0..params.q {
263        for y in 0..params.t {
264            let node_xy = params.q * y + x;
265            let z_y = z_vec[y];
266            let node_sw = params.q * y + z_y;
267            let z_sw = get_companion_layer(params, z, x, y, z_y);
268
269            if !erased_chunks.contains(&node_xy) {
270                if z_y == x {
271                    // Red vertex: U = C (no companion needed)
272                    let offset = z * sub_chunk_size;
273                    u_buf[node_xy][offset..offset + sub_chunk_size]
274                        .copy_from_slice(&chunks[node_xy][offset..offset + sub_chunk_size]);
275                    u_computed[node_xy][z] = true;
276                } else if !erased_chunks.contains(&node_sw) {
277                    // Both nodes available - apply PRT (only process once when z_y < x)
278                    if z_y < x {
279                        get_uncoupled_from_coupled(
280                            params, chunks, u_buf, x, y, z, z_y, z_sw, sub_chunk_size,
281                        );
282                        u_computed[node_xy][z] = true;
283                        u_computed[node_sw][z_sw] = true;
284                    }
285                } else {
286                    // Companion is erased - check if companion's U* is available
287                    // from a previous iteration (lower intersection score layer)
288                    if u_computed[node_sw][z_sw] {
289                        // Use U = det*C + γ*U* to compute U from C and known U*
290                        let offset_z = z * sub_chunk_size;
291                        let offset_zsw = z_sw * sub_chunk_size;
292                        let c_xy = &chunks[node_xy][offset_z..offset_z + sub_chunk_size];
293                        let u_sw = &u_buf[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
294                        let u_xy = compute_u_from_c_and_ustar(c_xy, u_sw);
295                        u_buf[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&u_xy);
296                        u_computed[node_xy][z] = true;
297                    } else {
298                        // Companion's U not available yet - mark for MDS
299                        needs_mds.insert(node_xy);
300                    }
301                }
302            }
303        }
304    }
305
306    // Decode uncoupled layer using MDS
307    decode_uncoupled_layer(params, &needs_mds, z, sub_chunk_size, u_buf)?;
308
309    // Mark reconstructed nodes as computed
310    for &node in &needs_mds {
311        u_computed[node][z] = true;
312    }
313
314    Ok(())
315}
316
317/// Decode uncoupled layer using RS MDS code
318pub fn decode_uncoupled_layer(
319    params: &DecodeParams,
320    erased_chunks: &BTreeSet<usize>,
321    z: usize,
322    sub_chunk_size: usize,
323    u_buf: &mut [Vec<u8>],
324) -> Result<(), ClayError> {
325    let total_nodes = params.q * params.t;
326    let offset = z * sub_chunk_size;
327    let parity_start = params.original_count; // k + nu
328
329    // Check if we have too many erasures for this layer
330    if erased_chunks.len() > params.m {
331        return Err(ClayError::TooManyErasures {
332            max: params.m,
333            actual: erased_chunks.len(),
334        });
335    }
336
337    // If no erasures, nothing to do
338    if erased_chunks.is_empty() {
339        return Ok(());
340    }
341
342    // Check if we have erased originals or parities
343    let has_erased_originals = erased_chunks.iter().any(|&i| i < parity_start);
344    let has_erased_parities = erased_chunks.iter().any(|&i| i >= parity_start);
345
346    // Create RS codec for this layer
347    let rs = reed_solomon_erasure::ReedSolomon::<galois_8::Field>::new(
348        params.original_count,
349        params.recovery_count,
350    )
351    .map_err(|e| ClayError::ReconstructionFailed(format!("Layer {} RS init failed: {:?}", z, e)))?;
352
353    if has_erased_originals {
354        // Build shards as Option<Vec<u8>> for reconstruction
355        let mut shards: Vec<Option<Vec<u8>>> = Vec::with_capacity(total_nodes);
356
357        for i in 0..total_nodes {
358            if erased_chunks.contains(&i) {
359                shards.push(None);
360            } else {
361                shards.push(Some(u_buf[i][offset..offset + sub_chunk_size].to_vec()));
362            }
363        }
364
365        // Reconstruct missing shards
366        rs.reconstruct(&mut shards).map_err(|e| {
367            ClayError::ReconstructionFailed(format!("Layer {} RS reconstruct failed: {:?}", z, e))
368        })?;
369
370        // Copy restored shards back
371        for i in 0..total_nodes {
372            if erased_chunks.contains(&i) {
373                if let Some(ref data) = shards[i] {
374                    u_buf[i][offset..offset + sub_chunk_size].copy_from_slice(data);
375                }
376            }
377        }
378    } else if has_erased_parities {
379        // Only parity shards erased - just re-encode
380        let mut shards: Vec<Vec<u8>> = Vec::with_capacity(total_nodes);
381
382        for i in 0..total_nodes {
383            shards.push(u_buf[i][offset..offset + sub_chunk_size].to_vec());
384        }
385
386        // Encode to regenerate parity shards
387        rs.encode(&mut shards).map_err(|e| {
388            ClayError::ReconstructionFailed(format!("Layer {} RS encode failed: {:?}", z, e))
389        })?;
390
391        // Copy regenerated parity shards back
392        for i in parity_start..total_nodes {
393            if erased_chunks.contains(&i) {
394                u_buf[i][offset..offset + sub_chunk_size].copy_from_slice(&shards[i]);
395            }
396        }
397    }
398
399    Ok(())
400}
401
402/// Get companion layer index with proper modular arithmetic
403///
404/// z_sw = (z + (x - z_y) * q^(t-1-y)) mod α
405pub fn get_companion_layer(params: &DecodeParams, z: usize, x: usize, y: usize, z_y: usize) -> usize {
406    debug_assert!(y < params.t, "y={} must be < t={}", y, params.t);
407    debug_assert!(x < params.q, "x={} must be < q={}", x, params.q);
408    debug_assert!(z_y < params.q, "z_y={} must be < q={}", z_y, params.q);
409    debug_assert!(
410        z < params.sub_chunk_no,
411        "z={} must be < α={}",
412        z,
413        params.sub_chunk_no
414    );
415
416    let alpha = params.sub_chunk_no as isize;
417    let multiplier = params.q.pow((params.t - 1 - y) as u32) as isize;
418    let diff = x as isize - z_y as isize;
419    let z_sw = ((z as isize) + diff * multiplier).rem_euclid(alpha) as usize;
420    debug_assert!(
421        z_sw < params.sub_chunk_no,
422        "z_sw out of bounds: {} >= {}",
423        z_sw,
424        params.sub_chunk_no
425    );
426    z_sw
427}
428
429/// Get uncoupled values from coupled values using PRT
430fn get_uncoupled_from_coupled(
431    params: &DecodeParams,
432    chunks: &[Vec<u8>],
433    u_buf: &mut [Vec<u8>],
434    x: usize,
435    y: usize,
436    z: usize,
437    z_y: usize,
438    z_sw: usize,
439    sub_chunk_size: usize,
440) {
441    let node_xy = y * params.q + x;
442    let node_sw = y * params.q + z_y;
443
444    let offset_z = z * sub_chunk_size;
445    let offset_zsw = z_sw * sub_chunk_size;
446
447    let c_xy = &chunks[node_xy][offset_z..offset_z + sub_chunk_size];
448    let c_sw = &chunks[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
449
450    // Determine which is C and which is C* based on x vs z_y
451    let (u_xy, u_sw) = if x < z_y {
452        prt_compute_both(c_xy, c_sw)
453    } else {
454        let (u_sw, u_xy) = prt_compute_both(c_sw, c_xy);
455        (u_xy, u_sw)
456    };
457
458    u_buf[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&u_xy);
459    u_buf[node_sw][offset_zsw..offset_zsw + sub_chunk_size].copy_from_slice(&u_sw);
460}
461
462/// Recover type 1 erasure (companion not erased)
463fn recover_type1_erasure(
464    params: &DecodeParams,
465    chunks: &mut [Vec<u8>],
466    u_buf: &[Vec<u8>],
467    x: usize,
468    y: usize,
469    z: usize,
470    z_y: usize,
471    z_sw: usize,
472    sub_chunk_size: usize,
473) {
474    let node_xy = y * params.q + x;
475    let node_sw = y * params.q + z_y;
476
477    let offset_z = z * sub_chunk_size;
478    let offset_zsw = z_sw * sub_chunk_size;
479
480    let c_sw = &chunks[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
481    let u_xy = &u_buf[node_xy][offset_z..offset_z + sub_chunk_size];
482
483    // Compute C from U and C*
484    let c_xy = compute_c_from_u_and_cstar(u_xy, c_sw);
485
486    chunks[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&c_xy);
487}
488
489/// Get coupled values from uncoupled values using PFT
490fn get_coupled_from_uncoupled(
491    params: &DecodeParams,
492    chunks: &mut [Vec<u8>],
493    u_buf: &[Vec<u8>],
494    x: usize,
495    y: usize,
496    z: usize,
497    z_y: usize,
498    z_sw: usize,
499    sub_chunk_size: usize,
500) {
501    let node_xy = y * params.q + x;
502    let node_sw = y * params.q + z_y;
503
504    let offset_z = z * sub_chunk_size;
505    let offset_zsw = z_sw * sub_chunk_size;
506
507    let u_xy = &u_buf[node_xy][offset_z..offset_z + sub_chunk_size];
508    let u_sw = &u_buf[node_sw][offset_zsw..offset_zsw + sub_chunk_size];
509
510    // PFT: compute C from U pair
511    let (c_xy, c_sw) = if x < z_y {
512        pft_compute_both(u_xy, u_sw)
513    } else {
514        let (c_sw, c_xy) = pft_compute_both(u_sw, u_xy);
515        (c_xy, c_sw)
516    };
517
518    chunks[node_xy][offset_z..offset_z + sub_chunk_size].copy_from_slice(&c_xy);
519    chunks[node_sw][offset_zsw..offset_zsw + sub_chunk_size].copy_from_slice(&c_sw);
520}
521
522/// Set decoding order based on intersection scores
523fn set_planes_sequential_decoding_order(
524    params: &DecodeParams,
525    order: &mut [usize],
526    erasures: &BTreeSet<usize>,
527) {
528    for z in 0..params.sub_chunk_no {
529        let z_vec = get_plane_vector(z, params.t, params.q);
530        order[z] = 0;
531        for &i in erasures {
532            if i % params.q == z_vec[i / params.q] {
533                order[z] += 1;
534            }
535        }
536    }
537}
538
539/// Get maximum intersection score
540fn get_max_iscore(params: &DecodeParams, erased_chunks: &BTreeSet<usize>) -> usize {
541    let mut weight_vec = vec![false; params.t];
542    let mut iscore = 0;
543
544    for &i in erased_chunks {
545        let y = i / params.q;
546        if !weight_vec[y] {
547            weight_vec[y] = true;
548            iscore += 1;
549        }
550    }
551
552    iscore
553}
554
555/// Compute C* from C and U (for repair)
556///
557/// companion_value = (U + C) / γ
558pub fn compute_cstar_from_c_and_u(c_helper: &[u8], u_helper: &[u8]) -> Vec<u8> {
559    let len = c_helper.len();
560    let mut companion_c = vec![0u8; len];
561    let gamma_inv = crate::transforms::gf_inv(GAMMA);
562
563    for i in 0..len {
564        companion_c[i] = gf_mul(gf_add(u_helper[i], c_helper[i]), gamma_inv);
565    }
566
567    companion_c
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    fn test_params() -> DecodeParams {
575        DecodeParams {
576            k: 4,
577            m: 2,
578            n: 6,
579            q: 2,
580            t: 3,
581            nu: 0,
582            sub_chunk_no: 8,
583            original_count: 4,
584            recovery_count: 2,
585        }
586    }
587
588    #[test]
589    fn test_companion_layer_valid_range() {
590        let params = test_params();
591
592        for z in 0..params.sub_chunk_no {
593            let z_vec = get_plane_vector(z, params.t, params.q);
594            for y in 0..params.t {
595                for x in 0..params.q {
596                    let z_sw = get_companion_layer(&params, z, x, y, z_vec[y]);
597                    assert!(
598                        z_sw < params.sub_chunk_no,
599                        "z_sw {} out of range for z={}, x={}, y={}",
600                        z_sw,
601                        z,
602                        x,
603                        y
604                    );
605                }
606            }
607        }
608    }
609
610    #[test]
611    fn test_get_max_iscore() {
612        let params = test_params();
613
614        // No erasures
615        let empty: BTreeSet<usize> = BTreeSet::new();
616        assert_eq!(get_max_iscore(&params, &empty), 0);
617
618        // One erasure
619        let mut one: BTreeSet<usize> = BTreeSet::new();
620        one.insert(0);
621        assert_eq!(get_max_iscore(&params, &one), 1);
622
623        // Two erasures in same y-section
624        let mut two_same: BTreeSet<usize> = BTreeSet::new();
625        two_same.insert(0);
626        two_same.insert(1);
627        assert_eq!(get_max_iscore(&params, &two_same), 1);
628
629        // Two erasures in different y-sections
630        let mut two_diff: BTreeSet<usize> = BTreeSet::new();
631        two_diff.insert(0);
632        two_diff.insert(2);
633        assert_eq!(get_max_iscore(&params, &two_diff), 2);
634    }
635}