#[cfg(not(feature = "no-std"))]
use std::any;
#[cfg(not(feature = "no-std"))]
use std::boxed::Box;
#[cfg(not(feature = "no-std"))]
use std::collections::HashMap;
#[cfg(not(feature = "no-std"))]
use std::fmt::Debug;
#[cfg(not(feature = "no-std"))]
use std::string::{String, ToString};
#[cfg(not(feature = "no-std"))]
use std::vec::Vec;
#[cfg(feature = "no-std")]
use alloc::boxed::Box;
#[cfg(feature = "no-std")]
use alloc::collections::BTreeMap as HashMap;
#[cfg(feature = "no-std")]
use alloc::format;
#[cfg(feature = "no-std")]
use alloc::string::{String, ToString};
#[cfg(feature = "no-std")]
use alloc::vec::Vec;
#[cfg(feature = "no-std")]
use core::any;
#[cfg(feature = "no-std")]
use core::fmt::Debug;
pub trait SimdOperation<T> {
type Output;
type Error;
fn execute(&self, input: &[T]) -> Result<Self::Output, Self::Error>;
fn optimal_width(&self) -> usize;
fn is_supported(&self) -> bool;
fn name(&self) -> &'static str;
}
pub trait VectorArithmetic<T> {
fn add(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
fn sub(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
fn mul(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
fn div(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
fn fma(&self, a: &[T], b: &[T], c: &[T]) -> Result<Vec<T>, SimdError>;
fn scale(&self, vector: &[T], scalar: T) -> Result<Vec<T>, SimdError>;
}
pub trait VectorReduction<T> {
fn sum(&self, vector: &[T]) -> Result<T, SimdError>;
fn min(&self, vector: &[T]) -> Result<T, SimdError>;
fn max(&self, vector: &[T]) -> Result<T, SimdError>;
fn dot_product(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
fn norm(&self, vector: &[T]) -> Result<T, SimdError>;
fn mean(&self, vector: &[T]) -> Result<T, SimdError>;
}
pub trait DistanceMetric<T> {
fn euclidean_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
fn manhattan_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
fn cosine_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
fn squared_euclidean_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
}
pub trait ActivationFunction<T: Copy> {
fn apply(&self, input: &[T]) -> Result<Vec<T>, SimdError>;
fn derivative(&self, input: &[T]) -> Result<Vec<T>, SimdError>;
fn name(&self) -> &'static str;
fn supports_inplace(&self) -> bool;
fn apply_inplace(&self, input: &mut [T]) -> Result<(), SimdError> {
if !self.supports_inplace() {
return Err(SimdError::UnsupportedOperation(
"In-place operation not supported".to_string(),
));
}
let result = self.apply(input)?;
input.copy_from_slice(&result);
Ok(())
}
}
pub trait KernelFunction<T> {
fn compute(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
fn kernel_matrix(&self, vectors: &[&[T]]) -> Result<Vec<Vec<T>>, SimdError>;
fn name(&self) -> &'static str;
fn has_parameters(&self) -> bool;
}
pub trait MatrixOperations<T> {
fn matrix_vector_multiply(&self, matrix: &[Vec<T>], vector: &[T]) -> Result<Vec<T>, SimdError>;
fn matrix_multiply(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
fn transpose(&self, matrix: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
fn elementwise_add(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
}
pub trait ClusteringOperations<T> {
fn point_to_centroid_distances(
&self,
points: &[&[T]],
centroids: &[&[T]],
) -> Result<Vec<Vec<T>>, SimdError>;
fn update_centroids(
&self,
points: &[&[T]],
assignments: &[usize],
k: usize,
) -> Result<Vec<Vec<T>>, SimdError>;
fn wcss(
&self,
points: &[&[T]],
centroids: &[&[T]],
assignments: &[usize],
) -> Result<T, SimdError>;
}
#[derive(Debug, Clone)]
pub enum SimdError {
DimensionMismatch { expected: usize, actual: usize },
EmptyInput,
UnsupportedPlatform,
UnsupportedOperation(String),
NumericalError(String),
InvalidParameter { name: String, value: String },
AllocationError,
ExternalLibraryError(String),
InvalidInput(String),
InvalidArgument(String),
NotImplemented(String),
Other(String),
}
impl core::fmt::Display for SimdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
SimdError::DimensionMismatch { expected, actual } => {
write!(
f,
"Dimension mismatch: expected {}, got {}",
expected, actual
)
}
SimdError::EmptyInput => write!(f, "Input data is empty"),
SimdError::UnsupportedPlatform => {
write!(f, "SIMD operation not supported on this platform")
}
SimdError::UnsupportedOperation(op) => write!(f, "Unsupported operation: {}", op),
SimdError::NumericalError(msg) => write!(f, "Numerical error: {}", msg),
SimdError::InvalidParameter { name, value } => {
write!(f, "Invalid parameter {}: {}", name, value)
}
SimdError::AllocationError => write!(f, "Memory allocation failed"),
SimdError::ExternalLibraryError(msg) => write!(f, "External library error: {}", msg),
SimdError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
SimdError::InvalidArgument(msg) => write!(f, "Invalid argument: {}", msg),
SimdError::NotImplemented(msg) => write!(f, "Not implemented: {}", msg),
SimdError::Other(msg) => write!(f, "Error: {}", msg),
}
}
}
#[cfg(not(feature = "no-std"))]
impl std::error::Error for SimdError {}
#[cfg(feature = "no-std")]
impl core::error::Error for SimdError {}
pub trait SimdDispatcher<T> {
type Operation;
fn select_implementation(
&self,
) -> Box<dyn SimdOperation<T, Output = Self::Operation, Error = SimdError>>;
fn available_implementations(&self) -> Vec<&'static str>;
fn force_implementation(
&self,
name: &str,
) -> Option<Box<dyn SimdOperation<T, Output = Self::Operation, Error = SimdError>>>;
}
pub trait SimdConfig {
fn set_simd_width(&mut self, width: usize);
fn simd_width(&self) -> usize;
fn set_scalar_fallback(&mut self, enabled: bool);
fn scalar_fallback_enabled(&self) -> bool;
fn set_precision_tolerance(&mut self, tolerance: f64);
fn precision_tolerance(&self) -> f64;
}
#[derive(Debug, Clone)]
pub struct DefaultSimdConfig {
pub simd_width: usize,
pub scalar_fallback: bool,
pub precision_tolerance: f64,
}
impl Default for DefaultSimdConfig {
fn default() -> Self {
Self {
simd_width: crate::SIMD_CAPS.best_f32_width(),
scalar_fallback: true,
precision_tolerance: 1e-6,
}
}
}
impl SimdConfig for DefaultSimdConfig {
fn set_simd_width(&mut self, width: usize) {
self.simd_width = width;
}
fn simd_width(&self) -> usize {
self.simd_width
}
fn set_scalar_fallback(&mut self, enabled: bool) {
self.scalar_fallback = enabled;
}
fn scalar_fallback_enabled(&self) -> bool {
self.scalar_fallback
}
fn set_precision_tolerance(&mut self, tolerance: f64) {
self.precision_tolerance = tolerance;
}
fn precision_tolerance(&self) -> f64 {
self.precision_tolerance
}
}
pub trait ComposableOperation<T>: SimdOperation<T> {
fn compose<Other>(self, other: Other) -> ComposedOperation<Self, Other>
where
Self: Sized,
Other: SimdOperation<T>;
fn map<F, U>(self, f: F) -> MappedOperation<Self, F>
where
Self: Sized,
F: Fn(Self::Output) -> U;
}
pub struct ComposedOperation<First, Second> {
#[allow(dead_code)] first: First,
#[allow(dead_code)] second: Second,
}
impl<First, Second> ComposedOperation<First, Second> {
pub fn new(first: First, second: Second) -> Self {
Self { first, second }
}
}
pub struct MappedOperation<Op, F> {
#[allow(dead_code)] operation: Op,
#[allow(dead_code)] mapper: F,
}
impl<Op, F> MappedOperation<Op, F> {
pub fn new(operation: Op, mapper: F) -> Self {
Self { operation, mapper }
}
}
pub trait ParallelSimdOperation<T>: SimdOperation<T> {
fn execute_parallel(&self, input: &[T], chunk_size: usize)
-> Result<Self::Output, Self::Error>;
fn optimal_chunk_size(&self, input_size: usize) -> usize;
fn should_parallelize(&self, input_size: usize) -> bool;
}
pub struct SimdRegistry {
#[cfg(not(feature = "no-std"))]
operations: HashMap<String, Box<dyn any::Any + Send + Sync>>,
#[cfg(feature = "no-std")]
operations: HashMap<String, Box<dyn any::Any + Send + Sync>>,
}
impl Default for SimdRegistry {
fn default() -> Self {
Self::new()
}
}
impl SimdRegistry {
pub fn new() -> Self {
Self {
operations: HashMap::new(),
}
}
pub fn register<T: 'static + Send + Sync>(&mut self, name: String, operation: T) {
self.operations.insert(name, Box::new(operation));
}
pub fn get<T: 'static>(&self, name: &str) -> Option<&T> {
self.operations
.get(name)
.and_then(|op| op.downcast_ref::<T>())
}
pub fn list_operations(&self) -> Vec<&String> {
self.operations.keys().collect()
}
}
#[macro_export]
macro_rules! impl_simd_operation {
($type:ty, $output:ty, $name:literal) => {
impl SimdOperation<f32> for $type {
type Output = $output;
type Error = SimdError;
fn execute(&self, input: &[f32]) -> Result<Self::Output, Self::Error> {
if input.is_empty() {
return Err(SimdError::EmptyInput);
}
self.compute(input)
}
fn optimal_width(&self) -> usize {
$crate::SIMD_CAPS.best_f32_width()
}
fn is_supported(&self) -> bool {
self.optimal_width() > 1
}
fn name(&self) -> &'static str {
$name
}
}
};
}
pub mod utils {
use super::*;
pub fn validate_same_length<T>(a: &[T], b: &[T]) -> Result<(), SimdError> {
if a.len() != b.len() {
Err(SimdError::DimensionMismatch {
expected: a.len(),
actual: b.len(),
})
} else {
Ok(())
}
}
pub fn validate_not_empty<T>(slice: &[T]) -> Result<(), SimdError> {
if slice.is_empty() {
Err(SimdError::EmptyInput)
} else {
Ok(())
}
}
pub fn validate_finite(slice: &[f32]) -> Result<(), SimdError> {
for &value in slice {
if !value.is_finite() {
return Err(SimdError::NumericalError(format!(
"Non-finite value encountered: {}",
value
)));
}
}
Ok(())
}
pub fn create_chunks<T>(slice: &[T], chunk_size: usize) -> impl Iterator<Item = &[T]> {
slice.chunks(chunk_size)
}
pub fn optimal_chunk_size(input_size: usize, simd_width: usize) -> usize {
let base_chunk = simd_width * 64; let max_chunk = input_size / 4;
if max_chunk < base_chunk {
max_chunk.max(simd_width)
} else {
base_chunk
}
}
}
#[allow(non_snake_case)]
#[cfg(all(test, not(feature = "no-std")))]
mod tests {
use super::*;
#[cfg(feature = "no-std")]
use alloc::{vec, vec::Vec};
struct MockVectorAdd;
impl MockVectorAdd {
fn compute(&self, input: &[f32]) -> Result<Vec<f32>, SimdError> {
Ok(input.iter().map(|&x| x + 1.0).collect())
}
}
impl_simd_operation!(MockVectorAdd, Vec<f32>, "mock_vector_add");
#[test]
fn test_simd_operation_trait() {
let op = MockVectorAdd;
let input = vec![1.0, 2.0, 3.0, 4.0];
let result = op.execute(&input).expect("operation should succeed");
assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0]);
assert_eq!(op.name(), "mock_vector_add");
assert!(op.optimal_width() >= 1);
}
#[test]
fn test_simd_error_display() {
let error = SimdError::DimensionMismatch {
expected: 4,
actual: 3,
};
assert!(error.to_string().contains("Dimension mismatch"));
let error = SimdError::EmptyInput;
assert!(error.to_string().contains("empty"));
}
#[test]
fn test_default_simd_config() {
let mut config = DefaultSimdConfig::default();
assert!(config.simd_width() >= 1);
assert!(config.scalar_fallback_enabled());
assert_eq!(config.precision_tolerance(), 1e-6);
config.set_simd_width(8);
assert_eq!(config.simd_width(), 8);
config.set_scalar_fallback(false);
assert!(!config.scalar_fallback_enabled());
config.set_precision_tolerance(1e-8);
assert_eq!(config.precision_tolerance(), 1e-8);
}
#[test]
fn test_simd_registry() {
let mut registry = SimdRegistry::new();
registry.register("test_op".to_string(), MockVectorAdd);
let operations = registry.list_operations();
assert_eq!(operations.len(), 1);
assert_eq!(operations[0], "test_op");
let op = registry.get::<MockVectorAdd>("test_op");
assert!(op.is_some());
let nonexistent = registry.get::<MockVectorAdd>("nonexistent");
assert!(nonexistent.is_none());
}
#[test]
fn test_validation_utils() {
use utils::*;
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let c = vec![7.0, 8.0];
assert!(validate_same_length(&a, &b).is_ok());
assert!(validate_same_length(&a, &c).is_err());
assert!(validate_not_empty(&a).is_ok());
assert!(validate_not_empty(&Vec::<f32>::new()).is_err());
let finite = vec![1.0, 2.0, 3.0];
let infinite = vec![1.0, f32::INFINITY, 3.0];
let nan = vec![1.0, f32::NAN, 3.0];
assert!(validate_finite(&finite).is_ok());
assert!(validate_finite(&infinite).is_err());
assert!(validate_finite(&nan).is_err());
}
#[test]
fn test_chunk_utilities() {
use utils::*;
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let chunks: Vec<&[i32]> = create_chunks(&data, 3).collect();
assert_eq!(chunks.len(), 4);
assert_eq!(chunks[0], &[1, 2, 3]);
assert_eq!(chunks[1], &[4, 5, 6]);
assert_eq!(chunks[2], &[7, 8, 9]);
assert_eq!(chunks[3], &[10]);
let chunk_size = optimal_chunk_size(1000, 8);
assert!(chunk_size >= 8);
assert!(chunk_size <= 1000);
}
}