Skip to main content

vck_common/
xts.rs

1// SPDX-FileCopyrightText: 2026 JC-Lab <joseph@jc-lab.net>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5//! Shared AES-256-XTS volume sector cipher used by both the kernel driver and
6//! the UEFI loader, so their on-disk crypto agrees by construction.
7//!
8//! **Tweak convention (authoritative):** the XTS tweak for a sector is its
9//! **data-region-relative** sector number, i.e. `rel = absolute_lba - offset_sector`,
10//! where `rel == 0` is the first encryptable sector. This matches the
11//! `EncryptedOffset` semantics (also data-region relative). Callers MUST map
12//! absolute LBAs to `rel` before invoking these methods, and MUST NOT call them
13//! for sectors inside header/footer metadata regions (those pass through in
14//! plaintext).
15//!
16//! Keys are two independent 256-bit halves (`key1` = data key, `key2` = tweak
17//! key), giving AES-256-XTS.
18//!
19//! # Performance
20//!
21//! All sectors are processed through an 8-block parallel XTS path that keeps 8
22//! independent AES operations in flight simultaneously. On x86-64 with AES-NI
23//! (detected at runtime by the `aes` crate) this saturates the throughput
24//! pipeline (~1 cycle per 16-byte block) instead of being latency-bound
25//! (~7 cycles per block for AES-256). Sectors are always a multiple of 16 bytes
26//! (512, 4096, …) so ciphertext stealing never applies.
27//!
28//! # Kernel stack safety
29//!
30//! The driver runs this crypto on a constrained kernel stack (a system-thread
31//! stack of ~24 KiB, and an IOCTL callout stack of 32 KiB). The crate is built
32//! for the driver WITHOUT `-C target-feature=+aes`, so the `aes` crate's
33//! fully-unrolled AES-NI `encrypt8`/`decrypt8` stay behind a runtime-dispatch
34//! call boundary instead of being inlined into (and ballooning the frames of)
35//! the deep storage/metadata call chain. The per-sector entry points are
36//! additionally marked `#[inline(never)]` so their AES frames can never combine
37//! with a caller's frame. (A prior build with `+aes` inlined the unrolled AES
38//! into the IOCTL path and double-faulted on stack overflow.)
39
40use alloc::boxed::Box;
41
42use aes::cipher::{BlockCipherDecrypt, BlockCipherEncrypt, KeyInit};
43use aes::{Aes256, Block};
44
45use crate::{
46    types::{VolumeCipher, VolumeCipherSupplier},
47    VckError, VckResult,
48};
49
50/// Number of AES-XTS blocks processed in one parallel batch.
51/// Matches the AES-NI backend's `ParBlocks = 8`, filling the 7-cycle pipeline.
52const BATCH: usize = 8;
53
54pub struct XtsVolumeCipher {
55    /// Data cipher for the AES-XTS payload blocks.
56    cipher_1: Aes256,
57    /// Tweak cipher (initial tweak = AES_K2(sector_number)).
58    cipher_2: Aes256,
59}
60
61impl VolumeCipher for XtsVolumeCipher {
62    fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
63        XtsVolumeCipher::encrypt_sector(self, rel_sector, sector)
64    }
65    fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
66        XtsVolumeCipher::decrypt_sector(self, rel_sector, sector)
67    }
68    fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
69        XtsVolumeCipher::encrypt_area(self, buf, sector_size, first_rel_sector)
70    }
71    fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
72        XtsVolumeCipher::decrypt_area(self, buf, sector_size, first_rel_sector)
73    }
74    // `destroy` uses the default no-op: `Aes256` does not expose its key-schedule
75    // bytes for manual zeroization without the `zeroize` feature. Callers that
76    // need guaranteed zeroization should use a custom `VolumeCipherSupplier` that
77    // wraps key material in a `Zeroizing<[u8; 32]>` and re-derives per burst.
78}
79
80/// Default [`VolumeCipherSupplier`] for volumes that store AES-256-XTS key
81/// material in ordinary (non-protected) memory.
82///
83/// Reconstructs the AES key schedule on each
84/// [`get_cipher`](VolumeCipherSupplier::get_cipher) call.  The cost is one AES
85/// key expansion (~480 bytes of computation) per I/O burst — negligible
86/// compared with the I/O itself and bounded to once per [`MAX_IRP_BURST`] IRPs.
87///
88/// For RAM-encryption, implement [`VolumeCipherSupplier`] directly: derive the
89/// key from protected storage on each call and override [`VolumeCipher::destroy`]
90/// to zeroize it.
91pub struct StaticCipherSupplier {
92    key1: [u8; 32],
93    key2: [u8; 32],
94}
95
96impl StaticCipherSupplier {
97    pub fn new(key1: [u8; 32], key2: [u8; 32]) -> Self {
98        Self { key1, key2 }
99    }
100}
101
102impl VolumeCipherSupplier for StaticCipherSupplier {
103    fn get_cipher(&self) -> Option<Box<dyn VolumeCipher>> {
104        XtsVolumeCipher::new(&self.key1, &self.key2)
105            .ok()
106            .map(|c| Box::new(c) as Box<dyn VolumeCipher>)
107    }
108}
109
110/// GF(2^128) multiplication by the primitive element alpha in the XTS field
111/// (little-endian byte order, primitive polynomial x^128 + x^7 + x^2 + x + 1).
112#[inline(always)]
113fn gf128_mul(t: Block) -> Block {
114    let lo = u64::from_le_bytes(t[..8].try_into().unwrap());
115    let hi = u64::from_le_bytes(t[8..].try_into().unwrap());
116    let carry = if hi >> 63 != 0 { 0x87u64 } else { 0u64 };
117    let mut out = Block::default();
118    out[..8].copy_from_slice(&((lo << 1) ^ carry).to_le_bytes());
119    out[8..].copy_from_slice(&((hi << 1) | (lo >> 63)).to_le_bytes());
120    out
121}
122
123impl XtsVolumeCipher {
124    pub fn new(key1: &[u8; 32], key2: &[u8; 32]) -> VckResult<Self> {
125        let cipher_1 =
126            Aes256::new_from_slice(key1).map_err(|_| VckError::CryptoFailed("invalid XTS key1"))?;
127        let cipher_2 =
128            Aes256::new_from_slice(key2).map_err(|_| VckError::CryptoFailed("invalid XTS key2"))?;
129        Ok(Self { cipher_1, cipher_2 })
130    }
131
132    /// Encrypt one sector in place. `rel_sector` is data-region relative.
133    pub fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
134        self.encrypt_sector_inner(rel_sector, sector);
135    }
136
137    /// Decrypt one sector in place. `rel_sector` is data-region relative.
138    pub fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
139        self.decrypt_sector_inner(rel_sector, sector);
140    }
141
142    /// Encrypt a contiguous buffer of `sector_size`-byte sectors starting at
143    /// data-region-relative sector `first_rel_sector`.
144    pub fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
145        for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
146            self.encrypt_sector_inner(first_rel_sector + si as u64, sector);
147        }
148    }
149
150    /// Decrypt a contiguous buffer (inverse of [`encrypt_area`]).
151    pub fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
152        for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
153            self.decrypt_sector_inner(first_rel_sector + si as u64, sector);
154        }
155    }
156
157    /// `#[inline(never)]` bounds this function's (AES-heavy) stack frame so it
158    /// cannot merge with a deep caller's frame on the kernel stack.
159    #[inline(never)]
160    fn encrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
161        // T_0 = AES_K2(sector_number as 128-bit little-endian)
162        let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
163        self.cipher_2.encrypt_block(&mut tw);
164
165        let n = sector.len() / 16;
166        let mut off = 0;
167
168        // 8-block parallel path: all 8 AES operations are independent so the
169        // CPU can keep the AES-NI units fully pipelined.
170        while off + BATCH <= n {
171            let mut ts = [Block::default(); BATCH];
172            ts[0] = tw;
173            for i in 1..BATCH {
174                ts[i] = gf128_mul(ts[i - 1]);
175            }
176            tw = gf128_mul(ts[BATCH - 1]);
177
178            let mut batch = [Block::default(); BATCH];
179            for i in 0..BATCH {
180                let src = &sector[(off + i) * 16..(off + i + 1) * 16];
181                for j in 0..16 {
182                    batch[i][j] = src[j] ^ ts[i][j];
183                }
184            }
185            self.cipher_1.encrypt_blocks(&mut batch);
186            for i in 0..BATCH {
187                let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
188                for j in 0..16 {
189                    dst[j] = batch[i][j] ^ ts[i][j];
190                }
191            }
192            off += BATCH;
193        }
194
195        // Scalar tail for sectors whose block count is not a multiple of BATCH.
196        while off < n {
197            let block = &mut sector[off * 16..(off + 1) * 16];
198            for j in 0..16 {
199                block[j] ^= tw[j];
200            }
201            let mut ga: Block = Block::try_from(&block[..]).unwrap();
202            self.cipher_1.encrypt_block(&mut ga);
203            block.copy_from_slice(&ga);
204            for j in 0..16 {
205                block[j] ^= tw[j];
206            }
207            tw = gf128_mul(tw);
208            off += 1;
209        }
210    }
211
212    #[inline(never)]
213    fn decrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
214        // Tweak is always encrypted with K2 (even during decryption).
215        let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
216        self.cipher_2.encrypt_block(&mut tw);
217
218        let n = sector.len() / 16;
219        let mut off = 0;
220
221        while off + BATCH <= n {
222            let mut ts = [Block::default(); BATCH];
223            ts[0] = tw;
224            for i in 1..BATCH {
225                ts[i] = gf128_mul(ts[i - 1]);
226            }
227            tw = gf128_mul(ts[BATCH - 1]);
228
229            let mut batch = [Block::default(); BATCH];
230            for i in 0..BATCH {
231                let src = &sector[(off + i) * 16..(off + i + 1) * 16];
232                for j in 0..16 {
233                    batch[i][j] = src[j] ^ ts[i][j];
234                }
235            }
236            self.cipher_1.decrypt_blocks(&mut batch);
237            for i in 0..BATCH {
238                let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
239                for j in 0..16 {
240                    dst[j] = batch[i][j] ^ ts[i][j];
241                }
242            }
243            off += BATCH;
244        }
245
246        while off < n {
247            let block = &mut sector[off * 16..(off + 1) * 16];
248            for j in 0..16 {
249                block[j] ^= tw[j];
250            }
251            let mut ga: Block = Block::try_from(&block[..]).unwrap();
252            self.cipher_1.decrypt_block(&mut ga);
253            block.copy_from_slice(&ga);
254            for j in 0..16 {
255                block[j] ^= tw[j];
256            }
257            tw = gf128_mul(tw);
258            off += 1;
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use alloc::vec::Vec;
267    // Reference implementation for cross-checking standards compliance.
268    use xts_mode::{get_tweak_default, Xts128};
269
270    const KEY1: [u8; 32] = [0x11; 32];
271    const KEY2: [u8; 32] = [0x22; 32];
272
273    /// Build the `xts-mode` reference cipher for the same key pair.
274    fn reference() -> Xts128<Aes256> {
275        let c1 = Aes256::new_from_slice(&KEY1).unwrap();
276        let c2 = Aes256::new_from_slice(&KEY2).unwrap();
277        Xts128::new(c1, c2)
278    }
279
280    #[test]
281    fn sector_roundtrip() {
282        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
283        let plain: Vec<u8> = (0..512).map(|i| i as u8).collect();
284        let mut buf = plain.clone();
285        c.encrypt_sector(42, &mut buf);
286        assert_ne!(buf, plain, "ciphertext must differ from plaintext");
287        c.decrypt_sector(42, &mut buf);
288        assert_eq!(buf, plain);
289    }
290
291    #[test]
292    fn tweak_depends_on_sector() {
293        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
294        let plain = [0xABu8; 512];
295        let mut a = plain;
296        let mut b = plain;
297        c.encrypt_sector(0, &mut a);
298        c.encrypt_sector(1, &mut b);
299        assert_ne!(a, b, "same plaintext at different sectors must differ");
300    }
301
302    /// Our parallel path must produce byte-identical ciphertext to the standard
303    /// `xts-mode` implementation (data-region-relative sector as the tweak).
304    #[test]
305    fn matches_xts_mode_reference() {
306        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
307        let xts = reference();
308        let sector_size = 512usize;
309        let first = 7u64;
310        let plain: Vec<u8> = (0..sector_size * 3).map(|i| (i * 7) as u8).collect();
311
312        let mut ours = plain.clone();
313        c.encrypt_area(&mut ours, sector_size, first);
314
315        let mut refer = plain.clone();
316        for s in 0..3u64 {
317            let start = s as usize * sector_size;
318            xts.encrypt_sector(
319                &mut refer[start..start + sector_size],
320                get_tweak_default((first + s) as u128),
321            );
322        }
323        assert_eq!(ours, refer, "parallel XTS must match xts-mode reference");
324
325        c.decrypt_area(&mut ours, sector_size, first);
326        assert_eq!(ours, plain);
327    }
328
329    /// Round-trip with a sector size whose block count is not a multiple of
330    /// BATCH (64 bytes = 4 blocks < 8) exercises the scalar tail.
331    #[test]
332    fn small_sector_roundtrip() {
333        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
334        let sector_size = 64usize;
335        let plain: Vec<u8> = (0..sector_size * 5).map(|i| i as u8).collect();
336        let mut buf = plain.clone();
337        c.encrypt_area(&mut buf, sector_size, 0);
338        assert_ne!(buf, plain);
339        c.decrypt_area(&mut buf, sector_size, 0);
340        assert_eq!(buf, plain);
341    }
342}