Skip to main content

sklears_simd/
vector.rs

1//! # SIMD Vector Operations Framework
2//!
3//! High-performance SIMD-optimized vector operations library providing comprehensive
4//! mathematical operations with automatic platform detection and optimization.
5//!
6//! ## Features
7//!
8//! - **Multi-Platform SIMD**: Automatic detection and use of SSE2, AVX2, AVX512, NEON
9//! - **Comprehensive Operations**: Basic arithmetic, advanced math, statistics, comparisons
10//! - **Performance Optimized**: Hand-tuned intrinsics with scalar fallbacks
11//! - **Type Safe**: Compile-time platform feature detection
12//! - **No-std Compatible**: Supports embedded and constrained environments
13//! - **Extensive Testing**: Comprehensive test coverage with accuracy verification
14//!
15//! ## Architecture
16//!
17//! ```text
18//! SIMD Vector Framework
19//! ├── Basic Operations (dot product, norms, fundamentals)
20//! ├── Arithmetic Operations (add, multiply, FMA, element-wise)
21//! ├── Comparison Operations (min/max, logical operations)
22//! ├── Math Functions (trigonometric, exponential functions)
23//! ├── Statistics Operations (mean, histogram, quantile)
24//! └── Platform Intrinsics (SSE2, AVX2, AVX512, NEON)
25//! ```
26//!
27//! ## Usage Examples
28//!
29//! ```rust
30//! use sklears_simd::vector::{dot_product, norm_l2, add_vec};
31//!
32//! // Basic vector operations
33//! let a = vec![1.0, 2.0, 3.0, 4.0];
34//! let b = vec![5.0, 6.0, 7.0, 8.0];
35//!
36//! // SIMD-optimized dot product
37//! let dot = dot_product(&a, &b);
38//!
39//! // L2 norm computation
40//! let norm = norm_l2(&a);
41//!
42//! // Element-wise addition
43//! let mut result = vec![0.0; a.len()];
44//! add_vec(&a, &b, &mut result);
45//! ```
46
47pub mod arithmetic_ops;
48pub mod basic_operations;
49pub mod comparison_ops;
50pub mod intrinsics;
51pub mod math_functions;
52pub mod statistics_ops;
53
54#[allow(non_snake_case)]
55#[cfg(all(test, not(feature = "no-std")))]
56pub mod integration_test;
57
58// Re-export all public functions for unified API
59pub use arithmetic_ops::{
60    abs_vec, add_vec, divide_vec, fma, multiply_vec, neg_vec, reciprocal_vec, scale_vec,
61    square_vec, subtract_vec,
62};
63pub use basic_operations::{cosine_similarity, dot_product, euclidean_distance, norm_l1, norm_l2};
64pub use comparison_ops::{
65    and_vec, eq_vec, ge_vec, gt_vec, le_vec, lt_vec, ne_vec, not_vec, or_vec, xor_vec,
66};
67pub use intrinsics::{
68    detect_simd_capabilities, optimal_chunk_size, simd_width_f32, F32x4, SimdCapabilities,
69};
70pub use math_functions::{cos_vec, exp_vec, ln_vec, pow_vec, sin_vec, sqrt_vec, tan_vec};
71pub use statistics_ops::{
72    dot_product as stats_dot_product, max_vec, mean_vec, min_max_vec, min_vec,
73    norm_l1 as stats_norm_l1, norm_l2 as stats_norm_l2, norm_l2_squared, product_vec, std_dev_vec,
74    sum_vec, variance_vec,
75};
76
77// Export sum for activation module
78pub use statistics_ops::sum_vec as sum;
79
80// Additional exports for other modules
81pub use arithmetic_ops::scale_vec_inplace as scale;
82pub use statistics_ops::mean_vec as mean;
83pub use statistics_ops::{min_max_vec as min_max, variance_vec as variance};
84// lt_vec and simd_width_f32 are already exported above
85pub use basic_operations::norm_l2 as norm;
86
87// Function aliases for compatibility
88pub use arithmetic_ops::add_vec as add_simd;
89pub use arithmetic_ops::fma as fma_simd;
90
91// Constants are already defined in the main constants module below
92
93// Advanced operations that use combinations of basic operations
94pub use basic_operations::{cross_product, outer_product};
95
96#[cfg(feature = "no-std")]
97use alloc::vec;
98#[cfg(feature = "no-std")]
99use alloc::vec::Vec;
100#[cfg(not(feature = "no-std"))]
101use std::vec::Vec;
102
103// Import f32 constants conditionally
104#[cfg(feature = "no-std")]
105use core::f32::consts;
106#[cfg(not(feature = "no-std"))]
107use std::f32::consts;
108
109/// SIMD vector operations configuration
110#[derive(Debug, Clone)]
111pub struct SimdConfig {
112    /// Enable fallback to scalar operations
113    pub enable_scalar_fallback: bool,
114    /// Minimum vector size for SIMD operations
115    pub simd_threshold: usize,
116    /// Enable accuracy checks for approximation functions
117    pub enable_accuracy_checks: bool,
118}
119
120impl Default for SimdConfig {
121    fn default() -> Self {
122        Self {
123            enable_scalar_fallback: true,
124            simd_threshold: 16, // Minimum 16 elements
125            enable_accuracy_checks: cfg!(debug_assertions),
126        }
127    }
128}
129
130// Global SIMD configuration (thread-local for thread safety)
131#[cfg(not(feature = "no-std"))]
132thread_local! {
133    static SIMD_CONFIG: std::cell::RefCell<SimdConfig> = std::cell::RefCell::new(SimdConfig::default());
134}
135
136#[cfg(feature = "no-std")]
137static mut SIMD_CONFIG: Option<SimdConfig> = None;
138
139/// Set global SIMD configuration
140pub fn set_simd_config(config: SimdConfig) {
141    #[cfg(not(feature = "no-std"))]
142    {
143        SIMD_CONFIG.with(|c| *c.borrow_mut() = config);
144    }
145    #[cfg(feature = "no-std")]
146    {
147        unsafe {
148            SIMD_CONFIG = Some(config);
149        }
150    }
151}
152
153/// Get current SIMD configuration
154pub fn get_simd_config() -> SimdConfig {
155    #[cfg(not(feature = "no-std"))]
156    {
157        SIMD_CONFIG.with(|c| c.borrow().clone())
158    }
159    #[cfg(feature = "no-std")]
160    {
161        unsafe { core::ptr::addr_of!(SIMD_CONFIG).read().unwrap_or_default() }
162    }
163}
164
165/// Platform-specific feature detection and optimization
166pub struct PlatformInfo {
167    /// Detected SIMD capabilities
168    pub capabilities: SimdCapabilities,
169    /// Optimal chunk size for current platform
170    pub optimal_chunk_size: usize,
171    /// Recommended alignment for vectors
172    pub recommended_alignment: usize,
173}
174
175/// Detect platform capabilities and optimization parameters
176pub fn detect_platform_info() -> PlatformInfo {
177    let capabilities = detect_simd_capabilities();
178    let optimal_chunk_size = optimal_chunk_size(1000, None); // Default array length for estimation
179    let recommended_alignment = intrinsics::preferred_alignment_f32();
180
181    PlatformInfo {
182        capabilities,
183        optimal_chunk_size,
184        recommended_alignment,
185    }
186}
187
188/// Optimized memory allocation for SIMD vectors
189pub fn allocate_aligned_vec(size: usize, _alignment: usize) -> Vec<f32> {
190    // Note: In a full implementation, would use aligned allocation
191    // For now, return standard Vec which has reasonable alignment for most platforms
192    vec![0.0; size]
193}
194
195/// Check if vector is properly aligned for SIMD operations
196pub fn is_properly_aligned(slice: &[f32], alignment: usize) -> bool {
197    (slice.as_ptr() as usize).is_multiple_of(alignment)
198}
199
200/// Performance benchmarking utilities for SIMD operations
201#[cfg(not(feature = "no-std"))]
202pub mod benchmarks {
203    use super::*;
204    use std::time::{Duration, Instant};
205
206    /// Benchmark result for SIMD operations
207    #[derive(Debug, Clone)]
208    pub struct BenchmarkResult {
209        /// Operation name
210        pub operation: String,
211        /// Total execution time
212        pub duration: Duration,
213        /// Operations per second
214        pub ops_per_sec: f64,
215        /// Throughput in elements per second
216        pub elements_per_sec: f64,
217        /// SIMD platform used
218        pub platform: String,
219    }
220
221    /// Benchmark a vector operation
222    pub fn benchmark_operation<F>(
223        name: &str,
224        vector_size: usize,
225        iterations: usize,
226        operation: F,
227    ) -> BenchmarkResult
228    where
229        F: Fn(),
230    {
231        // Warmup
232        for _ in 0..10 {
233            operation();
234        }
235
236        let start = Instant::now();
237        for _ in 0..iterations {
238            operation();
239        }
240        let duration = start.elapsed();
241
242        let platform_info = detect_platform_info();
243        let platform_name = platform_info.capabilities.platform_name();
244
245        BenchmarkResult {
246            operation: name.to_string(),
247            duration,
248            ops_per_sec: iterations as f64 / duration.as_secs_f64(),
249            elements_per_sec: (iterations * vector_size) as f64 / duration.as_secs_f64(),
250            platform: platform_name.to_string(),
251        }
252    }
253
254    /// Compare performance across different vector sizes
255    pub fn benchmark_scaling<F>(
256        name: &str,
257        sizes: &[usize],
258        iterations: usize,
259        operation_factory: F,
260    ) -> Vec<BenchmarkResult>
261    where
262        F: Fn(usize) -> Box<dyn Fn()>,
263    {
264        sizes
265            .iter()
266            .map(|&size| {
267                let operation = operation_factory(size);
268                benchmark_operation(name, size, iterations, operation)
269            })
270            .collect()
271    }
272}
273
274/// Accuracy verification utilities for approximation functions
275pub mod accuracy {
276    use super::*;
277
278    /// Accuracy test result
279    #[derive(Debug, Clone)]
280    pub struct AccuracyResult {
281        /// Maximum absolute error
282        pub max_abs_error: f32,
283        /// Root mean square error
284        pub rms_error: f32,
285        /// Mean absolute error
286        pub mean_abs_error: f32,
287        /// Number of test points
288        pub test_points: usize,
289        /// Accuracy grade (A-F)
290        pub grade: AccuracyGrade,
291    }
292
293    /// Accuracy grading system
294    #[derive(Debug, Clone, PartialEq)]
295    pub enum AccuracyGrade {
296        A, // Excellent (< 1e-6)
297        B, // Very Good (< 1e-5)
298        C, // Good (< 1e-4)
299        D, // Acceptable (< 1e-3)
300        F, // Poor (>= 1e-3)
301    }
302
303    /// Test accuracy of approximation function against reference
304    pub fn test_accuracy<F, R>(
305        approximation: F,
306        reference: R,
307        test_inputs: &[f32],
308    ) -> AccuracyResult
309    where
310        F: Fn(&[f32], &mut [f32]),
311        R: Fn(f32) -> f32,
312    {
313        let mut approx_results = vec![0.0; test_inputs.len()];
314        approximation(test_inputs, &mut approx_results);
315
316        let mut errors = Vec::with_capacity(test_inputs.len());
317        let mut abs_errors = Vec::with_capacity(test_inputs.len());
318
319        for (i, &input) in test_inputs.iter().enumerate() {
320            let reference_result = reference(input);
321            let error = approx_results[i] - reference_result;
322            let abs_error = error.abs();
323
324            errors.push(error);
325            abs_errors.push(abs_error);
326        }
327
328        let max_abs_error = abs_errors.iter().fold(0.0f32, |a, &b| a.max(b));
329        let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
330        let rms_error = (errors.iter().map(|&e| e * e).sum::<f32>() / errors.len() as f32).sqrt();
331
332        let grade = match max_abs_error {
333            e if e < 1e-6 => AccuracyGrade::A,
334            e if e < 1e-5 => AccuracyGrade::B,
335            e if e < 1e-4 => AccuracyGrade::C,
336            e if e < 1e-3 => AccuracyGrade::D,
337            _ => AccuracyGrade::F,
338        };
339
340        AccuracyResult {
341            max_abs_error,
342            rms_error,
343            mean_abs_error,
344            test_points: test_inputs.len(),
345            grade,
346        }
347    }
348
349    /// Generate comprehensive test inputs for mathematical functions
350    pub fn generate_test_inputs(
351        range_start: f32,
352        range_end: f32,
353        num_points: usize,
354        include_special_values: bool,
355    ) -> Vec<f32> {
356        let mut inputs = Vec::with_capacity(num_points + 20);
357
358        // Regular sampling across range
359        let step = (range_end - range_start) / (num_points as f32);
360        for i in 0..num_points {
361            inputs.push(range_start + i as f32 * step);
362        }
363
364        if include_special_values {
365            // Add special values that often cause accuracy issues
366            let special_values = vec![
367                0.0,
368                -0.0,
369                consts::PI,
370                -consts::PI,
371                consts::PI / 2.0,
372                -consts::PI / 2.0,
373                consts::PI / 4.0,
374                -consts::PI / 4.0,
375                consts::E,
376                -consts::E,
377                1.0,
378                -1.0,
379                2.0,
380                -2.0,
381                10.0,
382                -10.0,
383                0.1,
384                -0.1,
385                0.001,
386                -0.001,
387                1e-6,
388                -1e-6,
389            ];
390
391            for value in special_values {
392                if value >= range_start && value <= range_end {
393                    inputs.push(value);
394                }
395            }
396        }
397
398        inputs.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
399        inputs.dedup();
400        inputs
401    }
402}
403
404/// Utility functions for vector operations
405pub mod utils {
406    use super::*;
407
408    /// Check if two vectors have compatible lengths for binary operations
409    pub fn check_compatible_lengths(a: &[f32], b: &[f32]) -> Result<(), &'static str> {
410        if a.len() != b.len() {
411            Err("Vectors must have the same length")
412        } else {
413            Ok(())
414        }
415    }
416
417    /// Check if input and output vectors have compatible lengths
418    pub fn check_io_lengths(input: &[f32], output: &[f32]) -> Result<(), &'static str> {
419        check_compatible_lengths(input, output)
420    }
421
422    /// Validate that vector is not empty
423    pub fn check_not_empty(vec: &[f32]) -> Result<(), &'static str> {
424        if vec.is_empty() {
425            Err("Vector cannot be empty")
426        } else {
427            Ok(())
428        }
429    }
430
431    /// Get the optimal chunk size for current platform
432    pub fn get_platform_chunk_size() -> usize {
433        detect_platform_info().optimal_chunk_size
434    }
435
436    /// Split vector into SIMD-friendly chunks
437    pub fn chunk_vector(vec: &[f32], chunk_size: usize) -> (&[f32], &[f32]) {
438        let simd_len = (vec.len() / chunk_size) * chunk_size;
439        vec.split_at(simd_len)
440    }
441
442    /// Process vector in chunks with remainder handling
443    pub fn process_chunks<F, R>(
444        vec: &[f32],
445        chunk_size: usize,
446        mut chunk_processor: F,
447        mut remainder_processor: R,
448    ) where
449        F: FnMut(&[f32]),
450        R: FnMut(&[f32]),
451    {
452        let (chunks, remainder) = chunk_vector(vec, chunk_size);
453
454        for chunk in chunks.chunks_exact(chunk_size) {
455            chunk_processor(chunk);
456        }
457
458        if !remainder.is_empty() {
459            remainder_processor(remainder);
460        }
461    }
462
463    /// Convert degrees to radians
464    pub fn degrees_to_radians(degrees: f32) -> f32 {
465        degrees * consts::PI / 180.0
466    }
467
468    /// Convert radians to degrees
469    pub fn radians_to_degrees(radians: f32) -> f32 {
470        radians * 180.0 / consts::PI
471    }
472
473    /// Safe division with zero handling
474    pub fn safe_divide(numerator: f32, denominator: f32) -> f32 {
475        if denominator.abs() < f32::EPSILON {
476            if numerator >= 0.0 {
477                f32::INFINITY
478            } else {
479                f32::NEG_INFINITY
480            }
481        } else {
482            numerator / denominator
483        }
484    }
485
486    /// Clamp value to range [min, max]
487    pub fn clamp(value: f32, min: f32, max: f32) -> f32 {
488        if value < min {
489            min
490        } else if value > max {
491            max
492        } else {
493            value
494        }
495    }
496}
497
498/// Export commonly used constants
499pub mod constants {
500    #[cfg(feature = "no-std")]
501    use core::f32::consts;
502    #[cfg(not(feature = "no-std"))]
503    use std::f32::consts;
504
505    /// Mathematical constants optimized for SIMD operations
506    pub const PI_F32: f32 = consts::PI;
507    pub const E_F32: f32 = consts::E;
508    pub const LN_2_F32: f32 = consts::LN_2;
509    pub const LN_10_F32: f32 = consts::LN_10;
510    pub const SQRT_2_F32: f32 = consts::SQRT_2;
511
512    /// Common SIMD vector sizes
513    pub const SSE2_VECTOR_SIZE: usize = 4; // 128-bit / 32-bit = 4 floats
514    pub const AVX2_VECTOR_SIZE: usize = 8; // 256-bit / 32-bit = 8 floats
515    pub const AVX512_VECTOR_SIZE: usize = 16; // 512-bit / 32-bit = 16 floats
516    pub const NEON_VECTOR_SIZE: usize = 4; // 128-bit / 32-bit = 4 floats
517
518    /// Platform-specific alignment requirements
519    pub const SSE2_ALIGNMENT: usize = 16; // 128-bit alignment
520    pub const AVX2_ALIGNMENT: usize = 32; // 256-bit alignment
521    pub const AVX512_ALIGNMENT: usize = 64; // 512-bit alignment
522    pub const NEON_ALIGNMENT: usize = 16; // 128-bit alignment
523}
524
525#[allow(non_snake_case)]
526#[cfg(all(test, not(feature = "no-std")))]
527mod tests {
528    use super::*;
529
530    #[cfg(feature = "no-std")]
531    use alloc::{vec, vec::Vec};
532
533    #[test]
534    fn test_simd_config() {
535        let config = SimdConfig::default();
536        set_simd_config(config.clone());
537
538        let retrieved_config = get_simd_config();
539        assert_eq!(retrieved_config.simd_threshold, config.simd_threshold);
540        assert_eq!(
541            retrieved_config.enable_scalar_fallback,
542            config.enable_scalar_fallback
543        );
544    }
545
546    #[test]
547    fn test_platform_detection() {
548        let platform_info = detect_platform_info();
549        assert!(platform_info.optimal_chunk_size >= 4);
550        assert!(platform_info.recommended_alignment >= 4);
551
552        // Test that capabilities are detected
553        let caps = platform_info.capabilities;
554        println!("SIMD Capabilities: {:?}", caps);
555    }
556
557    #[test]
558    fn test_aligned_allocation() {
559        let vec = allocate_aligned_vec(16, 16);
560        assert_eq!(vec.len(), 16);
561        assert_eq!(vec[0], 0.0);
562    }
563
564    #[test]
565    fn test_utils() {
566        use utils::*;
567
568        // Test length compatibility
569        let a = vec![1.0, 2.0, 3.0];
570        let b = vec![4.0, 5.0, 6.0];
571        let c = vec![7.0, 8.0];
572
573        assert!(check_compatible_lengths(&a, &b).is_ok());
574        assert!(check_compatible_lengths(&a, &c).is_err());
575
576        // Test empty check
577        let empty_vec: Vec<f32> = vec![];
578        assert!(check_not_empty(&empty_vec).is_err());
579        assert!(check_not_empty(&a).is_ok());
580
581        // Test chunking
582        let vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
583        let (chunks, remainder) = chunk_vector(&vec, 4);
584        assert_eq!(chunks.len(), 8); // 2 complete chunks of 4
585        assert_eq!(remainder.len(), 1); // 1 remainder element
586
587        // Test mathematical utilities
588        assert!((degrees_to_radians(180.0) - constants::PI_F32).abs() < f32::EPSILON);
589        assert!((radians_to_degrees(constants::PI_F32) - 180.0).abs() < f32::EPSILON);
590
591        assert_eq!(safe_divide(10.0, 2.0), 5.0);
592        assert_eq!(safe_divide(10.0, 0.0), f32::INFINITY);
593        assert_eq!(safe_divide(-10.0, 0.0), f32::NEG_INFINITY);
594
595        assert_eq!(clamp(5.0, 1.0, 10.0), 5.0);
596        assert_eq!(clamp(-5.0, 1.0, 10.0), 1.0);
597        assert_eq!(clamp(15.0, 1.0, 10.0), 10.0);
598    }
599
600    #[test]
601    fn test_accuracy_grading() {
602        use accuracy::AccuracyGrade;
603
604        // This would test the accuracy grading system in a real implementation
605        let grade_a = AccuracyGrade::A;
606        let grade_f = AccuracyGrade::F;
607
608        assert!(grade_a != grade_f);
609        assert_eq!(grade_a, AccuracyGrade::A);
610    }
611}
612
613// Integration tests that verify the full SIMD operations work correctly
614// These will be completed when all modules are implemented
615#[allow(non_snake_case)]
616#[cfg(all(test, not(feature = "no-std")))]
617mod integration_tests {
618    use super::*;
619
620    #[cfg(feature = "no-std")]
621    use alloc::{vec, vec::Vec};
622
623    #[test]
624    fn test_basic_workflow() {
625        // Test a complete workflow using multiple SIMD operations
626        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
627        let b = [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
628
629        // This will work once all modules are implemented
630        // let dot = dot_product(&a, &b);
631        // let norm_a = norm_l2(&a);
632        // let mean_a = mean(&a);
633
634        // Placeholder assertions for now
635        assert_eq!(a.len(), b.len());
636        assert_eq!(a.len(), 8);
637    }
638
639    #[test]
640    fn test_platform_optimization_paths() {
641        // Test that different SIMD platforms produce equivalent results
642        let platform_info = detect_platform_info();
643        println!("SIMD capabilities: {:?}", platform_info.capabilities);
644        println!(
645            "Platform name: {}",
646            platform_info.capabilities.platform_name()
647        );
648        println!("Optimal chunk size: {}", platform_info.optimal_chunk_size);
649        println!(
650            "Recommended alignment: {}",
651            platform_info.recommended_alignment
652        );
653
654        // Basic capability testing
655        assert!(platform_info.optimal_chunk_size >= 1);
656        assert!(platform_info.recommended_alignment >= 4);
657    }
658}