use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread::JoinHandle;
use crossbeam_channel::{Receiver, Sender};
use crate::config::AeadCipher as AeadVariant;
use crate::crypto::AeadCipher;
use crate::error::Result;
pub const DEFAULT_AEAD_WORKER_COUNT: usize = 10;
const PER_WORKER_QUEUE: usize = 1024;
enum Job {
Encrypt {
nonce: [u8; 12],
plaintext: Vec<u8>,
reply: Sender<Result<Vec<u8>>>,
},
Decrypt {
nonce: [u8; 12],
ciphertext: Vec<u8>,
reply: Sender<Result<Vec<u8>>>,
},
}
pub struct AeadPipeline {
txs: Vec<Sender<Job>>,
next: AtomicUsize,
handles: Vec<JoinHandle<()>>,
}
impl AeadPipeline {
pub fn new(variant: AeadVariant, key: &[u8; 32], worker_count: usize) -> Result<Self> {
let n = if worker_count == 0 {
DEFAULT_AEAD_WORKER_COUNT
} else {
worker_count
};
let mut txs = Vec::with_capacity(n);
let mut handles = Vec::with_capacity(n);
for idx in 0..n {
let (tx, rx) = crossbeam_channel::bounded::<Job>(PER_WORKER_QUEUE);
let key = *key;
let name = format!("srx-aead-{idx}");
let handle = std::thread::Builder::new()
.name(name)
.spawn(move || worker_main(variant, &key, rx))
.map_err(|e| {
crate::error::SrxError::Crypto(crate::error::CryptoError::EncryptionFailed(
e.to_string(),
))
})?;
txs.push(tx);
handles.push(handle);
}
Ok(Self {
txs,
next: AtomicUsize::new(0),
handles,
})
}
#[must_use]
pub fn worker_count(&self) -> usize {
self.txs.len()
}
pub fn encrypt(&self, nonce: [u8; 12], plaintext: Vec<u8>) -> Result<Vec<u8>> {
let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
let job = Job::Encrypt {
nonce,
plaintext,
reply: reply_tx,
};
self.dispatch(job)?;
reply_rx.recv().map_err(|_| {
crate::error::SrxError::Crypto(crate::error::CryptoError::EncryptionFailed(
"AEAD worker disconnected".into(),
))
})?
}
pub fn decrypt(&self, nonce: [u8; 12], ciphertext: Vec<u8>) -> Result<Vec<u8>> {
let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
let job = Job::Decrypt {
nonce,
ciphertext,
reply: reply_tx,
};
self.dispatch(job)?;
reply_rx.recv().map_err(|_| {
crate::error::SrxError::Crypto(crate::error::CryptoError::DecryptionFailed(
"AEAD worker disconnected".into(),
))
})?
}
fn dispatch(&self, job: Job) -> Result<()> {
let i = self.next.fetch_add(1, Ordering::Relaxed) % self.txs.len();
self.txs[i].send(job).map_err(|_| {
crate::error::SrxError::Crypto(crate::error::CryptoError::EncryptionFailed(
"AEAD worker queue closed".into(),
))
})
}
}
impl Drop for AeadPipeline {
fn drop(&mut self) {
self.txs.clear();
for h in std::mem::take(&mut self.handles) {
let _ = h.join();
}
}
}
fn worker_main(variant: AeadVariant, key: &[u8; 32], rx: Receiver<Job>) {
let cipher = match AeadCipher::new(variant, key.as_slice()) {
Ok(c) => c,
Err(_) => return,
};
while let Ok(job) = rx.recv() {
match job {
Job::Encrypt {
nonce,
plaintext,
reply,
} => {
let _ = reply.send(cipher.encrypt(&nonce, &plaintext));
}
Job::Decrypt {
nonce,
ciphertext,
reply,
} => {
let _ = reply.send(cipher.decrypt(&nonce, &ciphertext));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AeadCipher as Variant;
#[test]
fn pipeline_encrypt_decrypt_roundtrip() {
let key = [0x5Au8; 32];
let pipe = AeadPipeline::new(Variant::ChaCha20Poly1305, &key, 4).unwrap();
let nonce = [3u8; 12];
let plain = b"parallel crypto pipeline".to_vec();
let ct = pipe.encrypt(nonce, plain.clone()).unwrap();
let pt = pipe.decrypt(nonce, ct).unwrap();
assert_eq!(pt, plain);
}
#[test]
fn default_worker_count_uses_ten_threads() {
let key = [0u8; 32];
let pipe = AeadPipeline::new(Variant::ChaCha20Poly1305, &key, 0).unwrap();
assert_eq!(pipe.worker_count(), DEFAULT_AEAD_WORKER_COUNT);
}
}