amari_gpu/
verification.rs

1//! GPU Verification Framework for Phase 4B
2//!
3//! This module implements boundary verification strategies for GPU-accelerated
4//! geometric algebra operations, addressing the challenge that phantom types
5//! cannot cross GPU memory boundaries while maintaining mathematical correctness.
6
7use crate::{GpuCliffordAlgebra, GpuError};
8use amari_core::Multivector;
9use amari_info_geom::Parameter;
10use std::collections::HashMap;
11use std::marker::PhantomData;
12use std::time::{Duration, Instant};
13use thiserror::Error;
14
15#[derive(Error, Debug)]
16pub enum GpuVerificationError {
17    #[error("Verification failed: {0}")]
18    VerificationFailed(String),
19
20    #[error("Signature mismatch: expected {expected:?}, got {actual:?}")]
21    SignatureMismatch {
22        expected: (usize, usize, usize),
23        actual: (usize, usize, usize),
24    },
25
26    #[error("Statistical verification failed: {failed}/{total} samples failed")]
27    StatisticalMismatch { failed: usize, total: usize },
28
29    #[error("Mathematical invariant violated: {invariant}")]
30    InvariantViolation { invariant: String },
31
32    #[error("GPU operation failed: {0}")]
33    GpuOperation(#[from] GpuError),
34
35    #[error("Performance budget exceeded: {actual:?} > {budget:?}")]
36    PerformanceBudgetExceeded { actual: Duration, budget: Duration },
37}
38
39/// Verification strategy for GPU operations
40#[derive(Debug, Clone)]
41pub enum VerificationStrategy {
42    /// Full verification of all elements (expensive)
43    Strict,
44    /// Statistical sampling verification (balanced)
45    Statistical { sample_rate: f64 },
46    /// Boundary verification only (fast)
47    Boundary,
48    /// Minimal verification (fastest)
49    Minimal,
50}
51
52/// Platform-aware verification configuration
53#[derive(Debug, Clone)]
54pub struct VerificationConfig {
55    pub strategy: VerificationStrategy,
56    pub performance_budget: Duration,
57    pub tolerance: f64,
58    pub enable_invariant_checking: bool,
59}
60
61impl Default for VerificationConfig {
62    fn default() -> Self {
63        Self {
64            strategy: VerificationStrategy::Statistical { sample_rate: 0.1 },
65            performance_budget: Duration::from_millis(10),
66            tolerance: 1e-12,
67            enable_invariant_checking: true,
68        }
69    }
70}
71
72/// Verified multivector with signature information preserved
73#[derive(Debug, Clone)]
74pub struct VerifiedMultivector<const P: usize, const Q: usize, const R: usize> {
75    pub inner: Multivector<P, Q, R>,
76    verification_hash: u64,
77    _phantom: PhantomData<(SignatureP<P>, SignatureQ<Q>, SignatureR<R>)>,
78}
79
80impl<const P: usize, const Q: usize, const R: usize> VerifiedMultivector<P, Q, R> {
81    /// Create verified multivector with compile-time signature checking
82    pub fn new(inner: Multivector<P, Q, R>) -> Self {
83        let verification_hash = Self::compute_verification_hash(&inner);
84        Self {
85            inner,
86            verification_hash,
87            _phantom: PhantomData,
88        }
89    }
90
91    /// Extract inner multivector for GPU operations (loses verification)
92    pub fn into_inner(self) -> Multivector<P, Q, R> {
93        self.inner
94    }
95
96    /// Get reference to inner multivector
97    pub fn inner(&self) -> &Multivector<P, Q, R> {
98        &self.inner
99    }
100
101    /// Verify mathematical invariants
102    pub fn verify_invariants(&self) -> Result<(), GpuVerificationError> {
103        // Check magnitude invariant
104        let magnitude = self.inner.magnitude();
105        if !magnitude.is_finite() {
106            return Err(GpuVerificationError::InvariantViolation {
107                invariant: "Magnitude must be finite".to_string(),
108            });
109        }
110
111        // Verify signature consistency
112        if !self.verify_signature_constraints() {
113            return Err(GpuVerificationError::InvariantViolation {
114                invariant: "Signature constraints violated".to_string(),
115            });
116        }
117
118        Ok(())
119    }
120
121    /// Compute verification hash for integrity checking
122    fn compute_verification_hash(mv: &Multivector<P, Q, R>) -> u64 {
123        use std::collections::hash_map::DefaultHasher;
124        use std::hash::{Hash, Hasher};
125
126        let mut hasher = DefaultHasher::new();
127
128        // Hash signature
129        (P, Q, R).hash(&mut hasher);
130
131        // Hash coefficients (with tolerance for floating point)
132        for i in 0..mv.dimension() {
133            let coeff = mv.get(i);
134            let normalized = (coeff * 1e12).round() as i64; // 12 decimal places
135            normalized.hash(&mut hasher);
136        }
137
138        hasher.finish()
139    }
140
141    /// Verify signature constraints are satisfied
142    fn verify_signature_constraints(&self) -> bool {
143        // Check that P + Q + R matches dimension
144        let expected_dim = 1 << (P + Q + R);
145        self.inner.dimension() == expected_dim
146    }
147
148    /// Get signature tuple
149    pub fn signature() -> (usize, usize, usize) {
150        (P, Q, R)
151    }
152}
153
154/// Phantom types for compile-time signature verification
155struct SignatureP<const P: usize>;
156struct SignatureQ<const Q: usize>;
157struct SignatureR<const R: usize>;
158
159/// GPU boundary verification system
160pub struct GpuBoundaryVerifier {
161    config: VerificationConfig,
162    performance_stats: PerformanceStats,
163}
164
165impl GpuBoundaryVerifier {
166    /// Create new boundary verifier with configuration
167    pub fn new(config: VerificationConfig) -> Self {
168        Self {
169            config,
170            performance_stats: PerformanceStats::new(),
171        }
172    }
173
174    /// Verify batch geometric product with boundary checking
175    pub async fn verified_batch_geometric_product<
176        const P: usize,
177        const Q: usize,
178        const R: usize,
179    >(
180        &mut self,
181        gpu: &GpuCliffordAlgebra,
182        a_batch: &[VerifiedMultivector<P, Q, R>],
183        b_batch: &[VerifiedMultivector<P, Q, R>],
184    ) -> Result<Vec<VerifiedMultivector<P, Q, R>>, GpuVerificationError> {
185        let start_time = Instant::now();
186
187        // 1. Pre-GPU verification phase
188        self.verify_input_batch_invariants(a_batch, b_batch)?;
189
190        // 2. Extract raw data for GPU (loses phantom types temporarily)
191        let raw_a = self.extract_raw_coefficients(a_batch);
192        let raw_b = self.extract_raw_coefficients(b_batch);
193
194        // 3. GPU computation (unverified internally)
195        let raw_result = gpu
196            .batch_geometric_product(&raw_a, &raw_b)
197            .await
198            .map_err(GpuVerificationError::GpuOperation)?;
199
200        // 4. Post-GPU verification and phantom type restoration
201        let verified_result =
202            self.verify_and_restore_types::<P, Q, R>(&raw_result, a_batch, b_batch)?;
203
204        // 5. Performance tracking
205        let elapsed = start_time.elapsed();
206        self.performance_stats
207            .record_operation(elapsed, a_batch.len());
208
209        if elapsed > self.config.performance_budget {
210            return Err(GpuVerificationError::PerformanceBudgetExceeded {
211                actual: elapsed,
212                budget: self.config.performance_budget,
213            });
214        }
215
216        Ok(verified_result)
217    }
218
219    /// Verify input batch mathematical invariants
220    fn verify_input_batch_invariants<const P: usize, const Q: usize, const R: usize>(
221        &self,
222        a_batch: &[VerifiedMultivector<P, Q, R>],
223        b_batch: &[VerifiedMultivector<P, Q, R>],
224    ) -> Result<(), GpuVerificationError> {
225        if a_batch.len() != b_batch.len() {
226            return Err(GpuVerificationError::VerificationFailed(
227                "Batch sizes must match".to_string(),
228            ));
229        }
230
231        if !self.config.enable_invariant_checking {
232            return Ok(());
233        }
234
235        // Verify invariants based on strategy
236        match &self.config.strategy {
237            VerificationStrategy::Strict => {
238                // Verify all elements
239                for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
240                    a.verify_invariants().map_err(|e| {
241                        GpuVerificationError::VerificationFailed(format!("Input A[{}]: {}", i, e))
242                    })?;
243                    b.verify_invariants().map_err(|e| {
244                        GpuVerificationError::VerificationFailed(format!("Input B[{}]: {}", i, e))
245                    })?;
246                }
247            }
248            VerificationStrategy::Statistical { sample_rate } => {
249                // Verify random sample
250                let sample_size = ((a_batch.len() as f64) * sample_rate).ceil() as usize;
251                let indices = self.select_random_indices(a_batch.len(), sample_size);
252
253                for &idx in &indices {
254                    a_batch[idx].verify_invariants()?;
255                    b_batch[idx].verify_invariants()?;
256                }
257            }
258            VerificationStrategy::Boundary | VerificationStrategy::Minimal => {
259                // Only verify first and last elements
260                if !a_batch.is_empty() {
261                    a_batch[0].verify_invariants()?;
262                    b_batch[0].verify_invariants()?;
263
264                    if a_batch.len() > 1 {
265                        let last = a_batch.len() - 1;
266                        a_batch[last].verify_invariants()?;
267                        b_batch[last].verify_invariants()?;
268                    }
269                }
270            }
271        }
272
273        Ok(())
274    }
275
276    /// Extract raw coefficients for GPU computation
277    fn extract_raw_coefficients<const P: usize, const Q: usize, const R: usize>(
278        &self,
279        batch: &[VerifiedMultivector<P, Q, R>],
280    ) -> Vec<f64> {
281        let basis_count = 1 << (P + Q + R);
282        let mut raw_data = Vec::with_capacity(batch.len() * basis_count);
283
284        for mv in batch {
285            for i in 0..basis_count {
286                raw_data.push(mv.inner.get(i));
287            }
288        }
289
290        raw_data
291    }
292
293    /// Verify GPU results and restore phantom types
294    fn verify_and_restore_types<const P: usize, const Q: usize, const R: usize>(
295        &self,
296        raw_result: &[f64],
297        a_batch: &[VerifiedMultivector<P, Q, R>],
298        b_batch: &[VerifiedMultivector<P, Q, R>],
299    ) -> Result<Vec<VerifiedMultivector<P, Q, R>>, GpuVerificationError> {
300        let basis_count = 1 << (P + Q + R);
301        let batch_size = raw_result.len() / basis_count;
302
303        if batch_size != a_batch.len() {
304            return Err(GpuVerificationError::VerificationFailed(
305                "Result batch size mismatch".to_string(),
306            ));
307        }
308
309        let mut verified_results = Vec::with_capacity(batch_size);
310
311        for i in 0..batch_size {
312            let start_idx = i * basis_count;
313            let end_idx = start_idx + basis_count;
314
315            let coefficients = raw_result[start_idx..end_idx].to_vec();
316            let result_mv = Multivector::<P, Q, R>::from_coefficients(coefficients);
317
318            // Verify result based on strategy
319            match &self.config.strategy {
320                VerificationStrategy::Strict => {
321                    // Full verification: check against CPU computation
322                    let expected = a_batch[i].inner.geometric_product(&b_batch[i].inner);
323                    self.verify_approximately_equal(&result_mv, &expected, i)?;
324                }
325                VerificationStrategy::Statistical { sample_rate } => {
326                    // Statistical verification: check random samples
327                    if self.should_verify_sample(i, *sample_rate) {
328                        let expected = a_batch[i].inner.geometric_product(&b_batch[i].inner);
329                        self.verify_approximately_equal(&result_mv, &expected, i)?;
330                    }
331                }
332                VerificationStrategy::Boundary => {
333                    // Boundary verification: check first and last
334                    if i == 0 || i == batch_size - 1 {
335                        let expected = a_batch[i].inner.geometric_product(&b_batch[i].inner);
336                        self.verify_approximately_equal(&result_mv, &expected, i)?;
337                    }
338                }
339                VerificationStrategy::Minimal => {
340                    // Minimal verification: basic sanity checks only
341                    if !result_mv.magnitude().is_finite() {
342                        return Err(GpuVerificationError::InvariantViolation {
343                            invariant: format!("Result[{}] magnitude is not finite", i),
344                        });
345                    }
346                }
347            }
348
349            verified_results.push(VerifiedMultivector::new(result_mv));
350        }
351
352        Ok(verified_results)
353    }
354
355    /// Verify two multivectors are approximately equal
356    fn verify_approximately_equal<const P: usize, const Q: usize, const R: usize>(
357        &self,
358        actual: &Multivector<P, Q, R>,
359        expected: &Multivector<P, Q, R>,
360        index: usize,
361    ) -> Result<(), GpuVerificationError> {
362        let basis_count = 1 << (P + Q + R);
363
364        for i in 0..basis_count {
365            let diff = (actual.get(i) - expected.get(i)).abs();
366            let rel_error = if expected.get(i).abs() > self.config.tolerance {
367                diff / expected.get(i).abs()
368            } else {
369                diff
370            };
371
372            if rel_error > self.config.tolerance {
373                return Err(GpuVerificationError::VerificationFailed(
374                    format!(
375                        "Verification failed at result[{}], component[{}]: expected {}, got {}, error {}",
376                        index, i, expected.get(i), actual.get(i), rel_error
377                    )
378                ));
379            }
380        }
381
382        Ok(())
383    }
384
385    /// Select random indices for statistical sampling
386    fn select_random_indices(&self, total: usize, sample_size: usize) -> Vec<usize> {
387        use std::collections::HashSet;
388
389        let mut indices = HashSet::new();
390        let sample_size = sample_size.min(total);
391
392        // Simple deterministic "random" selection for reproducibility
393        let step = total / sample_size.max(1);
394        for i in 0..sample_size {
395            indices.insert((i * step) % total);
396        }
397
398        // Ensure we always include first and last
399        if total > 0 {
400            indices.insert(0);
401            if total > 1 {
402                indices.insert(total - 1);
403            }
404        }
405
406        indices.into_iter().collect()
407    }
408
409    /// Determine if a sample should be verified
410    fn should_verify_sample(&self, index: usize, sample_rate: f64) -> bool {
411        // Simple deterministic sampling based on index
412        let hash = index.wrapping_mul(2654435761); // Large prime
413        let normalized = (hash as f64) / (u32::MAX as f64);
414        normalized < sample_rate
415    }
416
417    /// Get performance statistics
418    pub fn performance_stats(&self) -> &PerformanceStats {
419        &self.performance_stats
420    }
421}
422
423/// Performance tracking for verification operations
424#[derive(Debug, Clone)]
425pub struct PerformanceStats {
426    operation_count: usize,
427    total_duration: Duration,
428    total_elements: usize,
429    max_duration: Duration,
430}
431
432impl PerformanceStats {
433    fn new() -> Self {
434        Self {
435            operation_count: 0,
436            total_duration: Duration::ZERO,
437            total_elements: 0,
438            max_duration: Duration::ZERO,
439        }
440    }
441
442    fn record_operation(&mut self, duration: Duration, element_count: usize) {
443        self.operation_count += 1;
444        self.total_duration += duration;
445        self.total_elements += element_count;
446        if duration > self.max_duration {
447            self.max_duration = duration;
448        }
449    }
450
451    /// Get average operation duration
452    pub fn average_duration(&self) -> Duration {
453        if self.operation_count > 0 {
454            self.total_duration / (self.operation_count as u32)
455        } else {
456            Duration::ZERO
457        }
458    }
459
460    /// Get average throughput (elements per second)
461    pub fn average_throughput(&self) -> f64 {
462        if self.total_duration.as_secs_f64() > 0.0 {
463            self.total_elements as f64 / self.total_duration.as_secs_f64()
464        } else {
465            0.0
466        }
467    }
468
469    /// Get verification overhead as percentage
470    pub fn verification_overhead_percent(&self, baseline_duration: Duration) -> f64 {
471        if baseline_duration.as_secs_f64() > 0.0 {
472            let overhead = self.average_duration().as_secs_f64() / baseline_duration.as_secs_f64();
473            (overhead - 1.0) * 100.0
474        } else {
475            0.0
476        }
477    }
478
479    /// Get operation count
480    pub fn operation_count(&self) -> usize {
481        self.operation_count
482    }
483
484    /// Get total elements processed
485    pub fn total_elements(&self) -> usize {
486        self.total_elements
487    }
488
489    /// Get maximum operation duration
490    pub fn max_duration(&self) -> Duration {
491        self.max_duration
492    }
493}
494
495/// Statistical verification for large GPU batches
496pub struct StatisticalGpuVerifier<const P: usize, const Q: usize, const R: usize> {
497    sample_rate: f64,
498    tolerance: f64,
499    verification_cache: HashMap<u64, bool>,
500}
501
502impl<const P: usize, const Q: usize, const R: usize> StatisticalGpuVerifier<P, Q, R> {
503    /// Create new statistical verifier
504    pub fn new(sample_rate: f64, tolerance: f64) -> Self {
505        Self {
506            sample_rate,
507            tolerance,
508            verification_cache: HashMap::new(),
509        }
510    }
511
512    /// Verify batch result through statistical sampling
513    pub async fn verify_batch_statistical(
514        &mut self,
515        _gpu: &GpuCliffordAlgebra,
516        inputs: &[(VerifiedMultivector<P, Q, R>, VerifiedMultivector<P, Q, R>)],
517        gpu_results: &[Multivector<P, Q, R>],
518    ) -> Result<Vec<VerifiedMultivector<P, Q, R>>, GpuVerificationError> {
519        if inputs.len() != gpu_results.len() {
520            return Err(GpuVerificationError::VerificationFailed(
521                "Input and result batch sizes must match".to_string(),
522            ));
523        }
524
525        let sample_size = (inputs.len() as f64 * self.sample_rate).ceil() as usize;
526        let indices = self.select_random_indices(inputs.len(), sample_size);
527
528        let mut failed_samples = 0;
529
530        for &idx in &indices {
531            let (a, b) = &inputs[idx];
532            let expected = a.inner.geometric_product(&b.inner);
533            let actual = &gpu_results[idx];
534
535            if !self.approximately_equal(&expected, actual) {
536                failed_samples += 1;
537
538                // Cache failed verification
539                let hash = self.compute_input_hash(a, b);
540                self.verification_cache.insert(hash, false);
541            }
542        }
543
544        // Allow small number of failures for statistical verification
545        let failure_rate = failed_samples as f64 / indices.len() as f64;
546        let max_failure_rate = 0.01; // 1% maximum failure rate
547
548        if failure_rate > max_failure_rate {
549            return Err(GpuVerificationError::StatisticalMismatch {
550                failed: failed_samples,
551                total: indices.len(),
552            });
553        }
554
555        // If samples pass, assume entire batch is correct with verification restoration
556        let verified_results = gpu_results
557            .iter()
558            .map(|mv| VerifiedMultivector::new(mv.clone()))
559            .collect();
560
561        Ok(verified_results)
562    }
563
564    /// Check if two multivectors are approximately equal
565    fn approximately_equal(&self, a: &Multivector<P, Q, R>, b: &Multivector<P, Q, R>) -> bool {
566        let basis_count = 1 << (P + Q + R);
567
568        for i in 0..basis_count {
569            let diff = (a.get(i) - b.get(i)).abs();
570            let rel_error = if b.get(i).abs() > self.tolerance {
571                diff / b.get(i).abs()
572            } else {
573                diff
574            };
575
576            if rel_error > self.tolerance {
577                return false;
578            }
579        }
580
581        true
582    }
583
584    /// Select random indices for sampling
585    fn select_random_indices(&self, total: usize, sample_size: usize) -> Vec<usize> {
586        let mut indices = Vec::new();
587        let sample_size = sample_size.min(total);
588
589        if total == 0 {
590            return indices;
591        }
592
593        // Always include first and last
594        indices.push(0);
595        if total > 1 {
596            indices.push(total - 1);
597        }
598
599        // Add random intermediate indices
600        let step = if sample_size > 2 {
601            total / (sample_size - 2).max(1)
602        } else {
603            total
604        };
605
606        for i in 1..sample_size.saturating_sub(1) {
607            let idx = (i * step) % total;
608            if !indices.contains(&idx) {
609                indices.push(idx);
610            }
611        }
612
613        indices.sort_unstable();
614        indices
615    }
616
617    /// Compute hash for input pair caching
618    fn compute_input_hash(
619        &self,
620        a: &VerifiedMultivector<P, Q, R>,
621        b: &VerifiedMultivector<P, Q, R>,
622    ) -> u64 {
623        use std::collections::hash_map::DefaultHasher;
624        use std::hash::{Hash, Hasher};
625
626        let mut hasher = DefaultHasher::new();
627        a.verification_hash.hash(&mut hasher);
628        b.verification_hash.hash(&mut hasher);
629        hasher.finish()
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn test_verified_multivector_creation() {
639        let mv = Multivector::<3, 0, 0>::zero();
640        let verified = VerifiedMultivector::new(mv);
641
642        assert_eq!(VerifiedMultivector::<3, 0, 0>::signature(), (3, 0, 0));
643        assert!(verified.verify_invariants().is_ok());
644    }
645
646    #[test]
647    fn test_verification_config_default() {
648        let config = VerificationConfig::default();
649
650        match config.strategy {
651            VerificationStrategy::Statistical { sample_rate } => {
652                assert!((sample_rate - 0.1).abs() < 1e-10);
653            }
654            _ => panic!("Expected statistical strategy"),
655        }
656
657        assert_eq!(config.performance_budget, Duration::from_millis(10));
658        assert!((config.tolerance - 1e-12).abs() < 1e-15);
659        assert!(config.enable_invariant_checking);
660    }
661
662    #[test]
663    fn test_performance_stats() {
664        let mut stats = PerformanceStats::new();
665
666        stats.record_operation(Duration::from_millis(5), 100);
667        stats.record_operation(Duration::from_millis(10), 200);
668
669        assert_eq!(stats.operation_count, 2);
670        assert_eq!(stats.total_elements, 300);
671        assert_eq!(
672            stats.average_duration(),
673            Duration::from_millis(7) + Duration::from_micros(500)
674        );
675        assert_eq!(stats.max_duration, Duration::from_millis(10));
676
677        let throughput = stats.average_throughput();
678        assert!(throughput > 0.0);
679    }
680
681    #[test]
682    fn test_statistical_verifier_sampling() {
683        let verifier = StatisticalGpuVerifier::<3, 0, 0>::new(0.1, 1e-12);
684
685        let indices = verifier.select_random_indices(100, 10);
686        assert!(indices.len() <= 10);
687        assert!(indices.contains(&0)); // First element
688        assert!(indices.contains(&99)); // Last element
689
690        // Test empty case
691        let empty_indices = verifier.select_random_indices(0, 5);
692        assert!(empty_indices.is_empty());
693    }
694
695    #[tokio::test]
696    async fn test_boundary_verifier_creation() {
697        let config = VerificationConfig::default();
698        let verifier = GpuBoundaryVerifier::new(config);
699
700        assert_eq!(verifier.performance_stats().operation_count, 0);
701        assert_eq!(
702            verifier.performance_stats().average_duration(),
703            Duration::ZERO
704        );
705    }
706}