#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::boxed::Box;
#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::vec::Vec;
#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::{
format,
vec,
};
use core::sync::atomic::{
AtomicBool,
AtomicUsize,
Ordering,
};
#[cfg(feature = "std")]
use std::boxed::Box;
use std::sync::{
Arc,
OnceLock,
RwLock,
RwLockReadGuard,
RwLockWriteGuard,
};
use std::thread;
use std::time::Duration;
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(feature = "std")]
use std::{
format,
vec,
};
use crate::{
OptimizationLevel,
keccak_p,
};
#[cfg(all(
feature = "thread-affinity",
any(target_os = "linux", target_os = "windows", target_os = "macos")
))]
fn set_thread_affinity(thread_id: usize, strategy: AffinityStrategy) {
use std::sync::OnceLock;
if matches!(strategy, AffinityStrategy::Disabled) {
return;
}
static CPU_COUNT: OnceLock<usize> = OnceLock::new();
let cpu_count = CPU_COUNT.get_or_init(|| {
core_affinity::get_core_ids()
.map(|ids| ids.len())
.unwrap_or_else(num_cpus::get)
});
if *cpu_count == 0 {
return; }
let target_cpu = match strategy {
AffinityStrategy::Disabled => return,
AffinityStrategy::Spread => {
thread_id % *cpu_count
}
AffinityStrategy::Compact => {
let active_cores = cpu_count.div_ceil(2); thread_id % active_cores
}
AffinityStrategy::Custom => {
thread_id % *cpu_count
}
};
if let Some(core_ids) = core_affinity::get_core_ids() &&
let Some(core_id) = core_ids.get(target_cpu)
{
let _ = core_affinity::set_for_current(*core_id);
}
}
#[cfg(all(
feature = "thread-affinity",
not(any(target_os = "linux", target_os = "windows", target_os = "macos"))
))]
fn set_thread_affinity(_thread_id: usize, _strategy: AffinityStrategy) {}
#[cfg(not(feature = "thread-affinity"))]
fn set_thread_affinity(_thread_id: usize, _strategy: AffinityStrategy) {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AffinityStrategy {
Disabled,
Spread,
Compact,
Custom,
}
#[derive(Debug, Clone)]
pub struct ThreadingConfig {
pub num_threads: usize,
pub min_work_size: usize,
pub max_work_per_thread: usize,
pub timeout: Duration,
pub enable_affinity: bool,
pub affinity_strategy: AffinityStrategy,
}
impl Default for ThreadingConfig {
fn default() -> Self {
Self {
num_threads: num_cpus::get(),
min_work_size: 1024, max_work_per_thread: 64 * 1024, timeout: Duration::from_secs(30),
enable_affinity: true,
affinity_strategy: AffinityStrategy::Spread,
}
}
}
impl ThreadingConfig {
pub fn security_optimized() -> Self {
Self {
num_threads: 1, min_work_size: usize::MAX, max_work_per_thread: usize::MAX,
timeout: Duration::from_secs(5),
enable_affinity: false,
affinity_strategy: AffinityStrategy::Disabled,
}
}
pub fn performance_optimized() -> Self {
Self {
num_threads: num_cpus::get(),
min_work_size: 512, max_work_per_thread: 32 * 1024, timeout: Duration::from_secs(60),
enable_affinity: true,
affinity_strategy: AffinityStrategy::Spread,
}
}
pub fn balanced() -> Self {
Self {
num_threads: num_cpus::get().div_ceil(2), min_work_size: 2048, max_work_per_thread: 128 * 1024, timeout: Duration::from_secs(30),
enable_affinity: true,
affinity_strategy: AffinityStrategy::Compact,
}
}
}
#[derive(Debug, Clone)]
pub struct WorkerStats {
pub worker_id: usize,
pub work_items_processed: usize,
}
#[derive(Debug)]
struct WorkDistribution {
total_items: usize,
current_position: AtomicUsize,
completed: AtomicBool,
completed_count: AtomicUsize,
}
impl WorkDistribution {
fn new(total_items: usize) -> Self {
Self {
total_items,
current_position: AtomicUsize::new(0),
completed: AtomicBool::new(false),
completed_count: AtomicUsize::new(0),
}
}
fn get_next_chunk(&self, chunk_size: usize) -> Option<(usize, usize)> {
let start = self
.current_position
.fetch_add(chunk_size, Ordering::AcqRel);
if start >= self.total_items {
return None;
}
let end = (start + chunk_size).min(self.total_items);
Some((start, end))
}
fn mark_completed(&self) {
self.completed.store(true, Ordering::Release);
}
fn increment_completed(&self, count: usize) {
self.completed_count.fetch_add(count, Ordering::AcqRel);
}
fn is_all_work_completed(&self) -> bool {
self.completed_count.load(Ordering::Acquire) >= self.total_items
}
#[allow(dead_code)] fn is_completed(&self) -> bool {
self.completed.load(Ordering::Acquire)
}
}
#[derive(Debug)]
struct CryptoWorker {
#[allow(dead_code)] id: usize,
work_dist: Arc<WorkDistribution>,
results: Arc<RwLock<Vec<[u64; 25]>>>,
config: ThreadingConfig,
}
fn acquire_results_write<'a>(
results: &'a RwLock<Vec<[u64; 25]>>,
) -> RwLockWriteGuard<'a, Vec<[u64; 25]>> {
results.write().unwrap_or_else(|e| e.into_inner())
}
fn acquire_results_read<'a>(
results: &'a RwLock<Vec<[u64; 25]>>,
) -> RwLockReadGuard<'a, Vec<[u64; 25]>> {
results.read().unwrap_or_else(|e| e.into_inner())
}
impl CryptoWorker {
#[allow(dead_code)] pub fn get_stats(&self) -> WorkerStats {
WorkerStats {
worker_id: self.id,
work_items_processed: self.work_dist.completed_count.load(Ordering::Acquire),
}
}
#[allow(dead_code)] pub fn get_worker_id(&self) -> usize {
self.id
}
fn new(
id: usize,
work_dist: Arc<WorkDistribution>,
results: Arc<RwLock<Vec<[u64; 25]>>>,
config: ThreadingConfig,
) -> Self {
Self {
id,
work_dist,
results,
config,
}
}
fn process_keccak_parallel(&self, states: &[[u64; 25]], level: OptimizationLevel) {
let chunk_size = self
.config
.max_work_per_thread
.min(states.len() / self.config.num_threads);
while let Some((start, end)) = self.work_dist.get_next_chunk(chunk_size) {
let mut local_results = Vec::new();
for i in start..end {
if i < states.len() {
let mut state = states[i];
self.apply_keccak_optimization(&mut state, level);
local_results.push(state);
}
}
let mut results_guard = acquire_results_write(self.results.as_ref());
let results_len = results_guard.len();
let mut valid_results = 0;
for (i, result) in local_results.iter().enumerate() {
let global_index = start + i;
if global_index < results_len && global_index < states.len() {
results_guard[global_index] = *result;
valid_results += 1;
}
}
if valid_results > 0 {
self.work_dist.increment_completed(valid_results);
}
}
}
fn apply_keccak_optimization(&self, state: &mut [u64; 25], level: OptimizationLevel) {
match level {
OptimizationLevel::Reference => {
keccak_p(state, 24);
}
OptimizationLevel::Basic | OptimizationLevel::Advanced | OptimizationLevel::Maximum => {
crate::f1600(state)
}
}
}
}
#[derive(Debug)]
pub struct CryptoThreadPool {
config: ThreadingConfig,
shutdown: Arc<AtomicBool>,
}
impl CryptoThreadPool {
pub fn new(config: ThreadingConfig) -> Self {
Self {
config,
shutdown: Arc::new(AtomicBool::new(false)),
}
}
pub fn process_keccak_states(
&self,
states: &[[u64; 25]],
level: OptimizationLevel,
) -> Result<Vec<[u64; 25]>, Box<dyn std::error::Error + Send + Sync>> {
if states.len() < self.config.min_work_size || self.config.num_threads <= 1 {
return self.process_sequential(states, level);
}
let work_dist = Arc::new(WorkDistribution::new(states.len()));
let results = Arc::new(RwLock::new(vec![[0u64; 25]; states.len()]));
let shutdown = Arc::clone(&self.shutdown);
let mut handles = Vec::new();
for thread_id in 0..self.config.num_threads {
let worker = CryptoWorker::new(
thread_id,
Arc::clone(&work_dist),
Arc::clone(&results),
self.config.clone(),
);
let states_clone = states.to_vec();
let handle = thread::spawn(move || {
if worker.config.enable_affinity {
set_thread_affinity(thread_id, worker.config.affinity_strategy);
}
worker.process_keccak_parallel(&states_clone, level);
});
handles.push(handle);
}
for handle in handles {
if let Err(e) = handle.join() {
shutdown.store(true, Ordering::Release);
return Err(format!("Thread join error: {:?}", e).into());
}
}
work_dist.mark_completed();
let max_retries = 100; let mut retries = 0;
while !work_dist.is_all_work_completed() && retries < max_retries {
thread::yield_now();
retries += 1;
if retries % 10 == 0 {
let completed = work_dist.completed_count.load(Ordering::Acquire);
if completed >= work_dist.total_items {
break;
}
}
}
if !work_dist.is_all_work_completed() {
let completed = work_dist.completed_count.load(Ordering::Acquire);
return Err(format!(
"Incomplete processing after timeout: {} of {} items completed",
completed, work_dist.total_items
)
.into());
}
let results_guard = acquire_results_read(results.as_ref());
Ok(results_guard.clone())
}
fn process_sequential(
&self,
states: &[[u64; 25]],
level: OptimizationLevel,
) -> Result<Vec<[u64; 25]>, Box<dyn std::error::Error + Send + Sync>> {
let mut results = Vec::with_capacity(states.len());
for state in states {
let mut result_state = *state;
match level {
OptimizationLevel::Reference => {
keccak_p(&mut result_state, 24);
}
OptimizationLevel::Basic => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
))]
unsafe {
crate::x86::p1600_avx2(&mut result_state);
}
#[cfg(not(all(
target_arch = "x86_64",
target_feature = "avx2",
not(cross_compile)
)))]
{
keccak_p(&mut result_state, 24);
}
}
OptimizationLevel::Advanced => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
))]
unsafe {
crate::x86::p1600_avx2(&mut result_state);
}
#[cfg(not(all(
target_arch = "x86_64",
target_feature = "avx2",
not(cross_compile)
)))]
{
keccak_p(&mut result_state, 24);
}
}
OptimizationLevel::Maximum => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx512f"
))]
unsafe {
crate::x86::p1600_avx512(&mut result_state);
}
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(target_feature = "avx512f"),
not(cross_compile)
))]
unsafe {
crate::x86::p1600_avx2(&mut result_state);
}
#[cfg(not(all(
target_arch = "x86_64",
any(target_feature = "avx2", target_feature = "avx512f")
)))]
{
keccak_p(&mut result_state, 24);
}
}
}
results.push(result_state);
}
Ok(results)
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
}
}
static GLOBAL_THREAD_POOL: OnceLock<Arc<CryptoThreadPool>> = OnceLock::new();
pub fn init_global_thread_pool(config: ThreadingConfig) {
GLOBAL_THREAD_POOL.get_or_init(|| Arc::new(CryptoThreadPool::new(config)));
}
pub fn get_global_thread_pool() -> Option<Arc<CryptoThreadPool>> {
GLOBAL_THREAD_POOL.get().cloned()
}
pub fn process_keccak_states_global(
states: &[[u64; 25]],
level: OptimizationLevel,
) -> Result<Vec<[u64; 25]>, Box<dyn std::error::Error + Send + Sync>> {
if let Some(pool) = get_global_thread_pool() {
pool.process_keccak_states(states, level)
} else {
let config = ThreadingConfig::default();
let pool = CryptoThreadPool::new(config);
pool.process_keccak_states(states, level)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "std")]
fn test_threading_config_defaults() {
let config = ThreadingConfig::default();
assert!(config.num_threads > 0);
assert!(config.min_work_size > 0);
assert!(config.max_work_per_thread > 0);
assert_eq!(config.affinity_strategy, AffinityStrategy::Spread);
assert!(config.enable_affinity);
}
#[test]
#[cfg(feature = "std")]
fn test_threading_config_security_optimized() {
let config = ThreadingConfig::security_optimized();
assert_eq!(config.num_threads, 1);
assert_eq!(config.min_work_size, usize::MAX);
assert_eq!(config.affinity_strategy, AffinityStrategy::Disabled);
assert!(!config.enable_affinity);
}
#[test]
#[cfg(feature = "std")]
fn test_threading_config_performance_optimized() {
let config = ThreadingConfig::performance_optimized();
assert!(config.num_threads > 0);
assert!(config.min_work_size < usize::MAX);
assert_eq!(config.affinity_strategy, AffinityStrategy::Spread);
assert!(config.enable_affinity);
}
#[test]
#[cfg(feature = "std")]
fn test_threading_config_balanced() {
let config = ThreadingConfig::balanced();
assert!(config.num_threads > 0);
assert!(config.min_work_size > 0);
assert_eq!(config.affinity_strategy, AffinityStrategy::Compact);
assert!(config.enable_affinity);
}
#[test]
#[cfg(feature = "std")]
fn test_work_distribution() {
let work_dist = WorkDistribution::new(100);
assert!(!work_dist.is_completed());
let chunk1 = work_dist.get_next_chunk(25);
assert_eq!(chunk1, Some((0, 25)));
let chunk2 = work_dist.get_next_chunk(25);
assert_eq!(chunk2, Some((25, 50)));
work_dist.mark_completed();
assert!(work_dist.is_completed());
}
#[test]
#[cfg(feature = "std")]
fn poisoned_results_lock_still_persists_writes() {
let results = Arc::new(RwLock::new(vec![[0u64; 25]; 2]));
let results_for_panic = Arc::clone(&results);
let panicker = thread::spawn(move || {
let _guard = results_for_panic
.write()
.expect("lock results buffer for poison test");
panic!("intentional test panic while holding write lock");
});
assert!(panicker.join().is_err());
{
let mut guard = acquire_results_write(results.as_ref());
guard[1] = [42u64; 25];
}
let guard = acquire_results_read(results.as_ref());
assert_eq!(guard[1], [42u64; 25]);
}
#[test]
#[cfg(feature = "std")]
fn test_worker_id_and_stats() {
let work_dist = Arc::new(WorkDistribution::new(10));
let results = Arc::new(RwLock::new(vec![[0u64; 25]; 10]));
let config = ThreadingConfig::default();
let worker = CryptoWorker::new(42, Arc::clone(&work_dist), Arc::clone(&results), config);
assert_eq!(worker.get_worker_id(), 42);
let stats = worker.get_stats();
assert_eq!(stats.worker_id, 42);
assert_eq!(stats.work_items_processed, 0);
}
#[test]
#[cfg(feature = "std")]
fn test_sequential_processing() {
let config = ThreadingConfig::security_optimized();
let pool = CryptoThreadPool::new(config);
let states = vec![[0u64; 25], [1u64; 25], [2u64; 25]];
let results = pool
.process_keccak_states(&states, OptimizationLevel::Reference)
.expect("Failed to process Keccak states in thread pool");
assert_eq!(results.len(), states.len());
for (original, result) in states.iter().zip(results.iter()) {
assert_ne!(original, result);
}
}
#[test]
#[cfg(feature = "std")]
fn test_global_thread_pool() {
let config = ThreadingConfig::balanced();
init_global_thread_pool(config);
let pool = get_global_thread_pool();
assert!(pool.is_some());
let states = vec![[0u64; 25]; 10];
let results = process_keccak_states_global(&states, OptimizationLevel::Reference)
.expect("Failed to process Keccak states with global thread pool");
assert_eq!(results.len(), states.len());
}
}