Skip to main content

sklears_simd/
safety.rs

1//! Safety and correctness enhancements for SIMD operations
2//!
3//! This module provides safe wrappers, bounds checking, overflow detection,
4//! and special value handling for all SIMD operations.
5
6#[cfg(not(feature = "no-std"))]
7use std::fmt;
8#[cfg(not(feature = "no-std"))]
9use std::string::ToString;
10
11#[cfg(feature = "no-std")]
12use alloc::string::{String, ToString};
13#[cfg(feature = "no-std")]
14use alloc::vec::Vec;
15#[cfg(feature = "no-std")]
16use alloc::{format, vec};
17#[cfg(feature = "no-std")]
18use core::fmt;
19
20/// Enhanced error type for SIMD safety violations
21#[derive(Debug, Clone, PartialEq)]
22pub enum SimdSafetyError {
23    IndexOutOfBounds { index: usize, length: usize },
24    InvalidSliceLength { expected: usize, actual: usize },
25    ArithmeticOverflow { operation: String, values: Vec<f64> },
26    InvalidFloatingPoint { value: f64, reason: String },
27    DivisionByZero,
28    NegativeSquareRoot { value: f64 },
29    InvalidRange { start: f64, end: f64 },
30    InsufficientData { required: usize, available: usize },
31}
32
33impl fmt::Display for SimdSafetyError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            SimdSafetyError::IndexOutOfBounds { index, length } => {
37                write!(f, "Index {} out of bounds for length {}", index, length)
38            }
39            SimdSafetyError::InvalidSliceLength { expected, actual } => {
40                write!(
41                    f,
42                    "Invalid slice length: expected {}, got {}",
43                    expected, actual
44                )
45            }
46            SimdSafetyError::ArithmeticOverflow { operation, values } => {
47                write!(
48                    f,
49                    "Arithmetic overflow in operation '{}' with values: {:?}",
50                    operation, values
51                )
52            }
53            SimdSafetyError::InvalidFloatingPoint { value, reason } => {
54                write!(f, "Invalid floating point value {}: {}", value, reason)
55            }
56            SimdSafetyError::DivisionByZero => {
57                write!(f, "Division by zero")
58            }
59            SimdSafetyError::NegativeSquareRoot { value } => {
60                write!(f, "Square root of negative number: {}", value)
61            }
62            SimdSafetyError::InvalidRange { start, end } => {
63                write!(f, "Invalid range: start {} > end {}", start, end)
64            }
65            SimdSafetyError::InsufficientData {
66                required,
67                available,
68            } => {
69                write!(
70                    f,
71                    "Insufficient data: required {}, available {}",
72                    required, available
73                )
74            }
75        }
76    }
77}
78
79#[cfg(not(feature = "no-std"))]
80impl std::error::Error for SimdSafetyError {}
81
82#[cfg(feature = "no-std")]
83impl core::error::Error for SimdSafetyError {}
84
85pub type SafeSimdResult<T> = Result<T, SimdSafetyError>;
86
87/// Safe SIMD vector operations with comprehensive bounds checking
88pub struct SafeSimdOps;
89
90impl SafeSimdOps {
91    /// Safely validate floating point values
92    pub fn validate_f32(value: f32) -> SafeSimdResult<f32> {
93        if value.is_nan() {
94            Err(SimdSafetyError::InvalidFloatingPoint {
95                value: value as f64,
96                reason: "NaN (Not a Number)".to_string(),
97            })
98        } else if value.is_infinite() {
99            Err(SimdSafetyError::InvalidFloatingPoint {
100                value: value as f64,
101                reason: "Infinity".to_string(),
102            })
103        } else {
104            Ok(value)
105        }
106    }
107
108    /// Safely validate floating point values
109    pub fn validate_f64(value: f64) -> SafeSimdResult<f64> {
110        if value.is_nan() {
111            Err(SimdSafetyError::InvalidFloatingPoint {
112                value,
113                reason: "NaN (Not a Number)".to_string(),
114            })
115        } else if value.is_infinite() {
116            Err(SimdSafetyError::InvalidFloatingPoint {
117                value,
118                reason: "Infinity".to_string(),
119            })
120        } else {
121            Ok(value)
122        }
123    }
124
125    /// Validate an entire slice of f32 values
126    pub fn validate_f32_slice(values: &[f32]) -> SafeSimdResult<()> {
127        for (i, &value) in values.iter().enumerate() {
128            Self::validate_f32(value).map_err(|e| match e {
129                SimdSafetyError::InvalidFloatingPoint { value, reason } => {
130                    SimdSafetyError::InvalidFloatingPoint {
131                        value,
132                        reason: format!("at index {}: {}", i, reason),
133                    }
134                }
135                other => other,
136            })?;
137        }
138        Ok(())
139    }
140
141    /// Validate an entire slice of f64 values
142    pub fn validate_f64_slice(values: &[f64]) -> SafeSimdResult<()> {
143        for (i, &value) in values.iter().enumerate() {
144            Self::validate_f64(value).map_err(|e| match e {
145                SimdSafetyError::InvalidFloatingPoint { value, reason } => {
146                    SimdSafetyError::InvalidFloatingPoint {
147                        value,
148                        reason: format!("at index {}: {}", i, reason),
149                    }
150                }
151                other => other,
152            })?;
153        }
154        Ok(())
155    }
156
157    /// Safe addition with overflow detection
158    pub fn safe_add_f32(a: f32, b: f32) -> SafeSimdResult<f32> {
159        Self::validate_f32(a)?;
160        Self::validate_f32(b)?;
161
162        let result = a + b;
163        Self::validate_f32(result).map_err(|_| SimdSafetyError::ArithmeticOverflow {
164            operation: "addition".to_string(),
165            values: vec![a as f64, b as f64],
166        })
167    }
168
169    /// Safe subtraction with overflow detection
170    pub fn safe_sub_f32(a: f32, b: f32) -> SafeSimdResult<f32> {
171        Self::validate_f32(a)?;
172        Self::validate_f32(b)?;
173
174        let result = a - b;
175        Self::validate_f32(result).map_err(|_| SimdSafetyError::ArithmeticOverflow {
176            operation: "subtraction".to_string(),
177            values: vec![a as f64, b as f64],
178        })
179    }
180
181    /// Safe multiplication with overflow detection
182    pub fn safe_mul_f32(a: f32, b: f32) -> SafeSimdResult<f32> {
183        Self::validate_f32(a)?;
184        Self::validate_f32(b)?;
185
186        let result = a * b;
187        Self::validate_f32(result).map_err(|_| SimdSafetyError::ArithmeticOverflow {
188            operation: "multiplication".to_string(),
189            values: vec![a as f64, b as f64],
190        })
191    }
192
193    /// Safe division with zero and overflow checking
194    pub fn safe_div_f32(a: f32, b: f32) -> SafeSimdResult<f32> {
195        Self::validate_f32(a)?;
196        Self::validate_f32(b)?;
197
198        if b == 0.0 {
199            return Err(SimdSafetyError::DivisionByZero);
200        }
201
202        let result = a / b;
203        Self::validate_f32(result).map_err(|_| SimdSafetyError::ArithmeticOverflow {
204            operation: "division".to_string(),
205            values: vec![a as f64, b as f64],
206        })
207    }
208
209    /// Safe square root with negative number checking
210    pub fn safe_sqrt_f32(value: f32) -> SafeSimdResult<f32> {
211        Self::validate_f32(value)?;
212
213        if value < 0.0 {
214            return Err(SimdSafetyError::NegativeSquareRoot {
215                value: value as f64,
216            });
217        }
218
219        let result = value.sqrt();
220        Self::validate_f32(result)
221    }
222
223    /// Safe logarithm with domain checking
224    pub fn safe_ln_f32(value: f32) -> SafeSimdResult<f32> {
225        Self::validate_f32(value)?;
226
227        if value <= 0.0 {
228            return Err(SimdSafetyError::InvalidRange {
229                start: value as f64,
230                end: f64::INFINITY,
231            });
232        }
233
234        let result = value.ln();
235        Self::validate_f32(result)
236    }
237
238    /// Safe exponential with overflow checking
239    pub fn safe_exp_f32(value: f32) -> SafeSimdResult<f32> {
240        Self::validate_f32(value)?;
241
242        // Check for potential overflow before computing
243        if value > 88.0 {
244            // exp(88) ≈ 1.6e38, close to f32::MAX
245            return Err(SimdSafetyError::ArithmeticOverflow {
246                operation: "exponential".to_string(),
247                values: vec![value as f64],
248            });
249        }
250
251        let result = value.exp();
252        Self::validate_f32(result)
253    }
254
255    /// Safe vector dot product with bounds checking
256    pub fn safe_dot_product_f32(a: &[f32], b: &[f32]) -> SafeSimdResult<f32> {
257        if a.len() != b.len() {
258            return Err(SimdSafetyError::InvalidSliceLength {
259                expected: a.len(),
260                actual: b.len(),
261            });
262        }
263
264        Self::validate_f32_slice(a)?;
265        Self::validate_f32_slice(b)?;
266
267        let mut result = 0.0f32;
268        for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
269            let product = Self::safe_mul_f32(x, y).map_err(|e| match e {
270                SimdSafetyError::ArithmeticOverflow { operation, values } => {
271                    SimdSafetyError::ArithmeticOverflow {
272                        operation: format!("{} at index {}", operation, i),
273                        values,
274                    }
275                }
276                other => other,
277            })?;
278
279            result = Self::safe_add_f32(result, product).map_err(|e| match e {
280                SimdSafetyError::ArithmeticOverflow { operation, values } => {
281                    SimdSafetyError::ArithmeticOverflow {
282                        operation: format!(
283                            "{} in dot product accumulation at index {}",
284                            operation, i
285                        ),
286                        values,
287                    }
288                }
289                other => other,
290            })?;
291        }
292
293        Ok(result)
294    }
295
296    /// Safe vector normalization
297    pub fn safe_normalize_f32(vector: &[f32]) -> SafeSimdResult<Vec<f32>> {
298        if vector.is_empty() {
299            return Err(SimdSafetyError::InsufficientData {
300                required: 1,
301                available: 0,
302            });
303        }
304
305        Self::validate_f32_slice(vector)?;
306
307        let dot_product = Self::safe_dot_product_f32(vector, vector)?;
308        let norm = Self::safe_sqrt_f32(dot_product)?;
309
310        if norm == 0.0 {
311            return Err(SimdSafetyError::DivisionByZero);
312        }
313
314        let mut normalized = Vec::with_capacity(vector.len());
315        for &value in vector {
316            let normalized_value = Self::safe_div_f32(value, norm)?;
317            normalized.push(normalized_value);
318        }
319
320        Ok(normalized)
321    }
322
323    /// Safe array indexing with bounds checking
324    pub fn safe_get<T>(slice: &[T], index: usize) -> SafeSimdResult<&T> {
325        if index >= slice.len() {
326            Err(SimdSafetyError::IndexOutOfBounds {
327                index,
328                length: slice.len(),
329            })
330        } else {
331            Ok(&slice[index])
332        }
333    }
334
335    /// Safe mutable array indexing with bounds checking
336    pub fn safe_get_mut<T>(slice: &mut [T], index: usize) -> SafeSimdResult<&mut T> {
337        let length = slice.len();
338        if index >= length {
339            Err(SimdSafetyError::IndexOutOfBounds { index, length })
340        } else {
341            Ok(&mut slice[index])
342        }
343    }
344
345    /// Safe slice creation with bounds checking
346    pub fn safe_slice<T>(slice: &[T], start: usize, end: usize) -> SafeSimdResult<&[T]> {
347        if start > end {
348            return Err(SimdSafetyError::InvalidRange {
349                start: start as f64,
350                end: end as f64,
351            });
352        }
353
354        if end > slice.len() {
355            return Err(SimdSafetyError::IndexOutOfBounds {
356                index: end,
357                length: slice.len(),
358            });
359        }
360
361        Ok(&slice[start..end])
362    }
363
364    /// Check if all values in slice are finite (not NaN or infinite)
365    pub fn all_finite_f32(values: &[f32]) -> bool {
366        values.iter().all(|&x| x.is_finite())
367    }
368
369    /// Check if all values in slice are finite (not NaN or infinite)  
370    pub fn all_finite_f64(values: &[f64]) -> bool {
371        values.iter().all(|&x| x.is_finite())
372    }
373
374    /// Replace NaN and infinite values with safe alternatives
375    pub fn sanitize_f32_slice(values: &mut [f32], nan_replacement: f32, inf_replacement: f32) {
376        for value in values.iter_mut() {
377            if value.is_nan() {
378                *value = nan_replacement;
379            } else if value.is_infinite() {
380                *value = if value.is_sign_positive() {
381                    inf_replacement
382                } else {
383                    -inf_replacement
384                };
385            }
386        }
387    }
388
389    /// Replace NaN and infinite values with safe alternatives
390    pub fn sanitize_f64_slice(values: &mut [f64], nan_replacement: f64, inf_replacement: f64) {
391        for value in values.iter_mut() {
392            if value.is_nan() {
393                *value = nan_replacement;
394            } else if value.is_infinite() {
395                *value = if value.is_sign_positive() {
396                    inf_replacement
397                } else {
398                    -inf_replacement
399                };
400            }
401        }
402    }
403}
404
405/// Debug mode bounds checking wrapper
406#[derive(Debug, Clone)]
407pub struct DebugBoundsChecker<T> {
408    data: Vec<T>,
409    #[allow(dead_code)] // Identifies the buffer in panic/debug messages
410    name: String,
411}
412
413impl<T> DebugBoundsChecker<T> {
414    pub fn new(data: Vec<T>, name: String) -> Self {
415        Self { data, name }
416    }
417
418    #[cfg(debug_assertions)]
419    pub fn get(&self, index: usize) -> SafeSimdResult<&T> {
420        if index >= self.data.len() {
421            Err(SimdSafetyError::IndexOutOfBounds {
422                index,
423                length: self.data.len(),
424            })
425        } else {
426            Ok(&self.data[index])
427        }
428    }
429
430    #[cfg(not(debug_assertions))]
431    pub fn get(&self, index: usize) -> SafeSimdResult<&T> {
432        Ok(unsafe { self.data.get_unchecked(index) })
433    }
434
435    #[cfg(debug_assertions)]
436    pub fn get_mut(&mut self, index: usize) -> SafeSimdResult<&mut T> {
437        let length = self.data.len();
438        if index >= length {
439            Err(SimdSafetyError::IndexOutOfBounds { index, length })
440        } else {
441            Ok(&mut self.data[index])
442        }
443    }
444
445    #[cfg(not(debug_assertions))]
446    pub fn get_mut(&mut self, index: usize) -> SafeSimdResult<&mut T> {
447        Ok(unsafe { self.data.get_unchecked_mut(index) })
448    }
449
450    pub fn len(&self) -> usize {
451        self.data.len()
452    }
453
454    pub fn is_empty(&self) -> bool {
455        self.data.is_empty()
456    }
457
458    pub fn as_slice(&self) -> &[T] {
459        &self.data
460    }
461
462    pub fn as_mut_slice(&mut self) -> &mut [T] {
463        &mut self.data
464    }
465}
466
467/// Memory safety guarantees for SIMD operations
468pub struct MemorySafetyGuard;
469
470impl MemorySafetyGuard {
471    /// Ensure proper alignment for SIMD operations
472    pub fn check_alignment(ptr: *const u8, alignment: usize) -> bool {
473        (ptr as usize).is_multiple_of(alignment)
474    }
475
476    /// Create aligned vector for SIMD operations
477    pub fn create_aligned_vec<T>(size: usize, alignment: usize) -> Vec<T>
478    where
479        T: Default + Clone,
480    {
481        let mut vec = Vec::with_capacity(size + alignment);
482        vec.resize(size, T::default());
483
484        // Ensure alignment (simplified approach)
485        while !(vec.as_ptr() as usize).is_multiple_of(alignment) {
486            vec.reserve(1);
487        }
488
489        vec
490    }
491
492    /// Validate memory range for SIMD operations
493    pub fn validate_memory_range(ptr: *const u8, size: usize) -> SafeSimdResult<()> {
494        if ptr.is_null() {
495            return Err(SimdSafetyError::InvalidRange {
496                start: 0.0,
497                end: 0.0,
498            });
499        }
500
501        if size == 0 {
502            return Err(SimdSafetyError::InsufficientData {
503                required: 1,
504                available: 0,
505            });
506        }
507
508        Ok(())
509    }
510}
511
512#[allow(non_snake_case)]
513#[cfg(all(test, not(feature = "no-std")))]
514mod tests {
515    use super::*;
516    use core::ptr;
517
518    #[cfg(feature = "no-std")]
519    use alloc::{vec, vec::Vec};
520
521    #[test]
522    fn test_validate_f32() {
523        assert!(SafeSimdOps::validate_f32(1.0).is_ok());
524        assert!(SafeSimdOps::validate_f32(-1.0).is_ok());
525        assert!(SafeSimdOps::validate_f32(0.0).is_ok());
526
527        assert!(SafeSimdOps::validate_f32(f32::NAN).is_err());
528        assert!(SafeSimdOps::validate_f32(f32::INFINITY).is_err());
529        assert!(SafeSimdOps::validate_f32(f32::NEG_INFINITY).is_err());
530    }
531
532    #[test]
533    fn test_safe_arithmetic() {
534        assert_eq!(
535            SafeSimdOps::safe_add_f32(2.0, 3.0).expect("operation should succeed"),
536            5.0
537        );
538        assert_eq!(
539            SafeSimdOps::safe_sub_f32(5.0, 3.0).expect("operation should succeed"),
540            2.0
541        );
542        assert_eq!(
543            SafeSimdOps::safe_mul_f32(3.0, 4.0).expect("operation should succeed"),
544            12.0
545        );
546        assert_eq!(
547            SafeSimdOps::safe_div_f32(12.0, 4.0).expect("operation should succeed"),
548            3.0
549        );
550
551        assert!(SafeSimdOps::safe_div_f32(1.0, 0.0).is_err());
552        assert!(SafeSimdOps::safe_sqrt_f32(-1.0).is_err());
553        assert!(SafeSimdOps::safe_ln_f32(-1.0).is_err());
554        assert!(SafeSimdOps::safe_ln_f32(0.0).is_err());
555    }
556
557    #[test]
558    fn test_safe_dot_product() {
559        let a = vec![1.0, 2.0, 3.0];
560        let b = vec![4.0, 5.0, 6.0];
561
562        let result = SafeSimdOps::safe_dot_product_f32(&a, &b).expect("operation should succeed");
563        assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
564
565        let c = vec![1.0, 2.0];
566        assert!(SafeSimdOps::safe_dot_product_f32(&a, &c).is_err());
567    }
568
569    #[test]
570    fn test_safe_normalize() {
571        let vector = vec![3.0, 4.0];
572        let normalized =
573            SafeSimdOps::safe_normalize_f32(&vector).expect("operation should succeed");
574
575        assert!((normalized[0] - 0.6).abs() < 1e-6);
576        assert!((normalized[1] - 0.8).abs() < 1e-6);
577
578        let zero_vector = vec![0.0, 0.0];
579        assert!(SafeSimdOps::safe_normalize_f32(&zero_vector).is_err());
580
581        let empty_vector: Vec<f32> = vec![];
582        assert!(SafeSimdOps::safe_normalize_f32(&empty_vector).is_err());
583    }
584
585    #[test]
586    fn test_safe_indexing() {
587        let data = vec![1, 2, 3, 4, 5];
588
589        assert_eq!(
590            *SafeSimdOps::safe_get(&data, 2).expect("operation should succeed"),
591            3
592        );
593        assert!(SafeSimdOps::safe_get(&data, 10).is_err());
594
595        let slice = SafeSimdOps::safe_slice(&data, 1, 4).expect("slice operation should succeed");
596        assert_eq!(slice, &[2, 3, 4]);
597
598        assert!(SafeSimdOps::safe_slice(&data, 4, 1).is_err());
599        assert!(SafeSimdOps::safe_slice(&data, 0, 10).is_err());
600    }
601
602    #[test]
603    fn test_finite_checks() {
604        let finite_values = vec![1.0, 2.0, 3.0];
605        assert!(SafeSimdOps::all_finite_f32(&finite_values));
606
607        let mixed_values = vec![1.0, f32::NAN, 3.0];
608        assert!(!SafeSimdOps::all_finite_f32(&mixed_values));
609
610        let inf_values = vec![1.0, f32::INFINITY, 3.0];
611        assert!(!SafeSimdOps::all_finite_f32(&inf_values));
612    }
613
614    #[test]
615    fn test_sanitize_values() {
616        let mut values = vec![1.0, f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 5.0];
617        SafeSimdOps::sanitize_f32_slice(&mut values, 0.0, 1000.0);
618
619        assert_eq!(values[0], 1.0);
620        assert_eq!(values[1], 0.0); // NaN replaced with 0.0
621        assert_eq!(values[2], 1000.0); // +Inf replaced with 1000.0
622        assert_eq!(values[3], -1000.0); // -Inf replaced with -1000.0
623        assert_eq!(values[4], 5.0);
624
625        assert!(SafeSimdOps::all_finite_f32(&values));
626    }
627
628    #[test]
629    fn test_debug_bounds_checker() {
630        let data = vec![1, 2, 3, 4, 5];
631        let checker = DebugBoundsChecker::new(data, "test".to_string());
632
633        assert_eq!(*checker.get(2).expect("index should be valid"), 3);
634        assert!(checker.get(10).is_err());
635        assert_eq!(checker.len(), 5);
636        assert!(!checker.is_empty());
637    }
638
639    #[test]
640    fn test_memory_safety() {
641        let data = [1u8, 2, 3, 4];
642        let ptr = data.as_ptr();
643
644        assert!(MemorySafetyGuard::validate_memory_range(ptr, data.len()).is_ok());
645        assert!(MemorySafetyGuard::validate_memory_range(ptr::null(), 0).is_err());
646
647        let aligned_vec: Vec<f32> = MemorySafetyGuard::create_aligned_vec(10, 16);
648        assert_eq!(aligned_vec.len(), 10);
649    }
650
651    #[test]
652    fn test_arithmetic_overflow_detection() {
653        // Test with values that would cause overflow
654        let large_val = f32::MAX / 2.0;
655        assert!(SafeSimdOps::safe_mul_f32(large_val, 3.0).is_err());
656
657        // Test exponential overflow
658        assert!(SafeSimdOps::safe_exp_f32(100.0).is_err());
659
660        // Test valid operations
661        assert!(SafeSimdOps::safe_mul_f32(2.0, 3.0).is_ok());
662        assert!(SafeSimdOps::safe_exp_f32(1.0).is_ok());
663    }
664
665    #[test]
666    fn test_error_display() {
667        let error = SimdSafetyError::IndexOutOfBounds {
668            index: 5,
669            length: 3,
670        };
671        let display_str = format!("{}", error);
672        assert!(display_str.contains("Index 5 out of bounds for length 3"));
673
674        let div_error = SimdSafetyError::DivisionByZero;
675        assert_eq!(format!("{}", div_error), "Division by zero");
676    }
677}