1use anyhow::{Context, Result};
2use ring::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey};
3use zstd::bulk::Compressor;
4
5use crate::topology::CpuGeneration;
6
7pub const CHUNK_2_MB: usize = 2 * 1024 * 1024;
8pub const CHUNK_4_MB: usize = 4 * 1024 * 1024;
9pub const CHUNK_8_MB: usize = 8 * 1024 * 1024;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ChunkPolicy {
13 Fixed2M,
14 Fixed4M,
15 Fixed8M,
16 Adaptive8To4,
17}
18
19#[derive(Debug, Clone, Copy)]
20pub struct RuntimeMetrics {
21 pub put_p99_ms: f64,
22 pub median_queue_depth: f64,
23}
24
25#[derive(Debug, Clone)]
26pub struct ChunkController {
27 policy: ChunkPolicy,
28 current_chunk_bytes: usize,
29 downshift_streak: u32,
30 upshift_streak: u32,
31}
32
33#[derive(Debug, Clone)]
34pub struct EncryptedChunk {
35 pub nonce_id: u64,
36 pub crc32c: u32,
37 pub plain_len: usize,
38 pub cipher_len: usize,
39 pub bytes: Vec<u8>,
40}
41
42#[derive(Debug, Clone)]
43pub struct PipelineOutput {
44 pub chunks: Vec<EncryptedChunk>,
45 pub total_ciphertext_bytes: usize,
46}
47
48impl ChunkPolicy {
49 pub fn from_str(value: &str) -> Result<Self> {
50 match value {
51 "fixed_2m" => Ok(Self::Fixed2M),
52 "fixed_4m" => Ok(Self::Fixed4M),
53 "fixed_8m" => Ok(Self::Fixed8M),
54 "adaptive_8_to_4" => Ok(Self::Adaptive8To4),
55 _ => anyhow::bail!("unsupported chunk policy: {value}"),
56 }
57 }
58}
59
60impl ChunkController {
61 pub fn new(cpu_generation: CpuGeneration, policy: ChunkPolicy) -> Self {
62 let current_chunk_bytes = match policy {
63 ChunkPolicy::Fixed2M => CHUNK_2_MB,
64 ChunkPolicy::Fixed4M => CHUNK_4_MB,
65 ChunkPolicy::Fixed8M => CHUNK_8_MB,
66 ChunkPolicy::Adaptive8To4 => match cpu_generation {
67 CpuGeneration::Zen2 => CHUNK_2_MB,
68 CpuGeneration::Zen3 | CpuGeneration::Unknown => CHUNK_8_MB,
69 },
70 };
71
72 Self {
73 policy,
74 current_chunk_bytes,
75 downshift_streak: 0,
76 upshift_streak: 0,
77 }
78 }
79
80 pub fn current_chunk_bytes(&self) -> usize {
81 self.current_chunk_bytes
82 }
83
84 pub fn policy(&self) -> ChunkPolicy {
85 self.policy
86 }
87
88 pub fn observe(&mut self, metrics: RuntimeMetrics) {
89 if self.policy != ChunkPolicy::Adaptive8To4 {
90 return;
91 }
92
93 let should_downshift = metrics.put_p99_ms > 200.0 && metrics.median_queue_depth > 64.0;
94 let should_upshift = metrics.put_p99_ms < 120.0 && metrics.median_queue_depth < 32.0;
95
96 if self.current_chunk_bytes == CHUNK_8_MB {
97 if should_downshift {
98 self.downshift_streak += 1;
99 } else {
100 self.downshift_streak = 0;
101 }
102 self.upshift_streak = 0;
103 if self.downshift_streak >= 3 {
104 self.current_chunk_bytes = CHUNK_4_MB;
105 self.downshift_streak = 0;
106 }
107 return;
108 }
109
110 if self.current_chunk_bytes == CHUNK_4_MB {
111 if should_upshift {
112 self.upshift_streak += 1;
113 } else {
114 self.upshift_streak = 0;
115 }
116 self.downshift_streak = 0;
117 if self.upshift_streak >= 30 {
118 self.current_chunk_bytes = CHUNK_8_MB;
119 self.upshift_streak = 0;
120 }
121 }
122 }
123}
124
125pub fn process_payload_crc_zstd_aes(
126 payload: &[u8],
127 chunk_bytes: usize,
128 key_bytes: &[u8; 32],
129 nonce_seed: u64,
130) -> Result<PipelineOutput> {
131 if chunk_bytes == 0 {
132 anyhow::bail!("chunk_bytes must be > 0");
133 }
134
135 let unbound = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
136 .map_err(|_| anyhow::anyhow!("failed to initialize AES-256-GCM key"))?;
137 let key = LessSafeKey::new(unbound);
138 let mut compressor = Compressor::new(1).context("failed initializing zstd compressor")?;
139
140 let mut compressed_buffer = vec![0_u8; zstd::zstd_safe::compress_bound(chunk_bytes)];
141 let mut output = PipelineOutput {
142 chunks: Vec::new(),
143 total_ciphertext_bytes: 0,
144 };
145
146 for (chunk_idx, chunk) in payload.chunks(chunk_bytes).enumerate() {
147 let crc = crc32c::crc32c(chunk);
148 let compressed_len = compressor
149 .compress_to_buffer(chunk, &mut compressed_buffer)
150 .with_context(|| format!("zstd compression failed at chunk {chunk_idx}"))?;
151
152 let nonce_id = nonce_seed + chunk_idx as u64 + 1;
153 let nonce = nonce_for_chunk(nonce_id);
154 let tag = key
155 .seal_in_place_separate_tag(
156 nonce,
157 Aad::from(crc.to_le_bytes()),
158 &mut compressed_buffer[..compressed_len],
159 )
160 .map_err(|_| anyhow::anyhow!("AES-256-GCM encryption failed at chunk {chunk_idx}"))?;
161
162 let mut bytes = Vec::with_capacity(compressed_len + tag.as_ref().len());
163 bytes.extend_from_slice(&compressed_buffer[..compressed_len]);
164 bytes.extend_from_slice(tag.as_ref());
165
166 output.total_ciphertext_bytes += bytes.len();
167 output.chunks.push(EncryptedChunk {
168 nonce_id,
169 crc32c: crc,
170 plain_len: chunk.len(),
171 cipher_len: bytes.len(),
172 bytes,
173 });
174 }
175
176 Ok(output)
177}
178
179pub fn restore_chunk_zstd_aes(
180 encrypted_chunk: &[u8],
181 expected_crc32c: u32,
182 expected_plain_len: usize,
183 key_bytes: &[u8; 32],
184 nonce_id: u64,
185) -> Result<Vec<u8>> {
186 if encrypted_chunk.len() < aead::AES_256_GCM.tag_len() {
187 anyhow::bail!(
188 "encrypted chunk is too small: {} bytes",
189 encrypted_chunk.len()
190 );
191 }
192
193 let unbound = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
194 .map_err(|_| anyhow::anyhow!("failed to initialize AES-256-GCM key"))?;
195 let key = LessSafeKey::new(unbound);
196 let nonce = nonce_for_chunk(nonce_id);
197
198 let mut sealed = encrypted_chunk.to_vec();
199 let decrypted = key
200 .open_in_place(nonce, Aad::from(expected_crc32c.to_le_bytes()), &mut sealed)
201 .map_err(|_| anyhow::anyhow!("AES-256-GCM decrypt failed"))?;
202
203 let restored = zstd::bulk::decompress(decrypted, expected_plain_len)
204 .context("zstd decompression failed during restore")?;
205 let crc = crc32c::crc32c(&restored);
206 if crc != expected_crc32c {
207 anyhow::bail!(
208 "crc32c mismatch after decrypt/decompress: expected {}, got {}",
209 expected_crc32c,
210 crc
211 );
212 }
213
214 Ok(restored)
215}
216
217fn nonce_for_chunk(chunk: u64) -> Nonce {
218 let mut nonce = [0_u8; 12];
219 nonce[4..].copy_from_slice(&chunk.to_be_bytes());
220 Nonce::assume_unique_for_key(nonce)
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn adaptive_policy_downshifts_after_three_windows() {
229 let mut controller = ChunkController::new(CpuGeneration::Zen3, ChunkPolicy::Adaptive8To4);
230 for _ in 0..2 {
231 controller.observe(RuntimeMetrics {
232 put_p99_ms: 250.0,
233 median_queue_depth: 100.0,
234 });
235 assert_eq!(controller.current_chunk_bytes(), CHUNK_8_MB);
236 }
237 controller.observe(RuntimeMetrics {
238 put_p99_ms: 250.0,
239 median_queue_depth: 100.0,
240 });
241 assert_eq!(controller.current_chunk_bytes(), CHUNK_4_MB);
242 }
243
244 #[test]
245 fn adaptive_policy_upshifts_after_thirty_windows() {
246 let mut controller = ChunkController::new(CpuGeneration::Zen3, ChunkPolicy::Adaptive8To4);
247 for _ in 0..3 {
248 controller.observe(RuntimeMetrics {
249 put_p99_ms: 300.0,
250 median_queue_depth: 120.0,
251 });
252 }
253 assert_eq!(controller.current_chunk_bytes(), CHUNK_4_MB);
254 for _ in 0..30 {
255 controller.observe(RuntimeMetrics {
256 put_p99_ms: 80.0,
257 median_queue_depth: 20.0,
258 });
259 }
260 assert_eq!(controller.current_chunk_bytes(), CHUNK_8_MB);
261 }
262
263 #[test]
264 fn process_payload_emits_multiple_chunks() {
265 let payload = vec![7_u8; CHUNK_2_MB + 1024];
266 let output = process_payload_crc_zstd_aes(&payload, CHUNK_2_MB, &[0x42_u8; 32], 100)
267 .expect("pipeline processing should succeed");
268 assert_eq!(output.chunks.len(), 2);
269 assert!(output.total_ciphertext_bytes > 0);
270 }
271
272 #[test]
273 fn restore_chunk_round_trip_works() {
274 let payload = vec![9_u8; CHUNK_2_MB + 321];
275 let output = process_payload_crc_zstd_aes(&payload, CHUNK_2_MB, &[0xAA_u8; 32], 42)
276 .expect("pipeline processing should succeed");
277 let mut restored = Vec::new();
278 for chunk in output.chunks {
279 let plain = restore_chunk_zstd_aes(
280 &chunk.bytes,
281 chunk.crc32c,
282 chunk.plain_len,
283 &[0xAA_u8; 32],
284 chunk.nonce_id,
285 )
286 .expect("restore should succeed");
287 restored.extend_from_slice(&plain);
288 }
289 assert_eq!(restored, payload);
290 }
291}