local_reconstruction_code_gen/
lib.rs

1use g2p::GaloisField;
2use std::iter::FusedIterator;
3use std::mem::size_of;
4
5/// Select n evenly sized groups of values in `GF(2^m)`.
6///
7/// Every group has a different set of `m / groups` bits set to 0.
8///
9/// # Examples
10/// ```
11/// use local_reconstruction_code_gen::gf_bitgroups;
12///
13/// let mut groups = gf_bitgroups(2,4).map(Iterator::collect);
14/// assert_eq!(groups.next(), Some(vec![0b0001, 0b0010, 0b0011]));
15/// assert_eq!(groups.next(), Some(vec![0b0100, 0b1000, 0b1100]));
16/// assert_eq!(groups.next(), None);
17/// ```
18pub fn gf_bitgroups(
19    groups: usize,
20    m: usize,
21) -> impl Iterator<
22    Item = impl Iterator<Item = u8> + DoubleEndedIterator + ExactSizeIterator + FusedIterator,
23> + DoubleEndedIterator
24       + ExactSizeIterator
25       + FusedIterator {
26    assert!(m % groups == 0, "Parameter `m` must divide by `groups`");
27    assert!(m <= 8);
28
29    let group_len = m / groups;
30    (0..groups).map(move |group| {
31        (1..=1_u8
32            .checked_shl(group_len as u32)
33            .map_or(u8::MAX, |v| v - 1))
34            .map(move |v| v << group * group_len)
35    })
36}
37
38#[derive(Debug, PartialEq)]
39pub enum MatrixGenError {
40    /// Only values of `r` up to 2 are supported.
41    UnsupportedInput,
42    /// The provided field `GF(p)` must have a size such that `p / l - 1 < k / l`.
43    FieldTooSmall,
44}
45
46/// Generate encode matrix for a `(k,l,r)` Local Reconstruction Code.
47///
48/// The `k` data symbols divided into `l` groups with each having a single local parity symbol and
49/// `r` global parity symbols.
50///
51/// # Examples
52/// ```
53/// use local_reconstruction_code_gen::gen_encode_matrix;
54/// use g2p::g2p;
55///
56/// g2p!(GF16, 4, modulus: 0b10011);
57/// # fn main() {
58/// let encode_matrix = gen_encode_matrix::<GF16>(6, 2, 2).map(Iterator::collect);
59/// assert_eq!(
60///     encode_matrix,
61///     Ok(vec![
62///         0b0001, 0b0000, 0b0000, 0b0000, 0b0000, 0b0000,
63///         0b0000, 0b0001, 0b0000, 0b0000, 0b0000, 0b0000,
64///         0b0000, 0b0000, 0b0001, 0b0000, 0b0000, 0b0000,
65///         0b0000, 0b0000, 0b0000, 0b0001, 0b0000, 0b0000,
66///         0b0000, 0b0000, 0b0000, 0b0000, 0b0001, 0b0000,
67///         0b0000, 0b0000, 0b0000, 0b0000, 0b0000, 0b0001,
68///         0b0001, 0b0001, 0b0001, 0b0000, 0b0000, 0b0000,
69///         0b0000, 0b0000, 0b0000, 0b0001, 0b0001, 0b0001,
70///         0b0001, 0b0010, 0b0011, 0b0100, 0b1000, 0b1100,
71///         0b0001, 0b0100, 0b0101, 0b0011, 0b1100, 0b1111,
72///     ])
73/// );
74/// # }
75/// ```
76pub fn gen_encode_matrix<'a, GF: 'a + From<u8> + Into<u8> + GaloisField>(
77    k: usize,
78    l: usize,
79    r: usize,
80) -> Result<impl Iterator<Item = u8> + DoubleEndedIterator + FusedIterator + 'a, MatrixGenError> {
81    if r > 2 {
82        return Err(MatrixGenError::UnsupportedInput);
83    }
84
85    if GF::SIZE / l - 1 < k / l {
86        return Err(MatrixGenError::FieldTooSmall);
87    }
88
89    // Log(2) to get m
90    let m = size_of::<usize>() * 8 - GF::SIZE.leading_zeros() as usize - 1;
91
92    // The `k x k` identity matrix
93    let data_rows = (0..k).flat_map(move |row| (0..k).map(move |col| (row == col) as u8));
94
95    // The `l x k` rows that encode the local parities
96    let local_parity_rows = (0..l).flat_map(move |row| {
97        (0..k).map(move |col| ((col < (row + 1) * k / l) && (col >= row * k / l)) as u8)
98    });
99
100    // The `p x k` rows that encode the global parities
101    let global_parity_rows = (0..r).flat_map(move |row| {
102        gf_bitgroups(l, m)
103            .flat_map(move |group| group.take(k / l))
104            .map(move |value| (GF::from(value).pow(row + 1)).into())
105    });
106
107    Ok(data_rows.chain(local_parity_rows).chain(global_parity_rows))
108}
109
110#[cfg(test)]
111pub mod tests {
112    use crate::gf_bitgroups;
113    g2p::g2p!(GF16, 4, modulus: 0b10011);
114
115    #[test]
116    fn test_gf_bitgroups() {
117        (1..=8)
118            .flat_map(|groups| (1..=8).map(move |m| (groups, m)))
119            .filter(|(groups, m)| m % groups == 0)
120            .inspect(|(groups, m)| println!("Input: groups = {}, m = {}", groups, m))
121            .map(|(groups, m)| ((groups, m), gf_bitgroups(groups, m)))
122            .for_each(|((groups, m), actual)| {
123                assert_eq!(actual.len(), groups);
124                actual
125                    .enumerate()
126                    .map(|(i, group)| {
127                        (
128                            !((2_usize.pow((m / groups) as u32) - 1) << i * (m / groups)) as u8,
129                            group,
130                        )
131                    })
132                    .for_each(|(bitmask, group)| {
133                        println!("Expected bitmask: {:b}", bitmask);
134                        group.for_each(|val| {
135                            print!("{:0len$b} ", val, len = m);
136                            assert_ne!(val, 0);
137                            assert_eq!(val & bitmask as u8, 0);
138                        });
139                        println!();
140                    });
141                println!();
142            })
143    }
144}