#![cfg_attr(test, allow(dead_code))]
#[cfg(feature = "parallel")]
use super::dispatch_tables::ParallelTable;
#[cfg(any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "s390x",
target_arch = "powerpc64",
target_arch = "riscv64"
))]
use super::kernels::required_caps;
use super::{
dispatch_tables::{DispatchTable, StreamingTable},
kernels::{Blake3KernelId, Kernel, kernel},
};
#[cfg(target_arch = "x86_64")]
use crate::platform::caps::x86;
use crate::{backend::cache::OnceCache, hashes::crypto::dispatch_util::SizeClassDispatch, platform::Caps};
#[cfg(target_arch = "x86_64")]
#[inline]
#[must_use]
fn allow_avx2_hash_many_one_chunk_fast_path(caps: Caps) -> bool {
caps.has(x86::AVX512_READY)
&& !caps.has(x86::AMX_TILE)
&& !caps.has(x86::AMX_INT8)
&& !caps.has(x86::AMX_BF16)
&& !caps.has(x86::AMX_FP16)
&& !caps.has(x86::AMX_COMPLEX)
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[must_use]
fn is_wide_pipeline_for_hash_many(caps: Caps) -> bool {
if caps.has(x86::AMD) {
caps.has(x86::AMD_ZEN5)
} else {
true
}
}
#[derive(Clone, Copy)]
struct Entry {
kernel: Kernel,
}
#[derive(Clone, Copy)]
struct ActiveDispatch {
boundaries: [usize; 3],
xs: Entry,
s: Entry,
m: Entry,
l: Entry,
}
#[derive(Clone, Copy)]
struct ResolvedDispatch {
active: ActiveDispatch,
#[cfg(feature = "parallel")]
parallel: ParallelDispatch,
hasher: HasherDispatch,
#[cfg(target_arch = "x86_64")]
avx2_hash_many_one_chunk_fast_path: bool,
#[cfg(target_arch = "x86_64")]
hash_many_wide_pipeline: bool,
#[cfg(target_arch = "x86_64")]
avx2_available: bool,
}
static RESOLVED: OnceCache<ResolvedDispatch> = OnceCache::new();
#[derive(Clone, Copy)]
#[cfg(feature = "parallel")]
pub(crate) struct ParallelDispatch {
pub(crate) oneshot: ParallelTable,
pub(crate) keyed_oneshot: ParallelTable,
pub(crate) derive_oneshot: ParallelTable,
pub(crate) xof: ParallelTable,
pub(crate) keyed_xof: ParallelTable,
pub(crate) derive_xof: ParallelTable,
pub(crate) streaming: ParallelTable,
pub(crate) keyed_streaming: ParallelTable,
pub(crate) derive_streaming: ParallelTable,
}
#[derive(Clone, Copy)]
pub(crate) struct HasherDispatch {
size_classes: SizeClassDispatch<Kernel>,
stream_kernel: Kernel,
table_bulk_kernel: Kernel,
bulk_sizeclass_threshold: usize,
}
impl HasherDispatch {
#[inline]
#[must_use]
pub(crate) fn stream_kernel(&self) -> Kernel {
self.stream_kernel
}
#[inline]
#[must_use]
pub(crate) fn bulk_kernel_for_update(&self, input_len: usize) -> Kernel {
if input_len >= self.bulk_sizeclass_threshold {
self.size_classes.select(input_len)
} else {
self.table_bulk_kernel
}
}
#[inline]
#[must_use]
pub(crate) fn size_class_kernel(&self, len: usize) -> Kernel {
self.size_classes.select(len)
}
}
#[inline]
#[must_use]
fn resolve(id: Blake3KernelId, caps: Caps) -> Blake3KernelId {
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
let _ = caps;
match id {
Blake3KernelId::Portable => Blake3KernelId::Portable,
#[cfg(target_arch = "x86_64")]
Blake3KernelId::X86Avx512 => {
if caps.has(required_caps(Blake3KernelId::X86Avx512)) {
Blake3KernelId::X86Avx512
} else if caps.has(required_caps(Blake3KernelId::X86Avx2)) {
Blake3KernelId::X86Avx2
} else if caps.has(required_caps(Blake3KernelId::X86Sse41)) {
Blake3KernelId::X86Sse41
} else {
Blake3KernelId::Portable
}
}
#[cfg(target_arch = "x86_64")]
Blake3KernelId::X86Avx2 => {
if caps.has(required_caps(Blake3KernelId::X86Avx2)) {
Blake3KernelId::X86Avx2
} else if caps.has(required_caps(Blake3KernelId::X86Sse41)) {
Blake3KernelId::X86Sse41
} else {
Blake3KernelId::Portable
}
}
#[cfg(target_arch = "x86_64")]
Blake3KernelId::X86Sse41 => {
if caps.has(required_caps(Blake3KernelId::X86Sse41)) {
Blake3KernelId::X86Sse41
} else {
Blake3KernelId::Portable
}
}
#[cfg(target_arch = "aarch64")]
Blake3KernelId::Aarch64Neon => {
if caps.has(required_caps(Blake3KernelId::Aarch64Neon)) {
Blake3KernelId::Aarch64Neon
} else {
Blake3KernelId::Portable
}
}
#[cfg(target_arch = "s390x")]
Blake3KernelId::S390xVector => {
if caps.has(required_caps(Blake3KernelId::S390xVector)) {
Blake3KernelId::S390xVector
} else {
Blake3KernelId::Portable
}
}
#[cfg(target_arch = "powerpc64")]
Blake3KernelId::PowerVsx => {
if caps.has(required_caps(Blake3KernelId::PowerVsx)) {
Blake3KernelId::PowerVsx
} else {
Blake3KernelId::Portable
}
}
#[cfg(target_arch = "riscv64")]
Blake3KernelId::RiscvV => {
if caps.has(required_caps(Blake3KernelId::RiscvV)) {
Blake3KernelId::RiscvV
} else {
Blake3KernelId::Portable
}
}
}
}
#[inline]
#[must_use]
fn resolved() -> ResolvedDispatch {
RESOLVED.get_or_init(|| {
let caps = crate::platform::caps();
let table: &'static DispatchTable = super::dispatch_tables::select_table_for_caps(caps);
let stream_table: &'static StreamingTable = super::dispatch_tables::select_streaming_table_for_caps(caps);
#[cfg(feature = "parallel")]
let oneshot_parallel_table: &'static ParallelTable = super::dispatch_tables::select_parallel_table_for_caps(caps);
#[cfg(feature = "parallel")]
let streaming_parallel_table: &'static ParallelTable =
super::dispatch_tables::select_streaming_parallel_table_for_caps(caps);
let xs_id = resolve(table.xs, caps);
let s_id = resolve(table.s, caps);
let m_id = resolve(table.m, caps);
let l_id = resolve(table.l, caps);
let stream_id = resolve(stream_table.stream, caps);
let bulk_id = resolve(stream_table.bulk, caps);
let active = ActiveDispatch {
boundaries: table.boundaries,
xs: Entry { kernel: kernel(xs_id) },
s: Entry { kernel: kernel(s_id) },
m: Entry { kernel: kernel(m_id) },
l: Entry { kernel: kernel(l_id) },
};
#[cfg(feature = "parallel")]
let oneshot_base = *oneshot_parallel_table;
#[cfg(feature = "parallel")]
let streaming_base = *streaming_parallel_table;
let size_classes = SizeClassDispatch {
boundaries: active.boundaries,
xs: active.xs.kernel,
s: active.s.kernel,
m: active.m.kernel,
l: active.l.kernel,
};
let hasher = HasherDispatch {
size_classes,
stream_kernel: kernel(stream_id),
table_bulk_kernel: kernel(bulk_id),
bulk_sizeclass_threshold: stream_table.bulk_sizeclass_threshold,
};
ResolvedDispatch {
active,
#[cfg(feature = "parallel")]
parallel: ParallelDispatch {
oneshot: oneshot_base,
keyed_oneshot: oneshot_base,
derive_oneshot: oneshot_base,
xof: oneshot_base,
keyed_xof: oneshot_base,
derive_xof: oneshot_base,
streaming: streaming_base,
keyed_streaming: streaming_base,
derive_streaming: streaming_base,
},
hasher,
#[cfg(target_arch = "x86_64")]
avx2_hash_many_one_chunk_fast_path: allow_avx2_hash_many_one_chunk_fast_path(caps),
#[cfg(target_arch = "x86_64")]
hash_many_wide_pipeline: is_wide_pipeline_for_hash_many(caps),
#[cfg(target_arch = "x86_64")]
avx2_available: caps.has(required_caps(Blake3KernelId::X86Avx2)),
}
})
}
#[inline]
#[must_use]
fn active() -> ActiveDispatch {
resolved().active
}
#[cfg(feature = "parallel")]
#[inline]
#[must_use]
fn active_parallel() -> ParallelDispatch {
resolved().parallel
}
#[inline]
#[must_use]
fn select(d: &ActiveDispatch, len: usize) -> Entry {
let [xs_max, s_max, m_max] = d.boundaries;
if len <= xs_max {
d.xs
} else if len <= s_max {
d.s
} else if len <= m_max {
d.m
} else {
d.l
}
}
#[cfg(any(test, feature = "diag"))]
#[inline]
#[must_use]
pub fn kernel_name_for_len(len: usize) -> &'static str {
let d = active();
select(&d, len).kernel.name
}
#[inline]
#[must_use]
pub fn xof(data: &[u8]) -> super::Blake3XofReader {
let d = active();
let kernel = select(&d, data.len()).kernel;
if data.len() <= super::CHUNK_LEN {
return super::xof_oneshot_single_chunk(kernel, super::IV, 0, data);
}
let output = super::root_output_oneshot(
kernel,
super::IV,
0,
super::control::policy_kind_from_flags(0, true),
data,
);
super::Blake3XofReader::from_output(output)
}
#[inline]
#[must_use]
pub(crate) fn hasher_dispatch() -> HasherDispatch {
resolved().hasher
}
#[cfg(feature = "parallel")]
#[inline]
#[must_use]
pub(crate) fn parallel_dispatch() -> ParallelDispatch {
active_parallel()
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[must_use]
pub(crate) fn avx2_hash_many_one_chunk_fast_path() -> bool {
resolved().avx2_hash_many_one_chunk_fast_path
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[must_use]
pub(crate) fn hash_many_wide_pipeline() -> bool {
resolved().hash_many_wide_pipeline
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[must_use]
pub(crate) fn avx2_available() -> bool {
resolved().avx2_available
}