1use crate::BlockCipher;
7
8#[inline]
9fn xor_block(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
10 let mut out = [0u8; 16];
11 for i in 0..16 {
12 out[i] = a[i] ^ b[i];
13 }
14 out
15}
16
17#[inline]
18fn dbl_block(block: [u8; 16]) -> [u8; 16] {
19 let mut out = [0u8; 16];
20 let mut carry = 0u8;
21 for i in (0..16).rev() {
22 out[i] = (block[i] << 1) | carry;
23 carry = block[i] >> 7;
24 }
25 if carry != 0 {
26 out[15] ^= 0x87;
27 }
28 out
29}
30
31#[inline]
32fn ntz(i: usize) -> usize {
33 i.trailing_zeros() as usize
34}
35
36#[inline]
37fn split_blocks(data: &[u8]) -> (&[u8], &[u8]) {
38 let full = data.len() / 16 * 16;
39 (&data[..full], &data[full..])
40}
41
42fn nonce_block_from_bytes(tag_len_bits: usize, nonce: &[u8]) -> [u8; 16] {
43 assert!(nonce.len() <= 15, "OCB nonce must be at most 120 bits");
44 let n_bits = nonce.len() * 8;
45 let tag_mod = tag_len_bits % 128;
46
47 let mut n_aligned = [0u8; 16];
48 n_aligned[16 - nonce.len()..].copy_from_slice(nonce);
49 let n_val = u128::from_be_bytes(n_aligned);
50
51 let nonce_val = ((tag_mod as u128) << 121) | (1u128 << n_bits) | n_val;
52 nonce_val.to_be_bytes()
53}
54
55fn stretch_from_ktop(ktop: [u8; 16]) -> [u8; 24] {
56 let mut stretch = [0u8; 24];
57 stretch[..16].copy_from_slice(&ktop);
58 for i in 0..8 {
59 stretch[16 + i] = ktop[i] ^ ktop[i + 1];
60 }
61 stretch
62}
63
64fn offset_from_stretch(stretch: &[u8; 24], bottom: u8) -> [u8; 16] {
65 let byte_off = usize::from(bottom / 8);
66 let bit_off = usize::from(bottom % 8);
67 let mut out = [0u8; 16];
68
69 if bit_off == 0 {
70 out.copy_from_slice(&stretch[byte_off..byte_off + 16]);
71 return out;
72 }
73
74 for (i, out_byte) in out.iter_mut().enumerate() {
75 let b0 = stretch.get(byte_off + i).copied().unwrap_or(0);
76 let b1 = stretch.get(byte_off + i + 1).copied().unwrap_or(0);
77 *out_byte = (b0 << bit_off) | (b1 >> (8 - bit_off));
78 }
79 out
80}
81
82fn hash_associated_data<C: BlockCipher>(
83 cipher: &C,
84 l_star: [u8; 16],
85 l_dollar: [u8; 16],
86 aad: &[u8],
87) -> [u8; 16] {
88 let mut l_table = vec![dbl_block(l_dollar)];
89 let mut sum = [0u8; 16];
90 let mut offset = [0u8; 16];
91
92 let (full, partial) = split_blocks(aad);
93 for (idx, block) in full.chunks_exact(16).enumerate() {
94 let i = idx + 1;
96 let tz = ntz(i);
97 while l_table.len() <= tz {
98 let next = dbl_block(*l_table.last().expect("L table non-empty"));
99 l_table.push(next);
100 }
101 offset = xor_block(&offset, &l_table[tz]);
102
103 let mut x = [0u8; 16];
104 x.copy_from_slice(block);
105 x = xor_block(&x, &offset);
106 cipher.encrypt(&mut x);
107 sum = xor_block(&sum, &x);
108 }
109
110 if !partial.is_empty() {
111 offset = xor_block(&offset, &l_star);
113 let mut cipher_input = [0u8; 16];
114 cipher_input[..partial.len()].copy_from_slice(partial);
115 cipher_input[partial.len()] = 0x80;
116 cipher_input = xor_block(&cipher_input, &offset);
117 cipher.encrypt(&mut cipher_input);
118 sum = xor_block(&sum, &cipher_input);
119 }
120
121 sum
122}
123
124pub struct Ocb<C> {
126 cipher: C,
127}
128
129impl<C> Ocb<C> {
130 pub fn new(cipher: C) -> Self {
132 Self { cipher }
133 }
134
135 pub fn cipher(&self) -> &C {
137 &self.cipher
138 }
139}
140
141impl<C: BlockCipher> Ocb<C> {
142 fn compute_offsets(&self, nonce: &[u8]) -> ([u8; 16], [u8; 16], [u8; 16], Vec<[u8; 16]>) {
143 assert_eq!(C::BLOCK_LEN, 16, "OCB requires a 128-bit block cipher");
144 let mut l_star = [0u8; 16];
146 self.cipher.encrypt(&mut l_star);
147 let l_dollar = dbl_block(l_star);
148 let l0 = dbl_block(l_dollar);
149
150 let nonce_block = nonce_block_from_bytes(128, nonce);
151 let bottom = nonce_block[15] & 0x3f;
152 let mut ktop_input = nonce_block;
153 ktop_input[15] &= 0xC0;
154 self.cipher.encrypt(&mut ktop_input);
155 let stretch = stretch_from_ktop(ktop_input);
158 let offset0 = offset_from_stretch(&stretch, bottom);
159
160 (l_star, l_dollar, offset0, vec![l0])
161 }
162
163 pub fn encrypt(&self, nonce: &[u8], aad: &[u8], data: &mut [u8]) -> [u8; 16] {
165 let (l_star, l_dollar, mut offset, mut l_table) = self.compute_offsets(nonce);
166 let aad_hash = hash_associated_data(&self.cipher, l_star, l_dollar, aad);
167
168 let (full_len, partial_len) = (data.len() / 16 * 16, data.len() % 16);
169 let mut checksum = [0u8; 16];
170
171 for (idx, block) in data[..full_len].chunks_exact_mut(16).enumerate() {
172 let i = idx + 1;
174 let tz = ntz(i);
175 while l_table.len() <= tz {
176 let next = dbl_block(*l_table.last().expect("L table non-empty"));
177 l_table.push(next);
178 }
179 offset = xor_block(&offset, &l_table[tz]);
180
181 let mut p = [0u8; 16];
182 p.copy_from_slice(block);
183 checksum = xor_block(&checksum, &p);
184
185 p = xor_block(&p, &offset);
186 self.cipher.encrypt(&mut p);
187 p = xor_block(&p, &offset);
188 block.copy_from_slice(&p);
189 }
190
191 if partial_len != 0 {
192 offset = xor_block(&offset, &l_star);
194 let mut pad = offset;
195 self.cipher.encrypt(&mut pad);
196
197 let partial = &mut data[full_len..];
198 let mut partial_plain = [0u8; 16];
199 partial_plain[..partial.len()].copy_from_slice(partial);
200 for i in 0..partial.len() {
201 partial[i] ^= pad[i];
202 }
203
204 partial_plain[partial.len()] = 0x80;
205 checksum = xor_block(&checksum, &partial_plain);
206 }
207
208 let mut tag_input = xor_block(&checksum, &offset);
209 tag_input = xor_block(&tag_input, &l_dollar);
210 self.cipher.encrypt(&mut tag_input);
211 xor_block(&tag_input, &aad_hash)
212 }
213
214 pub fn decrypt(&self, nonce: &[u8], aad: &[u8], data: &mut [u8], tag: &[u8; 16]) -> bool {
216 let (l_star, l_dollar, mut offset, mut l_table) = self.compute_offsets(nonce);
217 let aad_hash = hash_associated_data(&self.cipher, l_star, l_dollar, aad);
218
219 let (full_len, partial_len) = (data.len() / 16 * 16, data.len() % 16);
220 let mut checksum = [0u8; 16];
221
222 let mut plaintext = data.to_vec();
223 for (idx, block) in plaintext[..full_len].chunks_exact_mut(16).enumerate() {
224 let i = idx + 1;
226 let tz = ntz(i);
227 while l_table.len() <= tz {
228 let next = dbl_block(*l_table.last().expect("L table non-empty"));
229 l_table.push(next);
230 }
231 offset = xor_block(&offset, &l_table[tz]);
232
233 let mut c = [0u8; 16];
234 c.copy_from_slice(block);
235 c = xor_block(&c, &offset);
236 self.cipher.decrypt(&mut c);
237 c = xor_block(&c, &offset);
238 checksum = xor_block(&checksum, &c);
239 block.copy_from_slice(&c);
240 }
241
242 if partial_len != 0 {
243 offset = xor_block(&offset, &l_star);
245 let mut pad = offset;
246 self.cipher.encrypt(&mut pad);
247 let partial = &mut plaintext[full_len..];
248 for i in 0..partial.len() {
249 partial[i] ^= pad[i];
250 }
251 let mut padded_p = [0u8; 16];
252 padded_p[..partial.len()].copy_from_slice(partial);
253 padded_p[partial.len()] = 0x80;
254 checksum = xor_block(&checksum, &padded_p);
255 }
256
257 let mut tag_input = xor_block(&checksum, &offset);
258 tag_input = xor_block(&tag_input, &l_dollar);
259 self.cipher.encrypt(&mut tag_input);
260 let expected = xor_block(&tag_input, &aad_hash);
261 if crate::ct::constant_time_eq_mask(&expected, tag) != u8::MAX {
262 crate::ct::zeroize_slice(&mut plaintext);
263 return false;
264 }
265
266 data.copy_from_slice(&plaintext);
267 true
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::Ocb;
274 use crate::Aes128;
275
276 fn unhex_ws(input: &str) -> Vec<u8> {
277 let compact: String = input.chars().filter(|c| !c.is_whitespace()).collect();
278 let mut out = Vec::with_capacity(compact.len() / 2);
279 let bytes = compact.as_bytes();
280 let mut i = 0usize;
281 while i + 1 < bytes.len() {
282 let hi = (bytes[i] as char).to_digit(16).expect("hex") as u8;
283 let lo = (bytes[i + 1] as char).to_digit(16).expect("hex") as u8;
284 out.push((hi << 4) | lo);
285 i += 2;
286 }
287 out
288 }
289
290 #[test]
291 fn rfc7253_sample_vector_1_empty() {
292 let key = <[u8; 16]>::try_from(unhex_ws("000102030405060708090A0B0C0D0E0F")).expect("key");
293 let nonce = unhex_ws("BBAA99887766554433221100");
294 let aad = [];
295 let mut pt = vec![];
296 let expected = unhex_ws("785407BFFFC8AD9EDCC5520AC9111EE6");
297
298 let ocb = Ocb::new(Aes128::new(&key));
299 let tag = ocb.encrypt(&nonce, &aad, &mut pt);
300 assert_eq!(pt, Vec::<u8>::new());
301 assert_eq!(tag.as_slice(), expected.as_slice());
302 }
303
304 #[test]
305 fn rfc7253_sample_vector_2_short_aad_and_pt() {
306 let key = <[u8; 16]>::try_from(unhex_ws("000102030405060708090A0B0C0D0E0F")).expect("key");
307 let nonce = unhex_ws("BBAA99887766554433221101");
308 let aad = unhex_ws("0001020304050607");
309 let mut pt = unhex_ws("0001020304050607");
310 let expected = unhex_ws("6820B3657B6F615A5725BDA0D3B4EB3A257C9AF1F8F03009");
311
312 let ocb = Ocb::new(Aes128::new(&key));
313 let tag = ocb.encrypt(&nonce, &aad, &mut pt);
314 let mut out = pt.clone();
315 out.extend_from_slice(&tag);
316 assert_eq!(out, expected);
317
318 assert!(ocb.decrypt(&nonce, &aad, &mut pt, &tag));
319 assert_eq!(pt, unhex_ws("0001020304050607"));
320 }
321
322 #[test]
323 fn rfc7253_sample_vector_4_short_pt_no_aad() {
324 let key = <[u8; 16]>::try_from(unhex_ws("000102030405060708090A0B0C0D0E0F")).expect("key");
325 let nonce = unhex_ws("BBAA99887766554433221103");
326 let aad = [];
327 let mut pt = unhex_ws("0001020304050607");
328 let expected = unhex_ws("45DD69F8F5AAE72414054CD1F35D82760B2CD00D2F99BFA9");
329
330 let ocb = Ocb::new(Aes128::new(&key));
331 let tag = ocb.encrypt(&nonce, &aad, &mut pt);
332 let mut out = pt.clone();
333 out.extend_from_slice(&tag);
334 assert_eq!(out, expected);
335 }
336
337 #[test]
338 fn ocb_rejects_tampered_tag() {
339 let key = [0x11u8; 16];
340 let nonce = [0x22u8; 12];
341 let aad = b"aad";
342 let mut msg = b"ocb message".to_vec();
343 let ocb = Ocb::new(Aes128::new(&key));
344 let tag = ocb.encrypt(&nonce, aad, &mut msg);
345
346 let mut tampered_tag = tag;
347 tampered_tag[0] ^= 1;
348 let snapshot = msg.clone();
349 assert!(!ocb.decrypt(&nonce, aad, &mut msg, &tampered_tag));
350 assert_eq!(msg, snapshot);
351 }
352}