1pub mod arithmetic_ops;
48pub mod basic_operations;
49pub mod comparison_ops;
50pub mod intrinsics;
51pub mod math_functions;
52pub mod statistics_ops;
53
54#[allow(non_snake_case)]
55#[cfg(all(test, not(feature = "no-std")))]
56pub mod integration_test;
57
58pub use arithmetic_ops::{
60 abs_vec, add_vec, divide_vec, fma, multiply_vec, neg_vec, reciprocal_vec, scale_vec,
61 square_vec, subtract_vec,
62};
63pub use basic_operations::{cosine_similarity, dot_product, euclidean_distance, norm_l1, norm_l2};
64pub use comparison_ops::{
65 and_vec, eq_vec, ge_vec, gt_vec, le_vec, lt_vec, ne_vec, not_vec, or_vec, xor_vec,
66};
67pub use intrinsics::{
68 detect_simd_capabilities, optimal_chunk_size, simd_width_f32, F32x4, SimdCapabilities,
69};
70pub use math_functions::{cos_vec, exp_vec, ln_vec, pow_vec, sin_vec, sqrt_vec, tan_vec};
71pub use statistics_ops::{
72 dot_product as stats_dot_product, max_vec, mean_vec, min_max_vec, min_vec,
73 norm_l1 as stats_norm_l1, norm_l2 as stats_norm_l2, norm_l2_squared, product_vec, std_dev_vec,
74 sum_vec, variance_vec,
75};
76
77pub use statistics_ops::sum_vec as sum;
79
80pub use arithmetic_ops::scale_vec_inplace as scale;
82pub use statistics_ops::mean_vec as mean;
83pub use statistics_ops::{min_max_vec as min_max, variance_vec as variance};
84pub use basic_operations::norm_l2 as norm;
86
87pub use arithmetic_ops::add_vec as add_simd;
89pub use arithmetic_ops::fma as fma_simd;
90
91pub use basic_operations::{cross_product, outer_product};
95
96#[cfg(feature = "no-std")]
97use alloc::vec;
98#[cfg(feature = "no-std")]
99use alloc::vec::Vec;
100#[cfg(not(feature = "no-std"))]
101use std::vec::Vec;
102
103#[cfg(feature = "no-std")]
105use core::f32::consts;
106#[cfg(not(feature = "no-std"))]
107use std::f32::consts;
108
109#[derive(Debug, Clone)]
111pub struct SimdConfig {
112 pub enable_scalar_fallback: bool,
114 pub simd_threshold: usize,
116 pub enable_accuracy_checks: bool,
118}
119
120impl Default for SimdConfig {
121 fn default() -> Self {
122 Self {
123 enable_scalar_fallback: true,
124 simd_threshold: 16, enable_accuracy_checks: cfg!(debug_assertions),
126 }
127 }
128}
129
130#[cfg(not(feature = "no-std"))]
132thread_local! {
133 static SIMD_CONFIG: std::cell::RefCell<SimdConfig> = std::cell::RefCell::new(SimdConfig::default());
134}
135
136#[cfg(feature = "no-std")]
137static mut SIMD_CONFIG: Option<SimdConfig> = None;
138
139pub fn set_simd_config(config: SimdConfig) {
141 #[cfg(not(feature = "no-std"))]
142 {
143 SIMD_CONFIG.with(|c| *c.borrow_mut() = config);
144 }
145 #[cfg(feature = "no-std")]
146 {
147 unsafe {
148 SIMD_CONFIG = Some(config);
149 }
150 }
151}
152
153pub fn get_simd_config() -> SimdConfig {
155 #[cfg(not(feature = "no-std"))]
156 {
157 SIMD_CONFIG.with(|c| c.borrow().clone())
158 }
159 #[cfg(feature = "no-std")]
160 {
161 unsafe { core::ptr::addr_of!(SIMD_CONFIG).read().unwrap_or_default() }
162 }
163}
164
165pub struct PlatformInfo {
167 pub capabilities: SimdCapabilities,
169 pub optimal_chunk_size: usize,
171 pub recommended_alignment: usize,
173}
174
175pub fn detect_platform_info() -> PlatformInfo {
177 let capabilities = detect_simd_capabilities();
178 let optimal_chunk_size = optimal_chunk_size(1000, None); let recommended_alignment = intrinsics::preferred_alignment_f32();
180
181 PlatformInfo {
182 capabilities,
183 optimal_chunk_size,
184 recommended_alignment,
185 }
186}
187
188pub fn allocate_aligned_vec(size: usize, _alignment: usize) -> Vec<f32> {
190 vec![0.0; size]
193}
194
195pub fn is_properly_aligned(slice: &[f32], alignment: usize) -> bool {
197 (slice.as_ptr() as usize).is_multiple_of(alignment)
198}
199
200#[cfg(not(feature = "no-std"))]
202pub mod benchmarks {
203 use super::*;
204 use std::time::{Duration, Instant};
205
206 #[derive(Debug, Clone)]
208 pub struct BenchmarkResult {
209 pub operation: String,
211 pub duration: Duration,
213 pub ops_per_sec: f64,
215 pub elements_per_sec: f64,
217 pub platform: String,
219 }
220
221 pub fn benchmark_operation<F>(
223 name: &str,
224 vector_size: usize,
225 iterations: usize,
226 operation: F,
227 ) -> BenchmarkResult
228 where
229 F: Fn(),
230 {
231 for _ in 0..10 {
233 operation();
234 }
235
236 let start = Instant::now();
237 for _ in 0..iterations {
238 operation();
239 }
240 let duration = start.elapsed();
241
242 let platform_info = detect_platform_info();
243 let platform_name = platform_info.capabilities.platform_name();
244
245 BenchmarkResult {
246 operation: name.to_string(),
247 duration,
248 ops_per_sec: iterations as f64 / duration.as_secs_f64(),
249 elements_per_sec: (iterations * vector_size) as f64 / duration.as_secs_f64(),
250 platform: platform_name.to_string(),
251 }
252 }
253
254 pub fn benchmark_scaling<F>(
256 name: &str,
257 sizes: &[usize],
258 iterations: usize,
259 operation_factory: F,
260 ) -> Vec<BenchmarkResult>
261 where
262 F: Fn(usize) -> Box<dyn Fn()>,
263 {
264 sizes
265 .iter()
266 .map(|&size| {
267 let operation = operation_factory(size);
268 benchmark_operation(name, size, iterations, operation)
269 })
270 .collect()
271 }
272}
273
274pub mod accuracy {
276 use super::*;
277
278 #[derive(Debug, Clone)]
280 pub struct AccuracyResult {
281 pub max_abs_error: f32,
283 pub rms_error: f32,
285 pub mean_abs_error: f32,
287 pub test_points: usize,
289 pub grade: AccuracyGrade,
291 }
292
293 #[derive(Debug, Clone, PartialEq)]
295 pub enum AccuracyGrade {
296 A, B, C, D, F, }
302
303 pub fn test_accuracy<F, R>(
305 approximation: F,
306 reference: R,
307 test_inputs: &[f32],
308 ) -> AccuracyResult
309 where
310 F: Fn(&[f32], &mut [f32]),
311 R: Fn(f32) -> f32,
312 {
313 let mut approx_results = vec![0.0; test_inputs.len()];
314 approximation(test_inputs, &mut approx_results);
315
316 let mut errors = Vec::with_capacity(test_inputs.len());
317 let mut abs_errors = Vec::with_capacity(test_inputs.len());
318
319 for (i, &input) in test_inputs.iter().enumerate() {
320 let reference_result = reference(input);
321 let error = approx_results[i] - reference_result;
322 let abs_error = error.abs();
323
324 errors.push(error);
325 abs_errors.push(abs_error);
326 }
327
328 let max_abs_error = abs_errors.iter().fold(0.0f32, |a, &b| a.max(b));
329 let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
330 let rms_error = (errors.iter().map(|&e| e * e).sum::<f32>() / errors.len() as f32).sqrt();
331
332 let grade = match max_abs_error {
333 e if e < 1e-6 => AccuracyGrade::A,
334 e if e < 1e-5 => AccuracyGrade::B,
335 e if e < 1e-4 => AccuracyGrade::C,
336 e if e < 1e-3 => AccuracyGrade::D,
337 _ => AccuracyGrade::F,
338 };
339
340 AccuracyResult {
341 max_abs_error,
342 rms_error,
343 mean_abs_error,
344 test_points: test_inputs.len(),
345 grade,
346 }
347 }
348
349 pub fn generate_test_inputs(
351 range_start: f32,
352 range_end: f32,
353 num_points: usize,
354 include_special_values: bool,
355 ) -> Vec<f32> {
356 let mut inputs = Vec::with_capacity(num_points + 20);
357
358 let step = (range_end - range_start) / (num_points as f32);
360 for i in 0..num_points {
361 inputs.push(range_start + i as f32 * step);
362 }
363
364 if include_special_values {
365 let special_values = vec![
367 0.0,
368 -0.0,
369 consts::PI,
370 -consts::PI,
371 consts::PI / 2.0,
372 -consts::PI / 2.0,
373 consts::PI / 4.0,
374 -consts::PI / 4.0,
375 consts::E,
376 -consts::E,
377 1.0,
378 -1.0,
379 2.0,
380 -2.0,
381 10.0,
382 -10.0,
383 0.1,
384 -0.1,
385 0.001,
386 -0.001,
387 1e-6,
388 -1e-6,
389 ];
390
391 for value in special_values {
392 if value >= range_start && value <= range_end {
393 inputs.push(value);
394 }
395 }
396 }
397
398 inputs.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
399 inputs.dedup();
400 inputs
401 }
402}
403
404pub mod utils {
406 use super::*;
407
408 pub fn check_compatible_lengths(a: &[f32], b: &[f32]) -> Result<(), &'static str> {
410 if a.len() != b.len() {
411 Err("Vectors must have the same length")
412 } else {
413 Ok(())
414 }
415 }
416
417 pub fn check_io_lengths(input: &[f32], output: &[f32]) -> Result<(), &'static str> {
419 check_compatible_lengths(input, output)
420 }
421
422 pub fn check_not_empty(vec: &[f32]) -> Result<(), &'static str> {
424 if vec.is_empty() {
425 Err("Vector cannot be empty")
426 } else {
427 Ok(())
428 }
429 }
430
431 pub fn get_platform_chunk_size() -> usize {
433 detect_platform_info().optimal_chunk_size
434 }
435
436 pub fn chunk_vector(vec: &[f32], chunk_size: usize) -> (&[f32], &[f32]) {
438 let simd_len = (vec.len() / chunk_size) * chunk_size;
439 vec.split_at(simd_len)
440 }
441
442 pub fn process_chunks<F, R>(
444 vec: &[f32],
445 chunk_size: usize,
446 mut chunk_processor: F,
447 mut remainder_processor: R,
448 ) where
449 F: FnMut(&[f32]),
450 R: FnMut(&[f32]),
451 {
452 let (chunks, remainder) = chunk_vector(vec, chunk_size);
453
454 for chunk in chunks.chunks_exact(chunk_size) {
455 chunk_processor(chunk);
456 }
457
458 if !remainder.is_empty() {
459 remainder_processor(remainder);
460 }
461 }
462
463 pub fn degrees_to_radians(degrees: f32) -> f32 {
465 degrees * consts::PI / 180.0
466 }
467
468 pub fn radians_to_degrees(radians: f32) -> f32 {
470 radians * 180.0 / consts::PI
471 }
472
473 pub fn safe_divide(numerator: f32, denominator: f32) -> f32 {
475 if denominator.abs() < f32::EPSILON {
476 if numerator >= 0.0 {
477 f32::INFINITY
478 } else {
479 f32::NEG_INFINITY
480 }
481 } else {
482 numerator / denominator
483 }
484 }
485
486 pub fn clamp(value: f32, min: f32, max: f32) -> f32 {
488 if value < min {
489 min
490 } else if value > max {
491 max
492 } else {
493 value
494 }
495 }
496}
497
498pub mod constants {
500 #[cfg(feature = "no-std")]
501 use core::f32::consts;
502 #[cfg(not(feature = "no-std"))]
503 use std::f32::consts;
504
505 pub const PI_F32: f32 = consts::PI;
507 pub const E_F32: f32 = consts::E;
508 pub const LN_2_F32: f32 = consts::LN_2;
509 pub const LN_10_F32: f32 = consts::LN_10;
510 pub const SQRT_2_F32: f32 = consts::SQRT_2;
511
512 pub const SSE2_VECTOR_SIZE: usize = 4; pub const AVX2_VECTOR_SIZE: usize = 8; pub const AVX512_VECTOR_SIZE: usize = 16; pub const NEON_VECTOR_SIZE: usize = 4; pub const SSE2_ALIGNMENT: usize = 16; pub const AVX2_ALIGNMENT: usize = 32; pub const AVX512_ALIGNMENT: usize = 64; pub const NEON_ALIGNMENT: usize = 16; }
524
525#[allow(non_snake_case)]
526#[cfg(all(test, not(feature = "no-std")))]
527mod tests {
528 use super::*;
529
530 #[cfg(feature = "no-std")]
531 use alloc::{vec, vec::Vec};
532
533 #[test]
534 fn test_simd_config() {
535 let config = SimdConfig::default();
536 set_simd_config(config.clone());
537
538 let retrieved_config = get_simd_config();
539 assert_eq!(retrieved_config.simd_threshold, config.simd_threshold);
540 assert_eq!(
541 retrieved_config.enable_scalar_fallback,
542 config.enable_scalar_fallback
543 );
544 }
545
546 #[test]
547 fn test_platform_detection() {
548 let platform_info = detect_platform_info();
549 assert!(platform_info.optimal_chunk_size >= 4);
550 assert!(platform_info.recommended_alignment >= 4);
551
552 let caps = platform_info.capabilities;
554 println!("SIMD Capabilities: {:?}", caps);
555 }
556
557 #[test]
558 fn test_aligned_allocation() {
559 let vec = allocate_aligned_vec(16, 16);
560 assert_eq!(vec.len(), 16);
561 assert_eq!(vec[0], 0.0);
562 }
563
564 #[test]
565 fn test_utils() {
566 use utils::*;
567
568 let a = vec![1.0, 2.0, 3.0];
570 let b = vec![4.0, 5.0, 6.0];
571 let c = vec![7.0, 8.0];
572
573 assert!(check_compatible_lengths(&a, &b).is_ok());
574 assert!(check_compatible_lengths(&a, &c).is_err());
575
576 let empty_vec: Vec<f32> = vec![];
578 assert!(check_not_empty(&empty_vec).is_err());
579 assert!(check_not_empty(&a).is_ok());
580
581 let vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
583 let (chunks, remainder) = chunk_vector(&vec, 4);
584 assert_eq!(chunks.len(), 8); assert_eq!(remainder.len(), 1); assert!((degrees_to_radians(180.0) - constants::PI_F32).abs() < f32::EPSILON);
589 assert!((radians_to_degrees(constants::PI_F32) - 180.0).abs() < f32::EPSILON);
590
591 assert_eq!(safe_divide(10.0, 2.0), 5.0);
592 assert_eq!(safe_divide(10.0, 0.0), f32::INFINITY);
593 assert_eq!(safe_divide(-10.0, 0.0), f32::NEG_INFINITY);
594
595 assert_eq!(clamp(5.0, 1.0, 10.0), 5.0);
596 assert_eq!(clamp(-5.0, 1.0, 10.0), 1.0);
597 assert_eq!(clamp(15.0, 1.0, 10.0), 10.0);
598 }
599
600 #[test]
601 fn test_accuracy_grading() {
602 use accuracy::AccuracyGrade;
603
604 let grade_a = AccuracyGrade::A;
606 let grade_f = AccuracyGrade::F;
607
608 assert!(grade_a != grade_f);
609 assert_eq!(grade_a, AccuracyGrade::A);
610 }
611}
612
613#[allow(non_snake_case)]
616#[cfg(all(test, not(feature = "no-std")))]
617mod integration_tests {
618 use super::*;
619
620 #[cfg(feature = "no-std")]
621 use alloc::{vec, vec::Vec};
622
623 #[test]
624 fn test_basic_workflow() {
625 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
627 let b = [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
628
629 assert_eq!(a.len(), b.len());
636 assert_eq!(a.len(), 8);
637 }
638
639 #[test]
640 fn test_platform_optimization_paths() {
641 let platform_info = detect_platform_info();
643 println!("SIMD capabilities: {:?}", platform_info.capabilities);
644 println!(
645 "Platform name: {}",
646 platform_info.capabilities.platform_name()
647 );
648 println!("Optimal chunk size: {}", platform_info.optimal_chunk_size);
649 println!(
650 "Recommended alignment: {}",
651 platform_info.recommended_alignment
652 );
653
654 assert!(platform_info.optimal_chunk_size >= 1);
656 assert!(platform_info.recommended_alignment >= 4);
657 }
658}