use std::io::Write;
use std::net::{SocketAddr, TcpStream};
use std::time::Duration;
use aes::Aes128;
use cbc::cipher::block_padding::Pkcs7;
use cbc::cipher::{BlockDecryptMut, KeyIvInit};
use tokio::io::AsyncReadExt;
use tokio::net::{TcpListener, TcpStream as TokioTcpStream};
use tokio::task::JoinHandle;
use crate::entropy::util::EntropyMaterial;
use crate::entropy::{
BoxedSnapshotSink, EntropyConfig, EntropyError, EntropyResult, NegotiationHeader,
SnapshotHeader, SnapshotSink, MAX_CIPHER_SIZE, MAX_SNAPSHOT_SIZE, SAFE_PREALLOC,
};
type Aes128CbcDec = cbc::Decryptor<Aes128>;
pub struct EntropyReceiver {
listener: TcpListener,
cfg: EntropyConfig,
sink: BoxedSnapshotSink,
}
impl EntropyReceiver {
pub async fn run(
cfg: EntropyConfig,
sink: BoxedSnapshotSink,
) -> EntropyResult<JoinHandle<EntropyResult<()>>> {
let recv = Self::bind(cfg, sink).await?;
Ok(tokio::spawn(async move { recv.accept_loop().await }))
}
pub async fn bind(cfg: EntropyConfig, sink: BoxedSnapshotSink) -> EntropyResult<Self> {
cfg.validate()?;
if cfg.encrypt {
let _ = crate::entropy::util::load_material(&cfg.key_file, &cfg.iv_file)?;
}
let listener = TcpListener::bind(cfg.listen_addr).await?;
Ok(Self {
listener,
cfg,
sink,
})
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.local_addr()
}
pub async fn accept_one(self) -> EntropyResult<usize> {
let (sock, _peer) = self.listener.accept().await?;
handle_one(sock, &self.cfg, self.sink.as_ref()).await
}
pub async fn accept_loop(self) -> EntropyResult<()> {
loop {
let (sock, peer) = self.listener.accept().await?;
tracing::debug!(?peer, "entropy receiver accepted connection");
if let Err(e) = handle_one(sock, &self.cfg, self.sink.as_ref()).await {
tracing::warn!(?e, "entropy worker errored");
}
}
}
}
async fn handle_one(
mut stream: TokioTcpStream,
cfg: &EntropyConfig,
sink: &dyn SnapshotSink,
) -> EntropyResult<usize> {
stream.set_nodelay(true)?;
let neg = read_negotiation(&mut stream).await?;
let snap = read_snapshot_header(&mut stream, &neg).await?;
let material = if snap.encrypt_flag == 1 {
Some(crate::entropy::util::load_material(
&cfg.key_file,
&cfg.iv_file,
)?)
} else {
None
};
let plaintext = read_chunks(&mut stream, &neg, &snap, material.as_ref()).await?;
sink.apply(&plaintext).map_err(|e| match e {
EntropyError::Sink(msg) => EntropyError::Sink(msg),
other => EntropyError::Sink(other.to_string()),
})?;
Ok(plaintext.len())
}
async fn read_negotiation(stream: &mut TokioTcpStream) -> EntropyResult<NegotiationHeader> {
let mut wire = [0u8; NegotiationHeader::SIZE];
stream.read_exact(&mut wire).await?;
NegotiationHeader::from_wire(&wire)
}
async fn read_snapshot_header(
stream: &mut TokioTcpStream,
neg: &NegotiationHeader,
) -> EntropyResult<SnapshotHeader> {
let mut buf = vec![0u8; neg.header_size as usize];
stream.read_exact(&mut buf).await?;
SnapshotHeader::from_wire(&buf)
}
async fn read_chunks(
stream: &mut TokioTcpStream,
neg: &NegotiationHeader,
snap: &SnapshotHeader,
material: Option<&EntropyMaterial>,
) -> EntropyResult<Vec<u8>> {
let total_len = snap.total_len as usize;
if total_len > MAX_SNAPSHOT_SIZE {
return Err(EntropyError::Protocol(format!(
"snapshot total_len {total_len} exceeds MAX_SNAPSHOT_SIZE"
)));
}
let mut plaintext = Vec::with_capacity(total_len.min(SAFE_PREALLOC));
let mut len_buf = [0u8; 4];
while plaintext.len() < total_len {
stream.read_exact(&mut len_buf).await?;
let chunk_len = u32::from_be_bytes(len_buf) as usize;
if chunk_len == 0 {
return Err(EntropyError::Protocol("zero-length chunk".to_string()));
}
if chunk_len > MAX_CIPHER_SIZE || chunk_len > neg.cipher_size as usize {
return Err(EntropyError::Protocol(format!(
"chunk_len {chunk_len} exceeds negotiated cipher_size {}",
neg.cipher_size
)));
}
let mut payload = vec![0u8; chunk_len];
stream.read_exact(&mut payload).await?;
let chunk_plain = if let Some(mat) = material {
decrypt_chunk(&payload, mat)?
} else {
payload
};
let take = (total_len - plaintext.len()).min(chunk_plain.len());
if take < chunk_plain.len() {
return Err(EntropyError::Protocol(format!(
"chunk overshoots total_len: have {} more after {} bytes",
chunk_plain.len() - take,
plaintext.len() + take
)));
}
plaintext.extend_from_slice(&chunk_plain[..take]);
}
let mut probe = [0u8; 1];
match stream.read(&mut probe).await {
Ok(0) | Err(_) => {}
Ok(_) => {
return Err(EntropyError::Protocol(
"trailing bytes after declared total_len".to_string(),
));
}
}
Ok(plaintext)
}
pub fn decrypt_chunk(ciphertext: &[u8], material: &EntropyMaterial) -> EntropyResult<Vec<u8>> {
if ciphertext.is_empty() || !ciphertext.len().is_multiple_of(16) {
return Err(EntropyError::Crypto(format!(
"ciphertext length {} is not a positive multiple of 16",
ciphertext.len()
)));
}
let key = material.key().as_bytes();
let iv = material.iv().as_bytes();
let cipher = Aes128CbcDec::new(key.into(), iv.into());
cipher
.decrypt_padded_vec_mut::<Pkcs7>(ciphertext)
.map_err(|e| EntropyError::Crypto(format!("PKCS#7 unpad failed: {e}")))
}
pub struct RedisReplaySink {
pub redis_addr: SocketAddr,
pub timeout: Duration,
}
impl Default for RedisReplaySink {
fn default() -> Self {
Self {
redis_addr: "127.0.0.1:22122".parse().expect("static literal parses"),
timeout: Duration::from_secs(30),
}
}
}
impl RedisReplaySink {
#[must_use]
pub fn with_redis_addr(mut self, addr: SocketAddr) -> Self {
self.redis_addr = addr;
self
}
}
impl SnapshotSink for RedisReplaySink {
fn apply(&self, snapshot: &[u8]) -> EntropyResult<()> {
let mut sock = TcpStream::connect_timeout(&self.redis_addr, self.timeout)
.map_err(|e| EntropyError::Sink(format!("connect to redis: {e}")))?;
sock.set_write_timeout(Some(self.timeout))
.map_err(|e| EntropyError::Sink(format!("redis timeout: {e}")))?;
sock.write_all(snapshot)
.map_err(|e| EntropyError::Sink(format!("redis write: {e}")))?;
Ok(())
}
}
#[derive(Default)]
pub struct MemorySink {
inner: parking_lot::Mutex<Vec<u8>>,
}
impl MemorySink {
#[must_use]
pub fn take(&self) -> Vec<u8> {
std::mem::take(&mut *self.inner.lock())
}
#[must_use]
pub fn snapshot(&self) -> Vec<u8> {
self.inner.lock().clone()
}
}
impl SnapshotSink for MemorySink {
fn apply(&self, snapshot: &[u8]) -> EntropyResult<()> {
let mut buf = self.inner.lock();
buf.clear();
buf.extend_from_slice(snapshot);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::entropy::send::encrypt_chunk;
use crate::entropy::util::{EntropyIv, EntropyKey, ENTROPY_IV_LEN, ENTROPY_KEY_LEN};
fn material() -> EntropyMaterial {
EntropyMaterial::new(
EntropyKey::from_bytes([0x10; ENTROPY_KEY_LEN]),
EntropyIv::from_bytes([0x42; ENTROPY_IV_LEN]),
)
}
#[test]
fn decrypt_chunk_rejects_short_buffer() {
let mat = material();
let err = decrypt_chunk(&[], &mat).unwrap_err();
assert!(matches!(err, EntropyError::Crypto(_)));
}
#[test]
fn decrypt_chunk_rejects_misaligned() {
let mat = material();
let err = decrypt_chunk(&[0u8; 17], &mat).unwrap_err();
assert!(matches!(err, EntropyError::Crypto(_)));
}
#[test]
fn encrypt_decrypt_round_trip() {
let mat = material();
let pt = b"the quick brown fox jumps over the lazy dog";
let ct = encrypt_chunk(pt, &mat).unwrap();
let plain = decrypt_chunk(&ct, &mat).unwrap();
assert_eq!(plain, pt);
}
#[test]
fn decrypt_chunk_rejects_tamper() {
let mat = material();
let pt = b"the quick brown fox";
let mut ct = encrypt_chunk(pt, &mat).unwrap();
let last = ct.last_mut().unwrap();
*last ^= 0xff;
let err = decrypt_chunk(&ct, &mat).unwrap_err();
assert!(matches!(err, EntropyError::Crypto(_)));
}
#[test]
fn memory_sink_round_trips() {
let sink = MemorySink::default();
sink.apply(b"abc").unwrap();
assert_eq!(sink.snapshot(), b"abc");
let drained = sink.take();
assert_eq!(drained, b"abc");
assert!(sink.snapshot().is_empty());
}
}