1use rayon::prelude::*;
20use scirs2_core::ndarray::{Array1, Array2};
21use std::time::Instant;
22
23#[derive(Debug, Clone)]
29pub struct ParallelConfig {
30 pub num_threads: usize,
32 pub chunk_size: usize,
34 pub use_simd: bool,
36 pub prefetch_distance: usize,
38}
39
40impl Default for ParallelConfig {
41 fn default() -> Self {
42 Self {
43 num_threads: 0,
44 chunk_size: 64,
45 use_simd: true,
46 prefetch_distance: 8,
47 }
48 }
49}
50
51pub trait FastConstraint: Send + Sync {
59 fn is_feasible(&self, x: &Array1<f32>) -> bool;
61
62 fn project(&self, x: &Array1<f32>) -> Array1<f32>;
65
66 fn violation(&self, x: &Array1<f32>) -> f32;
69}
70
71#[derive(Debug, Clone)]
80pub struct BoxConstraint {
81 pub lb: Array1<f32>,
83 pub ub: Array1<f32>,
85}
86
87impl BoxConstraint {
88 pub fn new(lb: Array1<f32>, ub: Array1<f32>) -> Result<Self, String> {
95 if lb.len() != ub.len() {
96 return Err(format!(
97 "BoxConstraint: lb.len()={} != ub.len()={}",
98 lb.len(),
99 ub.len()
100 ));
101 }
102 for (i, (&l, &u)) in lb.iter().zip(ub.iter()).enumerate() {
103 if l > u {
104 return Err(format!("BoxConstraint: lb[{i}]={l} > ub[{i}]={u}"));
105 }
106 }
107 Ok(Self { lb, ub })
108 }
109}
110
111impl FastConstraint for BoxConstraint {
112 fn is_feasible(&self, x: &Array1<f32>) -> bool {
113 x.iter()
115 .zip(self.lb.iter())
116 .zip(self.ub.iter())
117 .map(|((&xi, &li), &ui)| if xi < li || xi > ui { 1u8 } else { 0u8 })
118 .sum::<u8>()
119 == 0
120 }
121
122 fn project(&self, x: &Array1<f32>) -> Array1<f32> {
123 let n = x.len();
124 let mut out = vec![0.0f32; n];
125 for i in 0..n {
126 out[i] = x[i].clamp(self.lb[i], self.ub[i]);
127 }
128 Array1::from(out)
129 }
130
131 fn violation(&self, x: &Array1<f32>) -> f32 {
132 let mut v = 0.0f32;
133 for i in 0..x.len() {
134 let below = (self.lb[i] - x[i]).max(0.0);
135 let above = (x[i] - self.ub[i]).max(0.0);
136 v += below * below + above * above;
137 }
138 v.sqrt()
139 }
140}
141
142#[derive(Debug, Clone)]
150pub struct L2BallConstraint {
151 pub center: Array1<f32>,
153 pub radius: f32,
155}
156
157impl L2BallConstraint {
158 pub fn new(center: Array1<f32>, radius: f32) -> Result<Self, String> {
164 if radius <= 0.0 {
165 return Err(format!(
166 "L2BallConstraint: radius must be positive, got {radius}"
167 ));
168 }
169 Ok(Self { center, radius })
170 }
171
172 fn dist_sq(&self, x: &Array1<f32>) -> f32 {
173 let mut s = 0.0f32;
174 for i in 0..x.len() {
175 let d = x[i] - self.center[i];
176 s += d * d;
177 }
178 s
179 }
180}
181
182impl FastConstraint for L2BallConstraint {
183 fn is_feasible(&self, x: &Array1<f32>) -> bool {
184 self.dist_sq(x) <= self.radius * self.radius
185 }
186
187 fn project(&self, x: &Array1<f32>) -> Array1<f32> {
188 let dist_sq = self.dist_sq(x);
189 if dist_sq <= self.radius * self.radius {
190 return x.clone();
191 }
192 let dist = dist_sq.sqrt();
193 let scale = self.radius / dist;
194 let n = x.len();
195 let mut out = vec![0.0f32; n];
196 for i in 0..n {
197 out[i] = self.center[i] + (x[i] - self.center[i]) * scale;
198 }
199 Array1::from(out)
200 }
201
202 fn violation(&self, x: &Array1<f32>) -> f32 {
203 let dist = self.dist_sq(x).sqrt();
204 (dist - self.radius).max(0.0)
205 }
206}
207
208#[derive(Debug, Clone)]
217pub struct HyperplaneConstraint {
218 pub normal: Array1<f32>,
220 pub offset: f32,
222}
223
224impl HyperplaneConstraint {
225 pub fn new(normal: Array1<f32>, offset: f32) -> Result<Self, String> {
231 let norm_sq: f32 = normal.iter().map(|&v| v * v).sum();
232 if norm_sq == 0.0 {
233 return Err("HyperplaneConstraint: normal vector must be non-zero".to_string());
234 }
235 Ok(Self { normal, offset })
236 }
237
238 fn dot(&self, x: &Array1<f32>) -> f32 {
239 let mut s = 0.0f32;
240 for i in 0..x.len() {
241 s += self.normal[i] * x[i];
242 }
243 s
244 }
245
246 fn norm_sq(&self) -> f32 {
247 let mut s = 0.0f32;
248 for &v in self.normal.iter() {
249 s += v * v;
250 }
251 s
252 }
253}
254
255impl FastConstraint for HyperplaneConstraint {
256 fn is_feasible(&self, x: &Array1<f32>) -> bool {
257 self.dot(x) <= self.offset
258 }
259
260 fn project(&self, x: &Array1<f32>) -> Array1<f32> {
261 let ax = self.dot(x);
262 if ax <= self.offset {
263 return x.clone();
264 }
265 let scale = (ax - self.offset) / self.norm_sq();
267 let n = x.len();
268 let mut out = vec![0.0f32; n];
269 for i in 0..n {
270 out[i] = x[i] - scale * self.normal[i];
271 }
272 Array1::from(out)
273 }
274
275 fn violation(&self, x: &Array1<f32>) -> f32 {
276 (self.dot(x) - self.offset).max(0.0)
277 }
278}
279
280#[derive(Debug, Clone)]
288pub struct SimplexConstraint {
289 pub dim: usize,
291}
292
293impl SimplexConstraint {
294 pub fn new(dim: usize) -> Self {
296 Self { dim }
297 }
298}
299
300impl FastConstraint for SimplexConstraint {
301 fn is_feasible(&self, x: &Array1<f32>) -> bool {
302 if x.len() != self.dim {
303 return false;
304 }
305 let sum: f32 = x.iter().sum();
306 let all_nonneg = x
307 .iter()
308 .map(|&v| if v < 0.0 { 1u8 } else { 0u8 })
309 .sum::<u8>()
310 == 0;
311 all_nonneg && (sum - 1.0).abs() < 1e-5
312 }
313
314 fn project(&self, x: &Array1<f32>) -> Array1<f32> {
315 let n = x.len();
316 let mut sorted: Vec<f32> = x.iter().copied().collect();
318 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
319
320 let mut cumsum = 0.0f32;
321 let mut rho = 0usize;
322 for (j, &s) in sorted.iter().enumerate() {
323 cumsum += s;
324 if s > (cumsum - 1.0) / (j as f32 + 1.0) {
325 rho = j;
326 }
327 }
328 let cumsum_rho: f32 = sorted[..=rho].iter().sum();
329 let theta = (cumsum_rho - 1.0) / (rho as f32 + 1.0);
330
331 let mut out = vec![0.0f32; n];
332 for i in 0..n {
333 out[i] = (x[i] - theta).max(0.0);
334 }
335 Array1::from(out)
336 }
337
338 fn violation(&self, x: &Array1<f32>) -> f32 {
339 let sum: f32 = x.iter().sum();
340 let sum_viol = (sum - 1.0).abs();
341 let neg_viol: f32 = x.iter().map(|&v| (-v).max(0.0)).sum();
342 sum_viol + neg_viol
343 }
344}
345
346pub struct ParallelFeasibilityChecker {
356 constraints: Vec<Box<dyn FastConstraint>>,
357 config: ParallelConfig,
358}
359
360impl ParallelFeasibilityChecker {
361 pub fn new(config: ParallelConfig) -> Self {
363 Self {
364 constraints: Vec::new(),
365 config,
366 }
367 }
368
369 pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>) {
371 self.constraints.push(constraint);
372 }
373
374 pub fn num_constraints(&self) -> usize {
376 self.constraints.len()
377 }
378
379 pub fn check_batch(&self, points: &Array2<f32>) -> Vec<bool> {
384 let (n_points, dim) = points.dim();
385 let constraints = &self.constraints;
386 let chunk_size = self.config.chunk_size.max(1);
387
388 (0..n_points)
389 .into_par_iter()
390 .with_min_len(chunk_size)
391 .map(|i| {
392 let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
393 if row.len() != dim {
394 return false;
395 }
396 constraints.iter().all(|c| c.is_feasible(&row))
397 })
398 .collect()
399 }
400
401 pub fn violation_matrix(&self, points: &Array2<f32>) -> Array2<f32> {
406 let (n_points, _dim) = points.dim();
407 let n_constraints = self.constraints.len();
408
409 if n_constraints == 0 || n_points == 0 {
410 return Array2::zeros((n_points, n_constraints));
411 }
412
413 let chunk_size = self.config.chunk_size.max(1);
414 let constraints = &self.constraints;
415
416 let rows: Vec<Vec<f32>> = (0..n_points)
417 .into_par_iter()
418 .with_min_len(chunk_size)
419 .map(|i| {
420 let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
421 constraints.iter().map(|c| c.violation(&row)).collect()
422 })
423 .collect();
424
425 let mut out = Array2::zeros((n_points, n_constraints));
426 for (i, row) in rows.iter().enumerate() {
427 for (j, &v) in row.iter().enumerate() {
428 out[[i, j]] = v;
429 }
430 }
431 out
432 }
433
434 pub fn project_batch(&self, points: &Array2<f32>, max_iter: usize) -> Array2<f32> {
439 let (n_points, dim) = points.dim();
440 if n_points == 0 || dim == 0 {
441 return points.clone();
442 }
443
444 let chunk_size = self.config.chunk_size.max(1);
445 let constraints = &self.constraints;
446
447 let projected: Vec<Vec<f32>> = (0..n_points)
448 .into_par_iter()
449 .with_min_len(chunk_size)
450 .map(|i| {
451 let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
452 let result = dykstra_project(&row, constraints.as_slice(), max_iter);
453 result.into_raw_vec_and_offset().0
454 })
455 .collect();
456
457 let mut out = Array2::zeros((n_points, dim));
458 for (i, row) in projected.iter().enumerate() {
459 for (j, &v) in row.iter().enumerate() {
460 if j < dim {
461 out[[i, j]] = v;
462 }
463 }
464 }
465 out
466 }
467}
468
469fn dykstra_project(
474 x: &Array1<f32>,
475 constraints: &[Box<dyn FastConstraint>],
476 max_iter: usize,
477) -> Array1<f32> {
478 if constraints.is_empty() {
479 return x.clone();
480 }
481 let n = x.len();
482 let m = constraints.len();
483
484 let mut z = x.clone();
486 let mut increments: Vec<Array1<f32>> = vec![Array1::zeros(n); m];
487
488 for _ in 0..max_iter {
489 let prev = z.clone();
490 for (k, constraint) in constraints.iter().enumerate() {
491 let y = &z + &increments[k];
492 let proj = constraint.project(&y);
493 for j in 0..n {
495 increments[k][j] = y[j] - proj[j];
496 }
497 z = proj;
498 }
499 let max_diff = z
501 .iter()
502 .zip(prev.iter())
503 .map(|(&a, &b)| (a - b).abs())
504 .fold(0.0f32, f32::max);
505 if max_diff < 1e-6 {
506 break;
507 }
508 }
509 z
510}
511
512#[derive(Debug, Clone)]
518pub struct PropagationResult {
519 pub converged: bool,
521 pub iterations: usize,
523 pub num_violations: usize,
525}
526
527pub struct ConstraintGraph {
534 num_vars: usize,
535 constraints: Vec<Box<dyn FastConstraint>>,
536 var_indices: Vec<Vec<usize>>,
538 adjacency: Vec<Vec<usize>>,
540}
541
542impl ConstraintGraph {
543 pub fn new(num_vars: usize) -> Self {
545 Self {
546 num_vars,
547 constraints: Vec::new(),
548 var_indices: Vec::new(),
549 adjacency: Vec::new(),
550 }
551 }
552
553 pub fn num_constraints(&self) -> usize {
555 self.constraints.len()
556 }
557
558 pub fn num_vars(&self) -> usize {
560 self.num_vars
561 }
562
563 pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>, var_indices: Vec<usize>) {
568 let new_idx = self.constraints.len();
569 self.constraints.push(constraint);
570
571 let mut neighbors = Vec::new();
573 for (existing_idx, existing_vars) in self.var_indices.iter().enumerate() {
574 let shares = var_indices.iter().any(|v| existing_vars.contains(v));
575 if shares {
576 neighbors.push(existing_idx);
577 self.adjacency[existing_idx].push(new_idx);
578 }
579 }
580
581 self.var_indices.push(var_indices);
582 self.adjacency.push(neighbors);
583 }
584
585 pub fn independent_sets(&self) -> Vec<Vec<usize>> {
595 let n = self.constraints.len();
596 if n == 0 {
597 return Vec::new();
598 }
599
600 let mut order: Vec<usize> = (0..n).collect();
602 order.sort_by_key(|&i| std::cmp::Reverse(self.adjacency[i].len()));
603
604 let mut colors: Vec<Option<usize>> = vec![None; n];
605 let mut num_colors = 0usize;
606
607 for &node in &order {
608 let used_colors: std::collections::HashSet<usize> = self.adjacency[node]
610 .iter()
611 .filter_map(|&nb| colors[nb])
612 .collect();
613
614 let color = (0..).find(|c| !used_colors.contains(c)).unwrap_or(0);
615 colors[node] = Some(color);
616 if color >= num_colors {
617 num_colors = color + 1;
618 }
619 }
620
621 let mut sets: Vec<Vec<usize>> = vec![Vec::new(); num_colors];
622 for (node, color) in colors.iter().enumerate() {
623 if let Some(c) = color {
624 sets[*c].push(node);
625 }
626 }
627 sets
628 }
629
630 pub fn propagate_parallel(&self, x: &mut Array1<f32>) -> PropagationResult {
639 let max_iter = 50usize;
640 let tol = 1e-6f32;
641 let sets = self.independent_sets();
642
643 let mut iterations = 0usize;
644 let mut converged = false;
645
646 for _global_iter in 0..max_iter {
647 iterations += 1;
648 let prev = x.clone();
649
650 for set in &sets {
651 if set.is_empty() {
652 continue;
653 }
654 let projections: Vec<(usize, Array1<f32>)> = set
658 .par_iter()
659 .map(|&c_idx| {
660 let proj = self.constraints[c_idx].project(x);
661 (c_idx, proj)
662 })
663 .collect();
664
665 for (c_idx, proj) in &projections {
667 for &var in &self.var_indices[*c_idx] {
668 if var < x.len() {
669 x[var] = proj[var];
670 }
671 }
672 }
673 }
674
675 let max_diff = x
677 .iter()
678 .zip(prev.iter())
679 .map(|(&a, &b)| (a - b).abs())
680 .fold(0.0f32, f32::max);
681
682 if max_diff < tol {
683 converged = true;
684 break;
685 }
686 }
687
688 let num_violations = self
690 .constraints
691 .iter()
692 .filter(|c| !c.is_feasible(x))
693 .count();
694
695 PropagationResult {
696 converged,
697 iterations,
698 num_violations,
699 }
700 }
701}
702
703pub struct SimdConstraintEvaluator {
715 lb: Vec<f32>,
717 ub: Vec<f32>,
719 dim: usize,
721 num_constraints: usize,
723}
724
725impl SimdConstraintEvaluator {
726 pub fn new(bounds: Vec<(Vec<f32>, Vec<f32>)>) -> Result<Self, String> {
734 if bounds.is_empty() {
735 return Ok(Self {
736 lb: Vec::new(),
737 ub: Vec::new(),
738 dim: 0,
739 num_constraints: 0,
740 });
741 }
742
743 let dim = bounds[0].0.len();
744 for (k, (l, u)) in bounds.iter().enumerate() {
745 if l.len() != dim {
746 return Err(format!(
747 "SimdConstraintEvaluator: bounds[{k}].0.len()={} != dim={dim}",
748 l.len()
749 ));
750 }
751 if u.len() != dim {
752 return Err(format!(
753 "SimdConstraintEvaluator: bounds[{k}].1.len()={} != dim={dim}",
754 u.len()
755 ));
756 }
757 for i in 0..dim {
758 if l[i] > u[i] {
759 return Err(format!(
760 "SimdConstraintEvaluator: bounds[{k}].lb[{i}]={} > ub[{i}]={}",
761 l[i], u[i]
762 ));
763 }
764 }
765 }
766
767 let num_constraints = bounds.len();
768 let mut lb_flat = vec![0.0f32; num_constraints * dim];
769 let mut ub_flat = vec![0.0f32; num_constraints * dim];
770
771 for (k, (l, u)) in bounds.iter().enumerate() {
772 for i in 0..dim {
773 lb_flat[k * dim + i] = l[i];
774 ub_flat[k * dim + i] = u[i];
775 }
776 }
777
778 Ok(Self {
779 lb: lb_flat,
780 ub: ub_flat,
781 dim,
782 num_constraints,
783 })
784 }
785
786 pub fn num_constraints(&self) -> usize {
788 self.num_constraints
789 }
790
791 pub fn evaluate(&self, x: &[f32]) -> Vec<f32> {
796 (0..self.num_constraints)
797 .map(|c| {
798 let base = c * self.dim;
799 let lb_slice = &self.lb[base..base + self.dim];
800 let ub_slice = &self.ub[base..base + self.dim];
801 let v: f32 = x
803 .iter()
804 .zip(lb_slice.iter())
805 .zip(ub_slice.iter())
806 .map(|((&xi, &lbi), &ubi)| {
807 let lb_viol = (lbi - xi).max(0.0);
808 let ub_viol = (xi - ubi).max(0.0);
809 lb_viol * lb_viol + ub_viol * ub_viol
810 })
811 .sum();
812 v.sqrt()
813 })
814 .collect()
815 }
816
817 pub fn evaluate_batch(&self, points: &Array2<f32>) -> Array2<f32> {
821 let (n_points, _dim) = points.dim();
822 let n_c = self.num_constraints;
823
824 if n_points == 0 || n_c == 0 {
825 return Array2::zeros((n_points, n_c));
826 }
827
828 let rows: Vec<Vec<f32>> = (0..n_points)
829 .into_par_iter()
830 .map(|i| {
831 let row: Vec<f32> = points
832 .slice(scirs2_core::ndarray::s![i, ..])
833 .iter()
834 .copied()
835 .collect();
836 self.evaluate(&row)
837 })
838 .collect();
839
840 let mut out = Array2::zeros((n_points, n_c));
841 for (i, row) in rows.iter().enumerate() {
842 for (j, &v) in row.iter().enumerate() {
843 out[[i, j]] = v;
844 }
845 }
846 out
847 }
848
849 pub fn is_feasible_fast(&self, x: &[f32]) -> bool {
854 (0..self.num_constraints).all(|c| {
855 let base = c * self.dim;
856 let lb_slice = &self.lb[base..base + self.dim];
858 let ub_slice = &self.ub[base..base + self.dim];
859 let max_viol: f32 = x
860 .iter()
861 .zip(lb_slice.iter())
862 .zip(ub_slice.iter())
863 .map(|((&xi, &lbi), &ubi)| (lbi - xi).max(0.0) + (xi - ubi).max(0.0))
864 .sum();
865 max_viol == 0.0
866 })
867 }
868}
869
870#[derive(Debug, Clone)]
876pub struct SolverResult {
877 pub solution: Array1<f32>,
879 pub feasible: bool,
881 pub iterations: usize,
883 pub num_violations: usize,
885 pub solve_time_us: u64,
887}
888
889pub struct IncrementalParallelSolver {
898 #[allow(dead_code)]
899 config: ParallelConfig,
900 constraints: Vec<Box<dyn FastConstraint>>,
901 solution: Option<Array1<f32>>,
902 solution_valid: bool,
903}
904
905impl IncrementalParallelSolver {
906 pub fn new(config: ParallelConfig) -> Self {
908 Self {
909 config,
910 constraints: Vec::new(),
911 solution: None,
912 solution_valid: false,
913 }
914 }
915
916 pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>) {
919 if self.solution_valid {
920 if let Some(ref sol) = self.solution {
921 if !constraint.is_feasible(sol) {
922 self.solution_valid = false;
923 }
924 }
925 }
926 self.constraints.push(constraint);
927 }
928
929 pub fn remove_constraint(&mut self, idx: usize) -> bool {
934 if idx >= self.constraints.len() {
935 return false;
936 }
937 self.constraints.remove(idx);
938 true
941 }
942
943 pub fn invalidate(&mut self) {
945 self.solution_valid = false;
946 }
947
948 pub fn current_solution(&self) -> Option<&Array1<f32>> {
950 self.solution.as_ref()
951 }
952
953 pub fn num_constraints(&self) -> usize {
955 self.constraints.len()
956 }
957
958 pub fn solve(&mut self, init: Array1<f32>, max_iter: usize) -> SolverResult {
965 let start = Instant::now();
966
967 let start_point = if self.solution_valid {
968 self.solution.clone().unwrap_or_else(|| init.clone())
969 } else {
970 init.clone()
971 };
972
973 let actual_max_iter = if self.solution_valid {
975 (max_iter / 4).max(1)
976 } else {
977 max_iter
978 };
979
980 let result = dykstra_project(&start_point, &self.constraints, actual_max_iter);
981
982 let num_violations = self
984 .constraints
985 .iter()
986 .filter(|c| !c.is_feasible(&result))
987 .count();
988 let feasible = num_violations == 0;
989
990 let elapsed_us = start.elapsed().as_micros() as u64;
991
992 self.solution = Some(result.clone());
993 self.solution_valid = feasible;
994
995 SolverResult {
996 solution: result,
997 feasible,
998 iterations: actual_max_iter,
999 num_violations,
1000 solve_time_us: elapsed_us,
1001 }
1002 }
1003}
1004
1005#[cfg(test)]
1010mod tests {
1011 use super::*;
1012 use scirs2_core::ndarray::Array1;
1013
1014 fn make_box() -> BoxConstraint {
1015 BoxConstraint::new(
1016 Array1::from(vec![0.0f32, 0.0, 0.0]),
1017 Array1::from(vec![1.0f32, 2.0, 3.0]),
1018 )
1019 .expect("valid box")
1020 }
1021
1022 #[test]
1026 fn test_box_constraint_feasible() {
1027 let bc = make_box();
1028 let x = Array1::from(vec![0.5f32, 1.0, 2.0]);
1029 assert!(bc.is_feasible(&x));
1030 assert_eq!(bc.violation(&x), 0.0);
1031 }
1032
1033 #[test]
1037 fn test_box_constraint_project() {
1038 let bc = make_box();
1039 let x = Array1::from(vec![-1.0f32, 3.0, 5.0]); let p = bc.project(&x);
1041 assert!((p[0] - 0.0).abs() < 1e-5, "clamped to lb");
1042 assert!((p[1] - 2.0).abs() < 1e-5, "clamped to ub");
1043 assert!((p[2] - 3.0).abs() < 1e-5, "clamped to ub");
1044 assert!(bc.is_feasible(&p));
1045 }
1046
1047 #[test]
1051 fn test_l2_ball_project() {
1052 let center = Array1::from(vec![0.0f32, 0.0]);
1053 let ball = L2BallConstraint::new(center, 1.0).expect("valid ball");
1054
1055 let outside = Array1::from(vec![3.0f32, 4.0]); let p = ball.project(&outside);
1057 let dist: f32 = p.iter().map(|&v| v * v).sum::<f32>().sqrt();
1058 assert!(
1059 (dist - 1.0).abs() < 1e-4,
1060 "projected onto ball surface, dist={dist}"
1061 );
1062 assert!(ball.is_feasible(&p));
1063
1064 let inside = Array1::from(vec![0.1f32, 0.1]);
1065 let p2 = ball.project(&inside);
1066 assert!((p2[0] - inside[0]).abs() < 1e-5, "inside point unchanged");
1067 }
1068
1069 #[test]
1073 fn test_hyperplane_project() {
1074 let normal = Array1::from(vec![1.0f32, 1.0]);
1076 let hp = HyperplaneConstraint::new(normal, 1.0).expect("valid hyperplane");
1077
1078 let violating = Array1::from(vec![2.0f32, 2.0]); let p = hp.project(&violating);
1080 let ax: f32 = p[0] + p[1];
1081 assert!(
1082 ax <= 1.0 + 1e-5,
1083 "projected point satisfies a^T x <= b, got {ax}"
1084 );
1085 assert!(hp.is_feasible(&p));
1086
1087 let ok = Array1::from(vec![0.3f32, 0.3]);
1088 assert!(hp.is_feasible(&ok));
1089 assert_eq!(hp.violation(&ok), 0.0);
1090 }
1091
1092 #[test]
1096 fn test_simplex_project() {
1097 let simplex = SimplexConstraint::new(4);
1098 let x = Array1::from(vec![1.0f32, 2.0, 3.0, 4.0]); let p = simplex.project(&x);
1100
1101 let sum: f32 = p.iter().sum();
1102 assert!(
1103 (sum - 1.0).abs() < 1e-4,
1104 "sum of projected point should be 1, got {sum}"
1105 );
1106 for &v in p.iter() {
1107 assert!(v >= -1e-5, "all components should be non-negative, got {v}");
1108 }
1109 assert!(simplex.is_feasible(&p));
1110 }
1111
1112 #[test]
1116 fn test_parallel_checker_batch() {
1117 let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
1118 checker.add_constraint(Box::new(make_box()));
1119
1120 let mut data = vec![0.0f32; 100 * 3];
1122 for i in 0..50usize {
1123 data[i * 3] = 0.5;
1124 data[i * 3 + 1] = 1.0;
1125 data[i * 3 + 2] = 1.5;
1126 }
1127 for i in 50..100usize {
1128 data[i * 3] = 5.0; data[i * 3 + 1] = 0.5;
1130 data[i * 3 + 2] = 0.5;
1131 }
1132 let points = Array2::from_shape_vec((100, 3), data).expect("valid shape");
1133 let results = checker.check_batch(&points);
1134 assert_eq!(results.len(), 100);
1135 let feasible_count = results.iter().filter(|&&f| f).count();
1136 assert_eq!(
1137 feasible_count, 50,
1138 "expected 50 feasible points, got {feasible_count}"
1139 );
1140 }
1141
1142 #[test]
1146 fn test_parallel_checker_violations() {
1147 let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
1148 checker.add_constraint(Box::new(make_box()));
1149 checker.add_constraint(Box::new(
1150 L2BallConstraint::new(Array1::from(vec![0.5f32, 1.0, 1.5]), 2.0).expect("valid ball"),
1151 ));
1152
1153 let data: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0, 3.0, 4.0];
1154 let points = Array2::from_shape_vec((2, 3), data).expect("valid shape");
1155 let mat = checker.violation_matrix(&points);
1156 assert_eq!(mat.dim(), (2, 2), "expected (2, 2) violation matrix");
1157 assert!(mat[[0, 0]] < 1e-5, "first point, box violation should be 0");
1159 }
1160
1161 #[test]
1165 fn test_parallel_checker_project_batch() {
1166 let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
1167 checker.add_constraint(Box::new(make_box()));
1168
1169 let data: Vec<f32> = vec![-1.0, 5.0, 10.0, -2.0, 3.0, 7.0];
1171 let points = Array2::from_shape_vec((2, 3), data).expect("valid shape");
1172 let projected = checker.project_batch(&points, 50);
1173 assert_eq!(projected.dim(), (2, 3));
1174 for i in 0..2usize {
1175 let row: Array1<f32> = projected.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
1176 assert!(
1177 make_box().is_feasible(&row),
1178 "projected row {i} should be feasible"
1179 );
1180 }
1181 }
1182
1183 #[test]
1187 fn test_constraint_graph_independent_sets() {
1188 let mut graph = ConstraintGraph::new(3);
1190 graph.add_constraint(Box::new(make_box()), vec![0, 1]);
1191 graph.add_constraint(Box::new(make_box()), vec![0, 2]); graph.add_constraint(Box::new(make_box()), vec![1, 2]); let sets = graph.independent_sets();
1195 for set in &sets {
1197 for (a_idx, &a) in set.iter().enumerate() {
1198 for &b in set.iter().skip(a_idx + 1) {
1199 assert!(
1200 !graph.adjacency[a].contains(&b),
1201 "constraints {a} and {b} are adjacent but in the same independent set"
1202 );
1203 }
1204 }
1205 }
1206 let mut seen = std::collections::HashSet::new();
1208 for set in &sets {
1209 for &c in set {
1210 assert!(seen.insert(c), "constraint {c} appears in multiple sets");
1211 }
1212 }
1213 assert_eq!(seen.len(), 3, "all 3 constraints must appear");
1214 }
1215
1216 #[test]
1220 fn test_simd_evaluator_batch() {
1221 let bounds = vec![
1222 (vec![0.0f32, 0.0, 0.0], vec![1.0f32, 2.0, 3.0]),
1223 (vec![-1.0f32, -1.0, -1.0], vec![1.0f32, 1.0, 1.0]),
1224 ];
1225 let evaluator = SimdConstraintEvaluator::new(bounds).expect("valid bounds");
1226
1227 let data = vec![0.5f32, 1.5, 2.5, 2.0, 0.0, 0.0];
1228 let points = Array2::from_shape_vec((2, 3), data.clone()).expect("valid shape");
1229
1230 let batch = evaluator.evaluate_batch(&points);
1231
1232 for i in 0..2usize {
1234 let row = &data[i * 3..(i + 1) * 3];
1235 let seq = evaluator.evaluate(row);
1236 for j in 0..evaluator.num_constraints() {
1237 assert!(
1238 (batch[[i, j]] - seq[j]).abs() < 1e-5,
1239 "batch[{i},{j}]={} != seq[{j}]={}",
1240 batch[[i, j]],
1241 seq[j]
1242 );
1243 }
1244 }
1245 }
1246
1247 #[test]
1251 fn test_simd_evaluator_fast_feasibility() {
1252 let bounds = vec![
1253 (vec![0.0f32, 0.0, 0.0], vec![1.0f32, 2.0, 3.0]),
1254 (vec![-5.0f32, -5.0, -5.0], vec![5.0f32, 5.0, 5.0]),
1255 ];
1256 let evaluator = SimdConstraintEvaluator::new(bounds).expect("valid bounds");
1257
1258 let feasible = vec![0.5f32, 1.0, 2.0];
1259 assert!(evaluator.is_feasible_fast(&feasible), "point is feasible");
1260
1261 let infeasible = vec![2.0f32, 1.0, 2.0]; assert!(
1263 !evaluator.is_feasible_fast(&infeasible),
1264 "point is infeasible"
1265 );
1266 }
1267
1268 #[test]
1272 fn test_incremental_solver_add_constraint() {
1273 let mut solver = IncrementalParallelSolver::new(ParallelConfig::default());
1274
1275 solver.add_constraint(Box::new(make_box()));
1277
1278 let init = Array1::from(vec![5.0f32, 5.0, 5.0]);
1280 let result = solver.solve(init, 50);
1281
1282 assert!(
1283 result.feasible,
1284 "solution should be feasible after solving with box constraint"
1285 );
1286 assert!(make_box().is_feasible(&result.solution));
1287
1288 let tight_box = BoxConstraint::new(
1290 Array1::from(vec![0.0f32, 0.0, 0.0]),
1291 Array1::from(vec![0.5f32, 0.5, 0.5]),
1292 )
1293 .expect("valid box");
1294 solver.add_constraint(Box::new(tight_box));
1296
1297 let init2 = Array1::from(vec![1.0f32, 2.0, 3.0]);
1299 let result2 = solver.solve(init2, 100);
1300 assert!(
1301 result2.feasible,
1302 "solution should be feasible after adding tighter constraint"
1303 );
1304 assert!(result2.solution[0] <= 0.5 + 1e-4);
1305 assert!(result2.solution[1] <= 0.5 + 1e-4);
1306 assert!(result2.solution[2] <= 0.5 + 1e-4);
1307 }
1308
1309 #[test]
1313 fn test_incremental_solver_warmstart() {
1314 let mut solver = IncrementalParallelSolver::new(ParallelConfig::default());
1315 solver.add_constraint(Box::new(make_box()));
1316
1317 let init = Array1::from(vec![0.5f32, 1.0, 1.5]);
1319 let cold_result = solver.solve(init.clone(), 100);
1320 assert!(cold_result.feasible);
1321
1322 let warm_result = solver.solve(init, 100);
1324 assert!(warm_result.feasible);
1325 assert!(
1326 warm_result.iterations <= cold_result.iterations,
1327 "warm start should use fewer or equal iterations: warm={} cold={}",
1328 warm_result.iterations,
1329 cold_result.iterations
1330 );
1331 }
1332}