Skip to main content

dynomite/entropy/
receive.rs

1//! Entropy receiver.
2//!
3//! Drives the server side of the reconciliation channel: bind,
4//! accept, validate the negotiation, decrypt incoming chunks, and
5//! hand the plaintext snapshot to a [`SnapshotSink`].
6
7use std::io::Write;
8use std::net::{SocketAddr, TcpStream};
9use std::time::Duration;
10
11use aes::Aes128;
12use cbc::cipher::block_padding::Pkcs7;
13use cbc::cipher::{BlockDecryptMut, KeyIvInit};
14use tokio::io::AsyncReadExt;
15use tokio::net::{TcpListener, TcpStream as TokioTcpStream};
16use tokio::task::JoinHandle;
17
18use crate::entropy::util::EntropyMaterial;
19use crate::entropy::{
20    BoxedSnapshotSink, EntropyConfig, EntropyError, EntropyResult, NegotiationHeader,
21    SnapshotHeader, SnapshotSink, MAX_CIPHER_SIZE, MAX_SNAPSHOT_SIZE, SAFE_PREALLOC,
22};
23
24type Aes128CbcDec = cbc::Decryptor<Aes128>;
25
26/// Entropy receiver driver.
27///
28/// # Examples
29///
30/// ```no_run
31/// use std::path::PathBuf;
32/// use std::sync::Arc;
33/// use dynomite::entropy::{
34///     receive::RedisReplaySink, EntropyConfig, EntropyReceiver,
35/// };
36///
37/// # async fn run() {
38/// let cfg = EntropyConfig {
39///     key_file: PathBuf::from("conf/recon_key.pem"),
40///     iv_file: PathBuf::from("conf/recon_iv.pem"),
41///     listen_addr: "127.0.0.1:8105".parse().unwrap(),
42///     send_addr: None,
43///     peer_endpoint: "127.0.0.1:8105".parse().unwrap(),
44///     buffer_size: 16 * 1024,
45///     header_size: 1024,
46///     encrypt: true,
47/// };
48/// let sink = Arc::new(RedisReplaySink::default());
49/// let handle = EntropyReceiver::run(cfg, sink).await.unwrap();
50/// handle.abort();
51/// # }
52/// ```
53pub struct EntropyReceiver {
54    listener: TcpListener,
55    cfg: EntropyConfig,
56    sink: BoxedSnapshotSink,
57}
58
59impl EntropyReceiver {
60    /// Bind a receiver to `cfg.listen_addr` and spawn the accept
61    /// loop on a tokio task.
62    ///
63    /// Each accepted connection is handled in line on the same
64    /// task, mirroring the reference engine's single-threaded
65    /// entropy loop. The returned handle resolves to `Ok(())` only
66    /// after the listener is shut down (e.g. by aborting the task)
67    /// or to an error if the bind fails.
68    ///
69    /// # Errors
70    /// Forwards anything from key loading or socket bind.
71    pub async fn run(
72        cfg: EntropyConfig,
73        sink: BoxedSnapshotSink,
74    ) -> EntropyResult<JoinHandle<EntropyResult<()>>> {
75        let recv = Self::bind(cfg, sink).await?;
76        Ok(tokio::spawn(async move { recv.accept_loop().await }))
77    }
78
79    /// Bind without spawning. Used by tests that want to drive the
80    /// accept loop on the caller's task.
81    ///
82    /// # Errors
83    /// Forwards anything from key loading or socket bind.
84    pub async fn bind(cfg: EntropyConfig, sink: BoxedSnapshotSink) -> EntropyResult<Self> {
85        cfg.validate()?;
86        // Eagerly validate that key material is loadable when
87        // encryption is in use. Repeated reads happen on every
88        // accept so the listener is also useful as an end-to-end
89        // smoke test.
90        if cfg.encrypt {
91            let _ = crate::entropy::util::load_material(&cfg.key_file, &cfg.iv_file)?;
92        }
93        let listener = TcpListener::bind(cfg.listen_addr).await?;
94        Ok(Self {
95            listener,
96            cfg,
97            sink,
98        })
99    }
100
101    /// Local address the receiver is bound to. Useful for tests
102    /// that bind to `:0` and then dial the kernel-assigned port.
103    ///
104    /// # Errors
105    /// Forwarded from the underlying socket call.
106    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
107        self.listener.local_addr()
108    }
109
110    /// Accept exactly one connection, process it, and return.
111    ///
112    /// # Errors
113    /// Forwards I/O, protocol, and crypto errors from the worker.
114    pub async fn accept_one(self) -> EntropyResult<usize> {
115        let (sock, _peer) = self.listener.accept().await?;
116        handle_one(sock, &self.cfg, self.sink.as_ref()).await
117    }
118
119    /// Accept connections until the task is aborted.
120    pub async fn accept_loop(self) -> EntropyResult<()> {
121        loop {
122            let (sock, peer) = self.listener.accept().await?;
123            tracing::debug!(?peer, "entropy receiver accepted connection");
124            if let Err(e) = handle_one(sock, &self.cfg, self.sink.as_ref()).await {
125                tracing::warn!(?e, "entropy worker errored");
126            }
127        }
128    }
129}
130
131async fn handle_one(
132    mut stream: TokioTcpStream,
133    cfg: &EntropyConfig,
134    sink: &dyn SnapshotSink,
135) -> EntropyResult<usize> {
136    stream.set_nodelay(true)?;
137    let neg = read_negotiation(&mut stream).await?;
138    let snap = read_snapshot_header(&mut stream, &neg).await?;
139    let material = if snap.encrypt_flag == 1 {
140        Some(crate::entropy::util::load_material(
141            &cfg.key_file,
142            &cfg.iv_file,
143        )?)
144    } else {
145        None
146    };
147    let plaintext = read_chunks(&mut stream, &neg, &snap, material.as_ref()).await?;
148    sink.apply(&plaintext).map_err(|e| match e {
149        EntropyError::Sink(msg) => EntropyError::Sink(msg),
150        other => EntropyError::Sink(other.to_string()),
151    })?;
152    Ok(plaintext.len())
153}
154
155async fn read_negotiation(stream: &mut TokioTcpStream) -> EntropyResult<NegotiationHeader> {
156    let mut wire = [0u8; NegotiationHeader::SIZE];
157    stream.read_exact(&mut wire).await?;
158    NegotiationHeader::from_wire(&wire)
159}
160
161async fn read_snapshot_header(
162    stream: &mut TokioTcpStream,
163    neg: &NegotiationHeader,
164) -> EntropyResult<SnapshotHeader> {
165    let mut buf = vec![0u8; neg.header_size as usize];
166    stream.read_exact(&mut buf).await?;
167    SnapshotHeader::from_wire(&buf)
168}
169
170async fn read_chunks(
171    stream: &mut TokioTcpStream,
172    neg: &NegotiationHeader,
173    snap: &SnapshotHeader,
174    material: Option<&EntropyMaterial>,
175) -> EntropyResult<Vec<u8>> {
176    let total_len = snap.total_len as usize;
177    if total_len > MAX_SNAPSHOT_SIZE {
178        return Err(EntropyError::Protocol(format!(
179            "snapshot total_len {total_len} exceeds MAX_SNAPSHOT_SIZE"
180        )));
181    }
182    // Cap the upfront allocation to avoid a malicious or malformed
183    // total_len triggering an oversized prealloc; the Vec grows as
184    // plaintext arrives if the snapshot is genuinely larger.
185    let mut plaintext = Vec::with_capacity(total_len.min(SAFE_PREALLOC));
186    let mut len_buf = [0u8; 4];
187    while plaintext.len() < total_len {
188        stream.read_exact(&mut len_buf).await?;
189        let chunk_len = u32::from_be_bytes(len_buf) as usize;
190        if chunk_len == 0 {
191            return Err(EntropyError::Protocol("zero-length chunk".to_string()));
192        }
193        if chunk_len > MAX_CIPHER_SIZE || chunk_len > neg.cipher_size as usize {
194            return Err(EntropyError::Protocol(format!(
195                "chunk_len {chunk_len} exceeds negotiated cipher_size {}",
196                neg.cipher_size
197            )));
198        }
199        let mut payload = vec![0u8; chunk_len];
200        stream.read_exact(&mut payload).await?;
201        let chunk_plain = if let Some(mat) = material {
202            decrypt_chunk(&payload, mat)?
203        } else {
204            payload
205        };
206        let take = (total_len - plaintext.len()).min(chunk_plain.len());
207        if take < chunk_plain.len() {
208            return Err(EntropyError::Protocol(format!(
209                "chunk overshoots total_len: have {} more after {} bytes",
210                chunk_plain.len() - take,
211                plaintext.len() + take
212            )));
213        }
214        plaintext.extend_from_slice(&chunk_plain[..take]);
215    }
216
217    // Drain any final framing length the sender might emit; for
218    // well-behaved senders the stream is now at EOF. If extra
219    // bytes arrive, treat that as a protocol violation.
220    let mut probe = [0u8; 1];
221    match stream.read(&mut probe).await {
222        Ok(0) | Err(_) => {}
223        Ok(_) => {
224            return Err(EntropyError::Protocol(
225                "trailing bytes after declared total_len".to_string(),
226            ));
227        }
228    }
229
230    Ok(plaintext)
231}
232
233/// AES-128-CBC decrypt one ciphertext chunk using the entropy
234/// key + IV. Returned plaintext has any PKCS#7 padding stripped.
235///
236/// Pure helper exposed for the integration tests.
237///
238/// # Errors
239/// [`EntropyError::Crypto`] if the input length is not a positive
240/// multiple of 16 or the trailing block has invalid padding.
241pub fn decrypt_chunk(ciphertext: &[u8], material: &EntropyMaterial) -> EntropyResult<Vec<u8>> {
242    if ciphertext.is_empty() || !ciphertext.len().is_multiple_of(16) {
243        return Err(EntropyError::Crypto(format!(
244            "ciphertext length {} is not a positive multiple of 16",
245            ciphertext.len()
246        )));
247    }
248    let key = material.key().as_bytes();
249    let iv = material.iv().as_bytes();
250    let cipher = Aes128CbcDec::new(key.into(), iv.into());
251    cipher
252        .decrypt_padded_vec_mut::<Pkcs7>(ciphertext)
253        .map_err(|e| EntropyError::Crypto(format!("PKCS#7 unpad failed: {e}")))
254}
255
256/// Default sink: pushes the decrypted snapshot to a local Redis
257/// instance over a fresh TCP connection.
258///
259/// The reference engine's receiver opens a connection to Redis on
260/// `127.0.0.1:22122` and writes each per-key payload as it
261/// arrives. The Rust default is a single-shot equivalent: it
262/// connects to the configured Redis endpoint, writes the entire
263/// decrypted snapshot, and closes. Embedders that want a custom
264/// replay strategy plug in their own [`SnapshotSink`] through the
265/// Stage 13 API.
266///
267/// # Examples
268///
269/// ```
270/// use dynomite::entropy::receive::RedisReplaySink;
271/// let sink = RedisReplaySink::default();
272/// drop(sink);
273/// ```
274pub struct RedisReplaySink {
275    /// Address of the local Redis instance.
276    pub redis_addr: SocketAddr,
277    /// Connect/read timeout.
278    pub timeout: Duration,
279}
280
281impl Default for RedisReplaySink {
282    fn default() -> Self {
283        Self {
284            redis_addr: "127.0.0.1:22122".parse().expect("static literal parses"),
285            timeout: Duration::from_secs(30),
286        }
287    }
288}
289
290impl RedisReplaySink {
291    /// Override the Redis address.
292    #[must_use]
293    pub fn with_redis_addr(mut self, addr: SocketAddr) -> Self {
294        self.redis_addr = addr;
295        self
296    }
297}
298
299impl SnapshotSink for RedisReplaySink {
300    fn apply(&self, snapshot: &[u8]) -> EntropyResult<()> {
301        let mut sock = TcpStream::connect_timeout(&self.redis_addr, self.timeout)
302            .map_err(|e| EntropyError::Sink(format!("connect to redis: {e}")))?;
303        sock.set_write_timeout(Some(self.timeout))
304            .map_err(|e| EntropyError::Sink(format!("redis timeout: {e}")))?;
305        sock.write_all(snapshot)
306            .map_err(|e| EntropyError::Sink(format!("redis write: {e}")))?;
307        Ok(())
308    }
309}
310
311/// In-memory sink used by tests and embedders that just want to
312/// inspect the decrypted snapshot.
313///
314/// # Examples
315///
316/// ```
317/// use std::sync::Arc;
318/// use dynomite::entropy::receive::MemorySink;
319/// use dynomite::entropy::SnapshotSink;
320/// let sink = Arc::new(MemorySink::default());
321/// sink.apply(b"hi").unwrap();
322/// assert_eq!(sink.take(), b"hi");
323/// ```
324#[derive(Default)]
325pub struct MemorySink {
326    inner: parking_lot::Mutex<Vec<u8>>,
327}
328
329impl MemorySink {
330    /// Drain the buffered snapshot, leaving the sink empty.
331    #[must_use]
332    pub fn take(&self) -> Vec<u8> {
333        std::mem::take(&mut *self.inner.lock())
334    }
335
336    /// Borrow the current contents (clone).
337    #[must_use]
338    pub fn snapshot(&self) -> Vec<u8> {
339        self.inner.lock().clone()
340    }
341}
342
343impl SnapshotSink for MemorySink {
344    fn apply(&self, snapshot: &[u8]) -> EntropyResult<()> {
345        let mut buf = self.inner.lock();
346        buf.clear();
347        buf.extend_from_slice(snapshot);
348        Ok(())
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::entropy::send::encrypt_chunk;
356    use crate::entropy::util::{EntropyIv, EntropyKey, ENTROPY_IV_LEN, ENTROPY_KEY_LEN};
357
358    fn material() -> EntropyMaterial {
359        EntropyMaterial::new(
360            EntropyKey::from_bytes([0x10; ENTROPY_KEY_LEN]),
361            EntropyIv::from_bytes([0x42; ENTROPY_IV_LEN]),
362        )
363    }
364
365    #[test]
366    fn decrypt_chunk_rejects_short_buffer() {
367        let mat = material();
368        let err = decrypt_chunk(&[], &mat).unwrap_err();
369        assert!(matches!(err, EntropyError::Crypto(_)));
370    }
371
372    #[test]
373    fn decrypt_chunk_rejects_misaligned() {
374        let mat = material();
375        let err = decrypt_chunk(&[0u8; 17], &mat).unwrap_err();
376        assert!(matches!(err, EntropyError::Crypto(_)));
377    }
378
379    #[test]
380    fn encrypt_decrypt_round_trip() {
381        let mat = material();
382        let pt = b"the quick brown fox jumps over the lazy dog";
383        let ct = encrypt_chunk(pt, &mat).unwrap();
384        let plain = decrypt_chunk(&ct, &mat).unwrap();
385        assert_eq!(plain, pt);
386    }
387
388    #[test]
389    fn decrypt_chunk_rejects_tamper() {
390        let mat = material();
391        let pt = b"the quick brown fox";
392        let mut ct = encrypt_chunk(pt, &mat).unwrap();
393        // Flip a bit in the last (padding) block; PKCS#7 unpad
394        // will reject this with overwhelming probability.
395        let last = ct.last_mut().unwrap();
396        *last ^= 0xff;
397        let err = decrypt_chunk(&ct, &mat).unwrap_err();
398        assert!(matches!(err, EntropyError::Crypto(_)));
399    }
400
401    #[test]
402    fn memory_sink_round_trips() {
403        let sink = MemorySink::default();
404        sink.apply(b"abc").unwrap();
405        assert_eq!(sink.snapshot(), b"abc");
406        let drained = sink.take();
407        assert_eq!(drained, b"abc");
408        assert!(sink.snapshot().is_empty());
409    }
410}