key_vault/fragment/
random.rs1use alloc::borrow::Cow;
42use alloc::vec::Vec;
43
44use super::util::{fisher_yates, sample_range, zero_buffer};
45use super::{FragmentStrategy, Fragments};
46use crate::Result;
47use crate::error::Error;
48use crate::fetcher::RawKey;
49use crate::memory::LockedBytes;
50
51const DEFAULT_MIN_CHUNK: usize = 1;
53const DEFAULT_MAX_CHUNK: usize = 4;
55
56#[derive(Debug, Clone, Copy)]
62pub struct RandomFragmenter {
63 min_chunk: usize,
64 max_chunk: usize,
65}
66
67impl Default for RandomFragmenter {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl RandomFragmenter {
74 #[must_use]
76 pub fn new() -> Self {
77 Self {
78 min_chunk: DEFAULT_MIN_CHUNK,
79 max_chunk: DEFAULT_MAX_CHUNK,
80 }
81 }
82
83 #[must_use]
86 pub fn with_chunk_range(min: usize, max: usize) -> Self {
87 let min = min.max(1);
88 let max = max.max(min);
89 Self {
90 min_chunk: min,
91 max_chunk: max,
92 }
93 }
94}
95
96impl FragmentStrategy for RandomFragmenter {
97 #[allow(clippy::cast_possible_truncation)]
100 fn fragment(&self, key: &RawKey) -> Result<Fragments> {
101 let bytes = key.as_bytes();
102 let total_len = bytes.len();
103 if total_len == 0 {
104 return Err(Error::Fragment(alloc::string::ToString::to_string(
105 "empty key cannot be fragmented",
106 )));
107 }
108 if total_len > u32::MAX as usize {
110 return Err(Error::Fragment(alloc::string::ToString::to_string(
111 "key too large for fragmentation",
112 )));
113 }
114
115 let mut positions: Vec<u32> = (0..total_len as u32).collect();
118 fisher_yates(&mut positions)?;
119
120 let mut chunks: Vec<LockedBytes> = Vec::new();
124 let mut layout_bytes: Vec<u8> = Vec::new();
125 let mut cursor = 0usize;
126 while cursor < positions.len() {
127 let remaining = positions.len() - cursor;
128 let size = if remaining <= self.max_chunk {
129 remaining
130 } else {
131 let pick = sample_range(self.min_chunk, self.max_chunk)?;
132 pick.min(remaining.saturating_sub(self.min_chunk))
135 .max(self.min_chunk)
136 .min(self.max_chunk)
137 .min(remaining)
138 };
139
140 let mut chunk_bytes: Vec<u8> = Vec::with_capacity(size);
142 for &pos in &positions[cursor..cursor + size] {
143 chunk_bytes.push(bytes[pos as usize]);
144 }
145 chunks.push(LockedBytes::from_slice(&chunk_bytes));
146 zero_buffer(&mut chunk_bytes);
147 drop(chunk_bytes);
148
149 layout_bytes.extend_from_slice(&(size as u32).to_le_bytes());
153 for &pos in &positions[cursor..cursor + size] {
154 layout_bytes.extend_from_slice(&pos.to_le_bytes());
155 }
156
157 cursor += size;
158 }
159
160 let layout = LockedBytes::from_slice(&layout_bytes);
161 zero_buffer(&mut layout_bytes);
162 drop(layout_bytes);
163 drop(positions);
164
165 Ok(Fragments::from_parts(chunks, layout, total_len))
166 }
167
168 fn defragment(&self, fragments: &Fragments) -> Result<RawKey> {
169 let mut out = alloc::vec![0u8; fragments.total_len()];
170 self.defragment_into(fragments, &mut out)?;
171 Ok(RawKey::new(out))
172 }
173
174 fn defragment_into(&self, fragments: &Fragments, out: &mut [u8]) -> Result<()> {
175 let layout = fragments.layout().as_bytes();
176 let chunks = fragments.chunks();
177 let total_len = fragments.total_len();
178
179 if out.len() != total_len {
180 return Err(Error::Defragment(alloc::string::ToString::to_string(
181 "scratch buffer size does not match fragments.total_len()",
182 )));
183 }
184 let mut layout_cursor = 0usize;
185 for chunk in chunks {
186 if layout_cursor + 4 > layout.len() {
188 return Err(Error::Defragment(alloc::string::ToString::to_string(
189 "layout buffer truncated before size prefix",
190 )));
191 }
192 let size_raw: [u8; 4] = layout[layout_cursor..layout_cursor + 4]
193 .try_into()
194 .map_err(|_| {
195 Error::Defragment(alloc::string::ToString::to_string("layout slice"))
196 })?;
197 let size = u32::from_le_bytes(size_raw) as usize;
198 layout_cursor += 4;
199
200 if size != chunk.as_bytes().len() {
201 return Err(Error::Defragment(alloc::string::ToString::to_string(
202 "layout size does not match chunk length",
203 )));
204 }
205 if layout_cursor + size * 4 > layout.len() {
206 return Err(Error::Defragment(alloc::string::ToString::to_string(
207 "layout buffer truncated before position list",
208 )));
209 }
210
211 for (i, byte) in chunk.as_bytes().iter().enumerate() {
213 let pos_raw: [u8; 4] = layout[layout_cursor + i * 4..layout_cursor + (i + 1) * 4]
214 .try_into()
215 .map_err(|_| {
216 Error::Defragment(alloc::string::ToString::to_string("layout slice"))
217 })?;
218 let pos = u32::from_le_bytes(pos_raw) as usize;
219 if pos >= total_len {
220 return Err(Error::Defragment(alloc::string::ToString::to_string(
221 "layout position out of range",
222 )));
223 }
224 out[pos] = *byte;
225 }
226 layout_cursor += size * 4;
227 }
228
229 if layout_cursor != layout.len() {
230 return Err(Error::Defragment(alloc::string::ToString::to_string(
231 "trailing bytes in layout buffer",
232 )));
233 }
234
235 Ok(())
236 }
237
238 fn describe(&self) -> Cow<'_, str> {
239 Cow::Borrowed("random")
240 }
241}
242
243#[cfg(test)]
244#[allow(
245 clippy::unwrap_used,
246 clippy::expect_used,
247 clippy::cast_possible_truncation,
248 clippy::cast_sign_loss
249)]
250mod tests {
251 use super::*;
252
253 fn key(bytes: &[u8]) -> RawKey {
254 RawKey::new(bytes.to_vec())
255 }
256
257 #[test]
258 fn round_trip_short_key() {
259 let frag = RandomFragmenter::new();
260 let original = key(&[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
261 let fragments = frag.fragment(&original).unwrap();
262 let recovered = frag.defragment(&fragments).unwrap();
263 assert_eq!(recovered.as_bytes(), original.as_bytes());
264 }
265
266 #[test]
267 fn round_trip_many_sizes() {
268 let frag = RandomFragmenter::new();
269 for len in [1usize, 7, 16, 32, 64, 128, 255, 256, 500, 1024] {
270 let bytes: Vec<u8> = (0..len).map(|i| (i & 0xff) as u8).collect();
271 let original = key(&bytes);
272 let fragments = frag.fragment(&original).unwrap();
273 let recovered = frag.defragment(&fragments).unwrap();
274 assert_eq!(recovered.as_bytes(), &bytes[..], "mismatch at len {len}");
275 }
276 }
277
278 #[test]
279 fn empty_key_rejected() {
280 let frag = RandomFragmenter::new();
281 let err = frag.fragment(&key(&[])).unwrap_err();
282 assert!(matches!(err, Error::Fragment(_)));
283 }
284
285 #[test]
286 fn two_calls_produce_different_layouts() {
287 let frag = RandomFragmenter::new();
288 let bytes: Vec<u8> = (0..32).map(|i| i as u8).collect();
289 let original = key(&bytes);
290 let a = frag.fragment(&original).unwrap();
291 let b = frag.fragment(&original).unwrap();
292 assert_ne!(a.layout().as_bytes(), b.layout().as_bytes());
293 }
294
295 #[test]
296 fn describe_returns_random() {
297 assert_eq!(RandomFragmenter::new().describe(), "random");
298 }
299}