use crate::bufcache::{BufCache, BufCacheCons};
use crate::generator::{
GeneratorChaCha12, GeneratorChaCha20, GeneratorChaCha8, GeneratorCrc, NextRandom,
};
use crate::kdf::kdf;
use anyhow as ah;
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, AtomicIsize, Ordering};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub enum DtStreamType {
ChaCha8,
ChaCha12,
#[default]
ChaCha20,
Crc,
}
pub struct DtStreamChunk {
pub data: Option<Vec<u8>>,
#[cfg(test)]
pub index: u8,
}
fn try_lower_thread_priority() {
#[cfg(any(target_os = "linux", target_os = "android"))]
unsafe {
libc::nice(7);
}
}
#[allow(clippy::too_many_arguments)]
fn thread_worker(
stype: DtStreamType,
chunk_factor: usize,
seed: Vec<u8>,
thread_id: u32,
round_id: u64,
mut cache_cons: BufCacheCons,
byte_offset: u64,
invert_pattern: bool,
abort: Arc<AtomicBool>,
error: Arc<AtomicBool>,
level: Arc<AtomicIsize>,
sleep: Arc<(Mutex<bool>, Condvar)>,
tx: Sender<DtStreamChunk>,
) {
let thread_seed = kdf(&seed, thread_id, round_id);
drop(seed);
let mut generator: Box<dyn NextRandom> = match stype {
DtStreamType::ChaCha8 => Box::new(GeneratorChaCha8::new(&thread_seed)),
DtStreamType::ChaCha12 => Box::new(GeneratorChaCha12::new(&thread_seed)),
DtStreamType::ChaCha20 => Box::new(GeneratorChaCha20::new(&thread_seed)),
DtStreamType::Crc => Box::new(GeneratorCrc::new(&thread_seed)),
};
if let Err(e) = generator.seek(byte_offset) {
eprintln!("ERROR in generator thread {}: {}", thread_id, e);
error.store(true, Ordering::Relaxed);
return;
}
let chunk_size = generator.get_base_size() * chunk_factor;
try_lower_thread_priority();
#[cfg(test)]
let mut index = 0;
let mut cur_level = level.load(Ordering::Relaxed);
while !abort.load(Ordering::Acquire) {
if cur_level < DtStream::MAX_THRES {
let mut data = cache_cons.pull(chunk_size);
generator.next(&mut data, chunk_factor);
debug_assert_eq!(data.len(), chunk_size);
if invert_pattern {
for x in &mut data {
*x ^= 0xFFu8;
}
}
let chunk = DtStreamChunk {
data: Some(data),
#[cfg(test)]
index,
};
#[cfg(test)]
{
index = index.wrapping_add(1);
}
tx.send(chunk).expect("Worker thread: Send failed.");
cur_level = level.fetch_add(1, Ordering::Relaxed) + 1;
} else {
let mut sleeping = sleep.0.lock().expect("Thread Condvar lock poison");
*sleeping = true;
while *sleeping {
sleeping = sleep.1.wait(sleeping).expect("Thread Condvar wait poison");
}
cur_level = level.load(Ordering::Relaxed);
}
}
}
pub struct DtStream {
stype: DtStreamType,
seed: Vec<u8>,
invert_pattern: bool,
thread_id: u32,
round_id: u64,
rx: Option<Receiver<DtStreamChunk>>,
cache: Rc<RefCell<BufCache>>,
is_active: bool,
thread_join: Option<thread::JoinHandle<()>>,
abort: Arc<AtomicBool>,
error: Arc<AtomicBool>,
level: Arc<AtomicIsize>,
sleep: Arc<(Mutex<bool>, Condvar)>,
}
impl DtStream {
const MAX_THRES: isize = 10;
const LO_THRES: isize = 6;
pub fn new(
stype: DtStreamType,
seed: Vec<u8>,
invert_pattern: bool,
thread_id: u32,
round_id: u64,
cache: Rc<RefCell<BufCache>>,
) -> DtStream {
let abort = Arc::new(AtomicBool::new(false));
let error = Arc::new(AtomicBool::new(false));
let level = Arc::new(AtomicIsize::new(0));
let sleep = Arc::new((Mutex::new(false), Condvar::new()));
DtStream {
stype,
seed,
invert_pattern,
thread_id,
round_id,
rx: None,
cache,
is_active: false,
thread_join: None,
abort,
error,
level,
sleep,
}
}
fn wake_thread(&self) {
let mut sleeping = self.sleep.0.lock().expect("Wake Condvar lock poison");
if *sleeping {
*sleeping = false;
self.sleep.1.notify_one();
}
}
fn stop(&mut self) {
self.is_active = false;
self.abort.store(true, Ordering::Release);
self.wake_thread();
if let Some(thread_join) = self.thread_join.take() {
thread_join.join().expect("Thread join failed");
}
self.abort.store(false, Ordering::Release);
}
fn start(&mut self, byte_offset: u64, chunk_factor: usize) {
assert!(!self.is_active);
assert!(self.thread_join.is_none());
self.abort.store(false, Ordering::Release);
self.error.store(false, Ordering::Release);
self.level.store(0, Ordering::Release);
let (tx, rx) = channel();
self.rx = Some(rx);
let thread_stype = self.stype;
let thread_chunk_factor = chunk_factor;
let thread_seed = self.seed.to_vec();
let thread_id = self.thread_id;
let thread_round_id = self.round_id;
let thread_cache_cons = self.cache.borrow_mut().new_consumer(self.thread_id);
let thread_byte_offset = byte_offset;
let thread_invert_pattern = self.invert_pattern;
let thread_abort = Arc::clone(&self.abort);
let thread_error = Arc::clone(&self.error);
let thread_level = Arc::clone(&self.level);
let thread_sleep = Arc::clone(&self.sleep);
self.thread_join = Some(thread::spawn(move || {
thread_worker(
thread_stype,
thread_chunk_factor,
thread_seed,
thread_id,
thread_round_id,
thread_cache_cons,
thread_byte_offset,
thread_invert_pattern,
thread_abort,
thread_error,
thread_level,
thread_sleep,
tx,
);
}));
self.is_active = true;
}
#[inline]
fn is_thread_error(&self) -> bool {
self.error.load(Ordering::Relaxed)
}
pub fn activate(&mut self, byte_offset: u64, chunk_factor: usize) -> ah::Result<()> {
self.stop();
self.start(byte_offset, chunk_factor);
Ok(())
}
#[inline]
pub fn is_active(&self) -> bool {
self.is_active
}
pub fn get_chunk_size(&self) -> usize {
match self.stype {
DtStreamType::ChaCha8 => GeneratorChaCha8::BASE_SIZE,
DtStreamType::ChaCha12 => GeneratorChaCha12::BASE_SIZE,
DtStreamType::ChaCha20 => GeneratorChaCha20::BASE_SIZE,
DtStreamType::Crc => GeneratorCrc::BASE_SIZE,
}
}
pub fn get_default_chunk_factor(&self) -> usize {
match self.stype {
DtStreamType::ChaCha8 => GeneratorChaCha8::DEFAULT_CHUNK_FACTOR,
DtStreamType::ChaCha12 => GeneratorChaCha12::DEFAULT_CHUNK_FACTOR,
DtStreamType::ChaCha20 => GeneratorChaCha20::DEFAULT_CHUNK_FACTOR,
DtStreamType::Crc => GeneratorCrc::DEFAULT_CHUNK_FACTOR,
}
}
#[inline]
pub fn get_chunk(&mut self) -> ah::Result<Option<DtStreamChunk>> {
if !self.is_active() {
return Err(ah::format_err!("Generator stream is not active."));
}
if self.is_thread_error() {
return Err(ah::format_err!(
"Generator stream thread aborted with an error."
));
}
let Some(rx) = &self.rx else {
return Err(ah::format_err!("Generator stream RX channel not present."));
};
let Ok(chunk) = rx.try_recv() else {
self.wake_thread();
return Ok(None);
};
if self.level.fetch_sub(1, Ordering::Relaxed) - 1 <= DtStream::LO_THRES {
self.wake_thread();
}
Ok(Some(chunk))
}
}
impl Drop for DtStream {
fn drop(&mut self) {
self.stop();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::disktest::DisktestQuiet;
use std::time::Duration;
impl DtStream {
pub fn wait_chunk(&mut self) -> DtStreamChunk {
loop {
if let Some(chunk) = self.get_chunk().unwrap() {
break chunk;
}
thread::sleep(Duration::from_millis(1));
}
}
}
fn run_base_test(algorithm: DtStreamType) {
println!("stream base test");
let cache = Rc::new(RefCell::new(BufCache::new(DisktestQuiet::Normal)));
let mut s = DtStream::new(algorithm, vec![1, 2, 3], false, 0, 0, cache);
s.activate(0, s.get_default_chunk_factor()).unwrap();
assert!(s.is_active());
assert!(s.get_chunk_size() > 0);
assert!(s.get_default_chunk_factor() > 0);
let mut results_first = vec![];
for count in 0..5 {
let chunk = s.wait_chunk();
println!(
"{}: index={} data[0]={} (current level = {})",
count,
chunk.index,
chunk.data.as_ref().unwrap()[0],
s.level.load(Ordering::Relaxed)
);
results_first.push(chunk.data.as_ref().unwrap()[0]);
assert_eq!(chunk.index, count);
}
match algorithm {
DtStreamType::ChaCha8 => {
assert_eq!(results_first, vec![66, 209, 254, 224, 203]);
}
DtStreamType::ChaCha12 => {
assert_eq!(results_first, vec![200, 202, 12, 60, 234]);
}
DtStreamType::ChaCha20 => {
assert_eq!(results_first, vec![206, 236, 87, 55, 170]);
}
DtStreamType::Crc => {
assert_eq!(results_first, vec![108, 99, 114, 196, 213]);
}
}
}
fn run_offset_test(algorithm: DtStreamType) {
println!("stream offset test");
let cache = Rc::new(RefCell::new(BufCache::new(DisktestQuiet::Normal)));
let mut a = DtStream::new(algorithm, vec![1, 2, 3], false, 0, 0, cache);
a.activate(0, a.get_default_chunk_factor()).unwrap();
let cache = Rc::new(RefCell::new(BufCache::new(DisktestQuiet::Normal)));
let mut b = DtStream::new(algorithm, vec![1, 2, 3], false, 0, 0, cache);
b.activate(
a.get_chunk_size() as u64 * a.get_default_chunk_factor() as u64,
a.get_default_chunk_factor(),
)
.unwrap();
let achunk = a.wait_chunk();
let bchunk = b.wait_chunk();
assert!(achunk.data.as_ref().unwrap() != bchunk.data.as_ref().unwrap());
let achunk = a.wait_chunk();
assert!(achunk.data.as_ref().unwrap() == bchunk.data.as_ref().unwrap());
}
fn run_invert_test(algorithm: DtStreamType) {
println!("stream invert test");
let cache = Rc::new(RefCell::new(BufCache::new(DisktestQuiet::Normal)));
let mut a = DtStream::new(algorithm, vec![1, 2, 3], false, 0, 0, cache);
a.activate(0, a.get_default_chunk_factor()).unwrap();
let cache = Rc::new(RefCell::new(BufCache::new(DisktestQuiet::Normal)));
let mut b = DtStream::new(algorithm, vec![1, 2, 3], true, 0, 0, cache);
b.activate(0, a.get_default_chunk_factor()).unwrap();
let achunk = a.wait_chunk();
let bchunk = b.wait_chunk();
let inv_bchunk: Vec<u8> = bchunk
.data
.as_ref()
.unwrap()
.iter()
.map(|x| x ^ 0xFF)
.collect();
assert!(achunk.data.as_ref().unwrap() != bchunk.data.as_ref().unwrap());
assert!(achunk.data.as_ref().unwrap() == &inv_bchunk);
}
#[test]
fn test_chacha8() {
let alg = DtStreamType::ChaCha8;
run_base_test(alg);
run_offset_test(alg);
run_invert_test(alg);
}
#[test]
fn test_chacha12() {
let alg = DtStreamType::ChaCha12;
run_base_test(alg);
run_offset_test(alg);
run_invert_test(alg);
}
#[test]
fn test_chacha20() {
let alg = DtStreamType::ChaCha20;
run_base_test(alg);
run_offset_test(alg);
run_invert_test(alg);
}
#[test]
fn test_crc() {
let alg = DtStreamType::Crc;
run_base_test(alg);
run_offset_test(alg);
run_invert_test(alg);
}
}