Skip to main content

key_vault/fragment/
random.rs

1//! [`RandomFragmenter`] — non-contiguous byte scatter.
2//!
3//! Where [`StandardFragmenter`](super::StandardFragmenter) splits the key
4//! into contiguous chunks and shuffles those chunks, `RandomFragmenter`
5//! scatters bytes **within** each chunk: each chunk holds bytes drawn from
6//! non-contiguous positions in the original key. This defeats the "long
7//! contiguous run of key bytes" cue that an attacker might use to confirm
8//! a chunk hit.
9//!
10//! # When to use
11//!
12//! Use `RandomFragmenter` when:
13//!
14//! - You suspect an attacker can scan memory linearly and recognize
15//!   structured key formats (DER, PEM, ASCII-armored) even from a partial
16//!   contiguous read.
17//! - You are willing to pay slightly higher per-chunk overhead (each
18//!   chunk's bytes come from up to `max_chunk` random positions) for the
19//!   reduced linear-recognition risk.
20//!
21//! For most cases, [`StandardFragmenter`](super::StandardFragmenter)
22//! combined with [`SelfReferenceDecoy`](crate::SelfReferenceDecoy) is
23//! sufficient and faster.
24//!
25//! # Layout encoding
26//!
27//! Each chunk's layout records the **original position** of each byte in
28//! the chunk:
29//!
30//! ```text
31//! layout = [size: u32 LE,
32//!           pos[0]: u32 LE, pos[1]: u32 LE, ..., pos[size-1]: u32 LE,
33//!           size: u32 LE, ...]
34//! ```
35//!
36//! `defragment` walks the layout, places each byte at its recorded
37//! original position, and returns the result. Decoys are not currently
38//! supported by this strategy — combine with
39//! [`LayeredFragmenter`](super::LayeredFragmenter) if you need decoy mixing.
40
41use alloc::borrow::Cow;
42use alloc::vec::Vec;
43
44use super::util::{fisher_yates, sample_range, zero_buffer};
45use super::{FragmentStrategy, Fragments};
46use crate::Result;
47use crate::error::Error;
48use crate::fetcher::RawKey;
49use crate::memory::LockedBytes;
50
51/// Default minimum chunk size for [`RandomFragmenter`].
52const DEFAULT_MIN_CHUNK: usize = 1;
53/// Default maximum chunk size for [`RandomFragmenter`].
54const DEFAULT_MAX_CHUNK: usize = 4;
55
56/// Non-contiguous-scatter Layer 3 fragmenter.
57///
58/// Each chunk holds bytes drawn from independently-chosen random positions
59/// in the original key — no chunk ever contains a contiguous run of key
60/// bytes longer than 1.
61#[derive(Debug, Clone, Copy)]
62pub struct RandomFragmenter {
63    min_chunk: usize,
64    max_chunk: usize,
65}
66
67impl Default for RandomFragmenter {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl RandomFragmenter {
74    /// Construct with the default chunk-size range (`min = 1`, `max = 4`).
75    #[must_use]
76    pub fn new() -> Self {
77        Self {
78            min_chunk: DEFAULT_MIN_CHUNK,
79            max_chunk: DEFAULT_MAX_CHUNK,
80        }
81    }
82
83    /// Construct with a custom chunk-size range. `min` is clamped to
84    /// `>= 1`, `max` to `>= min`.
85    #[must_use]
86    pub fn with_chunk_range(min: usize, max: usize) -> Self {
87        let min = min.max(1);
88        let max = max.max(min);
89        Self {
90            min_chunk: min,
91            max_chunk: max,
92        }
93    }
94}
95
96impl FragmentStrategy for RandomFragmenter {
97    // All `as u32` casts in this method are bounded by checked
98    // pre-conditions (`total_len <= u32::MAX`, `size <= max_chunk <= u32`).
99    #[allow(clippy::cast_possible_truncation)]
100    fn fragment(&self, key: &RawKey) -> Result<Fragments> {
101        let bytes = key.as_bytes();
102        let total_len = bytes.len();
103        if total_len == 0 {
104            return Err(Error::Fragment(alloc::string::ToString::to_string(
105                "empty key cannot be fragmented",
106            )));
107        }
108        // Real-byte positions must fit in u32.
109        if total_len > u32::MAX as usize {
110            return Err(Error::Fragment(alloc::string::ToString::to_string(
111                "key too large for fragmentation",
112            )));
113        }
114
115        // Step 1: build a shuffled permutation of all original positions.
116        // Each position appears exactly once.
117        let mut positions: Vec<u32> = (0..total_len as u32).collect();
118        fisher_yates(&mut positions)?;
119
120        // Step 2: walk the permutation, peeling off variable-size groups
121        // and turning each group into a chunk. Each chunk's bytes thus
122        // come from non-contiguous, randomly-chosen original positions.
123        let mut chunks: Vec<LockedBytes> = Vec::new();
124        let mut layout_bytes: Vec<u8> = Vec::new();
125        let mut cursor = 0usize;
126        while cursor < positions.len() {
127            let remaining = positions.len() - cursor;
128            let size = if remaining <= self.max_chunk {
129                remaining
130            } else {
131                let pick = sample_range(self.min_chunk, self.max_chunk)?;
132                // Ensure we leave at least `min` bytes for at least one
133                // more chunk.
134                pick.min(remaining.saturating_sub(self.min_chunk))
135                    .max(self.min_chunk)
136                    .min(self.max_chunk)
137                    .min(remaining)
138            };
139
140            // Build the chunk's bytes by reading from each picked position.
141            let mut chunk_bytes: Vec<u8> = Vec::with_capacity(size);
142            for &pos in &positions[cursor..cursor + size] {
143                chunk_bytes.push(bytes[pos as usize]);
144            }
145            chunks.push(LockedBytes::from_slice(&chunk_bytes));
146            zero_buffer(&mut chunk_bytes);
147            drop(chunk_bytes);
148
149            // Append the layout entry: u32 size + size × u32 positions.
150            // Size fits in u32 because max_chunk <= u32::MAX (practically
151            // <= 4 by default).
152            layout_bytes.extend_from_slice(&(size as u32).to_le_bytes());
153            for &pos in &positions[cursor..cursor + size] {
154                layout_bytes.extend_from_slice(&pos.to_le_bytes());
155            }
156
157            cursor += size;
158        }
159
160        let layout = LockedBytes::from_slice(&layout_bytes);
161        zero_buffer(&mut layout_bytes);
162        drop(layout_bytes);
163        drop(positions);
164
165        Ok(Fragments::from_parts(chunks, layout, total_len))
166    }
167
168    fn defragment(&self, fragments: &Fragments) -> Result<RawKey> {
169        let mut out = alloc::vec![0u8; fragments.total_len()];
170        self.defragment_into(fragments, &mut out)?;
171        Ok(RawKey::new(out))
172    }
173
174    fn defragment_into(&self, fragments: &Fragments, out: &mut [u8]) -> Result<()> {
175        let layout = fragments.layout().as_bytes();
176        let chunks = fragments.chunks();
177        let total_len = fragments.total_len();
178
179        if out.len() != total_len {
180            return Err(Error::Defragment(alloc::string::ToString::to_string(
181                "scratch buffer size does not match fragments.total_len()",
182            )));
183        }
184        let mut layout_cursor = 0usize;
185        for chunk in chunks {
186            // Read size prefix.
187            if layout_cursor + 4 > layout.len() {
188                return Err(Error::Defragment(alloc::string::ToString::to_string(
189                    "layout buffer truncated before size prefix",
190                )));
191            }
192            let size_raw: [u8; 4] = layout[layout_cursor..layout_cursor + 4]
193                .try_into()
194                .map_err(|_| {
195                    Error::Defragment(alloc::string::ToString::to_string("layout slice"))
196                })?;
197            let size = u32::from_le_bytes(size_raw) as usize;
198            layout_cursor += 4;
199
200            if size != chunk.as_bytes().len() {
201                return Err(Error::Defragment(alloc::string::ToString::to_string(
202                    "layout size does not match chunk length",
203                )));
204            }
205            if layout_cursor + size * 4 > layout.len() {
206                return Err(Error::Defragment(alloc::string::ToString::to_string(
207                    "layout buffer truncated before position list",
208                )));
209            }
210
211            // Place each byte at its recorded original position.
212            for (i, byte) in chunk.as_bytes().iter().enumerate() {
213                let pos_raw: [u8; 4] = layout[layout_cursor + i * 4..layout_cursor + (i + 1) * 4]
214                    .try_into()
215                    .map_err(|_| {
216                        Error::Defragment(alloc::string::ToString::to_string("layout slice"))
217                    })?;
218                let pos = u32::from_le_bytes(pos_raw) as usize;
219                if pos >= total_len {
220                    return Err(Error::Defragment(alloc::string::ToString::to_string(
221                        "layout position out of range",
222                    )));
223                }
224                out[pos] = *byte;
225            }
226            layout_cursor += size * 4;
227        }
228
229        if layout_cursor != layout.len() {
230            return Err(Error::Defragment(alloc::string::ToString::to_string(
231                "trailing bytes in layout buffer",
232            )));
233        }
234
235        Ok(())
236    }
237
238    fn describe(&self) -> Cow<'_, str> {
239        Cow::Borrowed("random")
240    }
241}
242
243#[cfg(test)]
244#[allow(
245    clippy::unwrap_used,
246    clippy::expect_used,
247    clippy::cast_possible_truncation,
248    clippy::cast_sign_loss
249)]
250mod tests {
251    use super::*;
252
253    fn key(bytes: &[u8]) -> RawKey {
254        RawKey::new(bytes.to_vec())
255    }
256
257    #[test]
258    fn round_trip_short_key() {
259        let frag = RandomFragmenter::new();
260        let original = key(&[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
261        let fragments = frag.fragment(&original).unwrap();
262        let recovered = frag.defragment(&fragments).unwrap();
263        assert_eq!(recovered.as_bytes(), original.as_bytes());
264    }
265
266    #[test]
267    fn round_trip_many_sizes() {
268        let frag = RandomFragmenter::new();
269        for len in [1usize, 7, 16, 32, 64, 128, 255, 256, 500, 1024] {
270            let bytes: Vec<u8> = (0..len).map(|i| (i & 0xff) as u8).collect();
271            let original = key(&bytes);
272            let fragments = frag.fragment(&original).unwrap();
273            let recovered = frag.defragment(&fragments).unwrap();
274            assert_eq!(recovered.as_bytes(), &bytes[..], "mismatch at len {len}");
275        }
276    }
277
278    #[test]
279    fn empty_key_rejected() {
280        let frag = RandomFragmenter::new();
281        let err = frag.fragment(&key(&[])).unwrap_err();
282        assert!(matches!(err, Error::Fragment(_)));
283    }
284
285    #[test]
286    fn two_calls_produce_different_layouts() {
287        let frag = RandomFragmenter::new();
288        let bytes: Vec<u8> = (0..32).map(|i| i as u8).collect();
289        let original = key(&bytes);
290        let a = frag.fragment(&original).unwrap();
291        let b = frag.fragment(&original).unwrap();
292        assert_ne!(a.layout().as_bytes(), b.layout().as_bytes());
293    }
294
295    #[test]
296    fn describe_returns_random() {
297        assert_eq!(RandomFragmenter::new().describe(), "random");
298    }
299}