use super::ComputeBackend;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SimdBackendState {
#[default]
Uninitialized,
Configuring,
Ready,
Failed,
}
#[derive(Debug)]
pub struct LazySimdConfig {
state: SimdBackendState,
best_backend: ComputeBackend,
amx_supported: bool,
tile_config: Option<AmxTileConfig>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AmxTileConfig {
pub palette: u8,
pub start_row: u8,
pub rows: u8,
pub bytes_per_row: u16,
}
impl LazySimdConfig {
#[must_use]
pub fn new() -> Self {
Self {
state: SimdBackendState::Uninitialized,
best_backend: Self::detect_best_backend(),
amx_supported: Self::detect_amx(),
tile_config: None,
}
}
fn detect_best_backend() -> ComputeBackend {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return ComputeBackend::Avx512;
}
if is_x86_feature_detected!("avx2") {
return ComputeBackend::Avx2;
}
if is_x86_feature_detected!("sse2") {
return ComputeBackend::Sse2;
}
}
#[cfg(target_arch = "aarch64")]
{
return ComputeBackend::Neon;
}
ComputeBackend::Scalar
}
fn detect_amx() -> bool {
#[cfg(target_arch = "x86_64")]
{
false
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
pub fn ensure_ready(&mut self) -> Result<ComputeBackend, SimdBackendState> {
match self.state {
SimdBackendState::Ready => Ok(self.best_backend),
SimdBackendState::Failed => Err(SimdBackendState::Failed),
SimdBackendState::Configuring => Err(SimdBackendState::Configuring),
SimdBackendState::Uninitialized => {
self.state = SimdBackendState::Configuring;
if self.amx_supported {
self.tile_config = Some(AmxTileConfig {
palette: 1,
start_row: 0,
rows: 16,
bytes_per_row: 64,
});
}
self.state = SimdBackendState::Ready;
Ok(self.best_backend)
}
}
}
#[must_use]
pub fn state(&self) -> SimdBackendState {
self.state
}
#[must_use]
pub fn best_backend(&self) -> ComputeBackend {
self.best_backend
}
#[must_use]
pub fn has_amx(&self) -> bool {
self.amx_supported
}
pub fn reset(&mut self) {
self.state = SimdBackendState::Uninitialized;
self.tile_config = None;
}
}
impl Default for LazySimdConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnrollFactor {
None,
X2,
X4,
X8,
}
impl UnrollFactor {
#[must_use]
pub fn value(&self) -> usize {
match self {
UnrollFactor::None => 1,
UnrollFactor::X2 => 2,
UnrollFactor::X4 => 4,
UnrollFactor::X8 => 8,
}
}
#[must_use]
pub fn for_backend(backend: ComputeBackend) -> Self {
match backend {
ComputeBackend::Avx512 => UnrollFactor::X8,
ComputeBackend::Avx2 => UnrollFactor::X4,
ComputeBackend::Sse2 | ComputeBackend::Neon => UnrollFactor::X2,
_ => UnrollFactor::None,
}
}
}
#[derive(Debug)]
pub struct UnrollTailIterator {
total: usize,
position: usize,
chunk_size: usize,
}
impl UnrollTailIterator {
pub fn new(total: usize, factor: UnrollFactor) -> Self {
Self { total, position: 0, chunk_size: factor.value() }
}
#[must_use]
pub fn full_iterations(&self) -> usize {
self.total / self.chunk_size
}
#[must_use]
pub fn tail_size(&self) -> usize {
self.total % self.chunk_size
}
#[must_use]
pub fn has_tail(&self) -> bool {
self.tail_size() > 0
}
pub fn next_chunk(&mut self) -> Option<(usize, usize)> {
if self.position + self.chunk_size <= self.total {
let start = self.position;
self.position += self.chunk_size;
Some((start, start + self.chunk_size))
} else {
None
}
}
pub fn tail_range(&self) -> Option<(usize, usize)> {
let tail_start = self.full_iterations() * self.chunk_size;
if tail_start < self.total {
Some((tail_start, self.total))
} else {
None
}
}
}
pub fn unroll_tail_process<T, U, F, G>(
data: &[T],
factor: UnrollFactor,
mut process_chunk: F,
mut process_elem: G,
) -> Vec<U>
where
F: FnMut(&[T]) -> U,
G: FnMut(&T) -> U,
{
let mut iter = UnrollTailIterator::new(data.len(), factor);
let mut results =
Vec::with_capacity(iter.full_iterations() + if iter.has_tail() { 1 } else { 0 });
while let Some((start, end)) = iter.next_chunk() {
results.push(process_chunk(&data[start..end]));
}
if let Some((start, end)) = iter.tail_range() {
for elem in &data[start..end] {
results.push(process_elem(elem));
}
}
results
}
#[cfg(test)]
mod tests;