1use alloc::boxed::Box;
41
42use aes::cipher::{BlockCipherDecrypt, BlockCipherEncrypt, KeyInit};
43use aes::{Aes256, Block};
44
45use crate::{
46 types::{VolumeCipher, VolumeCipherSupplier},
47 VckError, VckResult,
48};
49
50const BATCH: usize = 8;
53
54pub struct XtsVolumeCipher {
55 cipher_1: Aes256,
57 cipher_2: Aes256,
59}
60
61impl VolumeCipher for XtsVolumeCipher {
62 fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
63 XtsVolumeCipher::encrypt_sector(self, rel_sector, sector)
64 }
65 fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
66 XtsVolumeCipher::decrypt_sector(self, rel_sector, sector)
67 }
68 fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
69 XtsVolumeCipher::encrypt_area(self, buf, sector_size, first_rel_sector)
70 }
71 fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
72 XtsVolumeCipher::decrypt_area(self, buf, sector_size, first_rel_sector)
73 }
74 }
79
80pub struct StaticCipherSupplier {
92 key1: [u8; 32],
93 key2: [u8; 32],
94}
95
96impl StaticCipherSupplier {
97 pub fn new(key1: [u8; 32], key2: [u8; 32]) -> Self {
98 Self { key1, key2 }
99 }
100}
101
102impl VolumeCipherSupplier for StaticCipherSupplier {
103 fn get_cipher(&self) -> Option<Box<dyn VolumeCipher>> {
104 XtsVolumeCipher::new(&self.key1, &self.key2)
105 .ok()
106 .map(|c| Box::new(c) as Box<dyn VolumeCipher>)
107 }
108}
109
110#[inline(always)]
113fn gf128_mul(t: Block) -> Block {
114 let lo = u64::from_le_bytes(t[..8].try_into().unwrap());
115 let hi = u64::from_le_bytes(t[8..].try_into().unwrap());
116 let carry = if hi >> 63 != 0 { 0x87u64 } else { 0u64 };
117 let mut out = Block::default();
118 out[..8].copy_from_slice(&((lo << 1) ^ carry).to_le_bytes());
119 out[8..].copy_from_slice(&((hi << 1) | (lo >> 63)).to_le_bytes());
120 out
121}
122
123impl XtsVolumeCipher {
124 pub fn new(key1: &[u8; 32], key2: &[u8; 32]) -> VckResult<Self> {
125 let cipher_1 =
126 Aes256::new_from_slice(key1).map_err(|_| VckError::CryptoFailed("invalid XTS key1"))?;
127 let cipher_2 =
128 Aes256::new_from_slice(key2).map_err(|_| VckError::CryptoFailed("invalid XTS key2"))?;
129 Ok(Self { cipher_1, cipher_2 })
130 }
131
132 pub fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
134 self.encrypt_sector_inner(rel_sector, sector);
135 }
136
137 pub fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
139 self.decrypt_sector_inner(rel_sector, sector);
140 }
141
142 pub fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
145 for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
146 self.encrypt_sector_inner(first_rel_sector + si as u64, sector);
147 }
148 }
149
150 pub fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
152 for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
153 self.decrypt_sector_inner(first_rel_sector + si as u64, sector);
154 }
155 }
156
157 #[inline(never)]
160 fn encrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
161 let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
163 self.cipher_2.encrypt_block(&mut tw);
164
165 let n = sector.len() / 16;
166 let mut off = 0;
167
168 while off + BATCH <= n {
171 let mut ts = [Block::default(); BATCH];
172 ts[0] = tw;
173 for i in 1..BATCH {
174 ts[i] = gf128_mul(ts[i - 1]);
175 }
176 tw = gf128_mul(ts[BATCH - 1]);
177
178 let mut batch = [Block::default(); BATCH];
179 for i in 0..BATCH {
180 let src = §or[(off + i) * 16..(off + i + 1) * 16];
181 for j in 0..16 {
182 batch[i][j] = src[j] ^ ts[i][j];
183 }
184 }
185 self.cipher_1.encrypt_blocks(&mut batch);
186 for i in 0..BATCH {
187 let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
188 for j in 0..16 {
189 dst[j] = batch[i][j] ^ ts[i][j];
190 }
191 }
192 off += BATCH;
193 }
194
195 while off < n {
197 let block = &mut sector[off * 16..(off + 1) * 16];
198 for j in 0..16 {
199 block[j] ^= tw[j];
200 }
201 let mut ga: Block = Block::try_from(&block[..]).unwrap();
202 self.cipher_1.encrypt_block(&mut ga);
203 block.copy_from_slice(&ga);
204 for j in 0..16 {
205 block[j] ^= tw[j];
206 }
207 tw = gf128_mul(tw);
208 off += 1;
209 }
210 }
211
212 #[inline(never)]
213 fn decrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
214 let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
216 self.cipher_2.encrypt_block(&mut tw);
217
218 let n = sector.len() / 16;
219 let mut off = 0;
220
221 while off + BATCH <= n {
222 let mut ts = [Block::default(); BATCH];
223 ts[0] = tw;
224 for i in 1..BATCH {
225 ts[i] = gf128_mul(ts[i - 1]);
226 }
227 tw = gf128_mul(ts[BATCH - 1]);
228
229 let mut batch = [Block::default(); BATCH];
230 for i in 0..BATCH {
231 let src = §or[(off + i) * 16..(off + i + 1) * 16];
232 for j in 0..16 {
233 batch[i][j] = src[j] ^ ts[i][j];
234 }
235 }
236 self.cipher_1.decrypt_blocks(&mut batch);
237 for i in 0..BATCH {
238 let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
239 for j in 0..16 {
240 dst[j] = batch[i][j] ^ ts[i][j];
241 }
242 }
243 off += BATCH;
244 }
245
246 while off < n {
247 let block = &mut sector[off * 16..(off + 1) * 16];
248 for j in 0..16 {
249 block[j] ^= tw[j];
250 }
251 let mut ga: Block = Block::try_from(&block[..]).unwrap();
252 self.cipher_1.decrypt_block(&mut ga);
253 block.copy_from_slice(&ga);
254 for j in 0..16 {
255 block[j] ^= tw[j];
256 }
257 tw = gf128_mul(tw);
258 off += 1;
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use alloc::vec::Vec;
267 use xts_mode::{get_tweak_default, Xts128};
269
270 const KEY1: [u8; 32] = [0x11; 32];
271 const KEY2: [u8; 32] = [0x22; 32];
272
273 fn reference() -> Xts128<Aes256> {
275 let c1 = Aes256::new_from_slice(&KEY1).unwrap();
276 let c2 = Aes256::new_from_slice(&KEY2).unwrap();
277 Xts128::new(c1, c2)
278 }
279
280 #[test]
281 fn sector_roundtrip() {
282 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
283 let plain: Vec<u8> = (0..512).map(|i| i as u8).collect();
284 let mut buf = plain.clone();
285 c.encrypt_sector(42, &mut buf);
286 assert_ne!(buf, plain, "ciphertext must differ from plaintext");
287 c.decrypt_sector(42, &mut buf);
288 assert_eq!(buf, plain);
289 }
290
291 #[test]
292 fn tweak_depends_on_sector() {
293 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
294 let plain = [0xABu8; 512];
295 let mut a = plain;
296 let mut b = plain;
297 c.encrypt_sector(0, &mut a);
298 c.encrypt_sector(1, &mut b);
299 assert_ne!(a, b, "same plaintext at different sectors must differ");
300 }
301
302 #[test]
305 fn matches_xts_mode_reference() {
306 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
307 let xts = reference();
308 let sector_size = 512usize;
309 let first = 7u64;
310 let plain: Vec<u8> = (0..sector_size * 3).map(|i| (i * 7) as u8).collect();
311
312 let mut ours = plain.clone();
313 c.encrypt_area(&mut ours, sector_size, first);
314
315 let mut refer = plain.clone();
316 for s in 0..3u64 {
317 let start = s as usize * sector_size;
318 xts.encrypt_sector(
319 &mut refer[start..start + sector_size],
320 get_tweak_default((first + s) as u128),
321 );
322 }
323 assert_eq!(ours, refer, "parallel XTS must match xts-mode reference");
324
325 c.decrypt_area(&mut ours, sector_size, first);
326 assert_eq!(ours, plain);
327 }
328
329 #[test]
332 fn small_sector_roundtrip() {
333 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
334 let sector_size = 64usize;
335 let plain: Vec<u8> = (0..sector_size * 5).map(|i| i as u8).collect();
336 let mut buf = plain.clone();
337 c.encrypt_area(&mut buf, sector_size, 0);
338 assert_ne!(buf, plain);
339 c.decrypt_area(&mut buf, sector_size, 0);
340 assert_eq!(buf, plain);
341 }
342}