use super::KernelType;
use crate::MobilePlatform;
use trustformers_core::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdInstructions {
Neon,
Sve,
AdvSimd,
None,
}
#[derive(Debug, Clone)]
pub struct VectorizationStrategy {
pub instruction_set: SimdInstructions,
pub vector_width: usize,
pub data_type: SimdDataType,
pub alignment: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdDataType {
Float32,
Float16,
Int8,
Int16,
Int32,
}
pub struct SimdOptimizer {
platform: MobilePlatform,
available_instructions: Vec<SimdInstructions>,
capabilities: SimdCapabilities,
}
#[derive(Debug, Clone)]
struct SimdCapabilities {
has_fma: bool,
has_dot_product: bool,
has_fp16: bool,
has_bf16: bool,
has_int8_matmul: bool,
max_vector_width: usize,
}
impl SimdOptimizer {
pub fn new(platform: MobilePlatform) -> Self {
let available_instructions = Self::detect_simd_support(&platform);
let capabilities = Self::detect_capabilities(&platform, &available_instructions);
Self {
platform,
available_instructions,
capabilities,
}
}
pub fn can_vectorize(&self, kernel: &KernelType) -> bool {
if self.available_instructions.is_empty() {
return false;
}
matches!(
kernel,
KernelType::Conv2d
| KernelType::Linear
| KernelType::BatchNorm
| KernelType::Activation
| KernelType::Pooling
| KernelType::Custom(_)
)
}
pub fn vectorize_kernel(
&self,
kernel: &KernelType,
input_shapes: &[Vec<usize>],
) -> Result<KernelType> {
let strategy = self.select_vectorization_strategy(kernel, input_shapes)?;
match strategy.instruction_set {
SimdInstructions::Neon => self.vectorize_with_neon(kernel, &strategy),
SimdInstructions::AdvSimd => self.vectorize_with_advsimd(kernel, &strategy),
SimdInstructions::Sve => self.vectorize_with_sve(kernel, &strategy),
SimdInstructions::None => Ok(kernel.clone()),
}
}
pub fn optimal_vector_width(&self, data_type: SimdDataType) -> usize {
let base_width = self.capabilities.max_vector_width;
match data_type {
SimdDataType::Float32 => base_width / 32,
SimdDataType::Float16 => base_width / 16,
SimdDataType::Int8 => base_width / 8,
SimdDataType::Int16 => base_width / 16,
SimdDataType::Int32 => base_width / 32,
}
}
fn detect_simd_support(platform: &MobilePlatform) -> Vec<SimdInstructions> {
let mut instructions = Vec::new();
match platform {
MobilePlatform::Ios => {
instructions.push(SimdInstructions::Neon);
instructions.push(SimdInstructions::AdvSimd);
},
MobilePlatform::Android => {
if cfg!(target_arch = "aarch64") {
instructions.push(SimdInstructions::Neon);
instructions.push(SimdInstructions::AdvSimd);
} else if cfg!(target_arch = "arm") {
instructions.push(SimdInstructions::Neon);
}
},
MobilePlatform::Generic => {
if cfg!(any(target_arch = "aarch64", target_arch = "arm")) {
instructions.push(SimdInstructions::Neon);
}
},
}
instructions
}
fn detect_capabilities(
platform: &MobilePlatform,
instructions: &[SimdInstructions],
) -> SimdCapabilities {
let mut caps = SimdCapabilities {
has_fma: false,
has_dot_product: false,
has_fp16: false,
has_bf16: false,
has_int8_matmul: false,
max_vector_width: 64,
};
if instructions.contains(&SimdInstructions::AdvSimd) {
caps.has_fma = true;
caps.has_fp16 = true;
caps.max_vector_width = 128;
if matches!(platform, MobilePlatform::Ios) {
caps.has_dot_product = true;
caps.has_int8_matmul = true;
}
} else if instructions.contains(&SimdInstructions::Neon) {
caps.max_vector_width = 128;
caps.has_fma = cfg!(target_arch = "aarch64");
}
caps
}
fn select_vectorization_strategy(
&self,
kernel: &KernelType,
input_shapes: &[Vec<usize>],
) -> Result<VectorizationStrategy> {
let instruction_set =
self.available_instructions.first().copied().unwrap_or(SimdInstructions::None);
let data_type = if self.capabilities.has_fp16 && self.should_use_fp16(kernel) {
SimdDataType::Float16
} else if self.should_use_int8(kernel) {
SimdDataType::Int8
} else {
SimdDataType::Float32
};
let vector_width = self.capabilities.max_vector_width;
let alignment = if vector_width >= 128 { 16 } else { 8 };
Ok(VectorizationStrategy {
instruction_set,
vector_width,
data_type,
alignment,
})
}
fn should_use_fp16(&self, kernel: &KernelType) -> bool {
matches!(kernel, KernelType::Activation | KernelType::BatchNorm)
}
fn should_use_int8(&self, kernel: &KernelType) -> bool {
false }
fn vectorize_with_neon(
&self,
kernel: &KernelType,
strategy: &VectorizationStrategy,
) -> Result<KernelType> {
let vectorized_name = format!("Neon{:?}", kernel);
Ok(KernelType::Custom(vectorized_name))
}
fn vectorize_with_advsimd(
&self,
kernel: &KernelType,
strategy: &VectorizationStrategy,
) -> Result<KernelType> {
let vectorized_name = format!("AdvSimd{:?}", kernel);
Ok(KernelType::Custom(vectorized_name))
}
fn vectorize_with_sve(
&self,
kernel: &KernelType,
strategy: &VectorizationStrategy,
) -> Result<KernelType> {
let vectorized_name = format!("Sve{:?}", kernel);
Ok(KernelType::Custom(vectorized_name))
}
}
pub struct NeonOptimizations;
impl NeonOptimizations {
pub fn generate_vadd_f32() -> &'static str {
r#"
void neon_vadd_f32(const float* a, const float* b, float* c, size_t n) {
size_t i = 0;
// Process 4 floats at a time
for (; i + 4 <= n; i += 4) {
float32x4_t va = vld1q_f32(a + i);
float32x4_t vb = vld1q_f32(b + i);
float32x4_t vc = vaddq_f32(va, vb);
vst1q_f32(c + i, vc);
}
// Handle remaining elements
for (; i < n; i++) {
c[i] = a[i] + b[i];
}
}
"#
}
pub fn generate_vmul_f32() -> &'static str {
r#"
void neon_vmul_f32(const float* a, const float* b, float* c, size_t n) {
size_t i = 0;
for (; i + 4 <= n; i += 4) {
float32x4_t va = vld1q_f32(a + i);
float32x4_t vb = vld1q_f32(b + i);
float32x4_t vc = vmulq_f32(va, vb);
vst1q_f32(c + i, vc);
}
for (; i < n; i++) {
c[i] = a[i] * b[i];
}
}
"#
}
pub fn generate_vfma_f32() -> &'static str {
r#"
void neon_vfma_f32(const float* a, const float* b, const float* c, float* d, size_t n) {
size_t i = 0;
for (; i + 4 <= n; i += 4) {
float32x4_t va = vld1q_f32(a + i);
float32x4_t vb = vld1q_f32(b + i);
float32x4_t vc = vld1q_f32(c + i);
float32x4_t vd = vfmaq_f32(vc, va, vb); // d = a * b + c
vst1q_f32(d + i, vd);
}
for (; i < n; i++) {
d[i] = a[i] * b[i] + c[i];
}
}
"#
}
pub fn generate_relu_f32() -> &'static str {
r#"
void neon_relu_f32(const float* input, float* output, size_t n) {
const float32x4_t zero = vdupq_n_f32(0.0f);
size_t i = 0;
// Process 4 elements at a time
for (; i + 4 <= n; i += 4) {
float32x4_t x = vld1q_f32(input + i);
float32x4_t result = vmaxq_f32(x, zero);
vst1q_f32(output + i, result);
}
// Handle remaining elements
for (; i < n; i++) {
output[i] = fmaxf(input[i], 0.0f);
}
}
"#
}
pub fn generate_dot_product_f32() -> &'static str {
r#"
float neon_dot_product_f32(const float* a, const float* b, size_t n) {
float32x4_t sum = vdupq_n_f32(0.0f);
size_t i = 0;
// Unroll by 4 for better performance
for (; i + 16 <= n; i += 16) {
float32x4_t a0 = vld1q_f32(a + i);
float32x4_t b0 = vld1q_f32(b + i);
float32x4_t a1 = vld1q_f32(a + i + 4);
float32x4_t b1 = vld1q_f32(b + i + 4);
float32x4_t a2 = vld1q_f32(a + i + 8);
float32x4_t b2 = vld1q_f32(b + i + 8);
float32x4_t a3 = vld1q_f32(a + i + 12);
float32x4_t b3 = vld1q_f32(b + i + 12);
sum = vfmaq_f32(sum, a0, b0);
sum = vfmaq_f32(sum, a1, b1);
sum = vfmaq_f32(sum, a2, b2);
sum = vfmaq_f32(sum, a3, b3);
}
// Process remaining groups of 4
for (; i + 4 <= n; i += 4) {
float32x4_t a0 = vld1q_f32(a + i);
float32x4_t b0 = vld1q_f32(b + i);
sum = vfmaq_f32(sum, a0, b0);
}
// Horizontal sum
float result = vaddvq_f32(sum);
// Handle remaining elements
for (; i < n; i++) {
result += a[i] * b[i];
}
return result;
}
"#
}
}
pub struct AdvSimdOptimizations;
impl AdvSimdOptimizations {
pub fn generate_fp16_ops() -> &'static str {
r#"
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void advsimd_vadd_f16(const __fp16* a, const __fp16* b, __fp16* c, size_t n) {
size_t i = 0;
// Process 8 FP16 values at a time
for (; i + 8 <= n; i += 8) {
float16x8_t va = vld1q_f16(a + i);
float16x8_t vb = vld1q_f16(b + i);
float16x8_t vc = vaddq_f16(va, vb);
vst1q_f16(c + i, vc);
}
// Handle remaining
for (; i < n; i++) {
c[i] = a[i] + b[i];
}
}
#endif
"#
}
pub fn generate_int8_dot_product() -> &'static str {
r#"
#ifdef __ARM_FEATURE_DOTPROD
int32_t advsimd_dot_product_i8(const int8_t* a, const int8_t* b, size_t n) {
int32x4_t sum = vdupq_n_s32(0);
size_t i = 0;
// Process 16 INT8 values at a time (4 dot products)
for (; i + 16 <= n; i += 16) {
int8x16_t va = vld1q_s8(a + i);
int8x16_t vb = vld1q_s8(b + i);
// Split into 4-element groups for dot product
int8x8_t va_low = vget_low_s8(va);
int8x8_t va_high = vget_high_s8(va);
int8x8_t vb_low = vget_low_s8(vb);
int8x8_t vb_high = vget_high_s8(vb);
// Accumulate dot products
sum = vdotq_s32(sum, va_low, vb_low);
sum = vdotq_s32(sum, va_high, vb_high);
}
// Horizontal sum
int32_t result = vaddvq_s32(sum);
// Handle remaining elements
for (; i < n; i++) {
result += (int32_t)a[i] * (int32_t)b[i];
}
return result;
}
#endif
"#
}
}
pub struct SimdPerformanceEstimator;
impl SimdPerformanceEstimator {
pub fn estimate_speedup(
instruction_set: SimdInstructions,
data_type: SimdDataType,
operation: &KernelType,
) -> f32 {
let vector_speedup = match (instruction_set, data_type) {
(SimdInstructions::Neon, SimdDataType::Float32) => 4.0,
(SimdInstructions::Neon, SimdDataType::Float16) => 8.0,
(SimdInstructions::Neon, SimdDataType::Int8) => 16.0,
(SimdInstructions::AdvSimd, SimdDataType::Float32) => 4.0,
(SimdInstructions::AdvSimd, SimdDataType::Float16) => 8.0,
(SimdInstructions::AdvSimd, SimdDataType::Int8) => 16.0,
(SimdInstructions::Sve, _) => 8.0, (SimdInstructions::None, _) => 1.0,
_ => 2.0,
};
let operation_efficiency = match operation {
KernelType::Conv2d => 0.8, KernelType::Linear => 0.9, KernelType::Activation => 0.95, KernelType::BatchNorm => 0.85, KernelType::Pooling => 0.7, _ => 0.6,
};
vector_speedup * operation_efficiency
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_optimizer_creation() {
let optimizer = SimdOptimizer::new(MobilePlatform::Ios);
assert!(!optimizer.available_instructions.is_empty());
}
#[test]
fn test_vectorization_check() {
let optimizer = SimdOptimizer::new(MobilePlatform::Ios);
assert!(optimizer.can_vectorize(&KernelType::Conv2d));
assert!(optimizer.can_vectorize(&KernelType::Linear));
assert!(optimizer.can_vectorize(&KernelType::Activation));
}
#[test]
fn test_optimal_vector_width() {
let optimizer = SimdOptimizer::new(MobilePlatform::Ios);
assert_eq!(optimizer.optimal_vector_width(SimdDataType::Float32), 4);
assert_eq!(optimizer.optimal_vector_width(SimdDataType::Float16), 8);
assert_eq!(optimizer.optimal_vector_width(SimdDataType::Int8), 16);
}
#[test]
fn test_performance_estimation() {
let speedup = SimdPerformanceEstimator::estimate_speedup(
SimdInstructions::Neon,
SimdDataType::Float32,
&KernelType::Linear,
);
assert!(speedup > 1.0);
assert!(speedup <= 4.0);
}
}