cyfs_raptorq/
pi_solver.rs

1use crate::arraymap::{BoolArrayMap, UndirectedGraph};
2use crate::arraymap::{U16ArrayMap, U32VecMap};
3use crate::matrix::BinaryMatrix;
4use crate::octet::Octet;
5use crate::octet_matrix::DenseOctetMatrix;
6use crate::operation_vector::SymbolOps;
7use crate::symbol::Symbol;
8use crate::systematic_constants::num_hdpc_symbols;
9use crate::systematic_constants::num_intermediate_symbols;
10use crate::systematic_constants::num_ldpc_symbols;
11use crate::systematic_constants::num_pi_symbols;
12use crate::util::get_both_indices;
13use std::mem::size_of;
14
15#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
16struct FirstPhaseRowSelectionStats {
17    original_degree: U16ArrayMap,
18    ones_per_row: U16ArrayMap,
19    ones_histogram: U32VecMap,
20    start_col: usize,
21    end_col: usize,
22    start_row: usize,
23    rows_with_single_one: Vec<usize>,
24}
25
26impl FirstPhaseRowSelectionStats {
27    #[inline(never)]
28    #[allow(non_snake_case)]
29    pub fn new<T: BinaryMatrix>(matrix: &T, end_col: usize) -> FirstPhaseRowSelectionStats {
30        let mut result = FirstPhaseRowSelectionStats {
31            original_degree: U16ArrayMap::new(0, 0),
32            ones_per_row: U16ArrayMap::new(0, matrix.height()),
33            ones_histogram: U32VecMap::new(0),
34            start_col: 0,
35            end_col,
36            start_row: 0,
37            rows_with_single_one: vec![],
38        };
39
40        for row in 0..matrix.height() {
41            let ones = matrix.count_ones(row, 0, end_col);
42            result.ones_per_row.insert(row, ones as u16);
43            result.ones_histogram.increment(ones);
44            if ones == 1 {
45                result.rows_with_single_one.push(row);
46            }
47        }
48        // Original degree is the degree of each row before processing begins
49        result.original_degree = result.ones_per_row.clone();
50
51        result
52    }
53
54    #[allow(dead_code)]
55    pub fn size_in_bytes(&self) -> usize {
56        let mut bytes = size_of::<Self>();
57
58        bytes += self.original_degree.size_in_bytes();
59        bytes += self.ones_per_row.size_in_bytes();
60        bytes += self.ones_histogram.size_in_bytes();
61
62        bytes
63    }
64
65    pub fn swap_rows(&mut self, i: usize, j: usize) {
66        self.ones_per_row.swap(i, j);
67        self.original_degree.swap(i, j);
68        for row in self.rows_with_single_one.iter_mut() {
69            if *row == i {
70                *row = j;
71            } else if *row == j {
72                *row = i;
73            }
74        }
75    }
76
77    // Recompute all stored statistics for the given row
78    pub fn recompute_row<T: BinaryMatrix>(&mut self, row: usize, matrix: &T) {
79        let ones = matrix.count_ones(row, self.start_col, self.end_col);
80        self.rows_with_single_one.retain(|x| *x != row);
81        if ones == 1 {
82            self.rows_with_single_one.push(row);
83        }
84        self.ones_histogram
85            .decrement(self.ones_per_row.get(row) as usize);
86        self.ones_histogram.increment(ones);
87        self.ones_per_row.insert(row, ones as u16);
88    }
89
90    pub fn eliminate_leading_value(&mut self, row: usize, value: &Octet) {
91        debug_assert_ne!(*value, Octet::zero());
92        debug_assert_eq!(*value, Octet::one());
93        self.ones_per_row.decrement(row);
94        let ones = self.ones_per_row.get(row);
95        if ones == 0 {
96            self.rows_with_single_one.retain(|x| *x != row);
97        } else if ones == 1 {
98            self.rows_with_single_one.push(row);
99        }
100        self.ones_histogram.decrement((ones + 1) as usize);
101        self.ones_histogram.increment(ones as usize);
102    }
103
104    // Set the valid columns, and recalculate statistics
105    // All values in column "start_col - 1" in rows start_row..end_row must be zero
106    #[inline(never)]
107    pub fn resize<T: BinaryMatrix>(
108        &mut self,
109        start_row: usize,
110        end_row: usize,
111        start_col: usize,
112        end_col: usize,
113        matrix: &T,
114    ) {
115        // Only shrinking is supported
116        assert!(end_col <= self.end_col);
117        assert_eq!(self.start_row, start_row - 1);
118        assert_eq!(self.start_col, start_col - 1);
119
120        self.ones_histogram
121            .decrement(self.ones_per_row.get(self.start_row) as usize);
122        self.rows_with_single_one.retain(|x| *x != start_row - 1);
123
124        for col in end_col..self.end_col {
125            for row in matrix.get_ones_in_column(col, start_row, end_row) {
126                let row = row as usize;
127                self.ones_per_row.decrement(row);
128                let ones = self.ones_per_row.get(row);
129                if ones == 0 {
130                    self.rows_with_single_one.retain(|x| *x != row);
131                } else if ones == 1 {
132                    self.rows_with_single_one.push(row);
133                }
134                self.ones_histogram.decrement((ones + 1) as usize);
135                self.ones_histogram.increment(ones as usize);
136            }
137        }
138
139        self.start_col = start_col;
140        self.end_col = end_col;
141        self.start_row = start_row;
142    }
143
144    #[inline(never)]
145    fn first_phase_graph_substep_build_adjacency<T: BinaryMatrix>(
146        &self,
147        start_row: usize,
148        end_row: usize,
149        matrix: &T,
150    ) -> UndirectedGraph {
151        let mut graph = UndirectedGraph::with_capacity(
152            self.start_col as u16,
153            self.end_col as u16,
154            end_row - start_row,
155        );
156
157        for row in start_row..end_row {
158            if self.ones_per_row.get(row) != 2 {
159                continue;
160            }
161            let mut ones = [0; 2];
162            let mut found = 0;
163            for (col, value) in matrix.get_row_iter(row, self.start_col, self.end_col) {
164                // "The following graph defined by the structure of V is used in determining which
165                // row of A is chosen. The columns that intersect V are the nodes in the graph,
166                // and the rows that have exactly 2 nonzero entries in V and are not HDPC rows
167                // are the edges of the graph that connect the two columns (nodes) in the positions
168                // of the two ones."
169                // This part of the matrix is over GF(2), so "nonzero entries" is equivalent to "ones"
170                if value == Octet::one() {
171                    ones[found] = col;
172                    found += 1;
173                }
174                if found == 2 {
175                    break;
176                }
177            }
178            assert_eq!(found, 2);
179            graph.add_edge(ones[0] as u16, ones[1] as u16);
180        }
181        graph.build();
182        return graph;
183    }
184
185    #[inline(never)]
186    fn first_phase_graph_substep<T: BinaryMatrix>(
187        &self,
188        start_row: usize,
189        end_row: usize,
190        matrix: &T,
191    ) -> usize {
192        let graph = self.first_phase_graph_substep_build_adjacency(start_row, end_row, matrix);
193        let mut visited = BoolArrayMap::new(start_row, end_row);
194
195        let mut examplar_largest_component_node = None;
196        let mut largest_component_size = 0;
197
198        let mut node_queue = Vec::with_capacity(10);
199        for key in graph.nodes() {
200            let mut component_size = 0;
201            // We can choose any edge (row) that connects this col to another in the graph
202            let mut examplar_node = None;
203            // Pick arbitrary node (column) to start
204            node_queue.clear();
205            node_queue.push(key);
206            while !node_queue.is_empty() {
207                let node = node_queue.pop().unwrap();
208                if visited.get(node as usize) {
209                    continue;
210                }
211                visited.insert(node as usize, true);
212                component_size += 1;
213                for next_node in graph.get_adjacent_nodes(node) {
214                    node_queue.push(next_node);
215                    examplar_node = Some(node);
216                }
217            }
218
219            if component_size > largest_component_size {
220                examplar_largest_component_node = examplar_node;
221                largest_component_size = component_size;
222            }
223        }
224
225        let node = examplar_largest_component_node.unwrap();
226        for row in matrix.get_ones_in_column(node as usize, start_row, end_row) {
227            let row = row as usize;
228            if self.ones_per_row.get(row) == 2 {
229                return row;
230            }
231        }
232        unreachable!();
233    }
234
235    #[inline(never)]
236    fn first_phase_original_degree_substep(
237        &self,
238        start_row: usize,
239        end_row: usize,
240        r: usize,
241    ) -> usize {
242        // There's no need for special handling of HDPC rows, since Errata 2 guarantees we won't
243        // select any, and they're excluded in the first_phase solver
244        let mut chosen = None;
245        let mut chosen_original_degree = std::u16::MAX;
246        // Fast path for r=1, since this is super common
247        if r == 1 {
248            assert_ne!(0, self.rows_with_single_one.len());
249            for &row in self.rows_with_single_one.iter() {
250                let ones = self.ones_per_row.get(row);
251                let row_original_degree = self.original_degree.get(row);
252                if ones as usize == r && row_original_degree < chosen_original_degree {
253                    chosen = Some(row);
254                    chosen_original_degree = row_original_degree;
255                }
256            }
257        } else {
258            for row in start_row..end_row {
259                let ones = self.ones_per_row.get(row);
260                let row_original_degree = self.original_degree.get(row);
261                if ones as usize == r && row_original_degree < chosen_original_degree {
262                    chosen = Some(row);
263                    chosen_original_degree = row_original_degree;
264                }
265            }
266        }
267        return chosen.unwrap();
268    }
269
270    // Verify there there are no non-HPDC rows with exactly two non-zero entries, greater than one
271    #[inline(never)]
272    #[cfg(debug_assertions)]
273    fn first_phase_graph_substep_verify(&self, start_row: usize, end_row: usize) {
274        for row in start_row..end_row {
275            if self.ones_per_row.get(row) == 2 {
276                return;
277            }
278        }
279        unreachable!("A row with 2 ones must exist given Errata 8");
280    }
281
282    // Helper method for decoder phase 1
283    // selects from [start_row, end_row) reading [start_col, end_col)
284    // Returns (the chosen row, and "r" number of non-zero values the row has)
285    pub fn first_phase_selection<T: BinaryMatrix>(
286        &self,
287        start_row: usize,
288        end_row: usize,
289        matrix: &T,
290    ) -> (Option<usize>, Option<usize>) {
291        let mut r = None;
292        for i in 1..=(self.end_col - self.start_col) {
293            if self.ones_histogram.get(i) > 0 {
294                r = Some(i);
295                break;
296            }
297        }
298
299        if r == None {
300            return (None, None);
301        }
302
303        if r.unwrap() == 2 {
304            // Paragraph starting "If r = 2 and there is no row with exactly 2 ones in V" can
305            // be ignored due to Errata 8.
306
307            // See paragraph starting "If r = 2 and there is a row with exactly 2 ones in V..."
308            #[cfg(debug_assertions)]
309            self.first_phase_graph_substep_verify(start_row, end_row);
310            let row = self.first_phase_graph_substep(start_row, end_row, matrix);
311            return (Some(row), r);
312        } else {
313            let row = self.first_phase_original_degree_substep(start_row, end_row, r.unwrap());
314            return (Some(row), r);
315        }
316    }
317}
318
319// See section 5.4.2.1
320#[allow(non_snake_case)]
321#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
322pub struct IntermediateSymbolDecoder<T: BinaryMatrix> {
323    A: T,
324    // If present, these are treated as replacing the last rows of A
325    // Errata 3 guarantees that these do not need to be included in X
326    A_hdpc_rows: Option<DenseOctetMatrix>,
327    X: T,
328    D: Vec<Symbol>,
329    c: Vec<usize>,
330    d: Vec<usize>,
331    i: usize,
332    u: usize,
333    L: usize,
334    // Operations on D are deferred to the end of the codec to improve cache hits
335    deferred_D_ops: Vec<SymbolOps>,
336    num_source_symbols: u32,
337    debug_symbol_mul_ops: u32,
338    debug_symbol_add_ops: u32,
339    debug_symbol_mul_ops_by_phase: Vec<u32>,
340    debug_symbol_add_ops_by_phase: Vec<u32>,
341}
342
343#[allow(non_snake_case)]
344impl<T: BinaryMatrix> IntermediateSymbolDecoder<T> {
345    pub fn new(
346        matrix: T,
347        hdpc_rows: DenseOctetMatrix,
348        symbols: Vec<Symbol>,
349        num_source_symbols: u32,
350    ) -> IntermediateSymbolDecoder<T> {
351        assert!(matrix.width() <= symbols.len());
352        assert_eq!(matrix.height(), symbols.len());
353        let mut c = Vec::with_capacity(matrix.width());
354        let mut d = Vec::with_capacity(symbols.len());
355        for i in 0..matrix.width() {
356            c.push(i);
357        }
358        for i in 0..symbols.len() {
359            d.push(i);
360        }
361
362        let intermediate_symbols = num_intermediate_symbols(num_source_symbols) as usize;
363
364        let num_rows = matrix.height();
365
366        let pi_symbols = num_pi_symbols(num_source_symbols) as usize;
367        let mut A = matrix.clone();
368        A.enable_column_acccess_acceleration();
369        let mut X = matrix;
370        // Drop the PI symbols, since they will never be accessed in X. X will be resized to
371        // i-by-i in the second phase.
372        X.resize(X.height(), X.width() - pi_symbols);
373
374        let mut temp = IntermediateSymbolDecoder {
375            A,
376            A_hdpc_rows: None,
377            X,
378            D: symbols,
379            c,
380            d,
381            i: 0,
382            u: pi_symbols,
383            L: intermediate_symbols,
384            deferred_D_ops: Vec::with_capacity(70 * intermediate_symbols),
385            num_source_symbols,
386            debug_symbol_mul_ops: 0,
387            debug_symbol_add_ops: 0,
388            debug_symbol_mul_ops_by_phase: vec![0; 5],
389            debug_symbol_add_ops_by_phase: vec![0; 5],
390        };
391
392        // Swap the HDPC rows, so that they're the last in the matrix
393        let S = num_ldpc_symbols(num_source_symbols) as usize;
394        let H = num_hdpc_symbols(num_source_symbols) as usize;
395        // See section 5.3.3.4.2, Figure 5.
396        for i in 0..H {
397            temp.swap_rows(S + i, num_rows - H + i);
398            temp.X.swap_rows(S + i, num_rows - H + i);
399        }
400
401        temp.A_hdpc_rows = Some(hdpc_rows);
402
403        temp
404    }
405
406    #[inline(never)]
407    fn apply_deferred_symbol_ops(&mut self) {
408        for op in self.deferred_D_ops.iter() {
409            match op {
410                SymbolOps::AddAssign { dest, src } => {
411                    let (dest, temp) = get_both_indices(&mut self.D, *dest, *src);
412                    *dest += temp;
413                }
414                SymbolOps::MulAssign { dest, scalar } => {
415                    self.D[*dest].mulassign_scalar(scalar);
416                }
417                SymbolOps::FMA { dest, src, scalar } => {
418                    let (dest, temp) = get_both_indices(&mut self.D, *dest, *src);
419                    dest.fused_addassign_mul_scalar(&temp, scalar);
420                }
421                SymbolOps::Reorder { order: _order } => {}
422            }
423        }
424    }
425
426    // Returns true iff all elements in A between [start_row, end_row)
427    // and [start_column, end_column) are zero
428    #[cfg(debug_assertions)]
429    fn all_zeroes(
430        &self,
431        start_row: usize,
432        end_row: usize,
433        start_column: usize,
434        end_column: usize,
435    ) -> bool {
436        for row in start_row..end_row {
437            for column in start_column..end_column {
438                if self.get_A_value(row, column) != Octet::zero() {
439                    return false;
440                }
441            }
442        }
443        return true;
444    }
445
446    #[cfg(debug_assertions)]
447    fn get_A_value(&self, row: usize, col: usize) -> Octet {
448        if let Some(ref hdpc) = self.A_hdpc_rows {
449            if row >= self.A.height() - hdpc.height() {
450                return hdpc.get(row - (self.A.height() - hdpc.height()), col);
451            }
452        }
453        return self.A.get(row, col);
454    }
455
456    // Performs the column swapping substep of first phase, after the row has been chosen
457    #[inline(never)]
458    fn first_phase_swap_columns_substep(&mut self, r: usize) {
459        let mut swapped_columns = 0;
460        // Fast path when r == 1, since this is very common
461        if r == 1 {
462            // self.i will never reference an HDPC row, so can ignore self.A_hdpc_rows
463            // because of Errata 2.
464            for (col, value) in self
465                .A
466                .get_row_iter(self.i, self.i, self.A.width() - self.u)
467                .clone()
468            {
469                if value != Octet::zero() {
470                    // No need to swap the first i rows, as they are all zero (see submatrix above V)
471                    self.swap_columns(self.i, col, self.i);
472                    // Also apply to X
473                    self.X.swap_columns(self.i, col, 0);
474                    swapped_columns += 1;
475                    break;
476                }
477            }
478        } else {
479            for col in self.i..(self.A.width() - self.u) {
480                // self.i will never reference an HDPC row, so can ignore self.A_hdpc_rows
481                // because of Errata 2.
482                if self.A.get(self.i, col) != Octet::zero() {
483                    let mut dest;
484                    if swapped_columns == 0 {
485                        dest = self.i;
486                    } else {
487                        dest = self.A.width() - self.u - swapped_columns;
488                        // Some of the right most columns may already contain non-zeros
489                        while self.A.get(self.i, dest) != Octet::zero() {
490                            dest -= 1;
491                            swapped_columns += 1;
492                        }
493                    }
494                    if swapped_columns == r {
495                        break;
496                    }
497                    // No need to swap the first i rows, as they are all zero (see submatrix above V)
498                    self.swap_columns(dest, col, self.i);
499                    // Also apply to X
500                    self.X.swap_columns(dest, col, 0);
501                    swapped_columns += 1;
502                    if swapped_columns == r {
503                        break;
504                    }
505                }
506            }
507        }
508        assert_eq!(r, swapped_columns);
509    }
510
511    // First phase (section 5.4.2.2)
512    #[allow(non_snake_case)]
513    #[inline(never)]
514    fn first_phase(&mut self) -> bool {
515        // First phase (section 5.4.2.2)
516
517        //    ----------> i                 u <--------
518        //  | +-----------+-----------------+---------+
519        //  | |           |                 |         |
520        //  | |     I     |    All Zeros    |         |
521        //  v |           |                 |         |
522        //  i +-----------+-----------------+    U    |
523        //    |           |                 |         |
524        //    |           |                 |         |
525        //    | All Zeros |       V         |         |
526        //    |           |                 |         |
527        //    |           |                 |         |
528        //    +-----------+-----------------+---------+
529        // Figure 6: Submatrices of A in the First Phase
530
531        let num_hdpc_rows = self.A_hdpc_rows.as_ref().unwrap().height();
532
533        let mut selection_helper =
534            FirstPhaseRowSelectionStats::new(&self.A, self.A.width() - self.u);
535
536        while self.i + self.u < self.L {
537            // Calculate r
538            // "Let r be the minimum integer such that at least one row of A has
539            // exactly r nonzeros in V."
540            // Exclude the HDPC rows, since Errata 2 guarantees they won't be chosen.
541            let (chosen_row, r) = selection_helper.first_phase_selection(
542                self.i,
543                self.A.height() - num_hdpc_rows,
544                &self.A,
545            );
546
547            if r == None {
548                return false;
549            }
550            let r = r.unwrap();
551            let chosen_row = chosen_row.unwrap();
552            assert!(chosen_row >= self.i);
553
554            // See paragraph beginning: "After the row is chosen in this step..."
555            // Reorder rows
556            let temp = self.i;
557            self.swap_rows(temp, chosen_row);
558            self.X.swap_rows(temp, chosen_row);
559            selection_helper.swap_rows(temp, chosen_row);
560            // Reorder columns
561            self.first_phase_swap_columns_substep(r);
562            // Zero out leading value in following rows
563            let temp = self.i;
564            // self.i will never reference an HDPC row, so can ignore self.A_hdpc_rows
565            // because of Errata 2.
566            let temp_value = self.A.get(temp, temp);
567
568            for i in 0..(r - 1) {
569                self.A
570                    .hint_column_dense_and_frozen(self.A.width() - self.u - 1 - i);
571            }
572            selection_helper.resize(
573                self.i + 1,
574                self.A.height() - self.A_hdpc_rows.as_ref().unwrap().height(),
575                self.i + 1,
576                self.A.width() - self.u - (r - 1),
577                &self.A,
578            );
579
580            // Cloning the iterator is safe here, because we don't re-read any of the rows that
581            // we add to
582            for row in self
583                .A
584                .get_ones_in_column(temp, self.i + 1, self.A.height() - num_hdpc_rows)
585            {
586                let row = row as usize;
587                assert_eq!(&temp_value, &Octet::one());
588                // Addition is equivalent to subtraction.
589                self.fma_rows(temp, row, Octet::one());
590                if r == 1 {
591                    // Hot path for r == 1, since it's very common due to maximum connected
592                    // component selection, and recompute_row() is expensive
593                    selection_helper.eliminate_leading_value(row, &Octet::one());
594                } else {
595                    selection_helper.recompute_row(row, &self.A);
596                }
597            }
598
599            // apply to hdpc rows as well, which are stored separately
600            let pi_octets = self
601                .A
602                .get_sub_row_as_octets(temp, self.A.width() - (self.u + r - 1));
603            for row in 0..num_hdpc_rows {
604                let leading_value = self.A_hdpc_rows.as_ref().unwrap().get(row, temp);
605                if leading_value != Octet::zero() {
606                    // Addition is equivalent to subtraction
607                    let beta = &leading_value / &temp_value;
608                    self.fma_rows_with_pi(
609                        temp,
610                        row + (self.A.height() - num_hdpc_rows),
611                        beta,
612                        // self.i is the only non-PI column which can have a nonzero,
613                        // since all the rest were column swapped into the PI submatrix.
614                        Some(temp),
615                        Some(&pi_octets),
616                    );
617                    // It's safe to skip updating the selection helper, since it will never
618                    // select an HDPC row
619                }
620            }
621
622            self.i += 1;
623            self.u += r - 1;
624            #[cfg(debug_assertions)]
625            self.first_phase_verify();
626        }
627
628        self.record_symbol_ops(0);
629        return true;
630    }
631
632    // See section 5.4.2.2. Verifies the two all-zeros submatrices and the identity submatrix
633    #[inline(never)]
634    #[cfg(debug_assertions)]
635    fn first_phase_verify(&self) {
636        for row in 0..self.i {
637            for col in 0..self.i {
638                if row == col {
639                    assert_eq!(Octet::one(), self.A.get(row, col));
640                } else {
641                    assert_eq!(Octet::zero(), self.A.get(row, col));
642                }
643            }
644        }
645        assert!(self.all_zeroes(0, self.i, self.i, self.A.width() - self.u));
646        assert!(self.all_zeroes(self.i, self.A.height(), 0, self.i));
647    }
648
649    // Second phase (section 5.4.2.3)
650    #[allow(non_snake_case)]
651    #[inline(never)]
652    fn second_phase(&mut self) -> bool {
653        #[cfg(debug_assertions)]
654        self.second_phase_verify();
655
656        self.X.resize(self.i, self.i);
657
658        // Convert U_lower to row echelon form
659        let temp = self.i;
660        let size = self.u;
661        // HDPC rows can be removed, since they can't have been selected for U_upper
662        let hdpc_rows = self.A_hdpc_rows.take().unwrap();
663        if let Some(submatrix) = self.record_reduce_to_row_echelon(hdpc_rows, temp, temp, size) {
664            // Perform backwards elimination
665            self.backwards_elimination(submatrix, temp, temp, size);
666        } else {
667            return false;
668        }
669
670        self.A.resize(self.L, self.L);
671
672        self.record_symbol_ops(1);
673        return true;
674    }
675
676    // Verifies that X is lower triangular. See section 5.4.2.3
677    #[inline(never)]
678    #[cfg(debug_assertions)]
679    fn second_phase_verify(&self) {
680        for row in 0..self.i {
681            for col in (row + 1)..self.i {
682                assert_eq!(Octet::zero(), self.X.get(row, col));
683            }
684        }
685    }
686
687    // Third phase (section 5.4.2.4)
688    #[allow(non_snake_case)]
689    #[inline(never)]
690    fn third_phase(&mut self) {
691        #[cfg(debug_assertions)]
692        self.third_phase_verify();
693
694        // A[0..i][..] = X * A[0..i][..]
695        self.A.mul_assign_submatrix(&self.X, self.i);
696
697        // Now apply the same operations to D.
698        // Note that X is lower triangular, so the row must be processed last to first
699        for row in (0..self.i).rev() {
700            if self.X.get(row, row) != Octet::one() {
701                self.debug_symbol_mul_ops += 1;
702                self.deferred_D_ops.push(SymbolOps::MulAssign {
703                    dest: self.d[row],
704                    scalar: self.X.get(row, row),
705                });
706            }
707
708            for (col, value) in self.X.get_row_iter(row, 0, row) {
709                if value == Octet::zero() {
710                    continue;
711                }
712                if value == Octet::one() {
713                    self.debug_symbol_add_ops += 1;
714                    self.deferred_D_ops.push(SymbolOps::AddAssign {
715                        dest: self.d[row],
716                        src: self.d[col],
717                    });
718                } else {
719                    self.debug_symbol_mul_ops += 1;
720                    self.debug_symbol_add_ops += 1;
721                    self.deferred_D_ops.push(SymbolOps::FMA {
722                        dest: self.d[row],
723                        src: self.d[col],
724                        scalar: self.X.get(row, col),
725                    });
726                }
727            }
728        }
729
730        self.record_symbol_ops(2);
731
732        #[cfg(debug_assertions)]
733        self.third_phase_verify_end();
734    }
735
736    #[inline(never)]
737    #[cfg(debug_assertions)]
738    fn third_phase_verify(&self) {
739        for row in 0..self.A.height() {
740            for col in 0..self.A.width() {
741                if row < self.i && col >= self.A.width() - self.u {
742                    // element is in U_upper, which can have arbitrary values at this point
743                    continue;
744                }
745                // The rest of A should be identity matrix
746                if row == col {
747                    assert_eq!(Octet::one(), self.A.get(row, col));
748                } else {
749                    assert_eq!(Octet::zero(), self.A.get(row, col));
750                }
751            }
752        }
753    }
754
755    #[inline(never)]
756    #[cfg(debug_assertions)]
757    fn third_phase_verify_end(&self) {
758        for row in 0..self.i {
759            for col in 0..self.i {
760                assert_eq!(self.X.get(row, col), self.A.get(row, col));
761            }
762        }
763    }
764
765    // Fourth phase (section 5.4.2.5)
766    #[allow(non_snake_case)]
767    #[inline(never)]
768    fn fourth_phase(&mut self) {
769        for i in 0..self.i {
770            for j in 0..self.u {
771                let b = self.A.get(i, j + self.i);
772                if b != Octet::zero() {
773                    let temp = self.i;
774                    self.fma_rows(temp + j, i, b);
775                }
776            }
777        }
778
779        self.record_symbol_ops(3);
780
781        #[cfg(debug_assertions)]
782        self.fourth_phase_verify();
783    }
784
785    #[inline(never)]
786    #[cfg(debug_assertions)]
787    fn fourth_phase_verify(&self) {
788        //    ---------> i u <------
789        //  | +-----------+--------+
790        //  | |\          |        |
791        //  | |  \ Zeros  | Zeros  |
792        //  v |     \     |        |
793        //  i |  X     \  |        |
794        //  u +---------- +--------+
795        //  ^ |           |        |
796        //  | | All Zeros |   I    |
797        //  | |           |        |
798        //    +-----------+--------+
799        // Same assertion about X being equal to the upper left of A
800        self.third_phase_verify_end();
801        assert!(self.all_zeroes(0, self.i, self.A.width() - self.u, self.A.width()));
802        assert!(self.all_zeroes(self.A.height() - self.u, self.A.height(), 0, self.i));
803        for row in (self.A.height() - self.u)..self.A.height() {
804            for col in (self.A.width() - self.u)..self.A.width() {
805                if row == col {
806                    assert_eq!(Octet::one(), self.A.get(row, col));
807                } else {
808                    assert_eq!(Octet::zero(), self.A.get(row, col));
809                }
810            }
811        }
812    }
813
814    // Fifth phase (section 5.4.2.6)
815    #[allow(non_snake_case)]
816    #[inline(never)]
817    fn fifth_phase(&mut self) {
818        // "For j from 1 to i". Note that A is 1-indexed in the spec, and ranges are inclusive,
819        // this is means [1, i], which is equal to [0, i)
820        for j in 0..self.i as usize {
821            // Skip normalizing the diagonal, since there can't be non-binary values due to
822            // Errata 7
823
824            // "For l from 1 to j-1". This means the lower triangular columns, not including the
825            // diagonal, which is [0, j)
826            for (l, _) in self.A.get_row_iter(j, 0, j).clone() {
827                let temp = self.A.get(j, l);
828                if temp != Octet::zero() {
829                    self.fma_rows(l, j, temp);
830                }
831            }
832        }
833
834        self.record_symbol_ops(4);
835
836        #[cfg(debug_assertions)]
837        self.fifth_phase_verify();
838    }
839
840    #[inline(never)]
841    #[cfg(debug_assertions)]
842    fn fifth_phase_verify(&self) {
843        assert_eq!(self.L, self.A.height());
844        for row in 0..self.A.height() {
845            assert_eq!(self.L, self.A.width());
846            for col in 0..self.A.width() {
847                if row == col {
848                    assert_eq!(Octet::one(), self.A.get(row, col));
849                } else {
850                    assert_eq!(Octet::zero(), self.A.get(row, col));
851                }
852            }
853        }
854    }
855
856    fn record_symbol_ops(&mut self, phase: usize) {
857        self.debug_symbol_add_ops_by_phase[phase] = self.debug_symbol_add_ops;
858        self.debug_symbol_mul_ops_by_phase[phase] = self.debug_symbol_mul_ops;
859        for i in 0..phase {
860            self.debug_symbol_add_ops_by_phase[phase] -= self.debug_symbol_add_ops_by_phase[i];
861            self.debug_symbol_mul_ops_by_phase[phase] -= self.debug_symbol_mul_ops_by_phase[i];
862        }
863    }
864
865    // Reduces the size x size submatrix, starting at row_offset and col_offset as the upper left
866    // corner, to row echelon form.
867    // Returns the reduced submatrix, which should be written back into this submatrix of A.
868    // The state of this submatrix in A is undefined, after calling this function.
869    #[inline(never)]
870    fn record_reduce_to_row_echelon(
871        &mut self,
872        hdpc_rows: DenseOctetMatrix,
873        row_offset: usize,
874        col_offset: usize,
875        size: usize,
876    ) -> Option<DenseOctetMatrix> {
877        // Copy U_lower into a new matrix and merge it with the HDPC rows
878        let mut submatrix = DenseOctetMatrix::new(self.A.height() - row_offset, size, 0);
879        let first_hdpc_row = self.A.height() - hdpc_rows.height();
880        for row in row_offset..self.A.height() {
881            for col in col_offset..(col_offset + size) {
882                let value = if row < first_hdpc_row {
883                    self.A.get(row, col)
884                } else {
885                    hdpc_rows.get(row - first_hdpc_row, col)
886                };
887                submatrix.set(row - row_offset, col - col_offset, value);
888            }
889        }
890
891        for i in 0..size {
892            // Swap a row with leading coefficient i into place
893            for j in i..submatrix.height() {
894                if submatrix.get(j, i) != Octet::zero() {
895                    submatrix.swap_rows(i, j);
896                    // Record the swap, in addition to swapping in the working submatrix
897                    // TODO: optimize to not perform op on A
898                    self.swap_rows(row_offset + i, j + row_offset);
899                    break;
900                }
901            }
902
903            if submatrix.get(i, i) == Octet::zero() {
904                // If all following rows are zero in this column, then matrix is singular
905                return None;
906            }
907
908            // Scale leading coefficient to 1
909            if submatrix.get(i, i) != Octet::one() {
910                let element_inverse = Octet::one() / submatrix.get(i, i);
911                submatrix.mul_assign_row(i, &element_inverse);
912                // Record the multiplication, in addition to multiplying the working submatrix
913                self.record_mul_row(row_offset + i, element_inverse);
914            }
915
916            // Zero out all following elements in i'th column
917            for j in (i + 1)..submatrix.height() {
918                if submatrix.get(j, i) != Octet::zero() {
919                    let scalar = submatrix.get(j, i);
920                    submatrix.fma_rows(j, i, &scalar);
921                    // Record the FMA, in addition to applying it to the working submatrix
922                    self.record_fma_rows(row_offset + i, row_offset + j, scalar);
923                }
924            }
925        }
926
927        return Some(submatrix);
928    }
929
930    // Performs backwards elimination in a size x size submatrix, starting at
931    // row_offset and col_offset as the upper left corner of the submatrix
932    //
933    // Applies the submatrix to the size-by-size lower right of A, and performs backwards
934    // elimination on it. "submatrix" must be in row echelon form.
935    #[inline(never)]
936    fn backwards_elimination(
937        &mut self,
938        submatrix: DenseOctetMatrix,
939        row_offset: usize,
940        col_offset: usize,
941        size: usize,
942    ) {
943        // Perform backwards elimination
944        for i in (0..size).rev() {
945            // Zero out all preceding elements in i'th column
946            for j in 0..i {
947                if submatrix.get(j, i) != Octet::zero() {
948                    let scalar = submatrix.get(j, i);
949                    // Record the FMA. No need to actually apply it to the submatrix,
950                    // since it will be discarded, and we never read these values
951                    self.record_fma_rows(row_offset + i, row_offset + j, scalar);
952                }
953            }
954        }
955
956        // Write the identity matrix into A, since that's the resulting value of this function
957        for row in row_offset..(row_offset + size) {
958            for col in col_offset..(col_offset + size) {
959                if row == col {
960                    self.A.set(row, col, Octet::one());
961                } else {
962                    self.A.set(row, col, Octet::zero());
963                }
964            }
965        }
966    }
967
968    #[allow(dead_code)]
969    pub fn get_symbol_mul_ops(&self) -> u32 {
970        self.debug_symbol_mul_ops
971    }
972
973    #[allow(dead_code)]
974    pub fn get_symbol_add_ops(&self) -> u32 {
975        self.debug_symbol_add_ops
976    }
977
978    #[allow(dead_code)]
979    pub fn get_symbol_mul_ops_by_phase(&self) -> Vec<u32> {
980        self.debug_symbol_mul_ops_by_phase.clone()
981    }
982
983    #[allow(dead_code)]
984    pub fn get_symbol_add_ops_by_phase(&self) -> Vec<u32> {
985        self.debug_symbol_add_ops_by_phase.clone()
986    }
987
988    #[cfg(feature = "benchmarking")]
989    pub fn get_non_symbol_bytes(&self) -> usize {
990        let mut bytes = size_of::<Self>();
991
992        bytes += self.A.size_in_bytes();
993        if let Some(ref hdpc) = self.A_hdpc_rows {
994            bytes += hdpc.size_in_bytes();
995        }
996        bytes += self.X.size_in_bytes();
997        // Skip self.D, since we're calculating non-Symbol bytes
998        bytes += size_of::<usize>() * self.c.len();
999        bytes += size_of::<usize>() * self.d.len();
1000
1001        bytes
1002    }
1003
1004    // Record operation to apply operations to D.
1005    fn record_mul_row(&mut self, i: usize, beta: Octet) {
1006        self.debug_symbol_mul_ops += 1;
1007        self.deferred_D_ops.push(SymbolOps::MulAssign {
1008            dest: self.d[i],
1009            scalar: beta,
1010        });
1011        assert!(self.A_hdpc_rows.is_none());
1012    }
1013
1014    fn fma_rows(&mut self, i: usize, iprime: usize, beta: Octet) {
1015        self.fma_rows_with_pi(i, iprime, beta, None, None);
1016    }
1017
1018    fn record_fma_rows(&mut self, i: usize, iprime: usize, beta: Octet) {
1019        if beta == Octet::one() {
1020            self.debug_symbol_add_ops += 1;
1021            self.deferred_D_ops.push(SymbolOps::AddAssign {
1022                dest: self.d[iprime],
1023                src: self.d[i],
1024            });
1025        } else {
1026            self.debug_symbol_add_ops += 1;
1027            self.debug_symbol_mul_ops += 1;
1028            self.deferred_D_ops.push(SymbolOps::FMA {
1029                dest: self.d[iprime],
1030                src: self.d[i],
1031                scalar: beta,
1032            });
1033        }
1034    }
1035
1036    fn fma_rows_with_pi(
1037        &mut self,
1038        i: usize,
1039        iprime: usize,
1040        beta: Octet,
1041        only_non_pi_nonzero_column: Option<usize>,
1042        pi_octets: Option<&Vec<u8>>,
1043    ) {
1044        self.record_fma_rows(i, iprime, beta.clone());
1045
1046        if let Some(ref mut hdpc) = self.A_hdpc_rows {
1047            let first_hdpc_row = self.A.height() - hdpc.height();
1048            // Adding HDPC rows to other rows isn't supported, since it should never happen
1049            assert!(i < first_hdpc_row);
1050            if iprime >= first_hdpc_row {
1051                let col = only_non_pi_nonzero_column.unwrap();
1052                let multiplicand = self.A.get(i, col);
1053                let mut value = hdpc.get(iprime - first_hdpc_row, col);
1054                value.fma(&multiplicand, &beta);
1055                hdpc.set(iprime - first_hdpc_row, col, value);
1056
1057                // Handle this part separately, since it's in the dense U part of the matrix
1058                let octets = pi_octets.unwrap();
1059                hdpc.fma_sub_row(
1060                    iprime - first_hdpc_row,
1061                    self.A.width() - octets.len(),
1062                    &beta,
1063                    octets,
1064                );
1065            } else {
1066                assert_eq!(&beta, &Octet::one());
1067                self.A.add_assign_rows(iprime, i);
1068            }
1069        } else {
1070            assert_eq!(&beta, &Octet::one());
1071            self.A.add_assign_rows(iprime, i);
1072        }
1073    }
1074
1075    fn swap_rows(&mut self, i: usize, iprime: usize) {
1076        if let Some(ref hdpc_rows) = self.A_hdpc_rows {
1077            // Can't swap HDPC rows
1078            assert!(i < self.A.height() - hdpc_rows.height());
1079            assert!(iprime < self.A.height() - hdpc_rows.height());
1080        }
1081        self.A.swap_rows(i, iprime);
1082        self.d.swap(i, iprime);
1083    }
1084
1085    fn swap_columns(&mut self, j: usize, jprime: usize, start_row: usize) {
1086        self.A.swap_columns(j, jprime, start_row);
1087        self.A_hdpc_rows
1088            .as_mut()
1089            .unwrap()
1090            .swap_columns(j, jprime, 0);
1091        self.c.swap(j, jprime);
1092    }
1093
1094    #[inline(never)]
1095    pub fn execute(&mut self) -> (u32, Option<Vec<Symbol>>, Option<Vec<SymbolOps>>) {
1096        self.X.disable_column_acccess_acceleration();
1097
1098        let old_i = self.i;
1099        if !self.first_phase() {
1100            let new_i = self.i;
1101            if old_i < new_i {
1102                return (1, None, None);
1103            }else{
1104                return (0, None, None);
1105            }
1106        }
1107
1108        self.A.disable_column_acccess_acceleration();
1109
1110        if !self.second_phase() {
1111            return (0, None, None);
1112        }
1113
1114        self.third_phase();
1115        self.fourth_phase();
1116        self.fifth_phase();
1117
1118        self.apply_deferred_symbol_ops();
1119
1120        // See end of section 5.4.2.1
1121        let mut index_mapping = vec![0; self.L];
1122        for i in 0..self.L {
1123            index_mapping[self.c[i]] = self.d[i];
1124        }
1125
1126        #[allow(non_snake_case)]
1127        let mut removable_D: Vec<Option<Symbol>> = self.D.drain(..).map(Some).collect();
1128
1129        let mut result = Vec::with_capacity(self.L);
1130        #[allow(clippy::needless_range_loop)]
1131        for i in 0..self.L {
1132            // push a None so it can be swapped in
1133            removable_D.push(None);
1134            result.push(removable_D.swap_remove(index_mapping[i]).unwrap());
1135        }
1136
1137        let mut reorder = Vec::with_capacity(self.L);
1138        for i in index_mapping.iter().take(self.L) {
1139            reorder.push(*i);
1140        }
1141
1142        let mut operation_vector = std::mem::replace(&mut self.deferred_D_ops, vec![]);
1143        operation_vector.push(SymbolOps::Reorder { order: reorder });
1144        return (2, Some(result), Some(operation_vector));
1145    }
1146}
1147
1148// Fused implementation for self.inverse().mul_symbols(symbols)
1149// See section 5.4.2.1
1150pub fn fused_inverse_mul_symbols<T: BinaryMatrix>(
1151    matrix: T,
1152    hdpc_rows: DenseOctetMatrix,
1153    symbols: Vec<Symbol>,
1154    num_source_symbols: u32,
1155) -> (u32, Option<Vec<Symbol>>, Option<Vec<SymbolOps>>) {
1156    IntermediateSymbolDecoder::new(matrix, hdpc_rows, symbols, num_source_symbols).execute()
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::IntermediateSymbolDecoder;
1162    use crate::constraint_matrix::generate_constraint_matrix;
1163    use crate::matrix::BinaryMatrix;
1164    use crate::matrix::DenseBinaryMatrix;
1165    use crate::symbol::Symbol;
1166    use crate::systematic_constants::{
1167        extended_source_block_symbols, num_ldpc_symbols, num_lt_symbols,
1168        MAX_SOURCE_SYMBOLS_PER_BLOCK,
1169    };
1170
1171    #[test]
1172    fn operations_per_symbol() {
1173        for &(elements, expected_mul_ops, expected_add_ops) in
1174            [(10, 35.0, 50.0), (100, 16.0, 35.0)].iter()
1175        {
1176            let num_symbols = extended_source_block_symbols(elements);
1177            let indices: Vec<u32> = (0..num_symbols).collect();
1178            let (a, hdpc) = generate_constraint_matrix::<DenseBinaryMatrix>(num_symbols, &indices);
1179            let symbols = vec![Symbol::zero(1usize); a.width()];
1180            let mut decoder = IntermediateSymbolDecoder::new(a, hdpc, symbols, num_symbols);
1181            decoder.execute();
1182            assert!(
1183                (decoder.get_symbol_mul_ops() as f64 / num_symbols as f64) < expected_mul_ops,
1184                "mul ops per symbol = {}",
1185                (decoder.get_symbol_mul_ops() as f64 / num_symbols as f64)
1186            );
1187            assert!(
1188                (decoder.get_symbol_add_ops() as f64 / num_symbols as f64) < expected_add_ops,
1189                "add ops per symbol = {}",
1190                (decoder.get_symbol_add_ops() as f64 / num_symbols as f64)
1191            );
1192        }
1193    }
1194
1195    #[test]
1196    fn check_errata_3() {
1197        // Check that the optimization of excluding HDPC rows from the X matrix during decoding is
1198        // safe. This is described in RFC6330_ERRATA.md
1199        for i in 0..=MAX_SOURCE_SYMBOLS_PER_BLOCK {
1200            assert!(extended_source_block_symbols(i) + num_ldpc_symbols(i) >= num_lt_symbols(i));
1201        }
1202    }
1203}