#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(feature = "std")]
use std::sync::mpsc;
#[cfg(feature = "std")]
use std::thread;
use lib_q_core::Result;
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub thread_count: usize,
pub min_parallel_size: usize,
pub max_threads: usize,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
thread_count: 0, min_parallel_size: 64 * 1024, max_threads: 16, }
}
}
impl ParallelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_thread_count(mut self, count: usize) -> Self {
self.thread_count = count;
self
}
pub fn with_min_parallel_size(mut self, size: usize) -> Self {
self.min_parallel_size = size;
self
}
pub fn with_max_threads(mut self, max: usize) -> Self {
self.max_threads = max;
self
}
pub fn effective_thread_count(&self) -> usize {
if self.thread_count > 0 {
self.thread_count.min(self.max_threads)
} else {
#[cfg(feature = "std")]
{
let cpu_count = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
cpu_count.min(self.max_threads)
}
#[cfg(not(feature = "std"))]
{
1 }
}
}
}
pub struct ParallelSaturninCore {
base_core: crate::core::SaturninCore,
config: ParallelConfig,
}
impl ParallelSaturninCore {
pub fn new(num_rounds: usize, domain: u8) -> Result<Self> {
let base_core = crate::core::SaturninCore::new(num_rounds, domain)?;
let config = ParallelConfig::new();
Ok(Self { base_core, config })
}
pub fn with_config(num_rounds: usize, domain: u8, config: ParallelConfig) -> Result<Self> {
let base_core = crate::core::SaturninCore::new(num_rounds, domain)?;
Ok(Self { base_core, config })
}
pub fn encrypt_block(&self, key: &[u8], block: &mut [u8]) -> Result<()> {
self.base_core.encrypt_block(key, block)
}
pub fn decrypt_block(&self, key: &[u8], block: &mut [u8]) -> Result<()> {
self.base_core.decrypt_block(key, block)
}
pub fn encrypt_blocks_parallel(&self, key: &[u8], data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(32) {
return Err(lib_q_core::Error::InvalidMessageSize {
max: data.len() - (data.len() % 32),
actual: data.len(),
});
}
if data.len() < self.config.min_parallel_size {
return self.encrypt_blocks_sequential(key, data);
}
#[cfg(feature = "std")]
{
self.encrypt_blocks_parallel_std(key, data)
}
#[cfg(not(feature = "std"))]
{
self.encrypt_blocks_sequential(key, data)
}
}
pub fn decrypt_blocks_parallel(&self, key: &[u8], data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(32) {
return Err(lib_q_core::Error::InvalidMessageSize {
max: data.len() - (data.len() % 32),
actual: data.len(),
});
}
if data.len() < self.config.min_parallel_size {
return self.decrypt_blocks_sequential(key, data);
}
#[cfg(feature = "std")]
{
self.decrypt_blocks_parallel_std(key, data)
}
#[cfg(not(feature = "std"))]
{
self.decrypt_blocks_sequential(key, data)
}
}
fn encrypt_blocks_sequential(&self, key: &[u8], data: &mut [u8]) -> Result<()> {
for chunk in data.chunks_mut(32) {
self.base_core.encrypt_block(key, chunk)?;
}
Ok(())
}
fn decrypt_blocks_sequential(&self, key: &[u8], data: &mut [u8]) -> Result<()> {
for chunk in data.chunks_mut(32) {
self.base_core.decrypt_block(key, chunk)?;
}
Ok(())
}
#[cfg(feature = "std")]
fn encrypt_blocks_parallel_std(&self, key: &[u8], data: &mut [u8]) -> Result<()> {
let thread_count = self.config.effective_thread_count();
if thread_count <= 1 {
return self.encrypt_blocks_sequential(key, data);
}
let block_count = data.len() / 32;
let blocks_per_thread = block_count.div_ceil(thread_count);
let core = Arc::new(self.base_core.clone());
let key = Arc::new(key.to_vec());
let (tx, rx) = mpsc::channel();
for thread_id in 0..thread_count {
let start_block = thread_id * blocks_per_thread;
let end_block = ((thread_id + 1) * blocks_per_thread).min(block_count);
if start_block >= end_block {
break;
}
let core: Arc<crate::core::SaturninCore> = Arc::clone(&core);
let key: Arc<Vec<u8>> = Arc::clone(&key);
let tx = tx.clone();
let thread_data = &data[start_block * 32..end_block * 32];
let mut thread_data_copy = thread_data.to_vec();
thread::spawn(move || {
for chunk in thread_data_copy.chunks_mut(32) {
if let Err(e) = core.encrypt_block(&key, chunk) {
let _ = tx.send(Err(e));
return;
}
}
let _ = tx.send(Ok((thread_id, thread_data_copy)));
});
}
let mut results = Vec::new();
for _ in 0..thread_count {
match rx.recv() {
Ok(result) => results.push(result),
Err(_) => {
return Err(lib_q_core::Error::InvalidAlgorithm {
algorithm: "Thread communication failed",
});
}
}
}
for result in results {
match result {
Ok((thread_id, encrypted_data)) => {
let start_block = thread_id * blocks_per_thread;
let end_block = ((thread_id + 1) * blocks_per_thread).min(block_count);
let start_byte = start_block * 32;
let end_byte = end_block * 32;
data[start_byte..end_byte].copy_from_slice(&encrypted_data);
}
Err(e) => return Err(e),
}
}
Ok(())
}
#[cfg(feature = "std")]
fn decrypt_blocks_parallel_std(&self, key: &[u8], data: &mut [u8]) -> Result<()> {
let thread_count = self.config.effective_thread_count();
if thread_count <= 1 {
return self.decrypt_blocks_sequential(key, data);
}
let block_count = data.len() / 32;
let blocks_per_thread = block_count.div_ceil(thread_count);
let core = Arc::new(self.base_core.clone());
let key = Arc::new(key.to_vec());
let (tx, rx) = mpsc::channel();
for thread_id in 0..thread_count {
let start_block = thread_id * blocks_per_thread;
let end_block = ((thread_id + 1) * blocks_per_thread).min(block_count);
if start_block >= end_block {
break;
}
let core: Arc<crate::core::SaturninCore> = Arc::clone(&core);
let key: Arc<Vec<u8>> = Arc::clone(&key);
let tx = tx.clone();
let thread_data = &data[start_block * 32..end_block * 32];
let mut thread_data_copy = thread_data.to_vec();
thread::spawn(move || {
for chunk in thread_data_copy.chunks_mut(32) {
if let Err(e) = core.decrypt_block(&key, chunk) {
let _ = tx.send(Err(e));
return;
}
}
let _ = tx.send(Ok((thread_id, thread_data_copy)));
});
}
let mut results = Vec::new();
for _ in 0..thread_count {
match rx.recv() {
Ok(result) => results.push(result),
Err(_) => {
return Err(lib_q_core::Error::InvalidAlgorithm {
algorithm: "Thread communication failed",
});
}
}
}
for result in results {
match result {
Ok((thread_id, decrypted_data)) => {
let start_block = thread_id * blocks_per_thread;
let end_block = ((thread_id + 1) * blocks_per_thread).min(block_count);
let start_byte = start_block * 32;
let end_byte = end_block * 32;
data[start_byte..end_byte].copy_from_slice(&decrypted_data);
}
Err(e) => return Err(e),
}
}
Ok(())
}
pub fn config(&self) -> &ParallelConfig {
&self.config
}
pub fn base_core(&self) -> &crate::core::SaturninCore {
&self.base_core
}
}
pub struct ParallelSaturninHash {
base_hash: crate::hash::SaturninHash,
config: ParallelConfig,
}
impl Default for ParallelSaturninHash {
fn default() -> Self {
Self::new()
}
}
impl ParallelSaturninHash {
pub fn new() -> Self {
let base_hash = crate::hash::SaturninHash::new();
let config = ParallelConfig::new();
Self { base_hash, config }
}
pub fn with_config(config: ParallelConfig) -> Self {
let base_hash = crate::hash::SaturninHash::new();
Self { base_hash, config }
}
pub fn hash_parallel(&self, data: &[u8]) -> Result<Vec<u8>> {
self.base_hash.hash(data)
}
pub fn config(&self) -> &ParallelConfig {
&self.config
}
pub fn base_hash(&self) -> &crate::hash::SaturninHash {
&self.base_hash
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "alloc")]
use alloc::vec;
use super::*;
#[test]
fn test_parallel_config_creation() {
let config = ParallelConfig::new();
assert_eq!(config.thread_count, 0);
assert_eq!(config.min_parallel_size, 64 * 1024);
assert_eq!(config.max_threads, 16);
}
#[test]
fn test_parallel_config_customization() {
let config = ParallelConfig::new()
.with_thread_count(4)
.with_min_parallel_size(32 * 1024)
.with_max_threads(8);
assert_eq!(config.thread_count, 4);
assert_eq!(config.min_parallel_size, 32 * 1024);
assert_eq!(config.max_threads, 8);
}
#[test]
fn test_parallel_core_creation() {
let core = ParallelSaturninCore::new(16, 7).unwrap();
assert_eq!(core.base_core().num_rounds(), 16);
assert_eq!(core.base_core().domain(), 7);
}
#[test]
fn test_parallel_core_single_block() -> Result<()> {
let core = ParallelSaturninCore::new(16, 7)?;
let key = [0u8; 32];
let mut block = [0u8; 32];
core.encrypt_block(&key, &mut block)?;
core.decrypt_block(&key, &mut block)?;
assert_eq!(block, [0u8; 32]);
Ok(())
}
#[test]
fn test_parallel_core_multiple_blocks() -> Result<()> {
let core = ParallelSaturninCore::new(16, 7)?;
let key = [0x12u8; 32];
let mut data = vec![0x34u8; 128];
core.encrypt_blocks_parallel(&key, &mut data)?;
core.decrypt_blocks_parallel(&key, &mut data)?;
assert_eq!(data, vec![0x34u8; 128]);
Ok(())
}
#[test]
fn test_parallel_core_vs_sequential_equivalence() -> Result<()> {
let parallel_core = ParallelSaturninCore::new(16, 7)?;
let base_core = parallel_core.base_core();
let key = [0x12u8; 32];
let mut data1 = vec![0x34u8; 128]; let mut data2 = data1.clone();
parallel_core.encrypt_blocks_parallel(&key, &mut data1)?;
for chunk in data2.chunks_mut(32) {
base_core.encrypt_block(&key, chunk)?;
}
assert_eq!(data1, data2);
Ok(())
}
#[test]
fn test_parallel_hash_creation() {
let hash = ParallelSaturninHash::new();
assert_eq!(hash.base_hash().output_size(), 32);
}
#[test]
fn test_parallel_hash_operation() -> Result<()> {
let hash = ParallelSaturninHash::new();
let data = b"Hello, World!";
let result = hash.hash_parallel(data)?;
assert_eq!(result.len(), 32);
Ok(())
}
}