1use aes::cipher::{BlockCipherDecrypt, BlockCipherEncrypt, KeyInit};
41use aes::{Aes256, Block};
42
43use crate::{VckError, VckResult};
44
45const BATCH: usize = 8;
48
49pub trait VolumeCipher: Send + Sync {
57 fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]);
59 fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]);
61 fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64);
63 fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64);
65}
66
67pub struct XtsVolumeCipher {
68 cipher_1: Aes256,
70 cipher_2: Aes256,
72}
73
74impl VolumeCipher for XtsVolumeCipher {
75 fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
76 XtsVolumeCipher::encrypt_sector(self, rel_sector, sector)
77 }
78 fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
79 XtsVolumeCipher::decrypt_sector(self, rel_sector, sector)
80 }
81 fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
82 XtsVolumeCipher::encrypt_area(self, buf, sector_size, first_rel_sector)
83 }
84 fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
85 XtsVolumeCipher::decrypt_area(self, buf, sector_size, first_rel_sector)
86 }
87}
88
89#[inline(always)]
92fn gf128_mul(t: Block) -> Block {
93 let lo = u64::from_le_bytes(t[..8].try_into().unwrap());
94 let hi = u64::from_le_bytes(t[8..].try_into().unwrap());
95 let carry = if hi >> 63 != 0 { 0x87u64 } else { 0u64 };
96 let mut out = Block::default();
97 out[..8].copy_from_slice(&((lo << 1) ^ carry).to_le_bytes());
98 out[8..].copy_from_slice(&((hi << 1) | (lo >> 63)).to_le_bytes());
99 out
100}
101
102impl XtsVolumeCipher {
103 pub fn new(key1: &[u8; 32], key2: &[u8; 32]) -> VckResult<Self> {
104 let cipher_1 =
105 Aes256::new_from_slice(key1).map_err(|_| VckError::CryptoFailed("invalid XTS key1"))?;
106 let cipher_2 =
107 Aes256::new_from_slice(key2).map_err(|_| VckError::CryptoFailed("invalid XTS key2"))?;
108 Ok(Self { cipher_1, cipher_2 })
109 }
110
111 pub fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
113 self.encrypt_sector_inner(rel_sector, sector);
114 }
115
116 pub fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
118 self.decrypt_sector_inner(rel_sector, sector);
119 }
120
121 pub fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
124 for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
125 self.encrypt_sector_inner(first_rel_sector + si as u64, sector);
126 }
127 }
128
129 pub fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
131 for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
132 self.decrypt_sector_inner(first_rel_sector + si as u64, sector);
133 }
134 }
135
136 #[inline(never)]
139 fn encrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
140 let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
142 self.cipher_2.encrypt_block(&mut tw);
143
144 let n = sector.len() / 16;
145 let mut off = 0;
146
147 while off + BATCH <= n {
150 let mut ts = [Block::default(); BATCH];
151 ts[0] = tw;
152 for i in 1..BATCH {
153 ts[i] = gf128_mul(ts[i - 1]);
154 }
155 tw = gf128_mul(ts[BATCH - 1]);
156
157 let mut batch = [Block::default(); BATCH];
158 for i in 0..BATCH {
159 let src = §or[(off + i) * 16..(off + i + 1) * 16];
160 for j in 0..16 {
161 batch[i][j] = src[j] ^ ts[i][j];
162 }
163 }
164 self.cipher_1.encrypt_blocks(&mut batch);
165 for i in 0..BATCH {
166 let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
167 for j in 0..16 {
168 dst[j] = batch[i][j] ^ ts[i][j];
169 }
170 }
171 off += BATCH;
172 }
173
174 while off < n {
176 let block = &mut sector[off * 16..(off + 1) * 16];
177 for j in 0..16 {
178 block[j] ^= tw[j];
179 }
180 let mut ga: Block = Block::try_from(&block[..]).unwrap();
181 self.cipher_1.encrypt_block(&mut ga);
182 block.copy_from_slice(&ga);
183 for j in 0..16 {
184 block[j] ^= tw[j];
185 }
186 tw = gf128_mul(tw);
187 off += 1;
188 }
189 }
190
191 #[inline(never)]
192 fn decrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
193 let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
195 self.cipher_2.encrypt_block(&mut tw);
196
197 let n = sector.len() / 16;
198 let mut off = 0;
199
200 while off + BATCH <= n {
201 let mut ts = [Block::default(); BATCH];
202 ts[0] = tw;
203 for i in 1..BATCH {
204 ts[i] = gf128_mul(ts[i - 1]);
205 }
206 tw = gf128_mul(ts[BATCH - 1]);
207
208 let mut batch = [Block::default(); BATCH];
209 for i in 0..BATCH {
210 let src = §or[(off + i) * 16..(off + i + 1) * 16];
211 for j in 0..16 {
212 batch[i][j] = src[j] ^ ts[i][j];
213 }
214 }
215 self.cipher_1.decrypt_blocks(&mut batch);
216 for i in 0..BATCH {
217 let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
218 for j in 0..16 {
219 dst[j] = batch[i][j] ^ ts[i][j];
220 }
221 }
222 off += BATCH;
223 }
224
225 while off < n {
226 let block = &mut sector[off * 16..(off + 1) * 16];
227 for j in 0..16 {
228 block[j] ^= tw[j];
229 }
230 let mut ga: Block = Block::try_from(&block[..]).unwrap();
231 self.cipher_1.decrypt_block(&mut ga);
232 block.copy_from_slice(&ga);
233 for j in 0..16 {
234 block[j] ^= tw[j];
235 }
236 tw = gf128_mul(tw);
237 off += 1;
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use alloc::vec::Vec;
246 use xts_mode::{get_tweak_default, Xts128};
248
249 const KEY1: [u8; 32] = [0x11; 32];
250 const KEY2: [u8; 32] = [0x22; 32];
251
252 fn reference() -> Xts128<Aes256> {
254 let c1 = Aes256::new_from_slice(&KEY1).unwrap();
255 let c2 = Aes256::new_from_slice(&KEY2).unwrap();
256 Xts128::new(c1, c2)
257 }
258
259 #[test]
260 fn sector_roundtrip() {
261 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
262 let plain: Vec<u8> = (0..512).map(|i| i as u8).collect();
263 let mut buf = plain.clone();
264 c.encrypt_sector(42, &mut buf);
265 assert_ne!(buf, plain, "ciphertext must differ from plaintext");
266 c.decrypt_sector(42, &mut buf);
267 assert_eq!(buf, plain);
268 }
269
270 #[test]
271 fn tweak_depends_on_sector() {
272 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
273 let plain = [0xABu8; 512];
274 let mut a = plain;
275 let mut b = plain;
276 c.encrypt_sector(0, &mut a);
277 c.encrypt_sector(1, &mut b);
278 assert_ne!(a, b, "same plaintext at different sectors must differ");
279 }
280
281 #[test]
284 fn matches_xts_mode_reference() {
285 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
286 let xts = reference();
287 let sector_size = 512usize;
288 let first = 7u64;
289 let plain: Vec<u8> = (0..sector_size * 3).map(|i| (i * 7) as u8).collect();
290
291 let mut ours = plain.clone();
292 c.encrypt_area(&mut ours, sector_size, first);
293
294 let mut refer = plain.clone();
295 for s in 0..3u64 {
296 let start = s as usize * sector_size;
297 xts.encrypt_sector(
298 &mut refer[start..start + sector_size],
299 get_tweak_default((first + s) as u128),
300 );
301 }
302 assert_eq!(ours, refer, "parallel XTS must match xts-mode reference");
303
304 c.decrypt_area(&mut ours, sector_size, first);
305 assert_eq!(ours, plain);
306 }
307
308 #[test]
311 fn small_sector_roundtrip() {
312 let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
313 let sector_size = 64usize;
314 let plain: Vec<u8> = (0..sector_size * 5).map(|i| i as u8).collect();
315 let mut buf = plain.clone();
316 c.encrypt_area(&mut buf, sector_size, 0);
317 assert_ne!(buf, plain);
318 c.decrypt_area(&mut buf, sector_size, 0);
319 assert_eq!(buf, plain);
320 }
321}