1use crate::constraint::ViolationComputable;
7use scirs2_core::ndarray::Array2;
8use std::collections::HashMap;
9
10pub struct BatchConstraintChecker<C> {
16 constraints: Vec<C>,
17 cache_enabled: bool,
18 cache: HashMap<Vec<i32>, bool>, cache_resolution: f32,
20}
21
22impl<C: ViolationComputable> BatchConstraintChecker<C> {
23 pub fn new(constraints: Vec<C>) -> Self {
25 Self {
26 constraints,
27 cache_enabled: false,
28 cache: HashMap::new(),
29 cache_resolution: 0.1,
30 }
31 }
32
33 pub fn with_caching(mut self, resolution: f32) -> Self {
35 self.cache_enabled = true;
36 self.cache_resolution = resolution;
37 self
38 }
39
40 pub fn check_batch(&mut self, points: &Array2<f32>) -> Vec<bool> {
42 let (n_points, _) = points.dim();
43 let mut results = Vec::with_capacity(n_points);
44
45 for i in 0..n_points {
46 let point = points.row(i);
47 let point_slice: Vec<f32> = point.iter().copied().collect();
48
49 if self.cache_enabled {
50 let key = self.discretize(&point_slice);
51 if let Some(&cached) = self.cache.get(&key) {
52 results.push(cached);
53 continue;
54 }
55
56 let satisfied = self.check_point(&point_slice);
57 self.cache.insert(key, satisfied);
58 results.push(satisfied);
59 } else {
60 results.push(self.check_point(&point_slice));
61 }
62 }
63
64 results
65 }
66
67 fn check_point(&self, point: &[f32]) -> bool {
69 self.constraints.iter().all(|c| c.check(point))
70 }
71
72 pub fn violation_batch(&self, points: &Array2<f32>) -> Vec<f32> {
74 let (n_points, _) = points.dim();
75 let mut violations = Vec::with_capacity(n_points);
76
77 for i in 0..n_points {
78 let point = points.row(i);
79 let point_slice: Vec<f32> = point.iter().copied().collect();
80
81 let total_violation: f32 = self
82 .constraints
83 .iter()
84 .map(|c| c.violation(&point_slice))
85 .sum();
86
87 violations.push(total_violation);
88 }
89
90 violations
91 }
92
93 fn discretize(&self, point: &[f32]) -> Vec<i32> {
95 point
96 .iter()
97 .map(|&x| (x / self.cache_resolution).round() as i32)
98 .collect()
99 }
100
101 pub fn clear_cache(&mut self) {
103 self.cache.clear();
104 }
105
106 pub fn cache_stats(&self) -> CacheStats {
108 CacheStats {
109 entries: self.cache.len(),
110 enabled: self.cache_enabled,
111 }
112 }
113
114 pub fn num_constraints(&self) -> usize {
116 self.constraints.len()
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct CacheStats {
123 pub entries: usize,
124 pub enabled: bool,
125}
126
127pub struct ParallelConstraintChecker<C> {
133 constraints: Vec<C>,
134}
135
136impl<C: ViolationComputable + Send + Sync> ParallelConstraintChecker<C> {
137 pub fn new(constraints: Vec<C>) -> Self {
139 Self { constraints }
140 }
141
142 pub fn check_batch(&self, points: &Array2<f32>) -> Vec<bool> {
144 let (n_points, _) = points.dim();
145 let mut results = Vec::with_capacity(n_points);
146
147 for i in 0..n_points {
149 let point = points.row(i);
150 let point_slice: Vec<f32> = point.iter().copied().collect();
151 let satisfied = self.constraints.iter().all(|c| c.check(&point_slice));
152 results.push(satisfied);
153 }
154
155 results
156 }
157
158 pub fn violation_batch(&self, points: &Array2<f32>) -> Vec<f32> {
160 let (n_points, _) = points.dim();
161 let mut violations = Vec::with_capacity(n_points);
162
163 for i in 0..n_points {
164 let point = points.row(i);
165 let point_slice: Vec<f32> = point.iter().copied().collect();
166 let total: f32 = self
167 .constraints
168 .iter()
169 .map(|c| c.violation(&point_slice))
170 .sum();
171 violations.push(total);
172 }
173
174 violations
175 }
176}
177
178pub struct LazyConstraintEvaluator<C> {
184 constraints: Vec<(C, bool)>, }
186
187impl<C: ViolationComputable> LazyConstraintEvaluator<C> {
188 pub fn new() -> Self {
190 Self {
191 constraints: Vec::new(),
192 }
193 }
194
195 pub fn add_constraint(&mut self, constraint: C, is_critical: bool) {
197 self.constraints.push((constraint, is_critical));
198 }
199
200 pub fn check_lazy(&self, point: &[f32]) -> (bool, usize) {
202 for (i, (constraint, is_critical)) in self.constraints.iter().enumerate() {
203 if !constraint.check(point) && *is_critical {
204 return (false, i);
206 }
207 }
208 (true, self.constraints.len())
209 }
210
211 pub fn violation_lazy(&self, point: &[f32], threshold: f32) -> (f32, bool) {
213 let mut total_violation = 0.0;
214
215 for (constraint, is_critical) in &self.constraints {
216 let viol = constraint.violation(point);
217 total_violation += viol;
218
219 if *is_critical && viol > threshold {
221 return (total_violation, true);
222 }
223 }
224
225 (total_violation, false)
226 }
227}
228
229impl<C: ViolationComputable> Default for LazyConstraintEvaluator<C> {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235pub struct VectorizedConstraints<C> {
241 constraints: Vec<C>,
242}
243
244impl<C: ViolationComputable> VectorizedConstraints<C> {
245 pub fn new(constraints: Vec<C>) -> Self {
247 Self { constraints }
248 }
249
250 pub fn violation_matrix(&self, points: &Array2<f32>) -> Array2<f32> {
252 let (n_points, _dim) = points.dim();
253 let n_constraints = self.constraints.len();
254
255 let mut violations = Array2::zeros((n_points, n_constraints));
256
257 for i in 0..n_points {
258 let point = points.row(i);
259 let point_slice: Vec<f32> = point.iter().copied().collect();
260
261 for (j, constraint) in self.constraints.iter().enumerate() {
262 violations[[i, j]] = constraint.violation(&point_slice);
263 }
264 }
265
266 violations
267 }
268
269 pub fn satisfaction_matrix(&self, points: &Array2<f32>) -> Vec<Vec<bool>> {
271 let (n_points, _) = points.dim();
272 let mut satisfaction = Vec::with_capacity(n_points);
273
274 for i in 0..n_points {
275 let point = points.row(i);
276 let point_slice: Vec<f32> = point.iter().copied().collect();
277
278 let row: Vec<bool> = self
279 .constraints
280 .iter()
281 .map(|c| c.check(&point_slice))
282 .collect();
283
284 satisfaction.push(row);
285 }
286
287 satisfaction
288 }
289
290 pub fn violation_counts(&self, points: &Array2<f32>) -> Vec<usize> {
292 let (n_points, _) = points.dim();
293 let mut counts = vec![0; self.constraints.len()];
294
295 for i in 0..n_points {
296 let point = points.row(i);
297 let point_slice: Vec<f32> = point.iter().copied().collect();
298
299 for (j, constraint) in self.constraints.iter().enumerate() {
300 if !constraint.check(&point_slice) {
301 counts[j] += 1;
302 }
303 }
304 }
305
306 counts
307 }
308}
309
310pub struct AdaptiveConstraintOrder<C> {
316 constraints: Vec<C>,
317 violation_counts: Vec<usize>,
318 check_count: usize,
319}
320
321impl<C: ViolationComputable> AdaptiveConstraintOrder<C> {
322 pub fn new(constraints: Vec<C>) -> Self {
324 let n = constraints.len();
325 Self {
326 constraints,
327 violation_counts: vec![0; n],
328 check_count: 0,
329 }
330 }
331
332 pub fn check_adaptive(&mut self, point: &[f32]) -> bool {
334 self.check_count += 1;
335
336 let mut indices: Vec<usize> = (0..self.constraints.len()).collect();
338 indices.sort_by_key(|&i| std::cmp::Reverse(self.violation_counts[i]));
339
340 for &i in &indices {
341 if !self.constraints[i].check(point) {
342 self.violation_counts[i] += 1;
343 return false;
344 }
345 }
346
347 true
348 }
349
350 pub fn get_statistics(&self) -> Vec<(usize, f32)> {
352 self.violation_counts
353 .iter()
354 .enumerate()
355 .map(|(i, &count)| {
356 let rate = if self.check_count > 0 {
357 count as f32 / self.check_count as f32
358 } else {
359 0.0
360 };
361 (i, rate)
362 })
363 .collect()
364 }
365
366 pub fn reset_statistics(&mut self) {
368 self.violation_counts.fill(0);
369 self.check_count = 0;
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::constraint::ConstraintBuilder;
377
378 #[test]
379 fn test_batch_checking() {
380 let c1 = ConstraintBuilder::new()
381 .name("x_positive")
382 .greater_eq(0.0)
383 .build()
384 .unwrap();
385
386 let c2 = ConstraintBuilder::new()
387 .name("x_bounded")
388 .less_eq(10.0)
389 .build()
390 .unwrap();
391
392 let mut checker = BatchConstraintChecker::new(vec![c1, c2]);
393
394 let points = Array2::from_shape_vec(
396 (4, 1),
397 vec![
398 -1.0, 5.0, 15.0, 3.0, ],
403 )
404 .unwrap();
405
406 let results = checker.check_batch(&points);
407 assert_eq!(results, vec![false, true, false, true]);
408 }
409
410 #[test]
411 fn test_batch_violations() {
412 let c = ConstraintBuilder::new()
413 .name("bound")
414 .less_eq(5.0)
415 .build()
416 .unwrap();
417
418 let checker = BatchConstraintChecker::new(vec![c]);
419
420 let points = Array2::from_shape_vec((3, 1), vec![3.0, 7.0, 10.0]).unwrap();
421 let violations = checker.violation_batch(&points);
422
423 assert_eq!(violations[0], 0.0); assert_eq!(violations[1], 2.0); assert_eq!(violations[2], 5.0); }
427
428 #[test]
429 fn test_caching() {
430 let c = ConstraintBuilder::new()
431 .name("test")
432 .in_range(0.0, 10.0)
433 .build()
434 .unwrap();
435
436 let mut checker = BatchConstraintChecker::new(vec![c]).with_caching(0.1);
437
438 let points = Array2::from_shape_vec((2, 1), vec![5.0, 5.05]).unwrap();
439 let _ = checker.check_batch(&points);
440
441 let stats = checker.cache_stats();
442 assert!(stats.enabled);
443 assert!(stats.entries >= 1);
445 }
446
447 #[test]
448 fn test_lazy_evaluation() {
449 let c1 = ConstraintBuilder::new()
450 .name("critical")
451 .greater_eq(0.0)
452 .build()
453 .unwrap();
454
455 let c2 = ConstraintBuilder::new()
456 .name("non_critical")
457 .less_eq(100.0)
458 .build()
459 .unwrap();
460
461 let mut evaluator = LazyConstraintEvaluator::new();
462 evaluator.add_constraint(c1, true); evaluator.add_constraint(c2, false); let (satisfied, stopped_at) = evaluator.check_lazy(&[-1.0]);
467 assert!(!satisfied);
468 assert_eq!(stopped_at, 0);
469
470 let (satisfied, stopped_at) = evaluator.check_lazy(&[5.0]);
472 assert!(satisfied);
473 assert_eq!(stopped_at, 2);
474 }
475
476 #[test]
477 fn test_adaptive_ordering() {
478 let c1 = ConstraintBuilder::new()
479 .name("rarely_violated")
480 .greater_eq(-100.0)
481 .build()
482 .unwrap();
483
484 let c2 = ConstraintBuilder::new()
485 .name("often_violated")
486 .less_eq(5.0)
487 .build()
488 .unwrap();
489
490 let mut adaptive = AdaptiveConstraintOrder::new(vec![c1, c2]);
491
492 adaptive.check_adaptive(&[10.0]); adaptive.check_adaptive(&[3.0]); adaptive.check_adaptive(&[15.0]); let stats = adaptive.get_statistics();
498 assert!(stats[1].1 > stats[0].1); }
500}