use std::collections::VecDeque;
use std::sync::Arc;
use bytes::Bytes;
use tokio::sync::{Mutex, Notify, broadcast, mpsc};
use crate::crypto::AeadPipeline;
use crate::pipeline::SrxPipeline;
use crate::transport::TransportKind;
struct WorkerPoolInner {
kind: TransportKind,
pool_size: usize,
crypto: Option<Arc<AeadPipeline>>,
shutdown_tx: broadcast::Sender<()>,
queue: Mutex<VecDeque<Bytes>>,
notify: Notify,
}
#[derive(Clone)]
pub struct WorkerPool {
inner: Arc<WorkerPoolInner>,
}
impl WorkerPool {
pub fn new(kind: TransportKind, pool_size: usize) -> Self {
let (shutdown_tx, _) = broadcast::channel(16);
Self {
inner: Arc::new(WorkerPoolInner {
kind,
pool_size,
crypto: None,
shutdown_tx,
queue: Mutex::new(VecDeque::new()),
notify: Notify::new(),
}),
}
}
pub fn with_crypto(kind: TransportKind, pool_size: usize, crypto: Arc<AeadPipeline>) -> Self {
let (shutdown_tx, _) = broadcast::channel(16);
Self {
inner: Arc::new(WorkerPoolInner {
kind,
pool_size,
crypto: Some(crypto),
shutdown_tx,
queue: Mutex::new(VecDeque::new()),
notify: Notify::new(),
}),
}
}
#[must_use]
pub fn crypto(&self) -> Option<Arc<AeadPipeline>> {
self.inner.crypto.clone()
}
pub async fn enqueue(&self, data: Bytes) {
self.inner.queue.lock().await.push_back(data);
self.inner.notify.notify_waiters();
}
pub fn shutdown(&self) {
let _ = self.inner.shutdown_tx.send(());
}
pub fn shutdown_rx(&self) -> broadcast::Receiver<()> {
self.inner.shutdown_tx.subscribe()
}
pub async fn run(&self) -> crate::error::Result<()> {
let mut handles = Vec::new();
for _ in 0..self.inner.pool_size {
let inner = self.inner.clone();
let mut shutdown_rx = inner.shutdown_tx.subscribe();
handles.push(tokio::spawn(async move {
loop {
while let Some(bytes) = inner.queue.lock().await.pop_front() {
tracing::trace!(
?inner.kind,
len = bytes.len(),
has_crypto = inner.crypto.is_some(),
"dequeued frame"
);
}
tokio::select! {
_ = shutdown_rx.recv() => break,
_ = inner.notify.notified() => {}
}
}
tracing::trace!(?inner.kind, "worker stopped");
}));
}
let mut main_rx = self.inner.shutdown_tx.subscribe();
let _ = main_rx.recv().await;
for h in handles {
let _ = h.await;
}
Ok(())
}
pub async fn run_with_pipeline(
&self,
pipeline: Arc<SrxPipeline>,
output_tx: mpsc::Sender<Vec<u8>>,
) -> crate::error::Result<()> {
let mut handles = Vec::new();
for _ in 0..self.inner.pool_size {
let inner = self.inner.clone();
let pipe = pipeline.clone();
let tx = output_tx.clone();
let mut shutdown_rx = inner.shutdown_tx.subscribe();
handles.push(tokio::spawn(async move {
loop {
while let Some(bytes) = inner.queue.lock().await.pop_front() {
match pipe.process_incoming(&bytes) {
Ok(payload) => {
tracing::debug!(
?inner.kind,
payload_len = payload.len(),
"frame decrypted"
);
if tx.send(payload).await.is_err() {
return; }
}
Err(e) => {
tracing::warn!(
?inner.kind,
error = %e,
"frame processing failed"
);
}
}
}
tokio::select! {
_ = shutdown_rx.recv() => break,
_ = inner.notify.notified() => {}
}
}
tracing::trace!(?inner.kind, "pipeline worker stopped");
}));
}
let mut main_rx = self.inner.shutdown_tx.subscribe();
let _ = main_rx.recv().await;
for h in handles {
let _ = h.await;
}
Ok(())
}
}
#[cfg(test)]
impl WorkerPool {
async fn test_queue_len(&self) -> usize {
self.inner.queue.lock().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn shutdown_unblocks_run() {
let pool = WorkerPool::new(TransportKind::Tcp, 2);
let pool2 = pool.clone();
let jh = tokio::spawn(async move { pool2.run().await });
tokio::time::sleep(Duration::from_millis(30)).await;
pool.shutdown();
jh.await.unwrap().unwrap();
}
#[tokio::test]
async fn enqueue_drained_before_shutdown() {
let pool = WorkerPool::new(TransportKind::Tcp, 1);
pool.enqueue(Bytes::from_static(b"a")).await;
pool.enqueue(Bytes::from_static(b"b")).await;
let pool2 = pool.clone();
let jh = tokio::spawn(async move { pool2.run().await });
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(pool.test_queue_len().await, 0);
pool.shutdown();
jh.await.unwrap().unwrap();
}
#[tokio::test]
async fn run_with_pipeline_decrypts_frames() {
use crate::config::{AeadCipher as Variant, MimicryMode};
use crate::masking::padding::PaddingStrategy;
use crate::seed::SeedRng;
use crate::session::Session;
use crate::transport::TransportManager;
let seed = [0xAAu8; 32];
let key = [0xBBu8; 32];
let aead =
Arc::new(crate::crypto::AeadPipeline::new(Variant::ChaCha20Poly1305, &key, 2).unwrap());
let mut sender = SrxPipeline::new(
Session::new(1, seed, key),
aead.clone(),
PaddingStrategy::new(SeedRng::new(seed), 32),
MimicryMode::None,
None,
TransportManager::new(),
);
let envelope = sender.prepare_outgoing(b"worker-test").unwrap();
let receiver = Arc::new(SrxPipeline::new(
Session::new(2, seed, key),
aead,
PaddingStrategy::new(SeedRng::new(seed), 32),
MimicryMode::None,
None,
TransportManager::new(),
));
let pool = WorkerPool::new(TransportKind::Tcp, 1);
let (tx, mut rx) = mpsc::channel(16);
let pool2 = pool.clone();
let jh = tokio::spawn(async move { pool2.run_with_pipeline(receiver, tx).await });
pool.enqueue(Bytes::from(envelope)).await;
let payload = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(payload, b"worker-test");
pool.shutdown();
jh.await.unwrap().unwrap();
}
}