1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use g2p::GaloisField;
use std::iter::FusedIterator;
use std::mem::size_of;

/// Select n evenly sized groups of values in `GF(2^m)`.
///
/// Every group has a different set of `m / groups` bits set to 0.
///
/// # Examples
/// ```
/// use local_reconstruction_code_gen::gf_bitgroups;
///
/// let mut groups = gf_bitgroups(2,4).map(Iterator::collect);
/// assert_eq!(groups.next(), Some(vec![0b0001, 0b0010, 0b0011]));
/// assert_eq!(groups.next(), Some(vec![0b0100, 0b1000, 0b1100]));
/// assert_eq!(groups.next(), None);
/// ```
pub fn gf_bitgroups(
    groups: usize,
    m: usize,
) -> impl Iterator<
    Item = impl Iterator<Item = u8> + DoubleEndedIterator + ExactSizeIterator + FusedIterator,
> + DoubleEndedIterator
       + ExactSizeIterator
       + FusedIterator {
    assert!(m % groups == 0, "Parameter `m` must divide by `groups`");
    assert!(m <= 8);

    let group_len = m / groups;
    (0..groups).map(move |group| {
        (1..=1_u8
            .checked_shl(group_len as u32)
            .map_or(u8::MAX, |v| v - 1))
            .map(move |v| v << group * group_len)
    })
}

#[derive(Debug, PartialEq)]
pub enum MatrixGenError {
    /// Only values of `r` up to 2 are supported.
    UnsupportedInput,
    /// The provided field `GF(p)` must have a size such that `p / l - 1 < k / l`.
    FieldTooSmall,
}

/// Generate encode matrix for a `(k,l,r)` Local Reconstruction Code.
///
/// The `k` data symbols divided into `l` groups with each having a single local parity symbol and
/// `r` global parity symbols.
///
/// # Examples
/// ```
/// use local_reconstruction_code_gen::gen_encode_matrix;
/// use g2p::g2p;
///
/// g2p!(GF16, 4, modulus: 0b10011);
/// # fn main() {
/// let encode_matrix = gen_encode_matrix::<GF16>(6, 2, 2).map(Iterator::collect);
/// assert_eq!(
///     encode_matrix,
///     Ok(vec![
///         0b0001, 0b0000, 0b0000, 0b0000, 0b0000, 0b0000,
///         0b0000, 0b0001, 0b0000, 0b0000, 0b0000, 0b0000,
///         0b0000, 0b0000, 0b0001, 0b0000, 0b0000, 0b0000,
///         0b0000, 0b0000, 0b0000, 0b0001, 0b0000, 0b0000,
///         0b0000, 0b0000, 0b0000, 0b0000, 0b0001, 0b0000,
///         0b0000, 0b0000, 0b0000, 0b0000, 0b0000, 0b0001,
///         0b0001, 0b0001, 0b0001, 0b0000, 0b0000, 0b0000,
///         0b0000, 0b0000, 0b0000, 0b0001, 0b0001, 0b0001,
///         0b0001, 0b0010, 0b0011, 0b0100, 0b1000, 0b1100,
///         0b0001, 0b0100, 0b0101, 0b0011, 0b1100, 0b1111,
///     ])
/// );
/// # }
/// ```
pub fn gen_encode_matrix<'a, GF: 'a + From<u8> + Into<u8> + GaloisField>(
    k: usize,
    l: usize,
    r: usize,
) -> Result<impl Iterator<Item = u8> + DoubleEndedIterator + FusedIterator + 'a, MatrixGenError> {
    if r > 2 {
        return Err(MatrixGenError::UnsupportedInput);
    }

    if GF::SIZE / l - 1 < k / l {
        return Err(MatrixGenError::FieldTooSmall);
    }

    // Log(2) to get m
    let m = size_of::<usize>() * 8 - GF::SIZE.leading_zeros() as usize - 1;

    // The `k x k` identity matrix
    let data_rows = (0..k).flat_map(move |row| (0..k).map(move |col| (row == col) as u8));

    // The `l x k` rows that encode the local parities
    let local_parity_rows = (0..l).flat_map(move |row| {
        (0..k).map(move |col| ((col < (row + 1) * k / l) && (col >= row * k / l)) as u8)
    });

    // The `p x k` rows that encode the global parities
    let global_parity_rows = (0..r).flat_map(move |row| {
        gf_bitgroups(l, m)
            .flat_map(move |group| group.take(k / l))
            .map(move |value| (GF::from(value).pow(row + 1)).into())
    });

    Ok(data_rows.chain(local_parity_rows).chain(global_parity_rows))
}

#[cfg(test)]
pub mod tests {
    use crate::gf_bitgroups;
    g2p::g2p!(GF16, 4, modulus: 0b10011);

    #[test]
    fn test_gf_bitgroups() {
        (1..=8)
            .flat_map(|groups| (1..=8).map(move |m| (groups, m)))
            .filter(|(groups, m)| m % groups == 0)
            .inspect(|(groups, m)| println!("Input: groups = {}, m = {}", groups, m))
            .map(|(groups, m)| ((groups, m), gf_bitgroups(groups, m)))
            .for_each(|((groups, m), actual)| {
                assert_eq!(actual.len(), groups);
                actual
                    .enumerate()
                    .map(|(i, group)| {
                        (
                            !((2_usize.pow((m / groups) as u32) - 1) << i * (m / groups)) as u8,
                            group,
                        )
                    })
                    .for_each(|(bitmask, group)| {
                        println!("Expected bitmask: {:b}", bitmask);
                        group.for_each(|val| {
                            print!("{:0len$b} ", val, len = m);
                            assert_ne!(val, 0);
                            assert_eq!(val & bitmask as u8, 0);
                        });
                        println!();
                    });
                println!();
            })
    }
}