1use std::io::{Read, Write};
8use std::net::{SocketAddr, TcpStream};
9use std::time::Duration;
10
11use aes::Aes128;
12use cbc::cipher::block_padding::Pkcs7;
13use cbc::cipher::{BlockEncryptMut, KeyIvInit};
14use tokio::io::AsyncWriteExt;
15use tokio::net::TcpStream as TokioTcpStream;
16use tokio::task::JoinHandle;
17
18use crate::entropy::util::EntropyMaterial;
19use crate::entropy::{
20 BoxedSnapshotSource, EntropyConfig, EntropyError, EntropyResult, NegotiationHeader,
21 SnapshotHeader, SnapshotSource, ENTROPY_COMMAND_SEND, ENTROPY_MAGIC,
22};
23
24type Aes128CbcEnc = cbc::Encryptor<Aes128>;
25
26pub struct EntropySender;
57
58impl EntropySender {
59 #[must_use]
65 pub fn run(
66 cfg: EntropyConfig,
67 source: BoxedSnapshotSource,
68 ) -> JoinHandle<EntropyResult<usize>> {
69 tokio::spawn(async move { Self::push(cfg, source).await })
70 }
71
72 pub async fn push(cfg: EntropyConfig, source: BoxedSnapshotSource) -> EntropyResult<usize> {
80 Self::push_with_material(cfg, source, None).await
81 }
82
83 pub async fn push_with_material(
99 cfg: EntropyConfig,
100 source: BoxedSnapshotSource,
101 override_material: Option<EntropyMaterial>,
102 ) -> EntropyResult<usize> {
103 cfg.validate()?;
104 let material = if cfg.encrypt {
105 match override_material {
106 Some(m) => Some(m),
107 None => Some(crate::entropy::util::load_material(
108 &cfg.key_file,
109 &cfg.iv_file,
110 )?),
111 }
112 } else {
113 None
114 };
115 let snapshot = collect_snapshot(source).await?;
116
117 let mut stream = dial(&cfg).await?;
118 write_negotiation(&mut stream, &cfg).await?;
119 write_snapshot_header(&mut stream, &cfg, snapshot.len()).await?;
120 write_chunks(&mut stream, &cfg, &snapshot, material.as_ref()).await?;
121 stream.shutdown().await?;
122 Ok(snapshot.len())
123 }
124}
125
126async fn collect_snapshot(source: BoxedSnapshotSource) -> EntropyResult<Vec<u8>> {
127 tokio::task::spawn_blocking(move || source.snapshot())
128 .await
129 .map_err(|e| EntropyError::Source(format!("snapshot task panicked: {e}")))?
130}
131
132async fn dial(cfg: &EntropyConfig) -> EntropyResult<TokioTcpStream> {
133 let socket = match cfg.peer_endpoint {
134 SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
135 SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
136 };
137 if let Some(local) = cfg.send_addr {
138 socket.bind(local)?;
139 }
140 let stream = socket.connect(cfg.peer_endpoint).await?;
141 stream.set_nodelay(true)?;
142 Ok(stream)
143}
144
145async fn write_negotiation(stream: &mut TokioTcpStream, cfg: &EntropyConfig) -> EntropyResult<()> {
146 let hdr = NegotiationHeader {
147 magic: ENTROPY_MAGIC,
148 command: ENTROPY_COMMAND_SEND,
149 header_size: u32::try_from(cfg.header_size)
150 .map_err(|_| EntropyError::Config("header_size > u32::MAX".to_string()))?,
151 buffer_size: u32::try_from(cfg.buffer_size)
152 .map_err(|_| EntropyError::Config("buffer_size > u32::MAX".to_string()))?,
153 cipher_size: u32::try_from(cipher_capacity(cfg.buffer_size))
154 .map_err(|_| EntropyError::Config("cipher_size > u32::MAX".to_string()))?,
155 };
156 stream.write_all(&hdr.to_wire()).await?;
157 Ok(())
158}
159
160async fn write_snapshot_header(
161 stream: &mut TokioTcpStream,
162 cfg: &EntropyConfig,
163 total_len: usize,
164) -> EntropyResult<()> {
165 let total_len_u32 = u32::try_from(total_len)
166 .map_err(|_| EntropyError::Source(format!("snapshot too large: {total_len} bytes")))?;
167 let hdr = SnapshotHeader {
168 total_len: total_len_u32,
169 encrypt_flag: u32::from(cfg.encrypt),
170 };
171 let bytes = hdr.to_wire(cfg.header_size)?;
172 stream.write_all(&bytes).await?;
173 Ok(())
174}
175
176async fn write_chunks(
177 stream: &mut TokioTcpStream,
178 cfg: &EntropyConfig,
179 snapshot: &[u8],
180 material: Option<&EntropyMaterial>,
181) -> EntropyResult<()> {
182 let buf = cfg.buffer_size;
183 let mut offset = 0;
184 while offset < snapshot.len() {
185 let end = (offset + buf).min(snapshot.len());
186 let plaintext = &snapshot[offset..end];
187 let payload: Vec<u8> = if let Some(mat) = material {
188 encrypt_chunk(plaintext, mat)?
189 } else {
190 plaintext.to_vec()
191 };
192 let chunk_len = u32::try_from(payload.len())
193 .map_err(|_| EntropyError::Protocol("chunk too large".to_string()))?;
194 stream.write_all(&chunk_len.to_be_bytes()).await?;
195 stream.write_all(&payload).await?;
196 offset = end;
197 }
198 Ok(())
199}
200
201pub fn encrypt_chunk(plaintext: &[u8], material: &EntropyMaterial) -> EntropyResult<Vec<u8>> {
208 let key = material.key().as_bytes();
209 let iv = material.iv().as_bytes();
210 let cipher = Aes128CbcEnc::new(key.into(), iv.into());
211 Ok(cipher.encrypt_padded_vec_mut::<Pkcs7>(plaintext))
212}
213
214pub(crate) fn cipher_capacity(buffer_size: usize) -> usize {
218 buffer_size + 16
219}
220
221pub struct RedisLocalSnapshot {
240 pub redis_addr: SocketAddr,
242 pub aof_path: std::path::PathBuf,
244 pub timeout: Duration,
246 pub bgrewrite_retries: u32,
250 pub bgrewrite_retry_pause: Duration,
252}
253
254impl Default for RedisLocalSnapshot {
255 fn default() -> Self {
256 Self {
257 redis_addr: "127.0.0.1:22122".parse().expect("static literal parses"),
258 aof_path: std::path::PathBuf::from("/mnt/data/nfredis/appendonly.aof"),
259 timeout: Duration::from_secs(30),
260 bgrewrite_retries: 1,
261 bgrewrite_retry_pause: Duration::from_secs(10),
262 }
263 }
264}
265
266impl RedisLocalSnapshot {
267 #[must_use]
269 pub fn with_redis_addr(mut self, addr: SocketAddr) -> Self {
270 self.redis_addr = addr;
271 self
272 }
273
274 #[must_use]
276 pub fn with_aof_path(mut self, path: std::path::PathBuf) -> Self {
277 self.aof_path = path;
278 self
279 }
280
281 fn bgrewriteaof(&self) -> EntropyResult<()> {
282 let mut last_err: Option<EntropyError> = None;
283 for attempt in 0..=self.bgrewrite_retries {
284 match self.try_bgrewriteaof() {
285 Ok(()) => return Ok(()),
286 Err(e) => {
287 last_err = Some(e);
288 if attempt < self.bgrewrite_retries {
289 std::thread::sleep(self.bgrewrite_retry_pause);
290 }
291 }
292 }
293 }
294 Err(last_err.unwrap_or_else(|| {
295 EntropyError::Source("bgrewriteaof failed without error".to_string())
296 }))
297 }
298
299 fn try_bgrewriteaof(&self) -> EntropyResult<()> {
300 let mut sock = TcpStream::connect_timeout(&self.redis_addr, self.timeout)
301 .map_err(|e| EntropyError::Source(format!("connect to redis: {e}")))?;
302 sock.set_read_timeout(Some(self.timeout))
303 .map_err(|e| EntropyError::Source(format!("redis timeout: {e}")))?;
304 sock.set_write_timeout(Some(self.timeout))
305 .map_err(|e| EntropyError::Source(format!("redis timeout: {e}")))?;
306 sock.write_all(b"*1\r\n$13\r\nBGREWRITEAOF\r\n")
307 .map_err(|e| EntropyError::Source(format!("redis write: {e}")))?;
308 let mut buf = [0u8; 256];
309 let n = sock
310 .read(&mut buf)
311 .map_err(|e| EntropyError::Source(format!("redis read: {e}")))?;
312 let reply = &buf[..n];
313 if reply.first() != Some(&b'+') {
314 return Err(EntropyError::Source(format!(
315 "BGREWRITEAOF rejected: {}",
316 String::from_utf8_lossy(reply)
317 )));
318 }
319 Ok(())
320 }
321}
322
323impl SnapshotSource for RedisLocalSnapshot {
324 fn snapshot(&self) -> EntropyResult<Vec<u8>> {
325 self.bgrewriteaof()?;
326 std::thread::sleep(Duration::from_secs(1));
328 let bytes = std::fs::read(&self.aof_path).map_err(|e| {
329 EntropyError::Source(format!("read AOF {}: {e}", self.aof_path.display()))
330 })?;
331 Ok(bytes)
332 }
333}
334
335pub struct StaticSnapshot {
348 bytes: Vec<u8>,
349}
350
351impl StaticSnapshot {
352 #[must_use]
354 pub fn new(bytes: Vec<u8>) -> Self {
355 Self { bytes }
356 }
357}
358
359impl SnapshotSource for StaticSnapshot {
360 fn snapshot(&self) -> EntropyResult<Vec<u8>> {
361 Ok(self.bytes.clone())
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::entropy::util::{EntropyIv, EntropyKey, ENTROPY_IV_LEN, ENTROPY_KEY_LEN};
369
370 fn material() -> EntropyMaterial {
371 EntropyMaterial::new(
372 EntropyKey::from_bytes([0x10; ENTROPY_KEY_LEN]),
373 EntropyIv::from_bytes([0x42; ENTROPY_IV_LEN]),
374 )
375 }
376
377 #[test]
378 fn encrypt_chunk_round_trips_with_pkcs7() {
379 use aes::Aes128;
380 use cbc::cipher::block_padding::Pkcs7;
381 use cbc::cipher::{BlockDecryptMut, KeyIvInit};
382 type Dec = cbc::Decryptor<Aes128>;
383
384 let mat = material();
385 let pt = b"hello entropy world";
386 let ct = encrypt_chunk(pt, &mat).unwrap();
387 assert!(ct.len() >= pt.len());
389 assert_eq!(ct.len() % 16, 0);
390
391 let dec = Dec::new(mat.key().as_bytes().into(), mat.iv().as_bytes().into());
392 let plain = dec.decrypt_padded_vec_mut::<Pkcs7>(&ct).unwrap();
393 assert_eq!(plain, pt);
394 }
395
396 #[test]
397 fn cipher_capacity_includes_pkcs7_block() {
398 assert_eq!(cipher_capacity(16), 32);
399 assert_eq!(cipher_capacity(15), 31);
400 }
401
402 #[test]
403 fn static_snapshot_returns_payload() {
404 let s = StaticSnapshot::new(b"abc".to_vec());
405 assert_eq!(s.snapshot().unwrap(), b"abc");
406 }
407
408 #[test]
409 fn arc_static_snapshot_returns_payload() {
410 let s: BoxedSnapshotSource = std::sync::Arc::new(StaticSnapshot::new(vec![1, 2, 3]));
411 assert_eq!(s.snapshot().unwrap(), vec![1, 2, 3]);
412 }
413}