Skip to main content

kizzasi_logic/
parallel_solver.rs

1//! Advanced Parallelization for Constraint Evaluation
2//!
3//! Provides GPU-accelerated batch projection, parallel constraint graph traversal,
4//! and SIMD-optimized constraint checking for high-performance constraint satisfaction.
5//!
6//! # Key Components
7//!
8//! - [`ParallelFeasibilityChecker`] — parallel batch feasibility checking with rayon
9//! - [`ConstraintGraph`] — dependency-aware parallel graph traversal with graph coloring
10//! - [`SimdConstraintEvaluator`] — SIMD/auto-vectorization friendly box constraint evaluation
11//! - [`IncrementalParallelSolver`] — warm-start incremental solving with parallel projections
12//!
13//! # Design Notes
14//!
15//! All inner loops are written in LLVM auto-vectorization friendly patterns
16//! (no early exits, simple accumulator patterns) so the compiler can emit SIMD
17//! instructions (SSE/AVX/NEON) automatically.
18
19use rayon::prelude::*;
20use scirs2_core::ndarray::{Array1, Array2};
21use std::time::Instant;
22
23// ============================================================================
24// Configuration
25// ============================================================================
26
27/// Configuration for parallel constraint solving
28#[derive(Debug, Clone)]
29pub struct ParallelConfig {
30    /// Number of threads — 0 means auto-detect (uses rayon global pool)
31    pub num_threads: usize,
32    /// Number of points per worker chunk (default 64)
33    pub chunk_size: usize,
34    /// Enable SIMD-optimized inner loops
35    pub use_simd: bool,
36    /// Data prefetch distance (default 8)
37    pub prefetch_distance: usize,
38}
39
40impl Default for ParallelConfig {
41    fn default() -> Self {
42        Self {
43            num_threads: 0,
44            chunk_size: 64,
45            use_simd: true,
46            prefetch_distance: 8,
47        }
48    }
49}
50
51// ============================================================================
52// FastConstraint trait
53// ============================================================================
54
55/// A constraint represented as a simple function over f32 arrays.
56///
57/// Implementations must be `Send + Sync` so they can be shared across rayon threads.
58pub trait FastConstraint: Send + Sync {
59    /// Return `true` if point `x` satisfies this constraint
60    fn is_feasible(&self, x: &Array1<f32>) -> bool;
61
62    /// Project `x` onto the feasible set of this constraint.
63    /// If `x` is already feasible, returns a clone of `x`.
64    fn project(&self, x: &Array1<f32>) -> Array1<f32>;
65
66    /// Return the constraint violation (0 if feasible, > 0 otherwise).
67    /// The violation is a non-negative scalar indicating how far `x` is from feasibility.
68    fn violation(&self, x: &Array1<f32>) -> f32;
69}
70
71// ============================================================================
72// BoxConstraint: lb <= x <= ub
73// ============================================================================
74
75/// Box (bound) constraint: `lb[i] <= x[i] <= ub[i]` for all `i`.
76///
77/// This is the most common constraint type in practice. Projection is
78/// element-wise clipping, which is trivially SIMD-friendly.
79#[derive(Debug, Clone)]
80pub struct BoxConstraint {
81    /// Lower bounds (element-wise)
82    pub lb: Array1<f32>,
83    /// Upper bounds (element-wise)
84    pub ub: Array1<f32>,
85}
86
87impl BoxConstraint {
88    /// Create a new box constraint.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error string if `lb` and `ub` have different lengths or if any
93    /// `lb[i] > ub[i]`.
94    pub fn new(lb: Array1<f32>, ub: Array1<f32>) -> Result<Self, String> {
95        if lb.len() != ub.len() {
96            return Err(format!(
97                "BoxConstraint: lb.len()={} != ub.len()={}",
98                lb.len(),
99                ub.len()
100            ));
101        }
102        for (i, (&l, &u)) in lb.iter().zip(ub.iter()).enumerate() {
103            if l > u {
104                return Err(format!("BoxConstraint: lb[{i}]={l} > ub[{i}]={u}"));
105            }
106        }
107        Ok(Self { lb, ub })
108    }
109}
110
111impl FastConstraint for BoxConstraint {
112    fn is_feasible(&self, x: &Array1<f32>) -> bool {
113        // Written as a fold so LLVM can vectorize (no short-circuit)
114        x.iter()
115            .zip(self.lb.iter())
116            .zip(self.ub.iter())
117            .map(|((&xi, &li), &ui)| if xi < li || xi > ui { 1u8 } else { 0u8 })
118            .sum::<u8>()
119            == 0
120    }
121
122    fn project(&self, x: &Array1<f32>) -> Array1<f32> {
123        let n = x.len();
124        let mut out = vec![0.0f32; n];
125        for i in 0..n {
126            out[i] = x[i].clamp(self.lb[i], self.ub[i]);
127        }
128        Array1::from(out)
129    }
130
131    fn violation(&self, x: &Array1<f32>) -> f32 {
132        let mut v = 0.0f32;
133        for i in 0..x.len() {
134            let below = (self.lb[i] - x[i]).max(0.0);
135            let above = (x[i] - self.ub[i]).max(0.0);
136            v += below * below + above * above;
137        }
138        v.sqrt()
139    }
140}
141
142// ============================================================================
143// L2BallConstraint: ||x - center||_2 <= radius
144// ============================================================================
145
146/// L2 ball constraint: `||x - center||_2 <= radius`.
147///
148/// Projection onto the L2 ball: if inside, keep as-is; if outside, scale to the boundary.
149#[derive(Debug, Clone)]
150pub struct L2BallConstraint {
151    /// Centre of the ball
152    pub center: Array1<f32>,
153    /// Radius of the ball (must be positive)
154    pub radius: f32,
155}
156
157impl L2BallConstraint {
158    /// Create a new L2 ball constraint.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error string if `radius <= 0`.
163    pub fn new(center: Array1<f32>, radius: f32) -> Result<Self, String> {
164        if radius <= 0.0 {
165            return Err(format!(
166                "L2BallConstraint: radius must be positive, got {radius}"
167            ));
168        }
169        Ok(Self { center, radius })
170    }
171
172    fn dist_sq(&self, x: &Array1<f32>) -> f32 {
173        let mut s = 0.0f32;
174        for i in 0..x.len() {
175            let d = x[i] - self.center[i];
176            s += d * d;
177        }
178        s
179    }
180}
181
182impl FastConstraint for L2BallConstraint {
183    fn is_feasible(&self, x: &Array1<f32>) -> bool {
184        self.dist_sq(x) <= self.radius * self.radius
185    }
186
187    fn project(&self, x: &Array1<f32>) -> Array1<f32> {
188        let dist_sq = self.dist_sq(x);
189        if dist_sq <= self.radius * self.radius {
190            return x.clone();
191        }
192        let dist = dist_sq.sqrt();
193        let scale = self.radius / dist;
194        let n = x.len();
195        let mut out = vec![0.0f32; n];
196        for i in 0..n {
197            out[i] = self.center[i] + (x[i] - self.center[i]) * scale;
198        }
199        Array1::from(out)
200    }
201
202    fn violation(&self, x: &Array1<f32>) -> f32 {
203        let dist = self.dist_sq(x).sqrt();
204        (dist - self.radius).max(0.0)
205    }
206}
207
208// ============================================================================
209// HyperplaneConstraint: a^T x <= b
210// ============================================================================
211
212/// Hyperplane (half-space) constraint: `a^T x <= b`.
213///
214/// Projection: if feasible keep x; otherwise project onto the hyperplane
215/// `{y : a^T y = b}` along the normal direction.
216#[derive(Debug, Clone)]
217pub struct HyperplaneConstraint {
218    /// Normal vector `a`
219    pub normal: Array1<f32>,
220    /// Offset `b`
221    pub offset: f32,
222}
223
224impl HyperplaneConstraint {
225    /// Create a new hyperplane constraint.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error string if `normal` is the zero vector.
230    pub fn new(normal: Array1<f32>, offset: f32) -> Result<Self, String> {
231        let norm_sq: f32 = normal.iter().map(|&v| v * v).sum();
232        if norm_sq == 0.0 {
233            return Err("HyperplaneConstraint: normal vector must be non-zero".to_string());
234        }
235        Ok(Self { normal, offset })
236    }
237
238    fn dot(&self, x: &Array1<f32>) -> f32 {
239        let mut s = 0.0f32;
240        for i in 0..x.len() {
241            s += self.normal[i] * x[i];
242        }
243        s
244    }
245
246    fn norm_sq(&self) -> f32 {
247        let mut s = 0.0f32;
248        for &v in self.normal.iter() {
249            s += v * v;
250        }
251        s
252    }
253}
254
255impl FastConstraint for HyperplaneConstraint {
256    fn is_feasible(&self, x: &Array1<f32>) -> bool {
257        self.dot(x) <= self.offset
258    }
259
260    fn project(&self, x: &Array1<f32>) -> Array1<f32> {
261        let ax = self.dot(x);
262        if ax <= self.offset {
263            return x.clone();
264        }
265        // Project: x - ((a^T x - b) / ||a||^2) * a
266        let scale = (ax - self.offset) / self.norm_sq();
267        let n = x.len();
268        let mut out = vec![0.0f32; n];
269        for i in 0..n {
270            out[i] = x[i] - scale * self.normal[i];
271        }
272        Array1::from(out)
273    }
274
275    fn violation(&self, x: &Array1<f32>) -> f32 {
276        (self.dot(x) - self.offset).max(0.0)
277    }
278}
279
280// ============================================================================
281// SimplexConstraint: x >= 0, sum(x) = 1
282// ============================================================================
283
284/// Probability simplex constraint: `x[i] >= 0` for all `i` and `sum(x) = 1`.
285///
286/// Uses the O(n log n) algorithm by Duchi et al. (2008).
287#[derive(Debug, Clone)]
288pub struct SimplexConstraint {
289    /// Dimension of the simplex
290    pub dim: usize,
291}
292
293impl SimplexConstraint {
294    /// Create a new simplex constraint.
295    pub fn new(dim: usize) -> Self {
296        Self { dim }
297    }
298}
299
300impl FastConstraint for SimplexConstraint {
301    fn is_feasible(&self, x: &Array1<f32>) -> bool {
302        if x.len() != self.dim {
303            return false;
304        }
305        let sum: f32 = x.iter().sum();
306        let all_nonneg = x
307            .iter()
308            .map(|&v| if v < 0.0 { 1u8 } else { 0u8 })
309            .sum::<u8>()
310            == 0;
311        all_nonneg && (sum - 1.0).abs() < 1e-5
312    }
313
314    fn project(&self, x: &Array1<f32>) -> Array1<f32> {
315        let n = x.len();
316        // Duchi et al. O(n log n) simplex projection
317        let mut sorted: Vec<f32> = x.iter().copied().collect();
318        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
319
320        let mut cumsum = 0.0f32;
321        let mut rho = 0usize;
322        for (j, &s) in sorted.iter().enumerate() {
323            cumsum += s;
324            if s > (cumsum - 1.0) / (j as f32 + 1.0) {
325                rho = j;
326            }
327        }
328        let cumsum_rho: f32 = sorted[..=rho].iter().sum();
329        let theta = (cumsum_rho - 1.0) / (rho as f32 + 1.0);
330
331        let mut out = vec![0.0f32; n];
332        for i in 0..n {
333            out[i] = (x[i] - theta).max(0.0);
334        }
335        Array1::from(out)
336    }
337
338    fn violation(&self, x: &Array1<f32>) -> f32 {
339        let sum: f32 = x.iter().sum();
340        let sum_viol = (sum - 1.0).abs();
341        let neg_viol: f32 = x.iter().map(|&v| (-v).max(0.0)).sum();
342        sum_viol + neg_viol
343    }
344}
345
346// ============================================================================
347// ParallelFeasibilityChecker
348// ============================================================================
349
350/// Parallel batch feasibility checker.
351///
352/// Checks N points against M constraints simultaneously using rayon.
353/// The outermost parallelism is over points (each point is independent);
354/// within a point, constraints are checked sequentially for cache efficiency.
355pub struct ParallelFeasibilityChecker {
356    constraints: Vec<Box<dyn FastConstraint>>,
357    config: ParallelConfig,
358}
359
360impl ParallelFeasibilityChecker {
361    /// Create a new checker with the given configuration.
362    pub fn new(config: ParallelConfig) -> Self {
363        Self {
364            constraints: Vec::new(),
365            config,
366        }
367    }
368
369    /// Add a constraint to the checker.
370    pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>) {
371        self.constraints.push(constraint);
372    }
373
374    /// Number of registered constraints.
375    pub fn num_constraints(&self) -> usize {
376        self.constraints.len()
377    }
378
379    /// Check feasibility of a batch of points.
380    ///
381    /// `points` must be a `(num_points, dim)` matrix in row-major order.
382    /// Returns one `bool` per point indicating whether all constraints are satisfied.
383    pub fn check_batch(&self, points: &Array2<f32>) -> Vec<bool> {
384        let (n_points, dim) = points.dim();
385        let constraints = &self.constraints;
386        let chunk_size = self.config.chunk_size.max(1);
387
388        (0..n_points)
389            .into_par_iter()
390            .with_min_len(chunk_size)
391            .map(|i| {
392                let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
393                if row.len() != dim {
394                    return false;
395                }
396                constraints.iter().all(|c| c.is_feasible(&row))
397            })
398            .collect()
399    }
400
401    /// Compute violation for each point against all constraints.
402    ///
403    /// Returns a `(num_points, num_constraints)` matrix where entry `[i, j]`
404    /// is the violation of point `i` against constraint `j` (0 if feasible).
405    pub fn violation_matrix(&self, points: &Array2<f32>) -> Array2<f32> {
406        let (n_points, _dim) = points.dim();
407        let n_constraints = self.constraints.len();
408
409        if n_constraints == 0 || n_points == 0 {
410            return Array2::zeros((n_points, n_constraints));
411        }
412
413        let chunk_size = self.config.chunk_size.max(1);
414        let constraints = &self.constraints;
415
416        let rows: Vec<Vec<f32>> = (0..n_points)
417            .into_par_iter()
418            .with_min_len(chunk_size)
419            .map(|i| {
420                let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
421                constraints.iter().map(|c| c.violation(&row)).collect()
422            })
423            .collect();
424
425        let mut out = Array2::zeros((n_points, n_constraints));
426        for (i, row) in rows.iter().enumerate() {
427            for (j, &v) in row.iter().enumerate() {
428                out[[i, j]] = v;
429            }
430        }
431        out
432    }
433
434    /// Project all points to the constraint-satisfying region.
435    ///
436    /// Uses Dykstra's alternating projections algorithm in parallel — each point
437    /// is independent, so rayon parallelism is embarrassingly parallel here.
438    pub fn project_batch(&self, points: &Array2<f32>, max_iter: usize) -> Array2<f32> {
439        let (n_points, dim) = points.dim();
440        if n_points == 0 || dim == 0 {
441            return points.clone();
442        }
443
444        let chunk_size = self.config.chunk_size.max(1);
445        let constraints = &self.constraints;
446
447        let projected: Vec<Vec<f32>> = (0..n_points)
448            .into_par_iter()
449            .with_min_len(chunk_size)
450            .map(|i| {
451                let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
452                let result = dykstra_project(&row, constraints.as_slice(), max_iter);
453                result.into_raw_vec_and_offset().0
454            })
455            .collect();
456
457        let mut out = Array2::zeros((n_points, dim));
458        for (i, row) in projected.iter().enumerate() {
459            for (j, &v) in row.iter().enumerate() {
460                if j < dim {
461                    out[[i, j]] = v;
462                }
463            }
464        }
465        out
466    }
467}
468
469/// Dykstra's alternating projections algorithm.
470///
471/// Projects `x` onto the intersection of all constraint sets using the
472/// incremental projections variant with correction vectors.
473fn dykstra_project(
474    x: &Array1<f32>,
475    constraints: &[Box<dyn FastConstraint>],
476    max_iter: usize,
477) -> Array1<f32> {
478    if constraints.is_empty() {
479        return x.clone();
480    }
481    let n = x.len();
482    let m = constraints.len();
483
484    // Dykstra's algorithm: maintain increment vectors p_k for each constraint
485    let mut z = x.clone();
486    let mut increments: Vec<Array1<f32>> = vec![Array1::zeros(n); m];
487
488    for _ in 0..max_iter {
489        let prev = z.clone();
490        for (k, constraint) in constraints.iter().enumerate() {
491            let y = &z + &increments[k];
492            let proj = constraint.project(&y);
493            // Update increment: p_k = y - proj(y)
494            for j in 0..n {
495                increments[k][j] = y[j] - proj[j];
496            }
497            z = proj;
498        }
499        // Check convergence: ||z_new - z_old||_inf
500        let max_diff = z
501            .iter()
502            .zip(prev.iter())
503            .map(|(&a, &b)| (a - b).abs())
504            .fold(0.0f32, f32::max);
505        if max_diff < 1e-6 {
506            break;
507        }
508    }
509    z
510}
511
512// ============================================================================
513// ConstraintGraph
514// ============================================================================
515
516/// Result of constraint propagation through the graph.
517#[derive(Debug, Clone)]
518pub struct PropagationResult {
519    /// Whether propagation converged within the iteration budget
520    pub converged: bool,
521    /// Number of iterations performed
522    pub iterations: usize,
523    /// Number of constraints still violated at termination
524    pub num_violations: usize,
525}
526
527/// Constraint dependency graph.
528///
529/// Models which constraints share variables. This enables:
530/// - Graph coloring to find independent constraint sets
531/// - Parallel propagation within independent sets
532/// - Efficient incremental solving after small changes
533pub struct ConstraintGraph {
534    num_vars: usize,
535    constraints: Vec<Box<dyn FastConstraint>>,
536    /// Which variable indices each constraint touches
537    var_indices: Vec<Vec<usize>>,
538    /// Adjacency list: constraints that share at least one variable
539    adjacency: Vec<Vec<usize>>,
540}
541
542impl ConstraintGraph {
543    /// Create a new empty constraint graph over `num_vars` variables.
544    pub fn new(num_vars: usize) -> Self {
545        Self {
546            num_vars,
547            constraints: Vec::new(),
548            var_indices: Vec::new(),
549            adjacency: Vec::new(),
550        }
551    }
552
553    /// Number of constraints in the graph.
554    pub fn num_constraints(&self) -> usize {
555        self.constraints.len()
556    }
557
558    /// Number of variables in the graph.
559    pub fn num_vars(&self) -> usize {
560        self.num_vars
561    }
562
563    /// Add a constraint that touches the given variable indices.
564    ///
565    /// The adjacency list is updated so that this constraint becomes adjacent
566    /// to all existing constraints that share at least one variable index.
567    pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>, var_indices: Vec<usize>) {
568        let new_idx = self.constraints.len();
569        self.constraints.push(constraint);
570
571        // Compute adjacency with existing constraints
572        let mut neighbors = Vec::new();
573        for (existing_idx, existing_vars) in self.var_indices.iter().enumerate() {
574            let shares = var_indices.iter().any(|v| existing_vars.contains(v));
575            if shares {
576                neighbors.push(existing_idx);
577                self.adjacency[existing_idx].push(new_idx);
578            }
579        }
580
581        self.var_indices.push(var_indices);
582        self.adjacency.push(neighbors);
583    }
584
585    /// Greedy graph coloring — returns independent sets (constraints with same color).
586    ///
587    /// Uses Welsh-Powell ordering (decreasing degree) for better coloring quality.
588    /// Constraints in the same independent set can be checked/projected in parallel.
589    ///
590    /// # Returns
591    ///
592    /// A `Vec<Vec<usize>>` where each inner `Vec` is a set of constraint indices
593    /// that are mutually non-adjacent (i.e., share no variables).
594    pub fn independent_sets(&self) -> Vec<Vec<usize>> {
595        let n = self.constraints.len();
596        if n == 0 {
597            return Vec::new();
598        }
599
600        // Welsh-Powell: sort by degree descending
601        let mut order: Vec<usize> = (0..n).collect();
602        order.sort_by_key(|&i| std::cmp::Reverse(self.adjacency[i].len()));
603
604        let mut colors: Vec<Option<usize>> = vec![None; n];
605        let mut num_colors = 0usize;
606
607        for &node in &order {
608            // Find the smallest color not used by any neighbor
609            let used_colors: std::collections::HashSet<usize> = self.adjacency[node]
610                .iter()
611                .filter_map(|&nb| colors[nb])
612                .collect();
613
614            let color = (0..).find(|c| !used_colors.contains(c)).unwrap_or(0);
615            colors[node] = Some(color);
616            if color >= num_colors {
617                num_colors = color + 1;
618            }
619        }
620
621        let mut sets: Vec<Vec<usize>> = vec![Vec::new(); num_colors];
622        for (node, color) in colors.iter().enumerate() {
623            if let Some(c) = color {
624                sets[*c].push(node);
625            }
626        }
627        sets
628    }
629
630    /// Parallel constraint propagation.
631    ///
632    /// Projects `x` using constraints in independent-set order.
633    /// Within each independent set, constraints are applied in parallel via rayon
634    /// (each one operates on disjoint variable subsets, so there are no races).
635    ///
636    /// Falls back to sequential application when variable sets overlap within a color
637    /// (the graph coloring guarantees no overlap for correctly registered constraints).
638    pub fn propagate_parallel(&self, x: &mut Array1<f32>) -> PropagationResult {
639        let max_iter = 50usize;
640        let tol = 1e-6f32;
641        let sets = self.independent_sets();
642
643        let mut iterations = 0usize;
644        let mut converged = false;
645
646        for _global_iter in 0..max_iter {
647            iterations += 1;
648            let prev = x.clone();
649
650            for set in &sets {
651                if set.is_empty() {
652                    continue;
653                }
654                // Within a color group: constraints touch disjoint variables,
655                // so we can project independently and merge results.
656                // Compute projections in parallel, then write back.
657                let projections: Vec<(usize, Array1<f32>)> = set
658                    .par_iter()
659                    .map(|&c_idx| {
660                        let proj = self.constraints[c_idx].project(x);
661                        (c_idx, proj)
662                    })
663                    .collect();
664
665                // Apply variable updates (write disjoint variable subsets)
666                for (c_idx, proj) in &projections {
667                    for &var in &self.var_indices[*c_idx] {
668                        if var < x.len() {
669                            x[var] = proj[var];
670                        }
671                    }
672                }
673            }
674
675            // Convergence check
676            let max_diff = x
677                .iter()
678                .zip(prev.iter())
679                .map(|(&a, &b)| (a - b).abs())
680                .fold(0.0f32, f32::max);
681
682            if max_diff < tol {
683                converged = true;
684                break;
685            }
686        }
687
688        // Count remaining violations
689        let num_violations = self
690            .constraints
691            .iter()
692            .filter(|c| !c.is_feasible(x))
693            .count();
694
695        PropagationResult {
696            converged,
697            iterations,
698            num_violations,
699        }
700    }
701}
702
703// ============================================================================
704// SimdConstraintEvaluator
705// ============================================================================
706
707/// SIMD-accelerated batch box constraint evaluation.
708///
709/// Stores bounds as flat `Vec<f32>` arrays for cache-friendly sequential access.
710/// Inner loops are written to be friendly to LLVM's auto-vectorizer:
711/// - No branches inside the accumulation loop
712/// - Sequential memory access patterns
713/// - Simple arithmetic (mul/add/max)
714pub struct SimdConstraintEvaluator {
715    /// Flattened lower bounds: `lb[c * dim + i]` is the lower bound for constraint `c`, dim `i`
716    lb: Vec<f32>,
717    /// Flattened upper bounds
718    ub: Vec<f32>,
719    /// Dimension of each constraint (all must be equal)
720    dim: usize,
721    /// Number of constraints
722    num_constraints: usize,
723}
724
725impl SimdConstraintEvaluator {
726    /// Create a new evaluator from a list of `(lb, ub)` bound pairs.
727    ///
728    /// All bound vectors must have the same length (the variable dimension).
729    ///
730    /// # Errors
731    ///
732    /// Returns an error string if bounds have different lengths or if `lb[i] > ub[i]`.
733    pub fn new(bounds: Vec<(Vec<f32>, Vec<f32>)>) -> Result<Self, String> {
734        if bounds.is_empty() {
735            return Ok(Self {
736                lb: Vec::new(),
737                ub: Vec::new(),
738                dim: 0,
739                num_constraints: 0,
740            });
741        }
742
743        let dim = bounds[0].0.len();
744        for (k, (l, u)) in bounds.iter().enumerate() {
745            if l.len() != dim {
746                return Err(format!(
747                    "SimdConstraintEvaluator: bounds[{k}].0.len()={} != dim={dim}",
748                    l.len()
749                ));
750            }
751            if u.len() != dim {
752                return Err(format!(
753                    "SimdConstraintEvaluator: bounds[{k}].1.len()={} != dim={dim}",
754                    u.len()
755                ));
756            }
757            for i in 0..dim {
758                if l[i] > u[i] {
759                    return Err(format!(
760                        "SimdConstraintEvaluator: bounds[{k}].lb[{i}]={} > ub[{i}]={}",
761                        l[i], u[i]
762                    ));
763                }
764            }
765        }
766
767        let num_constraints = bounds.len();
768        let mut lb_flat = vec![0.0f32; num_constraints * dim];
769        let mut ub_flat = vec![0.0f32; num_constraints * dim];
770
771        for (k, (l, u)) in bounds.iter().enumerate() {
772            for i in 0..dim {
773                lb_flat[k * dim + i] = l[i];
774                ub_flat[k * dim + i] = u[i];
775            }
776        }
777
778        Ok(Self {
779            lb: lb_flat,
780            ub: ub_flat,
781            dim,
782            num_constraints,
783        })
784    }
785
786    /// Number of registered box constraints.
787    pub fn num_constraints(&self) -> usize {
788        self.num_constraints
789    }
790
791    /// Evaluate all box constraints for a single point `x`.
792    ///
793    /// Returns a `Vec<f32>` of length `num_constraints` where each entry is the
794    /// L2-distance violation (0 if feasible).  Inner loop is SIMD-friendly.
795    pub fn evaluate(&self, x: &[f32]) -> Vec<f32> {
796        (0..self.num_constraints)
797            .map(|c| {
798                let base = c * self.dim;
799                let lb_slice = &self.lb[base..base + self.dim];
800                let ub_slice = &self.ub[base..base + self.dim];
801                // Auto-vectorizable: no branches, simple arithmetic
802                let v: f32 = x
803                    .iter()
804                    .zip(lb_slice.iter())
805                    .zip(ub_slice.iter())
806                    .map(|((&xi, &lbi), &ubi)| {
807                        let lb_viol = (lbi - xi).max(0.0);
808                        let ub_viol = (xi - ubi).max(0.0);
809                        lb_viol * lb_viol + ub_viol * ub_viol
810                    })
811                    .sum();
812                v.sqrt()
813            })
814            .collect()
815    }
816
817    /// Batch evaluation: very cache-friendly inner loop.
818    ///
819    /// `points` is `(num_points, dim)`. Returns `(num_points, num_constraints)`.
820    pub fn evaluate_batch(&self, points: &Array2<f32>) -> Array2<f32> {
821        let (n_points, _dim) = points.dim();
822        let n_c = self.num_constraints;
823
824        if n_points == 0 || n_c == 0 {
825            return Array2::zeros((n_points, n_c));
826        }
827
828        let rows: Vec<Vec<f32>> = (0..n_points)
829            .into_par_iter()
830            .map(|i| {
831                let row: Vec<f32> = points
832                    .slice(scirs2_core::ndarray::s![i, ..])
833                    .iter()
834                    .copied()
835                    .collect();
836                self.evaluate(&row)
837            })
838            .collect();
839
840        let mut out = Array2::zeros((n_points, n_c));
841        for (i, row) in rows.iter().enumerate() {
842            for (j, &v) in row.iter().enumerate() {
843                out[[i, j]] = v;
844            }
845        }
846        out
847    }
848
849    /// Fast feasibility check with early exit on first violation.
850    ///
851    /// Returns `true` only if `x` satisfies all box constraints.
852    /// Exits as soon as a violation is found, without evaluating remaining constraints.
853    pub fn is_feasible_fast(&self, x: &[f32]) -> bool {
854        (0..self.num_constraints).all(|c| {
855            let base = c * self.dim;
856            // Inner loop: SIMD-friendly accumulation using zip for cache locality
857            let lb_slice = &self.lb[base..base + self.dim];
858            let ub_slice = &self.ub[base..base + self.dim];
859            let max_viol: f32 = x
860                .iter()
861                .zip(lb_slice.iter())
862                .zip(ub_slice.iter())
863                .map(|((&xi, &lbi), &ubi)| (lbi - xi).max(0.0) + (xi - ubi).max(0.0))
864                .sum();
865            max_viol == 0.0
866        })
867    }
868}
869
870// ============================================================================
871// IncrementalParallelSolver
872// ============================================================================
873
874/// Result from the incremental parallel solver.
875#[derive(Debug, Clone)]
876pub struct SolverResult {
877    /// The solution (projected point)
878    pub solution: Array1<f32>,
879    /// Whether the solution satisfies all constraints
880    pub feasible: bool,
881    /// Number of projection iterations performed
882    pub iterations: usize,
883    /// Number of constraints violated at termination
884    pub num_violations: usize,
885    /// Wall-clock solve time in microseconds
886    pub solve_time_us: u64,
887}
888
889/// Multi-threaded incremental constraint solver.
890///
891/// Re-solves after constraint addition/removal without a full restart when possible:
892/// - On constraint *addition*: if the current solution is already feasible for the new
893///   constraint, no work is needed. Otherwise, only project onto the new constraint
894///   (warm start) before running full alternating projections.
895/// - On constraint *removal*: the solution remains feasible for remaining constraints,
896///   so no re-solve is needed; `solution_valid` is kept `true`.
897pub struct IncrementalParallelSolver {
898    #[allow(dead_code)]
899    config: ParallelConfig,
900    constraints: Vec<Box<dyn FastConstraint>>,
901    solution: Option<Array1<f32>>,
902    solution_valid: bool,
903}
904
905impl IncrementalParallelSolver {
906    /// Create a new solver with the given configuration.
907    pub fn new(config: ParallelConfig) -> Self {
908        Self {
909            config,
910            constraints: Vec::new(),
911            solution: None,
912            solution_valid: false,
913        }
914    }
915
916    /// Add a constraint. If a valid solution already exists and it satisfies the new
917    /// constraint, `solution_valid` remains `true` — avoiding a full re-solve.
918    pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>) {
919        if self.solution_valid {
920            if let Some(ref sol) = self.solution {
921                if !constraint.is_feasible(sol) {
922                    self.solution_valid = false;
923                }
924            }
925        }
926        self.constraints.push(constraint);
927    }
928
929    /// Remove a constraint by index. Returns `false` if `idx` is out of range.
930    ///
931    /// After removal the cached solution (if any) is still feasible for the remaining
932    /// constraints, so `solution_valid` is preserved.
933    pub fn remove_constraint(&mut self, idx: usize) -> bool {
934        if idx >= self.constraints.len() {
935            return false;
936        }
937        self.constraints.remove(idx);
938        // Solution is still feasible for the remaining (smaller) constraint set
939        // — no need to invalidate.
940        true
941    }
942
943    /// Force a full re-solve on the next call to `solve`.
944    pub fn invalidate(&mut self) {
945        self.solution_valid = false;
946    }
947
948    /// Return a reference to the cached solution, if any.
949    pub fn current_solution(&self) -> Option<&Array1<f32>> {
950        self.solution.as_ref()
951    }
952
953    /// Number of registered constraints.
954    pub fn num_constraints(&self) -> usize {
955        self.constraints.len()
956    }
957
958    /// Solve incrementally.
959    ///
960    /// - If `solution_valid`, warm-starts from the cached solution (fewer iterations).
961    /// - Otherwise, starts from `init`.
962    ///
963    /// Uses Dykstra's alternating projections in parallel (via [`ParallelFeasibilityChecker`]).
964    pub fn solve(&mut self, init: Array1<f32>, max_iter: usize) -> SolverResult {
965        let start = Instant::now();
966
967        let start_point = if self.solution_valid {
968            self.solution.clone().unwrap_or_else(|| init.clone())
969        } else {
970            init.clone()
971        };
972
973        // Determine iteration budget: warm starts need fewer iterations
974        let actual_max_iter = if self.solution_valid {
975            (max_iter / 4).max(1)
976        } else {
977            max_iter
978        };
979
980        let result = dykstra_project(&start_point, &self.constraints, actual_max_iter);
981
982        // Assess feasibility
983        let num_violations = self
984            .constraints
985            .iter()
986            .filter(|c| !c.is_feasible(&result))
987            .count();
988        let feasible = num_violations == 0;
989
990        let elapsed_us = start.elapsed().as_micros() as u64;
991
992        self.solution = Some(result.clone());
993        self.solution_valid = feasible;
994
995        SolverResult {
996            solution: result,
997            feasible,
998            iterations: actual_max_iter,
999            num_violations,
1000            solve_time_us: elapsed_us,
1001        }
1002    }
1003}
1004
1005// ============================================================================
1006// Tests
1007// ============================================================================
1008
1009#[cfg(test)]
1010mod tests {
1011    use super::*;
1012    use scirs2_core::ndarray::Array1;
1013
1014    fn make_box() -> BoxConstraint {
1015        BoxConstraint::new(
1016            Array1::from(vec![0.0f32, 0.0, 0.0]),
1017            Array1::from(vec![1.0f32, 2.0, 3.0]),
1018        )
1019        .expect("valid box")
1020    }
1021
1022    // -----------------------------------------------------------------------
1023    // 1. BoxConstraint: feasibility
1024    // -----------------------------------------------------------------------
1025    #[test]
1026    fn test_box_constraint_feasible() {
1027        let bc = make_box();
1028        let x = Array1::from(vec![0.5f32, 1.0, 2.0]);
1029        assert!(bc.is_feasible(&x));
1030        assert_eq!(bc.violation(&x), 0.0);
1031    }
1032
1033    // -----------------------------------------------------------------------
1034    // 2. BoxConstraint: projection
1035    // -----------------------------------------------------------------------
1036    #[test]
1037    fn test_box_constraint_project() {
1038        let bc = make_box();
1039        let x = Array1::from(vec![-1.0f32, 3.0, 5.0]); // out of bounds
1040        let p = bc.project(&x);
1041        assert!((p[0] - 0.0).abs() < 1e-5, "clamped to lb");
1042        assert!((p[1] - 2.0).abs() < 1e-5, "clamped to ub");
1043        assert!((p[2] - 3.0).abs() < 1e-5, "clamped to ub");
1044        assert!(bc.is_feasible(&p));
1045    }
1046
1047    // -----------------------------------------------------------------------
1048    // 3. L2BallConstraint: projection
1049    // -----------------------------------------------------------------------
1050    #[test]
1051    fn test_l2_ball_project() {
1052        let center = Array1::from(vec![0.0f32, 0.0]);
1053        let ball = L2BallConstraint::new(center, 1.0).expect("valid ball");
1054
1055        let outside = Array1::from(vec![3.0f32, 4.0]); // dist = 5
1056        let p = ball.project(&outside);
1057        let dist: f32 = p.iter().map(|&v| v * v).sum::<f32>().sqrt();
1058        assert!(
1059            (dist - 1.0).abs() < 1e-4,
1060            "projected onto ball surface, dist={dist}"
1061        );
1062        assert!(ball.is_feasible(&p));
1063
1064        let inside = Array1::from(vec![0.1f32, 0.1]);
1065        let p2 = ball.project(&inside);
1066        assert!((p2[0] - inside[0]).abs() < 1e-5, "inside point unchanged");
1067    }
1068
1069    // -----------------------------------------------------------------------
1070    // 4. HyperplaneConstraint: projection
1071    // -----------------------------------------------------------------------
1072    #[test]
1073    fn test_hyperplane_project() {
1074        // Constraint: x[0] + x[1] <= 1
1075        let normal = Array1::from(vec![1.0f32, 1.0]);
1076        let hp = HyperplaneConstraint::new(normal, 1.0).expect("valid hyperplane");
1077
1078        let violating = Array1::from(vec![2.0f32, 2.0]); // a^T x = 4 > 1
1079        let p = hp.project(&violating);
1080        let ax: f32 = p[0] + p[1];
1081        assert!(
1082            ax <= 1.0 + 1e-5,
1083            "projected point satisfies a^T x <= b, got {ax}"
1084        );
1085        assert!(hp.is_feasible(&p));
1086
1087        let ok = Array1::from(vec![0.3f32, 0.3]);
1088        assert!(hp.is_feasible(&ok));
1089        assert_eq!(hp.violation(&ok), 0.0);
1090    }
1091
1092    // -----------------------------------------------------------------------
1093    // 5. SimplexConstraint: projection
1094    // -----------------------------------------------------------------------
1095    #[test]
1096    fn test_simplex_project() {
1097        let simplex = SimplexConstraint::new(4);
1098        let x = Array1::from(vec![1.0f32, 2.0, 3.0, 4.0]); // sum = 10
1099        let p = simplex.project(&x);
1100
1101        let sum: f32 = p.iter().sum();
1102        assert!(
1103            (sum - 1.0).abs() < 1e-4,
1104            "sum of projected point should be 1, got {sum}"
1105        );
1106        for &v in p.iter() {
1107            assert!(v >= -1e-5, "all components should be non-negative, got {v}");
1108        }
1109        assert!(simplex.is_feasible(&p));
1110    }
1111
1112    // -----------------------------------------------------------------------
1113    // 6. ParallelFeasibilityChecker: batch feasibility
1114    // -----------------------------------------------------------------------
1115    #[test]
1116    fn test_parallel_checker_batch() {
1117        let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
1118        checker.add_constraint(Box::new(make_box()));
1119
1120        // Build 100 points: first 50 feasible, last 50 infeasible
1121        let mut data = vec![0.0f32; 100 * 3];
1122        for i in 0..50usize {
1123            data[i * 3] = 0.5;
1124            data[i * 3 + 1] = 1.0;
1125            data[i * 3 + 2] = 1.5;
1126        }
1127        for i in 50..100usize {
1128            data[i * 3] = 5.0; // out of [0,1]
1129            data[i * 3 + 1] = 0.5;
1130            data[i * 3 + 2] = 0.5;
1131        }
1132        let points = Array2::from_shape_vec((100, 3), data).expect("valid shape");
1133        let results = checker.check_batch(&points);
1134        assert_eq!(results.len(), 100);
1135        let feasible_count = results.iter().filter(|&&f| f).count();
1136        assert_eq!(
1137            feasible_count, 50,
1138            "expected 50 feasible points, got {feasible_count}"
1139        );
1140    }
1141
1142    // -----------------------------------------------------------------------
1143    // 7. ParallelFeasibilityChecker: violation matrix shape
1144    // -----------------------------------------------------------------------
1145    #[test]
1146    fn test_parallel_checker_violations() {
1147        let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
1148        checker.add_constraint(Box::new(make_box()));
1149        checker.add_constraint(Box::new(
1150            L2BallConstraint::new(Array1::from(vec![0.5f32, 1.0, 1.5]), 2.0).expect("valid ball"),
1151        ));
1152
1153        let data: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0, 3.0, 4.0];
1154        let points = Array2::from_shape_vec((2, 3), data).expect("valid shape");
1155        let mat = checker.violation_matrix(&points);
1156        assert_eq!(mat.dim(), (2, 2), "expected (2, 2) violation matrix");
1157        // First point is feasible for box constraint
1158        assert!(mat[[0, 0]] < 1e-5, "first point, box violation should be 0");
1159    }
1160
1161    // -----------------------------------------------------------------------
1162    // 8. ParallelFeasibilityChecker: project_batch
1163    // -----------------------------------------------------------------------
1164    #[test]
1165    fn test_parallel_checker_project_batch() {
1166        let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
1167        checker.add_constraint(Box::new(make_box()));
1168
1169        // All infeasible points
1170        let data: Vec<f32> = vec![-1.0, 5.0, 10.0, -2.0, 3.0, 7.0];
1171        let points = Array2::from_shape_vec((2, 3), data).expect("valid shape");
1172        let projected = checker.project_batch(&points, 50);
1173        assert_eq!(projected.dim(), (2, 3));
1174        for i in 0..2usize {
1175            let row: Array1<f32> = projected.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
1176            assert!(
1177                make_box().is_feasible(&row),
1178                "projected row {i} should be feasible"
1179            );
1180        }
1181    }
1182
1183    // -----------------------------------------------------------------------
1184    // 9. ConstraintGraph: independent sets
1185    // -----------------------------------------------------------------------
1186    #[test]
1187    fn test_constraint_graph_independent_sets() {
1188        // Two constraints sharing variable 0: should be in different sets
1189        let mut graph = ConstraintGraph::new(3);
1190        graph.add_constraint(Box::new(make_box()), vec![0, 1]);
1191        graph.add_constraint(Box::new(make_box()), vec![0, 2]); // shares var 0
1192        graph.add_constraint(Box::new(make_box()), vec![1, 2]); // shares var 1 with c0, var 2 with c1
1193
1194        let sets = graph.independent_sets();
1195        // Verify each set contains no two adjacent constraints
1196        for set in &sets {
1197            for (a_idx, &a) in set.iter().enumerate() {
1198                for &b in set.iter().skip(a_idx + 1) {
1199                    assert!(
1200                        !graph.adjacency[a].contains(&b),
1201                        "constraints {a} and {b} are adjacent but in the same independent set"
1202                    );
1203                }
1204            }
1205        }
1206        // All constraints must appear in exactly one set
1207        let mut seen = std::collections::HashSet::new();
1208        for set in &sets {
1209            for &c in set {
1210                assert!(seen.insert(c), "constraint {c} appears in multiple sets");
1211            }
1212        }
1213        assert_eq!(seen.len(), 3, "all 3 constraints must appear");
1214    }
1215
1216    // -----------------------------------------------------------------------
1217    // 10. SimdConstraintEvaluator: batch evaluation matches sequential
1218    // -----------------------------------------------------------------------
1219    #[test]
1220    fn test_simd_evaluator_batch() {
1221        let bounds = vec![
1222            (vec![0.0f32, 0.0, 0.0], vec![1.0f32, 2.0, 3.0]),
1223            (vec![-1.0f32, -1.0, -1.0], vec![1.0f32, 1.0, 1.0]),
1224        ];
1225        let evaluator = SimdConstraintEvaluator::new(bounds).expect("valid bounds");
1226
1227        let data = vec![0.5f32, 1.5, 2.5, 2.0, 0.0, 0.0];
1228        let points = Array2::from_shape_vec((2, 3), data.clone()).expect("valid shape");
1229
1230        let batch = evaluator.evaluate_batch(&points);
1231
1232        // Compare with sequential
1233        for i in 0..2usize {
1234            let row = &data[i * 3..(i + 1) * 3];
1235            let seq = evaluator.evaluate(row);
1236            for j in 0..evaluator.num_constraints() {
1237                assert!(
1238                    (batch[[i, j]] - seq[j]).abs() < 1e-5,
1239                    "batch[{i},{j}]={} != seq[{j}]={}",
1240                    batch[[i, j]],
1241                    seq[j]
1242                );
1243            }
1244        }
1245    }
1246
1247    // -----------------------------------------------------------------------
1248    // 11. SimdConstraintEvaluator: fast feasibility early exit
1249    // -----------------------------------------------------------------------
1250    #[test]
1251    fn test_simd_evaluator_fast_feasibility() {
1252        let bounds = vec![
1253            (vec![0.0f32, 0.0, 0.0], vec![1.0f32, 2.0, 3.0]),
1254            (vec![-5.0f32, -5.0, -5.0], vec![5.0f32, 5.0, 5.0]),
1255        ];
1256        let evaluator = SimdConstraintEvaluator::new(bounds).expect("valid bounds");
1257
1258        let feasible = vec![0.5f32, 1.0, 2.0];
1259        assert!(evaluator.is_feasible_fast(&feasible), "point is feasible");
1260
1261        let infeasible = vec![2.0f32, 1.0, 2.0]; // violates first constraint (x[0] > 1)
1262        assert!(
1263            !evaluator.is_feasible_fast(&infeasible),
1264            "point is infeasible"
1265        );
1266    }
1267
1268    // -----------------------------------------------------------------------
1269    // 12. IncrementalParallelSolver: add constraint updates solution
1270    // -----------------------------------------------------------------------
1271    #[test]
1272    fn test_incremental_solver_add_constraint() {
1273        let mut solver = IncrementalParallelSolver::new(ParallelConfig::default());
1274
1275        // Add a box constraint [0,1]^3
1276        solver.add_constraint(Box::new(make_box()));
1277
1278        // Solve with an infeasible init
1279        let init = Array1::from(vec![5.0f32, 5.0, 5.0]);
1280        let result = solver.solve(init, 50);
1281
1282        assert!(
1283            result.feasible,
1284            "solution should be feasible after solving with box constraint"
1285        );
1286        assert!(make_box().is_feasible(&result.solution));
1287
1288        // Add a tighter constraint: [0, 0.5]^3
1289        let tight_box = BoxConstraint::new(
1290            Array1::from(vec![0.0f32, 0.0, 0.0]),
1291            Array1::from(vec![0.5f32, 0.5, 0.5]),
1292        )
1293        .expect("valid box");
1294        // Current solution might violate the new constraint (it's at [1,2,3] boundary)
1295        solver.add_constraint(Box::new(tight_box));
1296
1297        // Re-solve — should produce a feasible result for both constraints
1298        let init2 = Array1::from(vec![1.0f32, 2.0, 3.0]);
1299        let result2 = solver.solve(init2, 100);
1300        assert!(
1301            result2.feasible,
1302            "solution should be feasible after adding tighter constraint"
1303        );
1304        assert!(result2.solution[0] <= 0.5 + 1e-4);
1305        assert!(result2.solution[1] <= 0.5 + 1e-4);
1306        assert!(result2.solution[2] <= 0.5 + 1e-4);
1307    }
1308
1309    // -----------------------------------------------------------------------
1310    // 13. IncrementalParallelSolver: warm start uses fewer iterations
1311    // -----------------------------------------------------------------------
1312    #[test]
1313    fn test_incremental_solver_warmstart() {
1314        let mut solver = IncrementalParallelSolver::new(ParallelConfig::default());
1315        solver.add_constraint(Box::new(make_box()));
1316
1317        // Cold solve
1318        let init = Array1::from(vec![0.5f32, 1.0, 1.5]);
1319        let cold_result = solver.solve(init.clone(), 100);
1320        assert!(cold_result.feasible);
1321
1322        // Solution is now cached and valid — next call uses warm start (fewer iterations)
1323        let warm_result = solver.solve(init, 100);
1324        assert!(warm_result.feasible);
1325        assert!(
1326            warm_result.iterations <= cold_result.iterations,
1327            "warm start should use fewer or equal iterations: warm={} cold={}",
1328            warm_result.iterations,
1329            cold_result.iterations
1330        );
1331    }
1332}