1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
use crate::matrix::{apply_matrix, vec_add};
use crate::mds::MDSMatrices;
use crate::quintic_s_box;
use ff::{Field, ScalarEngine};

// - Compress constants by pushing them back through linear layers and through the identity components of partial layers.
// - As a result, constants need only be added after each S-box.
pub(crate) fn compress_round_constants<E: ScalarEngine>(
    width: usize,
    full_rounds: usize,
    partial_rounds: usize,
    round_constants: &Vec<E::Fr>,
    mds_matrices: &MDSMatrices<E>,
    partial_preprocessed: usize,
) -> Vec<E::Fr> {
    let mds_matrix = &mds_matrices.m;
    let inverse_matrix = &mds_matrices.m_inv;

    let mut res = Vec::new();

    let round_keys = |r: usize| &round_constants[r * width..(r + 1) * width];

    let half_full_rounds = full_rounds / 2; // Not half-full rounds; half full-rounds.

    // First round constants are unchanged.
    res.extend(round_keys(0));

    let unpreprocessed = partial_rounds - partial_preprocessed;

    // Post S-box adds for the first set of full rounds should be 'inverted' from next round.
    // The final round is skipped when fully preprocessing because that value must be obtained from the result of preprocesing the partial rounds.
    let end = if unpreprocessed > 0 {
        half_full_rounds
    } else {
        half_full_rounds - 1
    };
    for i in 0..end {
        let next_round = round_keys(i + 1); // First round was added before any S-boxes.
        let inverted = apply_matrix::<E>(inverse_matrix, next_round);
        res.extend(inverted);
    }

    // The plan:
    // - Work backwards from last row in this group
    // - Invert the row.
    // - Save first constant (corresponding to the one S-box performed).
    // - Add inverted result to previous row.
    // - Repeat until all partial round key rows have been consumed.
    // - Extend the preprocessed result by the final resultant row.
    // - Move the accumulated list of single round keys to the preprocessed result.
    //   - (Last produced should be first applied, so either pop until empty, or reverse and extend, etc.

    // `partial_keys` will accumulate the single post-S-box constant for each partial-round, in reverse order.
    let mut partial_keys: Vec<E::Fr> = Vec::new();

    let final_round = half_full_rounds + partial_rounds;
    let final_round_key = round_keys(final_round).to_vec();

    // `round_acc` holds the accumulated result of inverting and adding subsequent round constants (in reverse).
    let round_acc = (0..partial_preprocessed)
        .map(|i| round_keys(final_round - i - 1))
        .fold(final_round_key, |acc, previous_round_keys| {
            let mut inverted = apply_matrix::<E>(inverse_matrix, &acc);

            partial_keys.push(inverted[0]);
            inverted[0] = E::Fr::zero();

            vec_add::<E>(&previous_round_keys, &inverted)
        });

    // Everything in here is dev-driven testing.
    // Dev test case only checks one deep.
    if partial_preprocessed == 1 {
        // Check assumptions about how the fold calculating round_acc  manifested.

        // The last round containing unpreprocessed constants which should be compressed.
        let terminal_constants_round = half_full_rounds + partial_rounds;

        // Constants from the last round (of two) which should be compressed.
        // T
        let terminal_round_keys = round_keys(terminal_constants_round);

        // Constants from the first round (of two) which should be compressed.
        // I
        let initial_round_keys = round_keys(terminal_constants_round - 1);

        // M^-1(T)
        let mut inv = apply_matrix::<E>(inverse_matrix, terminal_round_keys);

        // M^-1(T)[0]
        let pk = inv[0];

        // M^-1(T) - pk (kinda)
        inv[0] = E::Fr::zero();

        // (M^-1(T) - pk) - I
        let result_key = vec_add::<E>(&initial_round_keys, &inv);

        assert_eq!(&result_key, &round_acc, "Acc assumption failed.");
        assert_eq!(pk, partial_keys[0], "Partial-key assumption failed.");
        assert_eq!(
            1,
            partial_keys.len(),
            "Partial-keys length assumption failed."
        );

        ////////////////////////////////////////////////////////////////////////////////
        // Shared between branches, arbitrary initial state representing the output of a previous round's S-Box layer.
        // X
        let initial_state = vec![E::Fr::one(); width];

        ////////////////////////////////////////////////////////////////////////////////
        // Compute one step with the given (unpreprocessed) constants.

        // ARK
        // I + X
        let mut q_state = vec_add::<E>(initial_round_keys, &initial_state);

        // S-Box (partial layer)
        // S((I + X)[0]) = S(I[0] + X[0])
        quintic_s_box::<E>(&mut q_state[0], None, None);

        // Mix with mds_matrix
        let mixed = apply_matrix::<E>(mds_matrix, &q_state);

        // Ark
        let plain_result = vec_add::<E>(terminal_round_keys, &mixed);

        ////////////////////////////////////////////////////////////////////////////////
        // Compute the same step using the preprocessed constants.
        // M'(initial_state) + (inverted_id - initial_state) = inverted_id
        //let initial_state1 = apply_matrix::<E>(&m_prime, &initial_state);
        let mut p_state = vec_add::<E>(&result_key, &initial_state);

        // In order for the S-box result to be correct, it must have the same input as in the plain path.
        // That means its input (the first component of the state) must have been constructed by
        // adding the same single round constant in that position.
        // NOTE: this asssertion uncovered a bug which was causing failure.
        assert_eq!(
            &result_key[0], &initial_round_keys[0],
            "S-box inputs did not match."
        );

        quintic_s_box::<E>(&mut p_state[0], None, Some(&pk));

        let preprocessed_result = apply_matrix::<E>(&mds_matrix, &p_state);

        assert_eq!(
            plain_result, preprocessed_result,
            "Single preprocessing step couldn't be verified."
        );
    }

    for i in 1..unpreprocessed {
        res.extend(round_keys(half_full_rounds + i));
    }
    res.extend(apply_matrix::<E>(inverse_matrix, &round_acc));

    while let Some(x) = partial_keys.pop() {
        res.push(x)
    }

    // Post S-box adds for the first set of full rounds should be 'inverted' from next round.
    for i in 1..(half_full_rounds) {
        let start = half_full_rounds + partial_rounds;
        let next_round = round_keys(i + start);
        let inverted = apply_matrix::<E>(inverse_matrix, next_round);
        res.extend(inverted);
    }

    res
}