stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! Parallel AEAD pipeline: multiple OS threads, each with its own cipher state.
//!
//! Work is sharded with round-robin scheduling so hot paths do not share a single lock.
//! Queues are bounded for backpressure; each job carries a one-shot reply channel.

use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread::JoinHandle;

use crossbeam_channel::{Receiver, Sender};

use crate::config::AeadCipher as AeadVariant;
use crate::crypto::AeadCipher;
use crate::error::Result;

/// Default worker count when `CryptoConfig::aead_worker_count` is `0` (auto).
pub const DEFAULT_AEAD_WORKER_COUNT: usize = 10;

/// Depth of each worker’s inbound queue (frames waiting for crypto).
const PER_WORKER_QUEUE: usize = 1024;

enum Job {
    Encrypt {
        nonce: [u8; 12],
        plaintext: Vec<u8>,
        reply: Sender<Result<Vec<u8>>>,
    },
    Decrypt {
        nonce: [u8; 12],
        ciphertext: Vec<u8>,
        reply: Sender<Result<Vec<u8>>>,
    },
}

/// Multi-threaded AEAD encrypt/decrypt pool (one cipher instance per thread, same key material).
pub struct AeadPipeline {
    txs: Vec<Sender<Job>>,
    next: AtomicUsize,
    handles: Vec<JoinHandle<()>>,
}

impl AeadPipeline {
    /// Build a pipeline with `worker_count` dedicated threads. `0` selects [`DEFAULT_AEAD_WORKER_COUNT`].
    pub fn new(variant: AeadVariant, key: &[u8; 32], worker_count: usize) -> Result<Self> {
        let n = if worker_count == 0 {
            DEFAULT_AEAD_WORKER_COUNT
        } else {
            worker_count
        };

        let mut txs = Vec::with_capacity(n);
        let mut handles = Vec::with_capacity(n);

        for idx in 0..n {
            let (tx, rx) = crossbeam_channel::bounded::<Job>(PER_WORKER_QUEUE);
            let key = *key;
            let name = format!("srx-aead-{idx}");
            let handle = std::thread::Builder::new()
                .name(name)
                .spawn(move || worker_main(variant, &key, rx))
                .map_err(|e| {
                    crate::error::SrxError::Crypto(crate::error::CryptoError::EncryptionFailed(
                        e.to_string(),
                    ))
                })?;
            txs.push(tx);
            handles.push(handle);
        }

        Ok(Self {
            txs,
            next: AtomicUsize::new(0),
            handles,
        })
    }

    /// Number of parallel AEAD worker threads.
    #[must_use]
    pub fn worker_count(&self) -> usize {
        self.txs.len()
    }

    /// Encrypt using the next worker (round-robin).
    pub fn encrypt(&self, nonce: [u8; 12], plaintext: Vec<u8>) -> Result<Vec<u8>> {
        let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
        let job = Job::Encrypt {
            nonce,
            plaintext,
            reply: reply_tx,
        };
        self.dispatch(job)?;
        reply_rx.recv().map_err(|_| {
            crate::error::SrxError::Crypto(crate::error::CryptoError::EncryptionFailed(
                "AEAD worker disconnected".into(),
            ))
        })?
    }

    /// Decrypt using the next worker (round-robin).
    pub fn decrypt(&self, nonce: [u8; 12], ciphertext: Vec<u8>) -> Result<Vec<u8>> {
        let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
        let job = Job::Decrypt {
            nonce,
            ciphertext,
            reply: reply_tx,
        };
        self.dispatch(job)?;
        reply_rx.recv().map_err(|_| {
            crate::error::SrxError::Crypto(crate::error::CryptoError::DecryptionFailed(
                "AEAD worker disconnected".into(),
            ))
        })?
    }

    fn dispatch(&self, job: Job) -> Result<()> {
        let i = self.next.fetch_add(1, Ordering::Relaxed) % self.txs.len();
        self.txs[i].send(job).map_err(|_| {
            crate::error::SrxError::Crypto(crate::error::CryptoError::EncryptionFailed(
                "AEAD worker queue closed".into(),
            ))
        })
    }
}

impl Drop for AeadPipeline {
    fn drop(&mut self) {
        self.txs.clear();
        for h in std::mem::take(&mut self.handles) {
            let _ = h.join();
        }
    }
}

fn worker_main(variant: AeadVariant, key: &[u8; 32], rx: Receiver<Job>) {
    let cipher = match AeadCipher::new(variant, key.as_slice()) {
        Ok(c) => c,
        Err(_) => return,
    };

    while let Ok(job) = rx.recv() {
        match job {
            Job::Encrypt {
                nonce,
                plaintext,
                reply,
            } => {
                let _ = reply.send(cipher.encrypt(&nonce, &plaintext));
            }
            Job::Decrypt {
                nonce,
                ciphertext,
                reply,
            } => {
                let _ = reply.send(cipher.decrypt(&nonce, &ciphertext));
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::config::AeadCipher as Variant;

    #[test]
    fn pipeline_encrypt_decrypt_roundtrip() {
        let key = [0x5Au8; 32];
        let pipe = AeadPipeline::new(Variant::ChaCha20Poly1305, &key, 4).unwrap();
        let nonce = [3u8; 12];
        let plain = b"parallel crypto pipeline".to_vec();
        let ct = pipe.encrypt(nonce, plain.clone()).unwrap();
        let pt = pipe.decrypt(nonce, ct).unwrap();
        assert_eq!(pt, plain);
    }

    #[test]
    fn default_worker_count_uses_ten_threads() {
        let key = [0u8; 32];
        let pipe = AeadPipeline::new(Variant::ChaCha20Poly1305, &key, 0).unwrap();
        // Internal detail: 10 senders
        assert_eq!(pipe.worker_count(), DEFAULT_AEAD_WORKER_COUNT);
    }
}