Skip to main content

cryptography/modes/
ocb.rs

1//! OCB authenticated encryption (RFC 7253, OCB3).
2//!
3//! This implementation targets 128-bit block ciphers and the default 128-bit
4//! authentication tag profile.
5
6use crate::BlockCipher;
7
8#[inline]
9fn xor_block(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
10    let mut out = [0u8; 16];
11    for i in 0..16 {
12        out[i] = a[i] ^ b[i];
13    }
14    out
15}
16
17#[inline]
18fn dbl_block(block: [u8; 16]) -> [u8; 16] {
19    let mut out = [0u8; 16];
20    let mut carry = 0u8;
21    for i in (0..16).rev() {
22        out[i] = (block[i] << 1) | carry;
23        carry = block[i] >> 7;
24    }
25    if carry != 0 {
26        out[15] ^= 0x87;
27    }
28    out
29}
30
31#[inline]
32fn ntz(i: usize) -> usize {
33    i.trailing_zeros() as usize
34}
35
36#[inline]
37fn split_blocks(data: &[u8]) -> (&[u8], &[u8]) {
38    let full = data.len() / 16 * 16;
39    (&data[..full], &data[full..])
40}
41
42fn nonce_block_from_bytes(tag_len_bits: usize, nonce: &[u8]) -> [u8; 16] {
43    assert!(nonce.len() <= 15, "OCB nonce must be at most 120 bits");
44    let n_bits = nonce.len() * 8;
45    let tag_mod = tag_len_bits % 128;
46
47    let mut n_aligned = [0u8; 16];
48    n_aligned[16 - nonce.len()..].copy_from_slice(nonce);
49    let n_val = u128::from_be_bytes(n_aligned);
50
51    let nonce_val = ((tag_mod as u128) << 121) | (1u128 << n_bits) | n_val;
52    nonce_val.to_be_bytes()
53}
54
55fn stretch_from_ktop(ktop: [u8; 16]) -> [u8; 24] {
56    let mut stretch = [0u8; 24];
57    stretch[..16].copy_from_slice(&ktop);
58    for i in 0..8 {
59        stretch[16 + i] = ktop[i] ^ ktop[i + 1];
60    }
61    stretch
62}
63
64fn offset_from_stretch(stretch: &[u8; 24], bottom: u8) -> [u8; 16] {
65    let byte_off = usize::from(bottom / 8);
66    let bit_off = usize::from(bottom % 8);
67    let mut out = [0u8; 16];
68
69    if bit_off == 0 {
70        out.copy_from_slice(&stretch[byte_off..byte_off + 16]);
71        return out;
72    }
73
74    for (i, out_byte) in out.iter_mut().enumerate() {
75        let b0 = stretch.get(byte_off + i).copied().unwrap_or(0);
76        let b1 = stretch.get(byte_off + i + 1).copied().unwrap_or(0);
77        *out_byte = (b0 << bit_off) | (b1 >> (8 - bit_off));
78    }
79    out
80}
81
82fn hash_associated_data<C: BlockCipher>(
83    cipher: &C,
84    l_star: [u8; 16],
85    l_dollar: [u8; 16],
86    aad: &[u8],
87) -> [u8; 16] {
88    let mut l_table = vec![dbl_block(l_dollar)];
89    let mut sum = [0u8; 16];
90    let mut offset = [0u8; 16];
91
92    let (full, partial) = split_blocks(aad);
93    for (idx, block) in full.chunks_exact(16).enumerate() {
94        // RFC 7253 uses L_{ntz(i)} to advance offsets for full associated-data blocks.
95        let i = idx + 1;
96        let tz = ntz(i);
97        while l_table.len() <= tz {
98            let next = dbl_block(*l_table.last().expect("L table non-empty"));
99            l_table.push(next);
100        }
101        offset = xor_block(&offset, &l_table[tz]);
102
103        let mut x = [0u8; 16];
104        x.copy_from_slice(block);
105        x = xor_block(&x, &offset);
106        cipher.encrypt(&mut x);
107        sum = xor_block(&sum, &x);
108    }
109
110    if !partial.is_empty() {
111        // Final partial AD block uses Offset xor L_* and 10* padding.
112        offset = xor_block(&offset, &l_star);
113        let mut cipher_input = [0u8; 16];
114        cipher_input[..partial.len()].copy_from_slice(partial);
115        cipher_input[partial.len()] = 0x80;
116        cipher_input = xor_block(&cipher_input, &offset);
117        cipher.encrypt(&mut cipher_input);
118        sum = xor_block(&sum, &cipher_input);
119    }
120
121    sum
122}
123
124/// OCB3 authenticated encryption with a 16-byte detached tag.
125pub struct Ocb<C> {
126    cipher: C,
127}
128
129impl<C> Ocb<C> {
130    /// Wrap a 128-bit block cipher in RFC 7253 OCB mode.
131    pub fn new(cipher: C) -> Self {
132        Self { cipher }
133    }
134
135    /// Borrow the wrapped cipher.
136    pub fn cipher(&self) -> &C {
137        &self.cipher
138    }
139}
140
141impl<C: BlockCipher> Ocb<C> {
142    fn compute_offsets(&self, nonce: &[u8]) -> ([u8; 16], [u8; 16], [u8; 16], Vec<[u8; 16]>) {
143        assert_eq!(C::BLOCK_LEN, 16, "OCB requires a 128-bit block cipher");
144        // L_* = E_K(0^128), L_$ = dbl(L_*), L_0 = dbl(L_$) per RFC 7253.
145        let mut l_star = [0u8; 16];
146        self.cipher.encrypt(&mut l_star);
147        let l_dollar = dbl_block(l_star);
148        let l0 = dbl_block(l_dollar);
149
150        let nonce_block = nonce_block_from_bytes(128, nonce);
151        let bottom = nonce_block[15] & 0x3f;
152        let mut ktop_input = nonce_block;
153        ktop_input[15] &= 0xC0;
154        self.cipher.encrypt(&mut ktop_input);
155        // Nonce-dependent Offset_0 is derived from Ktop||Stretch and the
156        // bottom six nonce bits (RFC 7253 §4.2).
157        let stretch = stretch_from_ktop(ktop_input);
158        let offset0 = offset_from_stretch(&stretch, bottom);
159
160        (l_star, l_dollar, offset0, vec![l0])
161    }
162
163    /// Encrypt `data` in place and return a detached 16-byte tag.
164    pub fn encrypt(&self, nonce: &[u8], aad: &[u8], data: &mut [u8]) -> [u8; 16] {
165        let (l_star, l_dollar, mut offset, mut l_table) = self.compute_offsets(nonce);
166        let aad_hash = hash_associated_data(&self.cipher, l_star, l_dollar, aad);
167
168        let (full_len, partial_len) = (data.len() / 16 * 16, data.len() % 16);
169        let mut checksum = [0u8; 16];
170
171        for (idx, block) in data[..full_len].chunks_exact_mut(16).enumerate() {
172            // RFC 7253 §4.2: Offset_i = Offset_{i-1} xor L_{ntz(i)}.
173            let i = idx + 1;
174            let tz = ntz(i);
175            while l_table.len() <= tz {
176                let next = dbl_block(*l_table.last().expect("L table non-empty"));
177                l_table.push(next);
178            }
179            offset = xor_block(&offset, &l_table[tz]);
180
181            let mut p = [0u8; 16];
182            p.copy_from_slice(block);
183            checksum = xor_block(&checksum, &p);
184
185            p = xor_block(&p, &offset);
186            self.cipher.encrypt(&mut p);
187            p = xor_block(&p, &offset);
188            block.copy_from_slice(&p);
189        }
190
191        if partial_len != 0 {
192            // RFC 7253 §4.2 final partial block: Offset_* = Offset_m xor L_*.
193            offset = xor_block(&offset, &l_star);
194            let mut pad = offset;
195            self.cipher.encrypt(&mut pad);
196
197            let partial = &mut data[full_len..];
198            let mut partial_plain = [0u8; 16];
199            partial_plain[..partial.len()].copy_from_slice(partial);
200            for i in 0..partial.len() {
201                partial[i] ^= pad[i];
202            }
203
204            partial_plain[partial.len()] = 0x80;
205            checksum = xor_block(&checksum, &partial_plain);
206        }
207
208        let mut tag_input = xor_block(&checksum, &offset);
209        tag_input = xor_block(&tag_input, &l_dollar);
210        self.cipher.encrypt(&mut tag_input);
211        xor_block(&tag_input, &aad_hash)
212    }
213
214    /// Verify `tag` and decrypt `data` in place on success.
215    pub fn decrypt(&self, nonce: &[u8], aad: &[u8], data: &mut [u8], tag: &[u8; 16]) -> bool {
216        let (l_star, l_dollar, mut offset, mut l_table) = self.compute_offsets(nonce);
217        let aad_hash = hash_associated_data(&self.cipher, l_star, l_dollar, aad);
218
219        let (full_len, partial_len) = (data.len() / 16 * 16, data.len() % 16);
220        let mut checksum = [0u8; 16];
221
222        let mut plaintext = data.to_vec();
223        for (idx, block) in plaintext[..full_len].chunks_exact_mut(16).enumerate() {
224            // RFC 7253 §4.2: Offset_i = Offset_{i-1} xor L_{ntz(i)}.
225            let i = idx + 1;
226            let tz = ntz(i);
227            while l_table.len() <= tz {
228                let next = dbl_block(*l_table.last().expect("L table non-empty"));
229                l_table.push(next);
230            }
231            offset = xor_block(&offset, &l_table[tz]);
232
233            let mut c = [0u8; 16];
234            c.copy_from_slice(block);
235            c = xor_block(&c, &offset);
236            self.cipher.decrypt(&mut c);
237            c = xor_block(&c, &offset);
238            checksum = xor_block(&checksum, &c);
239            block.copy_from_slice(&c);
240        }
241
242        if partial_len != 0 {
243            // RFC 7253 §4.2 final partial block: Offset_* = Offset_m xor L_*.
244            offset = xor_block(&offset, &l_star);
245            let mut pad = offset;
246            self.cipher.encrypt(&mut pad);
247            let partial = &mut plaintext[full_len..];
248            for i in 0..partial.len() {
249                partial[i] ^= pad[i];
250            }
251            let mut padded_p = [0u8; 16];
252            padded_p[..partial.len()].copy_from_slice(partial);
253            padded_p[partial.len()] = 0x80;
254            checksum = xor_block(&checksum, &padded_p);
255        }
256
257        let mut tag_input = xor_block(&checksum, &offset);
258        tag_input = xor_block(&tag_input, &l_dollar);
259        self.cipher.encrypt(&mut tag_input);
260        let expected = xor_block(&tag_input, &aad_hash);
261        if crate::ct::constant_time_eq_mask(&expected, tag) != u8::MAX {
262            crate::ct::zeroize_slice(&mut plaintext);
263            return false;
264        }
265
266        data.copy_from_slice(&plaintext);
267        true
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::Ocb;
274    use crate::Aes128;
275
276    fn unhex_ws(input: &str) -> Vec<u8> {
277        let compact: String = input.chars().filter(|c| !c.is_whitespace()).collect();
278        let mut out = Vec::with_capacity(compact.len() / 2);
279        let bytes = compact.as_bytes();
280        let mut i = 0usize;
281        while i + 1 < bytes.len() {
282            let hi = (bytes[i] as char).to_digit(16).expect("hex") as u8;
283            let lo = (bytes[i + 1] as char).to_digit(16).expect("hex") as u8;
284            out.push((hi << 4) | lo);
285            i += 2;
286        }
287        out
288    }
289
290    #[test]
291    fn rfc7253_sample_vector_1_empty() {
292        let key = <[u8; 16]>::try_from(unhex_ws("000102030405060708090A0B0C0D0E0F")).expect("key");
293        let nonce = unhex_ws("BBAA99887766554433221100");
294        let aad = [];
295        let mut pt = vec![];
296        let expected = unhex_ws("785407BFFFC8AD9EDCC5520AC9111EE6");
297
298        let ocb = Ocb::new(Aes128::new(&key));
299        let tag = ocb.encrypt(&nonce, &aad, &mut pt);
300        assert_eq!(pt, Vec::<u8>::new());
301        assert_eq!(tag.as_slice(), expected.as_slice());
302    }
303
304    #[test]
305    fn rfc7253_sample_vector_2_short_aad_and_pt() {
306        let key = <[u8; 16]>::try_from(unhex_ws("000102030405060708090A0B0C0D0E0F")).expect("key");
307        let nonce = unhex_ws("BBAA99887766554433221101");
308        let aad = unhex_ws("0001020304050607");
309        let mut pt = unhex_ws("0001020304050607");
310        let expected = unhex_ws("6820B3657B6F615A5725BDA0D3B4EB3A257C9AF1F8F03009");
311
312        let ocb = Ocb::new(Aes128::new(&key));
313        let tag = ocb.encrypt(&nonce, &aad, &mut pt);
314        let mut out = pt.clone();
315        out.extend_from_slice(&tag);
316        assert_eq!(out, expected);
317
318        assert!(ocb.decrypt(&nonce, &aad, &mut pt, &tag));
319        assert_eq!(pt, unhex_ws("0001020304050607"));
320    }
321
322    #[test]
323    fn rfc7253_sample_vector_4_short_pt_no_aad() {
324        let key = <[u8; 16]>::try_from(unhex_ws("000102030405060708090A0B0C0D0E0F")).expect("key");
325        let nonce = unhex_ws("BBAA99887766554433221103");
326        let aad = [];
327        let mut pt = unhex_ws("0001020304050607");
328        let expected = unhex_ws("45DD69F8F5AAE72414054CD1F35D82760B2CD00D2F99BFA9");
329
330        let ocb = Ocb::new(Aes128::new(&key));
331        let tag = ocb.encrypt(&nonce, &aad, &mut pt);
332        let mut out = pt.clone();
333        out.extend_from_slice(&tag);
334        assert_eq!(out, expected);
335    }
336
337    #[test]
338    fn ocb_rejects_tampered_tag() {
339        let key = [0x11u8; 16];
340        let nonce = [0x22u8; 12];
341        let aad = b"aad";
342        let mut msg = b"ocb message".to_vec();
343        let ocb = Ocb::new(Aes128::new(&key));
344        let tag = ocb.encrypt(&nonce, aad, &mut msg);
345
346        let mut tampered_tag = tag;
347        tampered_tag[0] ^= 1;
348        let snapshot = msg.clone();
349        assert!(!ocb.decrypt(&nonce, aad, &mut msg, &tampered_tag));
350        assert_eq!(msg, snapshot);
351    }
352}