Skip to main content

dynomite/entropy/
send.rs

1//! Entropy sender.
2//!
3//! Drives the client side of the reconciliation channel: dial the
4//! peer, negotiate, then push a snapshot from a [`SnapshotSource`]
5//! one chunk at a time.
6
7use std::io::{Read, Write};
8use std::net::{SocketAddr, TcpStream};
9use std::time::Duration;
10
11use aes::Aes128;
12use cbc::cipher::block_padding::Pkcs7;
13use cbc::cipher::{BlockEncryptMut, KeyIvInit};
14use tokio::io::AsyncWriteExt;
15use tokio::net::TcpStream as TokioTcpStream;
16use tokio::task::JoinHandle;
17
18use crate::entropy::util::EntropyMaterial;
19use crate::entropy::{
20    BoxedSnapshotSource, EntropyConfig, EntropyError, EntropyResult, NegotiationHeader,
21    SnapshotHeader, SnapshotSource, ENTROPY_COMMAND_SEND, ENTROPY_MAGIC,
22};
23
24type Aes128CbcEnc = cbc::Encryptor<Aes128>;
25
26/// Entropy sender driver.
27///
28/// Operationally, calling [`run`](EntropySender::run) spawns a
29/// tokio task that performs one push and then exits.
30///
31/// # Examples
32///
33/// ```no_run
34/// use std::path::PathBuf;
35/// use std::sync::Arc;
36/// use dynomite::entropy::{
37///     EntropyConfig, EntropySender, RedisLocalSnapshot,
38/// };
39///
40/// # async fn run() {
41/// let cfg = EntropyConfig {
42///     key_file: PathBuf::from("conf/recon_key.pem"),
43///     iv_file: PathBuf::from("conf/recon_iv.pem"),
44///     listen_addr: "127.0.0.1:8105".parse().unwrap(),
45///     send_addr: None,
46///     peer_endpoint: "127.0.0.1:8105".parse().unwrap(),
47///     buffer_size: 16 * 1024,
48///     header_size: 1024,
49///     encrypt: true,
50/// };
51/// let source = Arc::new(RedisLocalSnapshot::default());
52/// let handle = EntropySender::run(cfg, source);
53/// handle.await.unwrap().unwrap();
54/// # }
55/// ```
56pub struct EntropySender;
57
58impl EntropySender {
59    /// Spawn a tokio task that performs one snapshot push.
60    ///
61    /// The returned [`JoinHandle`] resolves to the push result.
62    /// The task uses `cfg.peer_endpoint` to dial the receiver and
63    /// `cfg.send_addr` (when set) for the local bind.
64    #[must_use]
65    pub fn run(
66        cfg: EntropyConfig,
67        source: BoxedSnapshotSource,
68    ) -> JoinHandle<EntropyResult<usize>> {
69        tokio::spawn(async move { Self::push(cfg, source).await })
70    }
71
72    /// Perform one snapshot push and return the number of plaintext
73    /// bytes transferred. Used directly by tests; the spawned
74    /// variant just calls this on a tokio task.
75    ///
76    /// # Errors
77    /// Forwards anything from key loading, source acquisition,
78    /// dialing, or transport.
79    pub async fn push(cfg: EntropyConfig, source: BoxedSnapshotSource) -> EntropyResult<usize> {
80        Self::push_with_material(cfg, source, None).await
81    }
82
83    /// Perform one snapshot push using a pre-loaded
84    /// [`EntropyMaterial`].
85    ///
86    /// When `override_material` is `None` and `cfg.encrypt` is
87    /// set, key + IV bytes are loaded from `cfg.key_file` /
88    /// `cfg.iv_file` (the historical behaviour). When
89    /// `override_material` is `Some(_)`, the supplied material is
90    /// used verbatim and the on-disk paths are not touched. The
91    /// override is consumed by the [`crate::entropy::driver::EntropyDriver`]
92    /// run loop, which loads the AES material once at startup
93    /// and drives many cycles from the same handle.
94    ///
95    /// # Errors
96    /// Forwards anything from key loading (when no override is
97    /// supplied), source acquisition, dialing, or transport.
98    pub async fn push_with_material(
99        cfg: EntropyConfig,
100        source: BoxedSnapshotSource,
101        override_material: Option<EntropyMaterial>,
102    ) -> EntropyResult<usize> {
103        cfg.validate()?;
104        let material = if cfg.encrypt {
105            match override_material {
106                Some(m) => Some(m),
107                None => Some(crate::entropy::util::load_material(
108                    &cfg.key_file,
109                    &cfg.iv_file,
110                )?),
111            }
112        } else {
113            None
114        };
115        let snapshot = collect_snapshot(source).await?;
116
117        let mut stream = dial(&cfg).await?;
118        write_negotiation(&mut stream, &cfg).await?;
119        write_snapshot_header(&mut stream, &cfg, snapshot.len()).await?;
120        write_chunks(&mut stream, &cfg, &snapshot, material.as_ref()).await?;
121        stream.shutdown().await?;
122        Ok(snapshot.len())
123    }
124}
125
126async fn collect_snapshot(source: BoxedSnapshotSource) -> EntropyResult<Vec<u8>> {
127    tokio::task::spawn_blocking(move || source.snapshot())
128        .await
129        .map_err(|e| EntropyError::Source(format!("snapshot task panicked: {e}")))?
130}
131
132async fn dial(cfg: &EntropyConfig) -> EntropyResult<TokioTcpStream> {
133    let socket = match cfg.peer_endpoint {
134        SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
135        SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
136    };
137    if let Some(local) = cfg.send_addr {
138        socket.bind(local)?;
139    }
140    let stream = socket.connect(cfg.peer_endpoint).await?;
141    stream.set_nodelay(true)?;
142    Ok(stream)
143}
144
145async fn write_negotiation(stream: &mut TokioTcpStream, cfg: &EntropyConfig) -> EntropyResult<()> {
146    let hdr = NegotiationHeader {
147        magic: ENTROPY_MAGIC,
148        command: ENTROPY_COMMAND_SEND,
149        header_size: u32::try_from(cfg.header_size)
150            .map_err(|_| EntropyError::Config("header_size > u32::MAX".to_string()))?,
151        buffer_size: u32::try_from(cfg.buffer_size)
152            .map_err(|_| EntropyError::Config("buffer_size > u32::MAX".to_string()))?,
153        cipher_size: u32::try_from(cipher_capacity(cfg.buffer_size))
154            .map_err(|_| EntropyError::Config("cipher_size > u32::MAX".to_string()))?,
155    };
156    stream.write_all(&hdr.to_wire()).await?;
157    Ok(())
158}
159
160async fn write_snapshot_header(
161    stream: &mut TokioTcpStream,
162    cfg: &EntropyConfig,
163    total_len: usize,
164) -> EntropyResult<()> {
165    let total_len_u32 = u32::try_from(total_len)
166        .map_err(|_| EntropyError::Source(format!("snapshot too large: {total_len} bytes")))?;
167    let hdr = SnapshotHeader {
168        total_len: total_len_u32,
169        encrypt_flag: u32::from(cfg.encrypt),
170    };
171    let bytes = hdr.to_wire(cfg.header_size)?;
172    stream.write_all(&bytes).await?;
173    Ok(())
174}
175
176async fn write_chunks(
177    stream: &mut TokioTcpStream,
178    cfg: &EntropyConfig,
179    snapshot: &[u8],
180    material: Option<&EntropyMaterial>,
181) -> EntropyResult<()> {
182    let buf = cfg.buffer_size;
183    let mut offset = 0;
184    while offset < snapshot.len() {
185        let end = (offset + buf).min(snapshot.len());
186        let plaintext = &snapshot[offset..end];
187        let payload: Vec<u8> = if let Some(mat) = material {
188            encrypt_chunk(plaintext, mat)?
189        } else {
190            plaintext.to_vec()
191        };
192        let chunk_len = u32::try_from(payload.len())
193            .map_err(|_| EntropyError::Protocol("chunk too large".to_string()))?;
194        stream.write_all(&chunk_len.to_be_bytes()).await?;
195        stream.write_all(&payload).await?;
196        offset = end;
197    }
198    Ok(())
199}
200
201/// AES-128-CBC encrypt one chunk using the entropy key + IV.
202///
203/// Pure helper exposed for the integration tests.
204///
205/// # Errors
206/// [`EntropyError::Crypto`] if the cipher reports a fault.
207pub fn encrypt_chunk(plaintext: &[u8], material: &EntropyMaterial) -> EntropyResult<Vec<u8>> {
208    let key = material.key().as_bytes();
209    let iv = material.iv().as_bytes();
210    let cipher = Aes128CbcEnc::new(key.into(), iv.into());
211    Ok(cipher.encrypt_padded_vec_mut::<Pkcs7>(plaintext))
212}
213
214/// PKCS#7 cipher overhead is at most one block; the negotiated
215/// cipher capacity is therefore plaintext + 16 bytes. Production
216/// servers can safely advertise more headroom.
217pub(crate) fn cipher_capacity(buffer_size: usize) -> usize {
218    buffer_size + 16
219}
220
221/// Default snapshot source: pulls a Redis snapshot from a local
222/// Redis instance.
223///
224/// The reference engine pipes the on-disk AOF file produced by
225/// `BGREWRITEAOF` over the entropy channel. The Rust default
226/// follows that contract: it issues a `BGREWRITEAOF` over the
227/// configured Redis TCP endpoint, waits for the command to be
228/// acknowledged, then reads the AOF file from disk. Embedders
229/// that need a different snapshot strategy plug in their own
230/// [`SnapshotSource`] through the Stage 13 API.
231///
232/// # Examples
233///
234/// ```
235/// use dynomite::entropy::RedisLocalSnapshot;
236/// let source = RedisLocalSnapshot::default();
237/// drop(source);
238/// ```
239pub struct RedisLocalSnapshot {
240    /// Address of the local Redis instance.
241    pub redis_addr: SocketAddr,
242    /// Path to the AOF file on disk.
243    pub aof_path: std::path::PathBuf,
244    /// Connect/read timeout.
245    pub timeout: Duration,
246    /// Number of `BGREWRITEAOF` retries before giving up. The
247    /// reference engine retries once after a 10 second sleep; we
248    /// expose the count and pause as parameters.
249    pub bgrewrite_retries: u32,
250    /// Pause between `BGREWRITEAOF` retries.
251    pub bgrewrite_retry_pause: Duration,
252}
253
254impl Default for RedisLocalSnapshot {
255    fn default() -> Self {
256        Self {
257            redis_addr: "127.0.0.1:22122".parse().expect("static literal parses"),
258            aof_path: std::path::PathBuf::from("/mnt/data/nfredis/appendonly.aof"),
259            timeout: Duration::from_secs(30),
260            bgrewrite_retries: 1,
261            bgrewrite_retry_pause: Duration::from_secs(10),
262        }
263    }
264}
265
266impl RedisLocalSnapshot {
267    /// Override the Redis address.
268    #[must_use]
269    pub fn with_redis_addr(mut self, addr: SocketAddr) -> Self {
270        self.redis_addr = addr;
271        self
272    }
273
274    /// Override the AOF on-disk path.
275    #[must_use]
276    pub fn with_aof_path(mut self, path: std::path::PathBuf) -> Self {
277        self.aof_path = path;
278        self
279    }
280
281    fn bgrewriteaof(&self) -> EntropyResult<()> {
282        let mut last_err: Option<EntropyError> = None;
283        for attempt in 0..=self.bgrewrite_retries {
284            match self.try_bgrewriteaof() {
285                Ok(()) => return Ok(()),
286                Err(e) => {
287                    last_err = Some(e);
288                    if attempt < self.bgrewrite_retries {
289                        std::thread::sleep(self.bgrewrite_retry_pause);
290                    }
291                }
292            }
293        }
294        Err(last_err.unwrap_or_else(|| {
295            EntropyError::Source("bgrewriteaof failed without error".to_string())
296        }))
297    }
298
299    fn try_bgrewriteaof(&self) -> EntropyResult<()> {
300        let mut sock = TcpStream::connect_timeout(&self.redis_addr, self.timeout)
301            .map_err(|e| EntropyError::Source(format!("connect to redis: {e}")))?;
302        sock.set_read_timeout(Some(self.timeout))
303            .map_err(|e| EntropyError::Source(format!("redis timeout: {e}")))?;
304        sock.set_write_timeout(Some(self.timeout))
305            .map_err(|e| EntropyError::Source(format!("redis timeout: {e}")))?;
306        sock.write_all(b"*1\r\n$13\r\nBGREWRITEAOF\r\n")
307            .map_err(|e| EntropyError::Source(format!("redis write: {e}")))?;
308        let mut buf = [0u8; 256];
309        let n = sock
310            .read(&mut buf)
311            .map_err(|e| EntropyError::Source(format!("redis read: {e}")))?;
312        let reply = &buf[..n];
313        if reply.first() != Some(&b'+') {
314            return Err(EntropyError::Source(format!(
315                "BGREWRITEAOF rejected: {}",
316                String::from_utf8_lossy(reply)
317            )));
318        }
319        Ok(())
320    }
321}
322
323impl SnapshotSource for RedisLocalSnapshot {
324    fn snapshot(&self) -> EntropyResult<Vec<u8>> {
325        self.bgrewriteaof()?;
326        // Brief pause to let Redis flush the AOF rewrite to disk.
327        std::thread::sleep(Duration::from_secs(1));
328        let bytes = std::fs::read(&self.aof_path).map_err(|e| {
329            EntropyError::Source(format!("read AOF {}: {e}", self.aof_path.display()))
330        })?;
331        Ok(bytes)
332    }
333}
334
335/// In-memory snapshot source useful in tests and embedders that
336/// already hold the snapshot in RAM.
337///
338/// # Examples
339///
340/// ```
341/// use std::sync::Arc;
342/// use dynomite::entropy::send::StaticSnapshot;
343/// use dynomite::entropy::SnapshotSource;
344/// let s = Arc::new(StaticSnapshot::new(b"hello".to_vec()));
345/// assert_eq!(s.snapshot().unwrap(), b"hello");
346/// ```
347pub struct StaticSnapshot {
348    bytes: Vec<u8>,
349}
350
351impl StaticSnapshot {
352    /// Wrap a fixed byte vector.
353    #[must_use]
354    pub fn new(bytes: Vec<u8>) -> Self {
355        Self { bytes }
356    }
357}
358
359impl SnapshotSource for StaticSnapshot {
360    fn snapshot(&self) -> EntropyResult<Vec<u8>> {
361        Ok(self.bytes.clone())
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use crate::entropy::util::{EntropyIv, EntropyKey, ENTROPY_IV_LEN, ENTROPY_KEY_LEN};
369
370    fn material() -> EntropyMaterial {
371        EntropyMaterial::new(
372            EntropyKey::from_bytes([0x10; ENTROPY_KEY_LEN]),
373            EntropyIv::from_bytes([0x42; ENTROPY_IV_LEN]),
374        )
375    }
376
377    #[test]
378    fn encrypt_chunk_round_trips_with_pkcs7() {
379        use aes::Aes128;
380        use cbc::cipher::block_padding::Pkcs7;
381        use cbc::cipher::{BlockDecryptMut, KeyIvInit};
382        type Dec = cbc::Decryptor<Aes128>;
383
384        let mat = material();
385        let pt = b"hello entropy world";
386        let ct = encrypt_chunk(pt, &mat).unwrap();
387        // PKCS#7 always extends to the next 16-byte block.
388        assert!(ct.len() >= pt.len());
389        assert_eq!(ct.len() % 16, 0);
390
391        let dec = Dec::new(mat.key().as_bytes().into(), mat.iv().as_bytes().into());
392        let plain = dec.decrypt_padded_vec_mut::<Pkcs7>(&ct).unwrap();
393        assert_eq!(plain, pt);
394    }
395
396    #[test]
397    fn cipher_capacity_includes_pkcs7_block() {
398        assert_eq!(cipher_capacity(16), 32);
399        assert_eq!(cipher_capacity(15), 31);
400    }
401
402    #[test]
403    fn static_snapshot_returns_payload() {
404        let s = StaticSnapshot::new(b"abc".to_vec());
405        assert_eq!(s.snapshot().unwrap(), b"abc");
406    }
407
408    #[test]
409    fn arc_static_snapshot_returns_payload() {
410        let s: BoxedSnapshotSource = std::sync::Arc::new(StaticSnapshot::new(vec![1, 2, 3]));
411        assert_eq!(s.snapshot().unwrap(), vec![1, 2, 3]);
412    }
413}