stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! Worker pool for processing incoming traffic across transport types.
//!
//! Uses the same [`crate::crypto::AeadPipeline`] model as the crypto layer: many
//! parallel workers so encryption/decryption never blocks the I/O reactor.

use std::collections::VecDeque;
use std::sync::Arc;

use bytes::Bytes;
use tokio::sync::{Mutex, Notify, broadcast, mpsc};

use crate::crypto::AeadPipeline;
use crate::pipeline::SrxPipeline;
use crate::transport::TransportKind;

struct WorkerPoolInner {
    kind: TransportKind,
    pool_size: usize,
    crypto: Option<Arc<AeadPipeline>>,
    shutdown_tx: broadcast::Sender<()>,
    queue: Mutex<VecDeque<Bytes>>,
    notify: Notify,
}

/// A pool of async workers handling traffic for a specific transport.
#[derive(Clone)]
pub struct WorkerPool {
    inner: Arc<WorkerPoolInner>,
}

impl WorkerPool {
    pub fn new(kind: TransportKind, pool_size: usize) -> Self {
        let (shutdown_tx, _) = broadcast::channel(16);
        Self {
            inner: Arc::new(WorkerPoolInner {
                kind,
                pool_size,
                crypto: None,
                shutdown_tx,
                queue: Mutex::new(VecDeque::new()),
                notify: Notify::new(),
            }),
        }
    }

    /// Attach a pre-built AEAD pipeline (e.g. from [`SrxConfig::crypto`](crate::config::SrxConfig)).
    pub fn with_crypto(kind: TransportKind, pool_size: usize, crypto: Arc<AeadPipeline>) -> Self {
        let (shutdown_tx, _) = broadcast::channel(16);
        Self {
            inner: Arc::new(WorkerPoolInner {
                kind,
                pool_size,
                crypto: Some(crypto),
                shutdown_tx,
                queue: Mutex::new(VecDeque::new()),
                notify: Notify::new(),
            }),
        }
    }

    /// Optional AEAD pipeline attached at construction.
    #[must_use]
    pub fn crypto(&self) -> Option<Arc<AeadPipeline>> {
        self.inner.crypto.clone()
    }

    /// Push a raw frame chunk for workers to process (decrypt/dispatch hot path).
    pub async fn enqueue(&self, data: Bytes) {
        self.inner.queue.lock().await.push_back(data);
        self.inner.notify.notify_waiters();
    }

    /// Signal all workers to exit; unblocks [`WorkerPool::run`] after join.
    pub fn shutdown(&self) {
        let _ = self.inner.shutdown_tx.send(());
    }

    /// Subscribe to the shutdown broadcast (for external listeners).
    pub fn shutdown_rx(&self) -> broadcast::Receiver<()> {
        self.inner.shutdown_tx.subscribe()
    }

    /// Spawns `pool_size` workers that dequeue from the internal queue until [`WorkerPool::shutdown`].
    pub async fn run(&self) -> crate::error::Result<()> {
        let mut handles = Vec::new();
        for _ in 0..self.inner.pool_size {
            let inner = self.inner.clone();
            let mut shutdown_rx = inner.shutdown_tx.subscribe();
            handles.push(tokio::spawn(async move {
                loop {
                    while let Some(bytes) = inner.queue.lock().await.pop_front() {
                        tracing::trace!(
                            ?inner.kind,
                            len = bytes.len(),
                            has_crypto = inner.crypto.is_some(),
                            "dequeued frame"
                        );
                    }
                    tokio::select! {
                        _ = shutdown_rx.recv() => break,
                        _ = inner.notify.notified() => {}
                    }
                }
                tracing::trace!(?inner.kind, "worker stopped");
            }));
        }
        let mut main_rx = self.inner.shutdown_tx.subscribe();
        let _ = main_rx.recv().await;
        for h in handles {
            let _ = h.await;
        }
        Ok(())
    }

    /// Spawns `pool_size` workers that dequeue frames, process them through
    /// the [`SrxPipeline`] recv path, and forward decrypted payloads to `output_tx`.
    ///
    /// Each worker shares the pipeline (read-only for `process_incoming`).
    pub async fn run_with_pipeline(
        &self,
        pipeline: Arc<SrxPipeline>,
        output_tx: mpsc::Sender<Vec<u8>>,
    ) -> crate::error::Result<()> {
        let mut handles = Vec::new();
        for _ in 0..self.inner.pool_size {
            let inner = self.inner.clone();
            let pipe = pipeline.clone();
            let tx = output_tx.clone();
            let mut shutdown_rx = inner.shutdown_tx.subscribe();

            handles.push(tokio::spawn(async move {
                loop {
                    while let Some(bytes) = inner.queue.lock().await.pop_front() {
                        match pipe.process_incoming(&bytes) {
                            Ok(payload) => {
                                tracing::debug!(
                                    ?inner.kind,
                                    payload_len = payload.len(),
                                    "frame decrypted"
                                );
                                if tx.send(payload).await.is_err() {
                                    return; // output channel closed
                                }
                            }
                            Err(e) => {
                                tracing::warn!(
                                    ?inner.kind,
                                    error = %e,
                                    "frame processing failed"
                                );
                            }
                        }
                    }
                    tokio::select! {
                        _ = shutdown_rx.recv() => break,
                        _ = inner.notify.notified() => {}
                    }
                }
                tracing::trace!(?inner.kind, "pipeline worker stopped");
            }));
        }
        let mut main_rx = self.inner.shutdown_tx.subscribe();
        let _ = main_rx.recv().await;
        for h in handles {
            let _ = h.await;
        }
        Ok(())
    }
}

#[cfg(test)]
impl WorkerPool {
    async fn test_queue_len(&self) -> usize {
        self.inner.queue.lock().await.len()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;

    #[tokio::test]
    async fn shutdown_unblocks_run() {
        let pool = WorkerPool::new(TransportKind::Tcp, 2);
        let pool2 = pool.clone();
        let jh = tokio::spawn(async move { pool2.run().await });
        tokio::time::sleep(Duration::from_millis(30)).await;
        pool.shutdown();
        jh.await.unwrap().unwrap();
    }

    #[tokio::test]
    async fn enqueue_drained_before_shutdown() {
        let pool = WorkerPool::new(TransportKind::Tcp, 1);
        pool.enqueue(Bytes::from_static(b"a")).await;
        pool.enqueue(Bytes::from_static(b"b")).await;
        let pool2 = pool.clone();
        let jh = tokio::spawn(async move { pool2.run().await });
        tokio::time::sleep(Duration::from_millis(50)).await;
        assert_eq!(pool.test_queue_len().await, 0);
        pool.shutdown();
        jh.await.unwrap().unwrap();
    }

    #[tokio::test]
    async fn run_with_pipeline_decrypts_frames() {
        use crate::config::{AeadCipher as Variant, MimicryMode};
        use crate::masking::padding::PaddingStrategy;
        use crate::seed::SeedRng;
        use crate::session::Session;
        use crate::transport::TransportManager;

        let seed = [0xAAu8; 32];
        let key = [0xBBu8; 32];
        let aead =
            Arc::new(crate::crypto::AeadPipeline::new(Variant::ChaCha20Poly1305, &key, 2).unwrap());

        // Build sender pipeline to prepare a frame.
        let mut sender = SrxPipeline::new(
            Session::new(1, seed, key),
            aead.clone(),
            PaddingStrategy::new(SeedRng::new(seed), 32),
            MimicryMode::None,
            None,
            TransportManager::new(),
        );
        let envelope = sender.prepare_outgoing(b"worker-test").unwrap();

        // Build receiver pipeline (shared, read-only process_incoming).
        let receiver = Arc::new(SrxPipeline::new(
            Session::new(2, seed, key),
            aead,
            PaddingStrategy::new(SeedRng::new(seed), 32),
            MimicryMode::None,
            None,
            TransportManager::new(),
        ));

        let pool = WorkerPool::new(TransportKind::Tcp, 1);
        let (tx, mut rx) = mpsc::channel(16);

        let pool2 = pool.clone();
        let jh = tokio::spawn(async move { pool2.run_with_pipeline(receiver, tx).await });

        pool.enqueue(Bytes::from(envelope)).await;
        let payload = tokio::time::timeout(Duration::from_secs(2), rx.recv())
            .await
            .expect("timeout")
            .expect("channel closed");
        assert_eq!(payload, b"worker-test");

        pool.shutdown();
        jh.await.unwrap().unwrap();
    }
}