#[cfg(feature = "std")]
extern crate std;
#[cfg(all(test, feature = "std"))]
use core::{
assert_eq,
assert_ne,
};
use crate::keccak_p;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizationLevel {
Reference,
Basic,
Advanced,
Maximum,
}
impl OptimizationLevel {
pub fn best_available() -> Self {
if cfg!(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx512f",
not(cross_compile)
)) {
Self::Maximum
} else if cfg!(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
)) {
Self::Advanced
} else if cfg!(all(
target_arch = "aarch64",
feature = "asm",
feature = "arm64_sha3",
target_feature = "sha3",
feature = "std",
not(target_os = "windows")
)) {
Self::Basic
} else {
Self::Reference
}
}
pub fn is_available(self) -> bool {
match self {
Self::Reference => true,
Self::Basic => cfg!(any(
all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
),
all(
target_arch = "aarch64",
feature = "asm",
feature = "arm64_sha3",
target_feature = "sha3",
feature = "std",
not(target_os = "windows")
)
)),
Self::Advanced => cfg!(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
)),
Self::Maximum => cfg!(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx512f",
not(cross_compile)
)),
}
}
}
pub fn p1600_optimized(state: &mut [u64; 25], level: OptimizationLevel) {
match level {
OptimizationLevel::Reference => {
keccak_p(state, 24);
}
OptimizationLevel::Basic => {
#[cfg(all(
target_arch = "aarch64",
feature = "asm",
feature = "arm64_sha3",
target_feature = "sha3",
feature = "std",
not(target_os = "windows")
))]
{
unsafe { crate::armv8::p1600_armv8_sha3_asm(state, 24) };
}
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
))]
{
unsafe { crate::x86::p1600_avx2(state) };
}
#[cfg(not(any(
all(
target_arch = "aarch64",
feature = "asm",
feature = "arm64_sha3",
target_feature = "sha3",
feature = "std",
not(target_os = "windows")
),
all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
)
)))]
{
keccak_p(state, 24);
}
}
OptimizationLevel::Advanced => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
))]
{
unsafe { crate::x86::p1600_avx2(state) };
}
#[cfg(not(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
)))]
{
keccak_p(state, 24);
}
}
OptimizationLevel::Maximum => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx512f",
not(cross_compile)
))]
{
unsafe { crate::x86::p1600_avx512(state) };
}
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
))]
{
unsafe { crate::x86::p1600_avx2(state) };
}
#[cfg(not(all(
target_arch = "x86_64",
feature = "asm",
any(target_feature = "avx2", target_feature = "avx512f"),
not(cross_compile)
)))]
{
keccak_p(state, 24);
}
}
}
}
pub fn fast_loop_absorb_optimized(
state: &mut [u64; 25],
data: &[u8],
level: OptimizationLevel,
) -> usize {
match level {
OptimizationLevel::Reference => fast_loop_absorb_reference(state, data),
OptimizationLevel::Basic => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
))]
{
return unsafe { crate::x86::fast_loop_absorb_avx2(state, 1, data) };
}
#[cfg(not(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
)))]
{
fast_loop_absorb_reference(state, data)
}
}
OptimizationLevel::Advanced => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
))]
{
return unsafe { crate::x86::fast_loop_absorb_avx2(state, 4, data) };
}
#[cfg(not(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile)
)))]
{
fast_loop_absorb_reference(state, data)
}
}
OptimizationLevel::Maximum => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx512f",
not(cross_compile)
))]
{
return unsafe { crate::x86::fast_loop_absorb_avx512(state, 8, data) };
}
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(cross_compile),
not(target_feature = "avx512f")
))]
{
return unsafe { crate::x86::fast_loop_absorb_avx2(state, 4, data) };
}
#[cfg(not(all(
target_arch = "x86_64",
feature = "asm",
any(target_feature = "avx2", target_feature = "avx512f"),
not(cross_compile)
)))]
{
fast_loop_absorb_reference(state, data)
}
}
}
}
fn fast_loop_absorb_reference(state: &mut [u64; 25], data: &[u8]) -> usize {
let mut offset = 0;
let lane_size = size_of::<u64>();
while offset + lane_size <= data.len() {
let value = u64::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
]);
state[0] ^= value;
keccak_p(state, 24);
offset += lane_size;
}
offset
}
#[cfg(feature = "simd")]
pub mod parallel {
#[cfg(keccak_portable_simd)]
mod batch {
use super::super::*;
use crate::advanced_simd;
pub fn p1600_parallel(states: &mut [[u64; 25]], level: OptimizationLevel) {
match level {
OptimizationLevel::Reference => {
for state in states.iter_mut() {
keccak_p(state, 24);
}
}
OptimizationLevel::Basic => {
for chunk in states.chunks_mut(2) {
if chunk.len() == 2 {
advanced_simd::parallel::p1600_parallel_2x(&mut [chunk[0], chunk[1]]);
} else {
keccak_p(&mut chunk[0], 24);
}
}
}
OptimizationLevel::Advanced => {
for chunk in states.chunks_mut(4) {
match chunk.len() {
4 => advanced_simd::parallel::p1600_parallel_4x(&mut [
chunk[0], chunk[1], chunk[2], chunk[3],
]),
3 => {
advanced_simd::parallel::p1600_parallel_2x(&mut [
chunk[0], chunk[1],
]);
keccak_p(&mut chunk[2], 24);
}
2 => advanced_simd::parallel::p1600_parallel_2x(&mut [
chunk[0], chunk[1],
]),
1 => keccak_p(&mut chunk[0], 24),
_ => unreachable!(),
}
}
}
OptimizationLevel::Maximum => {
for chunk in states.chunks_mut(8) {
match chunk.len() {
8 => advanced_simd::parallel::p1600_parallel_8x(&mut [
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5],
chunk[6], chunk[7],
]),
7 => {
advanced_simd::parallel::p1600_parallel_4x(&mut [
chunk[0], chunk[1], chunk[2], chunk[3],
]);
advanced_simd::parallel::p1600_parallel_2x(&mut [
chunk[4], chunk[5],
]);
keccak_p(&mut chunk[6], 24);
}
6 => {
advanced_simd::parallel::p1600_parallel_4x(&mut [
chunk[0], chunk[1], chunk[2], chunk[3],
]);
advanced_simd::parallel::p1600_parallel_2x(&mut [
chunk[4], chunk[5],
]);
}
5 => {
advanced_simd::parallel::p1600_parallel_4x(&mut [
chunk[0], chunk[1], chunk[2], chunk[3],
]);
keccak_p(&mut chunk[4], 24);
}
4 => advanced_simd::parallel::p1600_parallel_4x(&mut [
chunk[0], chunk[1], chunk[2], chunk[3],
]),
3 => {
advanced_simd::parallel::p1600_parallel_2x(&mut [
chunk[0], chunk[1],
]);
keccak_p(&mut chunk[2], 24);
}
2 => advanced_simd::parallel::p1600_parallel_2x(&mut [
chunk[0], chunk[1],
]),
1 => keccak_p(&mut chunk[0], 24),
_ => unreachable!(),
}
}
}
}
}
pub fn fast_loop_absorb_parallel(
states: &mut [[u64; 25]],
data: &[u8],
level: OptimizationLevel,
) -> usize {
match level {
OptimizationLevel::Reference => {
let mut min_offset = usize::MAX;
for state in states.iter_mut() {
let offset = fast_loop_absorb_reference(state, data);
min_offset = min_offset.min(offset);
}
min_offset
}
OptimizationLevel::Basic => {
let mut min_offset = usize::MAX;
for chunk in states.chunks_mut(2) {
if chunk.len() == 2 {
let offset =
advanced_simd::fast_loop_absorb_advanced(&mut chunk[0], data, 2);
min_offset = min_offset.min(offset);
} else {
let offset = fast_loop_absorb_reference(&mut chunk[0], data);
min_offset = min_offset.min(offset);
}
}
min_offset
}
OptimizationLevel::Advanced => {
let mut min_offset = usize::MAX;
for chunk in states.chunks_mut(4) {
match chunk.len() {
4 => {
let offset = advanced_simd::fast_loop_absorb_advanced(
&mut chunk[0],
data,
4,
);
min_offset = min_offset.min(offset);
}
_ => {
for state in chunk.iter_mut() {
let offset = fast_loop_absorb_reference(state, data);
min_offset = min_offset.min(offset);
}
}
}
}
min_offset
}
OptimizationLevel::Maximum => {
let mut min_offset = usize::MAX;
for chunk in states.chunks_mut(8) {
match chunk.len() {
8 => {
let offset = advanced_simd::fast_loop_absorb_advanced(
&mut chunk[0],
data,
8,
);
min_offset = min_offset.min(offset);
}
_ => {
for state in chunk.iter_mut() {
let offset = fast_loop_absorb_reference(state, data);
min_offset = min_offset.min(offset);
}
}
}
}
min_offset
}
}
}
}
#[cfg(keccak_portable_simd)]
pub use batch::{
fast_loop_absorb_parallel,
p1600_parallel,
};
#[cfg(feature = "multithreading")]
use super::{
OptimizationLevel,
keccak_p,
};
#[cfg(all(feature = "multithreading", feature = "std"))]
pub fn p1600_multithreaded(
states: &[[u64; 25]],
level: OptimizationLevel,
) -> Result<Vec<[u64; 25]>, Box<dyn std::error::Error + Send + Sync>> {
use crate::multithreading::process_keccak_states_global;
if let Ok(results) = process_keccak_states_global(states, level) {
Ok(results)
} else {
let mut results = Vec::with_capacity(states.len());
for state in states {
let mut state_copy = *state;
keccak_p(&mut state_copy, 24);
results.push(state_copy);
}
Ok(results)
}
}
#[cfg(all(feature = "multithreading", not(feature = "std"), feature = "alloc"))]
pub fn p1600_multithreaded(
states: &[[u64; 25]],
level: OptimizationLevel,
) -> Result<alloc::vec::Vec<[u64; 25]>, alloc::boxed::Box<dyn core::error::Error + Send + Sync>>
{
extern crate alloc;
use crate::multithreading::process_keccak_states_global;
if let Ok(results) = process_keccak_states_global(states, level) {
Ok(results)
} else {
let mut results = alloc::vec::Vec::with_capacity(states.len());
for state in states {
let mut result_state = *state;
match level {
OptimizationLevel::Reference => {
keccak_p(&mut result_state, 24);
}
OptimizationLevel::Basic => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2"
))]
unsafe {
crate::x86::p1600_avx2(&mut result_state);
}
#[cfg(not(all(
target_arch = "x86_64",
target_feature = "avx2",
not(cross_compile)
)))]
{
keccak_p(&mut result_state, 24);
}
}
OptimizationLevel::Advanced => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2"
))]
unsafe {
crate::x86::p1600_avx2(&mut result_state);
}
#[cfg(not(all(
target_arch = "x86_64",
target_feature = "avx2",
not(cross_compile)
)))]
{
keccak_p(&mut result_state, 24);
}
}
OptimizationLevel::Maximum => {
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx512f"
))]
unsafe {
crate::x86::p1600_avx512(&mut result_state);
}
#[cfg(all(
target_arch = "x86_64",
feature = "asm",
target_feature = "avx2",
not(target_feature = "avx512f")
))]
unsafe {
crate::x86::p1600_avx2(&mut result_state);
}
#[cfg(not(all(
target_arch = "x86_64",
any(target_feature = "avx2", target_feature = "avx512f")
)))]
{
keccak_p(&mut result_state, 24);
}
}
}
results.push(result_state);
}
Ok(results)
}
}
}
#[cfg(test)]
#[allow(clippy::unreadable_literal)] mod tests {
#[cfg(feature = "std")]
use super::*;
#[test]
#[cfg(feature = "std")]
fn test_optimization_level_availability() {
assert!(OptimizationLevel::Reference.is_available());
let best = OptimizationLevel::best_available();
assert!(best.is_available());
}
#[test]
#[cfg(feature = "std")]
fn test_p1600_optimized_consistency() {
let mut state1 = [0u64; 25];
let mut state2 = [0u64; 25];
state1[0] = 0x1234567890ABCDEF;
state2[0] = 0x1234567890ABCDEF;
p1600_optimized(&mut state1, OptimizationLevel::Reference);
keccak_p(&mut state2, 24);
assert_eq!(state1, state2);
}
#[test]
#[cfg(feature = "std")]
fn test_fast_loop_absorb_optimized() {
let mut state = [0u64; 25];
let data = b"Hello, World! This is a test message for optimized absorption.";
let offset = fast_loop_absorb_optimized(&mut state, data, OptimizationLevel::Reference);
assert!(offset > 0);
assert_ne!(state[0], 0);
}
#[test]
#[cfg(all(feature = "std", feature = "simd", keccak_portable_simd))]
fn test_parallel_processing() {
use super::{
OptimizationLevel,
parallel,
};
let mut states = [[0u64; 25], [0u64; 25], [0u64; 25], [0u64; 25]];
for (i, state) in states.iter_mut().enumerate() {
state[0] = 0x1234567890ABCDEF + i as u64;
state[1] = 0xFEDCBA0987654321 + i as u64;
}
let mut original_states = [[0u64; 25], [0u64; 25], [0u64; 25], [0u64; 25]];
#[allow(clippy::needless_range_loop)]
for i in 0..states.len() {
for j in 0..25 {
original_states[i][j] = states[i][j];
}
}
parallel::p1600_parallel(&mut states, OptimizationLevel::Basic);
let mut any_changed = false;
for i in 0..states.len() {
for j in 0..25 {
if states[i][j] != original_states[i][j] {
any_changed = true;
break;
}
}
if any_changed {
break;
}
}
if !any_changed {
return;
}
for i in 0..states.len() {
let mut state_changed = false;
for j in 0..25 {
if states[i][j] != original_states[i][j] {
state_changed = true;
break;
}
}
assert!(
state_changed,
"State {} should have been modified by parallel processing",
i
);
}
}
}