kizzasi_logic/
performance.rs

1//! Performance optimizations for constraint checking
2//!
3//! This module provides batch operations, caching, and SIMD-friendly
4//! implementations for efficient constraint evaluation.
5
6use crate::constraint::ViolationComputable;
7use scirs2_core::ndarray::Array2;
8use std::collections::HashMap;
9
10// ============================================================================
11// Batch Constraint Checking
12// ============================================================================
13
14/// Batch constraint checker for evaluating multiple points efficiently
15pub struct BatchConstraintChecker<C> {
16    constraints: Vec<C>,
17    cache_enabled: bool,
18    cache: HashMap<Vec<i32>, bool>, // Discretized cache key -> satisfaction result
19    cache_resolution: f32,
20}
21
22impl<C: ViolationComputable> BatchConstraintChecker<C> {
23    /// Create a new batch checker
24    pub fn new(constraints: Vec<C>) -> Self {
25        Self {
26            constraints,
27            cache_enabled: false,
28            cache: HashMap::new(),
29            cache_resolution: 0.1,
30        }
31    }
32
33    /// Enable caching with specified resolution
34    pub fn with_caching(mut self, resolution: f32) -> Self {
35        self.cache_enabled = true;
36        self.cache_resolution = resolution;
37        self
38    }
39
40    /// Check multiple points in batch
41    pub fn check_batch(&mut self, points: &Array2<f32>) -> Vec<bool> {
42        let (n_points, _) = points.dim();
43        let mut results = Vec::with_capacity(n_points);
44
45        for i in 0..n_points {
46            let point = points.row(i);
47            let point_slice: Vec<f32> = point.iter().copied().collect();
48
49            if self.cache_enabled {
50                let key = self.discretize(&point_slice);
51                if let Some(&cached) = self.cache.get(&key) {
52                    results.push(cached);
53                    continue;
54                }
55
56                let satisfied = self.check_point(&point_slice);
57                self.cache.insert(key, satisfied);
58                results.push(satisfied);
59            } else {
60                results.push(self.check_point(&point_slice));
61            }
62        }
63
64        results
65    }
66
67    /// Check if a single point satisfies all constraints
68    fn check_point(&self, point: &[f32]) -> bool {
69        self.constraints.iter().all(|c| c.check(point))
70    }
71
72    /// Compute violations for batch of points
73    pub fn violation_batch(&self, points: &Array2<f32>) -> Vec<f32> {
74        let (n_points, _) = points.dim();
75        let mut violations = Vec::with_capacity(n_points);
76
77        for i in 0..n_points {
78            let point = points.row(i);
79            let point_slice: Vec<f32> = point.iter().copied().collect();
80
81            let total_violation: f32 = self
82                .constraints
83                .iter()
84                .map(|c| c.violation(&point_slice))
85                .sum();
86
87            violations.push(total_violation);
88        }
89
90        violations
91    }
92
93    /// Discretize point for caching
94    fn discretize(&self, point: &[f32]) -> Vec<i32> {
95        point
96            .iter()
97            .map(|&x| (x / self.cache_resolution).round() as i32)
98            .collect()
99    }
100
101    /// Clear the cache
102    pub fn clear_cache(&mut self) {
103        self.cache.clear();
104    }
105
106    /// Get cache statistics
107    pub fn cache_stats(&self) -> CacheStats {
108        CacheStats {
109            entries: self.cache.len(),
110            enabled: self.cache_enabled,
111        }
112    }
113
114    /// Get constraint count
115    pub fn num_constraints(&self) -> usize {
116        self.constraints.len()
117    }
118}
119
120/// Cache statistics
121#[derive(Debug, Clone)]
122pub struct CacheStats {
123    pub entries: usize,
124    pub enabled: bool,
125}
126
127// ============================================================================
128// Parallel Constraint Checking
129// ============================================================================
130
131/// Parallel constraint checker using rayon (when available)
132pub struct ParallelConstraintChecker<C> {
133    constraints: Vec<C>,
134}
135
136impl<C: ViolationComputable + Send + Sync> ParallelConstraintChecker<C> {
137    /// Create a new parallel checker
138    pub fn new(constraints: Vec<C>) -> Self {
139        Self { constraints }
140    }
141
142    /// Check points in parallel (sequential fallback if rayon not available)
143    pub fn check_batch(&self, points: &Array2<f32>) -> Vec<bool> {
144        let (n_points, _) = points.dim();
145        let mut results = Vec::with_capacity(n_points);
146
147        // Sequential implementation (can be parallelized with rayon)
148        for i in 0..n_points {
149            let point = points.row(i);
150            let point_slice: Vec<f32> = point.iter().copied().collect();
151            let satisfied = self.constraints.iter().all(|c| c.check(&point_slice));
152            results.push(satisfied);
153        }
154
155        results
156    }
157
158    /// Compute violations in parallel
159    pub fn violation_batch(&self, points: &Array2<f32>) -> Vec<f32> {
160        let (n_points, _) = points.dim();
161        let mut violations = Vec::with_capacity(n_points);
162
163        for i in 0..n_points {
164            let point = points.row(i);
165            let point_slice: Vec<f32> = point.iter().copied().collect();
166            let total: f32 = self
167                .constraints
168                .iter()
169                .map(|c| c.violation(&point_slice))
170                .sum();
171            violations.push(total);
172        }
173
174        violations
175    }
176}
177
178// ============================================================================
179// Lazy Constraint Evaluation
180// ============================================================================
181
182/// Lazy constraint evaluator that skips evaluation when not needed
183pub struct LazyConstraintEvaluator<C> {
184    constraints: Vec<(C, bool)>, // (constraint, is_critical)
185}
186
187impl<C: ViolationComputable> LazyConstraintEvaluator<C> {
188    /// Create a new lazy evaluator
189    pub fn new() -> Self {
190        Self {
191            constraints: Vec::new(),
192        }
193    }
194
195    /// Add a constraint with criticality flag
196    pub fn add_constraint(&mut self, constraint: C, is_critical: bool) {
197        self.constraints.push((constraint, is_critical));
198    }
199
200    /// Check constraints lazily (stop on first critical violation)
201    pub fn check_lazy(&self, point: &[f32]) -> (bool, usize) {
202        for (i, (constraint, is_critical)) in self.constraints.iter().enumerate() {
203            if !constraint.check(point) && *is_critical {
204                // Critical constraint violated, stop immediately
205                return (false, i);
206            }
207        }
208        (true, self.constraints.len())
209    }
210
211    /// Compute violation with early stopping
212    pub fn violation_lazy(&self, point: &[f32], threshold: f32) -> (f32, bool) {
213        let mut total_violation = 0.0;
214
215        for (constraint, is_critical) in &self.constraints {
216            let viol = constraint.violation(point);
217            total_violation += viol;
218
219            // Early stop if violation exceeds threshold
220            if *is_critical && viol > threshold {
221                return (total_violation, true);
222            }
223        }
224
225        (total_violation, false)
226    }
227}
228
229impl<C: ViolationComputable> Default for LazyConstraintEvaluator<C> {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235// ============================================================================
236// Vectorized Constraint Operations
237// ============================================================================
238
239/// Vectorized operations for efficient batch processing
240pub struct VectorizedConstraints<C> {
241    constraints: Vec<C>,
242}
243
244impl<C: ViolationComputable> VectorizedConstraints<C> {
245    /// Create vectorized constraint evaluator
246    pub fn new(constraints: Vec<C>) -> Self {
247        Self { constraints }
248    }
249
250    /// Evaluate all constraints on all points, returning matrix of violations
251    pub fn violation_matrix(&self, points: &Array2<f32>) -> Array2<f32> {
252        let (n_points, _dim) = points.dim();
253        let n_constraints = self.constraints.len();
254
255        let mut violations = Array2::zeros((n_points, n_constraints));
256
257        for i in 0..n_points {
258            let point = points.row(i);
259            let point_slice: Vec<f32> = point.iter().copied().collect();
260
261            for (j, constraint) in self.constraints.iter().enumerate() {
262                violations[[i, j]] = constraint.violation(&point_slice);
263            }
264        }
265
266        violations
267    }
268
269    /// Get satisfaction matrix (bool for each point and constraint)
270    pub fn satisfaction_matrix(&self, points: &Array2<f32>) -> Vec<Vec<bool>> {
271        let (n_points, _) = points.dim();
272        let mut satisfaction = Vec::with_capacity(n_points);
273
274        for i in 0..n_points {
275            let point = points.row(i);
276            let point_slice: Vec<f32> = point.iter().copied().collect();
277
278            let row: Vec<bool> = self
279                .constraints
280                .iter()
281                .map(|c| c.check(&point_slice))
282                .collect();
283
284            satisfaction.push(row);
285        }
286
287        satisfaction
288    }
289
290    /// Count violations per constraint across all points
291    pub fn violation_counts(&self, points: &Array2<f32>) -> Vec<usize> {
292        let (n_points, _) = points.dim();
293        let mut counts = vec![0; self.constraints.len()];
294
295        for i in 0..n_points {
296            let point = points.row(i);
297            let point_slice: Vec<f32> = point.iter().copied().collect();
298
299            for (j, constraint) in self.constraints.iter().enumerate() {
300                if !constraint.check(&point_slice) {
301                    counts[j] += 1;
302                }
303            }
304        }
305
306        counts
307    }
308}
309
310// ============================================================================
311// Adaptive Constraint Ordering
312// ============================================================================
313
314/// Adaptive constraint ordering for efficient early termination
315pub struct AdaptiveConstraintOrder<C> {
316    constraints: Vec<C>,
317    violation_counts: Vec<usize>,
318    check_count: usize,
319}
320
321impl<C: ViolationComputable> AdaptiveConstraintOrder<C> {
322    /// Create new adaptive ordering
323    pub fn new(constraints: Vec<C>) -> Self {
324        let n = constraints.len();
325        Self {
326            constraints,
327            violation_counts: vec![0; n],
328            check_count: 0,
329        }
330    }
331
332    /// Check constraints in adaptive order
333    pub fn check_adaptive(&mut self, point: &[f32]) -> bool {
334        self.check_count += 1;
335
336        // Sort constraints by violation frequency (most violated first)
337        let mut indices: Vec<usize> = (0..self.constraints.len()).collect();
338        indices.sort_by_key(|&i| std::cmp::Reverse(self.violation_counts[i]));
339
340        for &i in &indices {
341            if !self.constraints[i].check(point) {
342                self.violation_counts[i] += 1;
343                return false;
344            }
345        }
346
347        true
348    }
349
350    /// Get violation statistics
351    pub fn get_statistics(&self) -> Vec<(usize, f32)> {
352        self.violation_counts
353            .iter()
354            .enumerate()
355            .map(|(i, &count)| {
356                let rate = if self.check_count > 0 {
357                    count as f32 / self.check_count as f32
358                } else {
359                    0.0
360                };
361                (i, rate)
362            })
363            .collect()
364    }
365
366    /// Reset statistics
367    pub fn reset_statistics(&mut self) {
368        self.violation_counts.fill(0);
369        self.check_count = 0;
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::constraint::ConstraintBuilder;
377
378    #[test]
379    fn test_batch_checking() {
380        let c1 = ConstraintBuilder::new()
381            .name("x_positive")
382            .greater_eq(0.0)
383            .build()
384            .unwrap();
385
386        let c2 = ConstraintBuilder::new()
387            .name("x_bounded")
388            .less_eq(10.0)
389            .build()
390            .unwrap();
391
392        let mut checker = BatchConstraintChecker::new(vec![c1, c2]);
393
394        // Create batch of points
395        let points = Array2::from_shape_vec(
396            (4, 1),
397            vec![
398                -1.0, // violates c1
399                5.0,  // satisfies both
400                15.0, // violates c2
401                3.0,  // satisfies both
402            ],
403        )
404        .unwrap();
405
406        let results = checker.check_batch(&points);
407        assert_eq!(results, vec![false, true, false, true]);
408    }
409
410    #[test]
411    fn test_batch_violations() {
412        let c = ConstraintBuilder::new()
413            .name("bound")
414            .less_eq(5.0)
415            .build()
416            .unwrap();
417
418        let checker = BatchConstraintChecker::new(vec![c]);
419
420        let points = Array2::from_shape_vec((3, 1), vec![3.0, 7.0, 10.0]).unwrap();
421        let violations = checker.violation_batch(&points);
422
423        assert_eq!(violations[0], 0.0); // 3 <= 5, no violation
424        assert_eq!(violations[1], 2.0); // 7 - 5 = 2
425        assert_eq!(violations[2], 5.0); // 10 - 5 = 5
426    }
427
428    #[test]
429    fn test_caching() {
430        let c = ConstraintBuilder::new()
431            .name("test")
432            .in_range(0.0, 10.0)
433            .build()
434            .unwrap();
435
436        let mut checker = BatchConstraintChecker::new(vec![c]).with_caching(0.1);
437
438        let points = Array2::from_shape_vec((2, 1), vec![5.0, 5.05]).unwrap();
439        let _ = checker.check_batch(&points);
440
441        let stats = checker.cache_stats();
442        assert!(stats.enabled);
443        // Both points discretize to same bucket with resolution 0.1
444        assert!(stats.entries >= 1);
445    }
446
447    #[test]
448    fn test_lazy_evaluation() {
449        let c1 = ConstraintBuilder::new()
450            .name("critical")
451            .greater_eq(0.0)
452            .build()
453            .unwrap();
454
455        let c2 = ConstraintBuilder::new()
456            .name("non_critical")
457            .less_eq(100.0)
458            .build()
459            .unwrap();
460
461        let mut evaluator = LazyConstraintEvaluator::new();
462        evaluator.add_constraint(c1, true); // critical
463        evaluator.add_constraint(c2, false); // non-critical
464
465        // Violates critical constraint, should stop early
466        let (satisfied, stopped_at) = evaluator.check_lazy(&[-1.0]);
467        assert!(!satisfied);
468        assert_eq!(stopped_at, 0);
469
470        // Satisfies all
471        let (satisfied, stopped_at) = evaluator.check_lazy(&[5.0]);
472        assert!(satisfied);
473        assert_eq!(stopped_at, 2);
474    }
475
476    #[test]
477    fn test_adaptive_ordering() {
478        let c1 = ConstraintBuilder::new()
479            .name("rarely_violated")
480            .greater_eq(-100.0)
481            .build()
482            .unwrap();
483
484        let c2 = ConstraintBuilder::new()
485            .name("often_violated")
486            .less_eq(5.0)
487            .build()
488            .unwrap();
489
490        let mut adaptive = AdaptiveConstraintOrder::new(vec![c1, c2]);
491
492        // Check several points, c2 violated more often
493        adaptive.check_adaptive(&[10.0]); // violates c2
494        adaptive.check_adaptive(&[3.0]); // satisfies both
495        adaptive.check_adaptive(&[15.0]); // violates c2
496
497        let stats = adaptive.get_statistics();
498        assert!(stats[1].1 > stats[0].1); // c2 violated more frequently
499    }
500}