dynomite/entropy/
receive.rs1use std::io::Write;
8use std::net::{SocketAddr, TcpStream};
9use std::time::Duration;
10
11use aes::Aes128;
12use cbc::cipher::block_padding::Pkcs7;
13use cbc::cipher::{BlockDecryptMut, KeyIvInit};
14use tokio::io::AsyncReadExt;
15use tokio::net::{TcpListener, TcpStream as TokioTcpStream};
16use tokio::task::JoinHandle;
17
18use crate::entropy::util::EntropyMaterial;
19use crate::entropy::{
20 BoxedSnapshotSink, EntropyConfig, EntropyError, EntropyResult, NegotiationHeader,
21 SnapshotHeader, SnapshotSink, MAX_CIPHER_SIZE, MAX_SNAPSHOT_SIZE, SAFE_PREALLOC,
22};
23
24type Aes128CbcDec = cbc::Decryptor<Aes128>;
25
26pub struct EntropyReceiver {
54 listener: TcpListener,
55 cfg: EntropyConfig,
56 sink: BoxedSnapshotSink,
57}
58
59impl EntropyReceiver {
60 pub async fn run(
72 cfg: EntropyConfig,
73 sink: BoxedSnapshotSink,
74 ) -> EntropyResult<JoinHandle<EntropyResult<()>>> {
75 let recv = Self::bind(cfg, sink).await?;
76 Ok(tokio::spawn(async move { recv.accept_loop().await }))
77 }
78
79 pub async fn bind(cfg: EntropyConfig, sink: BoxedSnapshotSink) -> EntropyResult<Self> {
85 cfg.validate()?;
86 if cfg.encrypt {
91 let _ = crate::entropy::util::load_material(&cfg.key_file, &cfg.iv_file)?;
92 }
93 let listener = TcpListener::bind(cfg.listen_addr).await?;
94 Ok(Self {
95 listener,
96 cfg,
97 sink,
98 })
99 }
100
101 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
107 self.listener.local_addr()
108 }
109
110 pub async fn accept_one(self) -> EntropyResult<usize> {
115 let (sock, _peer) = self.listener.accept().await?;
116 handle_one(sock, &self.cfg, self.sink.as_ref()).await
117 }
118
119 pub async fn accept_loop(self) -> EntropyResult<()> {
121 loop {
122 let (sock, peer) = self.listener.accept().await?;
123 tracing::debug!(?peer, "entropy receiver accepted connection");
124 if let Err(e) = handle_one(sock, &self.cfg, self.sink.as_ref()).await {
125 tracing::warn!(?e, "entropy worker errored");
126 }
127 }
128 }
129}
130
131async fn handle_one(
132 mut stream: TokioTcpStream,
133 cfg: &EntropyConfig,
134 sink: &dyn SnapshotSink,
135) -> EntropyResult<usize> {
136 stream.set_nodelay(true)?;
137 let neg = read_negotiation(&mut stream).await?;
138 let snap = read_snapshot_header(&mut stream, &neg).await?;
139 let material = if snap.encrypt_flag == 1 {
140 Some(crate::entropy::util::load_material(
141 &cfg.key_file,
142 &cfg.iv_file,
143 )?)
144 } else {
145 None
146 };
147 let plaintext = read_chunks(&mut stream, &neg, &snap, material.as_ref()).await?;
148 sink.apply(&plaintext).map_err(|e| match e {
149 EntropyError::Sink(msg) => EntropyError::Sink(msg),
150 other => EntropyError::Sink(other.to_string()),
151 })?;
152 Ok(plaintext.len())
153}
154
155async fn read_negotiation(stream: &mut TokioTcpStream) -> EntropyResult<NegotiationHeader> {
156 let mut wire = [0u8; NegotiationHeader::SIZE];
157 stream.read_exact(&mut wire).await?;
158 NegotiationHeader::from_wire(&wire)
159}
160
161async fn read_snapshot_header(
162 stream: &mut TokioTcpStream,
163 neg: &NegotiationHeader,
164) -> EntropyResult<SnapshotHeader> {
165 let mut buf = vec![0u8; neg.header_size as usize];
166 stream.read_exact(&mut buf).await?;
167 SnapshotHeader::from_wire(&buf)
168}
169
170async fn read_chunks(
171 stream: &mut TokioTcpStream,
172 neg: &NegotiationHeader,
173 snap: &SnapshotHeader,
174 material: Option<&EntropyMaterial>,
175) -> EntropyResult<Vec<u8>> {
176 let total_len = snap.total_len as usize;
177 if total_len > MAX_SNAPSHOT_SIZE {
178 return Err(EntropyError::Protocol(format!(
179 "snapshot total_len {total_len} exceeds MAX_SNAPSHOT_SIZE"
180 )));
181 }
182 let mut plaintext = Vec::with_capacity(total_len.min(SAFE_PREALLOC));
186 let mut len_buf = [0u8; 4];
187 while plaintext.len() < total_len {
188 stream.read_exact(&mut len_buf).await?;
189 let chunk_len = u32::from_be_bytes(len_buf) as usize;
190 if chunk_len == 0 {
191 return Err(EntropyError::Protocol("zero-length chunk".to_string()));
192 }
193 if chunk_len > MAX_CIPHER_SIZE || chunk_len > neg.cipher_size as usize {
194 return Err(EntropyError::Protocol(format!(
195 "chunk_len {chunk_len} exceeds negotiated cipher_size {}",
196 neg.cipher_size
197 )));
198 }
199 let mut payload = vec![0u8; chunk_len];
200 stream.read_exact(&mut payload).await?;
201 let chunk_plain = if let Some(mat) = material {
202 decrypt_chunk(&payload, mat)?
203 } else {
204 payload
205 };
206 let take = (total_len - plaintext.len()).min(chunk_plain.len());
207 if take < chunk_plain.len() {
208 return Err(EntropyError::Protocol(format!(
209 "chunk overshoots total_len: have {} more after {} bytes",
210 chunk_plain.len() - take,
211 plaintext.len() + take
212 )));
213 }
214 plaintext.extend_from_slice(&chunk_plain[..take]);
215 }
216
217 let mut probe = [0u8; 1];
221 match stream.read(&mut probe).await {
222 Ok(0) | Err(_) => {}
223 Ok(_) => {
224 return Err(EntropyError::Protocol(
225 "trailing bytes after declared total_len".to_string(),
226 ));
227 }
228 }
229
230 Ok(plaintext)
231}
232
233pub fn decrypt_chunk(ciphertext: &[u8], material: &EntropyMaterial) -> EntropyResult<Vec<u8>> {
242 if ciphertext.is_empty() || !ciphertext.len().is_multiple_of(16) {
243 return Err(EntropyError::Crypto(format!(
244 "ciphertext length {} is not a positive multiple of 16",
245 ciphertext.len()
246 )));
247 }
248 let key = material.key().as_bytes();
249 let iv = material.iv().as_bytes();
250 let cipher = Aes128CbcDec::new(key.into(), iv.into());
251 cipher
252 .decrypt_padded_vec_mut::<Pkcs7>(ciphertext)
253 .map_err(|e| EntropyError::Crypto(format!("PKCS#7 unpad failed: {e}")))
254}
255
256pub struct RedisReplaySink {
275 pub redis_addr: SocketAddr,
277 pub timeout: Duration,
279}
280
281impl Default for RedisReplaySink {
282 fn default() -> Self {
283 Self {
284 redis_addr: "127.0.0.1:22122".parse().expect("static literal parses"),
285 timeout: Duration::from_secs(30),
286 }
287 }
288}
289
290impl RedisReplaySink {
291 #[must_use]
293 pub fn with_redis_addr(mut self, addr: SocketAddr) -> Self {
294 self.redis_addr = addr;
295 self
296 }
297}
298
299impl SnapshotSink for RedisReplaySink {
300 fn apply(&self, snapshot: &[u8]) -> EntropyResult<()> {
301 let mut sock = TcpStream::connect_timeout(&self.redis_addr, self.timeout)
302 .map_err(|e| EntropyError::Sink(format!("connect to redis: {e}")))?;
303 sock.set_write_timeout(Some(self.timeout))
304 .map_err(|e| EntropyError::Sink(format!("redis timeout: {e}")))?;
305 sock.write_all(snapshot)
306 .map_err(|e| EntropyError::Sink(format!("redis write: {e}")))?;
307 Ok(())
308 }
309}
310
311#[derive(Default)]
325pub struct MemorySink {
326 inner: parking_lot::Mutex<Vec<u8>>,
327}
328
329impl MemorySink {
330 #[must_use]
332 pub fn take(&self) -> Vec<u8> {
333 std::mem::take(&mut *self.inner.lock())
334 }
335
336 #[must_use]
338 pub fn snapshot(&self) -> Vec<u8> {
339 self.inner.lock().clone()
340 }
341}
342
343impl SnapshotSink for MemorySink {
344 fn apply(&self, snapshot: &[u8]) -> EntropyResult<()> {
345 let mut buf = self.inner.lock();
346 buf.clear();
347 buf.extend_from_slice(snapshot);
348 Ok(())
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::entropy::send::encrypt_chunk;
356 use crate::entropy::util::{EntropyIv, EntropyKey, ENTROPY_IV_LEN, ENTROPY_KEY_LEN};
357
358 fn material() -> EntropyMaterial {
359 EntropyMaterial::new(
360 EntropyKey::from_bytes([0x10; ENTROPY_KEY_LEN]),
361 EntropyIv::from_bytes([0x42; ENTROPY_IV_LEN]),
362 )
363 }
364
365 #[test]
366 fn decrypt_chunk_rejects_short_buffer() {
367 let mat = material();
368 let err = decrypt_chunk(&[], &mat).unwrap_err();
369 assert!(matches!(err, EntropyError::Crypto(_)));
370 }
371
372 #[test]
373 fn decrypt_chunk_rejects_misaligned() {
374 let mat = material();
375 let err = decrypt_chunk(&[0u8; 17], &mat).unwrap_err();
376 assert!(matches!(err, EntropyError::Crypto(_)));
377 }
378
379 #[test]
380 fn encrypt_decrypt_round_trip() {
381 let mat = material();
382 let pt = b"the quick brown fox jumps over the lazy dog";
383 let ct = encrypt_chunk(pt, &mat).unwrap();
384 let plain = decrypt_chunk(&ct, &mat).unwrap();
385 assert_eq!(plain, pt);
386 }
387
388 #[test]
389 fn decrypt_chunk_rejects_tamper() {
390 let mat = material();
391 let pt = b"the quick brown fox";
392 let mut ct = encrypt_chunk(pt, &mat).unwrap();
393 let last = ct.last_mut().unwrap();
396 *last ^= 0xff;
397 let err = decrypt_chunk(&ct, &mat).unwrap_err();
398 assert!(matches!(err, EntropyError::Crypto(_)));
399 }
400
401 #[test]
402 fn memory_sink_round_trips() {
403 let sink = MemorySink::default();
404 sink.apply(b"abc").unwrap();
405 assert_eq!(sink.snapshot(), b"abc");
406 let drained = sink.take();
407 assert_eq!(drained, b"abc");
408 assert!(sink.snapshot().is_empty());
409 }
410}