use rayon::prelude::*;
use scirs2_core::ndarray::{Array1, Array2};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub num_threads: usize,
pub chunk_size: usize,
pub use_simd: bool,
pub prefetch_distance: usize,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_threads: 0,
chunk_size: 64,
use_simd: true,
prefetch_distance: 8,
}
}
}
pub trait FastConstraint: Send + Sync {
fn is_feasible(&self, x: &Array1<f32>) -> bool;
fn project(&self, x: &Array1<f32>) -> Array1<f32>;
fn violation(&self, x: &Array1<f32>) -> f32;
}
#[derive(Debug, Clone)]
pub struct BoxConstraint {
pub lb: Array1<f32>,
pub ub: Array1<f32>,
}
impl BoxConstraint {
pub fn new(lb: Array1<f32>, ub: Array1<f32>) -> Result<Self, String> {
if lb.len() != ub.len() {
return Err(format!(
"BoxConstraint: lb.len()={} != ub.len()={}",
lb.len(),
ub.len()
));
}
for (i, (&l, &u)) in lb.iter().zip(ub.iter()).enumerate() {
if l > u {
return Err(format!("BoxConstraint: lb[{i}]={l} > ub[{i}]={u}"));
}
}
Ok(Self { lb, ub })
}
}
impl FastConstraint for BoxConstraint {
fn is_feasible(&self, x: &Array1<f32>) -> bool {
x.iter()
.zip(self.lb.iter())
.zip(self.ub.iter())
.map(|((&xi, &li), &ui)| if xi < li || xi > ui { 1u8 } else { 0u8 })
.sum::<u8>()
== 0
}
fn project(&self, x: &Array1<f32>) -> Array1<f32> {
let n = x.len();
let mut out = vec![0.0f32; n];
for i in 0..n {
out[i] = x[i].clamp(self.lb[i], self.ub[i]);
}
Array1::from(out)
}
fn violation(&self, x: &Array1<f32>) -> f32 {
let mut v = 0.0f32;
for i in 0..x.len() {
let below = (self.lb[i] - x[i]).max(0.0);
let above = (x[i] - self.ub[i]).max(0.0);
v += below * below + above * above;
}
v.sqrt()
}
}
#[derive(Debug, Clone)]
pub struct L2BallConstraint {
pub center: Array1<f32>,
pub radius: f32,
}
impl L2BallConstraint {
pub fn new(center: Array1<f32>, radius: f32) -> Result<Self, String> {
if radius <= 0.0 {
return Err(format!(
"L2BallConstraint: radius must be positive, got {radius}"
));
}
Ok(Self { center, radius })
}
fn dist_sq(&self, x: &Array1<f32>) -> f32 {
let mut s = 0.0f32;
for i in 0..x.len() {
let d = x[i] - self.center[i];
s += d * d;
}
s
}
}
impl FastConstraint for L2BallConstraint {
fn is_feasible(&self, x: &Array1<f32>) -> bool {
self.dist_sq(x) <= self.radius * self.radius
}
fn project(&self, x: &Array1<f32>) -> Array1<f32> {
let dist_sq = self.dist_sq(x);
if dist_sq <= self.radius * self.radius {
return x.clone();
}
let dist = dist_sq.sqrt();
let scale = self.radius / dist;
let n = x.len();
let mut out = vec![0.0f32; n];
for i in 0..n {
out[i] = self.center[i] + (x[i] - self.center[i]) * scale;
}
Array1::from(out)
}
fn violation(&self, x: &Array1<f32>) -> f32 {
let dist = self.dist_sq(x).sqrt();
(dist - self.radius).max(0.0)
}
}
#[derive(Debug, Clone)]
pub struct HyperplaneConstraint {
pub normal: Array1<f32>,
pub offset: f32,
}
impl HyperplaneConstraint {
pub fn new(normal: Array1<f32>, offset: f32) -> Result<Self, String> {
let norm_sq: f32 = normal.iter().map(|&v| v * v).sum();
if norm_sq == 0.0 {
return Err("HyperplaneConstraint: normal vector must be non-zero".to_string());
}
Ok(Self { normal, offset })
}
fn dot(&self, x: &Array1<f32>) -> f32 {
let mut s = 0.0f32;
for i in 0..x.len() {
s += self.normal[i] * x[i];
}
s
}
fn norm_sq(&self) -> f32 {
let mut s = 0.0f32;
for &v in self.normal.iter() {
s += v * v;
}
s
}
}
impl FastConstraint for HyperplaneConstraint {
fn is_feasible(&self, x: &Array1<f32>) -> bool {
self.dot(x) <= self.offset
}
fn project(&self, x: &Array1<f32>) -> Array1<f32> {
let ax = self.dot(x);
if ax <= self.offset {
return x.clone();
}
let scale = (ax - self.offset) / self.norm_sq();
let n = x.len();
let mut out = vec![0.0f32; n];
for i in 0..n {
out[i] = x[i] - scale * self.normal[i];
}
Array1::from(out)
}
fn violation(&self, x: &Array1<f32>) -> f32 {
(self.dot(x) - self.offset).max(0.0)
}
}
#[derive(Debug, Clone)]
pub struct SimplexConstraint {
pub dim: usize,
}
impl SimplexConstraint {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl FastConstraint for SimplexConstraint {
fn is_feasible(&self, x: &Array1<f32>) -> bool {
if x.len() != self.dim {
return false;
}
let sum: f32 = x.iter().sum();
let all_nonneg = x
.iter()
.map(|&v| if v < 0.0 { 1u8 } else { 0u8 })
.sum::<u8>()
== 0;
all_nonneg && (sum - 1.0).abs() < 1e-5
}
fn project(&self, x: &Array1<f32>) -> Array1<f32> {
let n = x.len();
let mut sorted: Vec<f32> = x.iter().copied().collect();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0f32;
let mut rho = 0usize;
for (j, &s) in sorted.iter().enumerate() {
cumsum += s;
if s > (cumsum - 1.0) / (j as f32 + 1.0) {
rho = j;
}
}
let cumsum_rho: f32 = sorted[..=rho].iter().sum();
let theta = (cumsum_rho - 1.0) / (rho as f32 + 1.0);
let mut out = vec![0.0f32; n];
for i in 0..n {
out[i] = (x[i] - theta).max(0.0);
}
Array1::from(out)
}
fn violation(&self, x: &Array1<f32>) -> f32 {
let sum: f32 = x.iter().sum();
let sum_viol = (sum - 1.0).abs();
let neg_viol: f32 = x.iter().map(|&v| (-v).max(0.0)).sum();
sum_viol + neg_viol
}
}
pub struct ParallelFeasibilityChecker {
constraints: Vec<Box<dyn FastConstraint>>,
config: ParallelConfig,
}
impl ParallelFeasibilityChecker {
pub fn new(config: ParallelConfig) -> Self {
Self {
constraints: Vec::new(),
config,
}
}
pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>) {
self.constraints.push(constraint);
}
pub fn num_constraints(&self) -> usize {
self.constraints.len()
}
pub fn check_batch(&self, points: &Array2<f32>) -> Vec<bool> {
let (n_points, dim) = points.dim();
let constraints = &self.constraints;
let chunk_size = self.config.chunk_size.max(1);
(0..n_points)
.into_par_iter()
.with_min_len(chunk_size)
.map(|i| {
let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
if row.len() != dim {
return false;
}
constraints.iter().all(|c| c.is_feasible(&row))
})
.collect()
}
pub fn violation_matrix(&self, points: &Array2<f32>) -> Array2<f32> {
let (n_points, _dim) = points.dim();
let n_constraints = self.constraints.len();
if n_constraints == 0 || n_points == 0 {
return Array2::zeros((n_points, n_constraints));
}
let chunk_size = self.config.chunk_size.max(1);
let constraints = &self.constraints;
let rows: Vec<Vec<f32>> = (0..n_points)
.into_par_iter()
.with_min_len(chunk_size)
.map(|i| {
let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
constraints.iter().map(|c| c.violation(&row)).collect()
})
.collect();
let mut out = Array2::zeros((n_points, n_constraints));
for (i, row) in rows.iter().enumerate() {
for (j, &v) in row.iter().enumerate() {
out[[i, j]] = v;
}
}
out
}
pub fn project_batch(&self, points: &Array2<f32>, max_iter: usize) -> Array2<f32> {
let (n_points, dim) = points.dim();
if n_points == 0 || dim == 0 {
return points.clone();
}
let chunk_size = self.config.chunk_size.max(1);
let constraints = &self.constraints;
let projected: Vec<Vec<f32>> = (0..n_points)
.into_par_iter()
.with_min_len(chunk_size)
.map(|i| {
let row: Array1<f32> = points.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
let result = dykstra_project(&row, constraints.as_slice(), max_iter);
result.into_raw_vec_and_offset().0
})
.collect();
let mut out = Array2::zeros((n_points, dim));
for (i, row) in projected.iter().enumerate() {
for (j, &v) in row.iter().enumerate() {
if j < dim {
out[[i, j]] = v;
}
}
}
out
}
}
fn dykstra_project(
x: &Array1<f32>,
constraints: &[Box<dyn FastConstraint>],
max_iter: usize,
) -> Array1<f32> {
if constraints.is_empty() {
return x.clone();
}
let n = x.len();
let m = constraints.len();
let mut z = x.clone();
let mut increments: Vec<Array1<f32>> = vec![Array1::zeros(n); m];
for _ in 0..max_iter {
let prev = z.clone();
for (k, constraint) in constraints.iter().enumerate() {
let y = &z + &increments[k];
let proj = constraint.project(&y);
for j in 0..n {
increments[k][j] = y[j] - proj[j];
}
z = proj;
}
let max_diff = z
.iter()
.zip(prev.iter())
.map(|(&a, &b)| (a - b).abs())
.fold(0.0f32, f32::max);
if max_diff < 1e-6 {
break;
}
}
z
}
#[derive(Debug, Clone)]
pub struct PropagationResult {
pub converged: bool,
pub iterations: usize,
pub num_violations: usize,
}
pub struct ConstraintGraph {
num_vars: usize,
constraints: Vec<Box<dyn FastConstraint>>,
var_indices: Vec<Vec<usize>>,
adjacency: Vec<Vec<usize>>,
}
impl ConstraintGraph {
pub fn new(num_vars: usize) -> Self {
Self {
num_vars,
constraints: Vec::new(),
var_indices: Vec::new(),
adjacency: Vec::new(),
}
}
pub fn num_constraints(&self) -> usize {
self.constraints.len()
}
pub fn num_vars(&self) -> usize {
self.num_vars
}
pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>, var_indices: Vec<usize>) {
let new_idx = self.constraints.len();
self.constraints.push(constraint);
let mut neighbors = Vec::new();
for (existing_idx, existing_vars) in self.var_indices.iter().enumerate() {
let shares = var_indices.iter().any(|v| existing_vars.contains(v));
if shares {
neighbors.push(existing_idx);
self.adjacency[existing_idx].push(new_idx);
}
}
self.var_indices.push(var_indices);
self.adjacency.push(neighbors);
}
pub fn independent_sets(&self) -> Vec<Vec<usize>> {
let n = self.constraints.len();
if n == 0 {
return Vec::new();
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| std::cmp::Reverse(self.adjacency[i].len()));
let mut colors: Vec<Option<usize>> = vec![None; n];
let mut num_colors = 0usize;
for &node in &order {
let used_colors: std::collections::HashSet<usize> = self.adjacency[node]
.iter()
.filter_map(|&nb| colors[nb])
.collect();
let color = (0..).find(|c| !used_colors.contains(c)).unwrap_or(0);
colors[node] = Some(color);
if color >= num_colors {
num_colors = color + 1;
}
}
let mut sets: Vec<Vec<usize>> = vec![Vec::new(); num_colors];
for (node, color) in colors.iter().enumerate() {
if let Some(c) = color {
sets[*c].push(node);
}
}
sets
}
pub fn propagate_parallel(&self, x: &mut Array1<f32>) -> PropagationResult {
let max_iter = 50usize;
let tol = 1e-6f32;
let sets = self.independent_sets();
let mut iterations = 0usize;
let mut converged = false;
for _global_iter in 0..max_iter {
iterations += 1;
let prev = x.clone();
for set in &sets {
if set.is_empty() {
continue;
}
let projections: Vec<(usize, Array1<f32>)> = set
.par_iter()
.map(|&c_idx| {
let proj = self.constraints[c_idx].project(x);
(c_idx, proj)
})
.collect();
for (c_idx, proj) in &projections {
for &var in &self.var_indices[*c_idx] {
if var < x.len() {
x[var] = proj[var];
}
}
}
}
let max_diff = x
.iter()
.zip(prev.iter())
.map(|(&a, &b)| (a - b).abs())
.fold(0.0f32, f32::max);
if max_diff < tol {
converged = true;
break;
}
}
let num_violations = self
.constraints
.iter()
.filter(|c| !c.is_feasible(x))
.count();
PropagationResult {
converged,
iterations,
num_violations,
}
}
}
pub struct SimdConstraintEvaluator {
lb: Vec<f32>,
ub: Vec<f32>,
dim: usize,
num_constraints: usize,
}
impl SimdConstraintEvaluator {
pub fn new(bounds: Vec<(Vec<f32>, Vec<f32>)>) -> Result<Self, String> {
if bounds.is_empty() {
return Ok(Self {
lb: Vec::new(),
ub: Vec::new(),
dim: 0,
num_constraints: 0,
});
}
let dim = bounds[0].0.len();
for (k, (l, u)) in bounds.iter().enumerate() {
if l.len() != dim {
return Err(format!(
"SimdConstraintEvaluator: bounds[{k}].0.len()={} != dim={dim}",
l.len()
));
}
if u.len() != dim {
return Err(format!(
"SimdConstraintEvaluator: bounds[{k}].1.len()={} != dim={dim}",
u.len()
));
}
for i in 0..dim {
if l[i] > u[i] {
return Err(format!(
"SimdConstraintEvaluator: bounds[{k}].lb[{i}]={} > ub[{i}]={}",
l[i], u[i]
));
}
}
}
let num_constraints = bounds.len();
let mut lb_flat = vec![0.0f32; num_constraints * dim];
let mut ub_flat = vec![0.0f32; num_constraints * dim];
for (k, (l, u)) in bounds.iter().enumerate() {
for i in 0..dim {
lb_flat[k * dim + i] = l[i];
ub_flat[k * dim + i] = u[i];
}
}
Ok(Self {
lb: lb_flat,
ub: ub_flat,
dim,
num_constraints,
})
}
pub fn num_constraints(&self) -> usize {
self.num_constraints
}
pub fn evaluate(&self, x: &[f32]) -> Vec<f32> {
(0..self.num_constraints)
.map(|c| {
let base = c * self.dim;
let lb_slice = &self.lb[base..base + self.dim];
let ub_slice = &self.ub[base..base + self.dim];
let v: f32 = x
.iter()
.zip(lb_slice.iter())
.zip(ub_slice.iter())
.map(|((&xi, &lbi), &ubi)| {
let lb_viol = (lbi - xi).max(0.0);
let ub_viol = (xi - ubi).max(0.0);
lb_viol * lb_viol + ub_viol * ub_viol
})
.sum();
v.sqrt()
})
.collect()
}
pub fn evaluate_batch(&self, points: &Array2<f32>) -> Array2<f32> {
let (n_points, _dim) = points.dim();
let n_c = self.num_constraints;
if n_points == 0 || n_c == 0 {
return Array2::zeros((n_points, n_c));
}
let rows: Vec<Vec<f32>> = (0..n_points)
.into_par_iter()
.map(|i| {
let row: Vec<f32> = points
.slice(scirs2_core::ndarray::s![i, ..])
.iter()
.copied()
.collect();
self.evaluate(&row)
})
.collect();
let mut out = Array2::zeros((n_points, n_c));
for (i, row) in rows.iter().enumerate() {
for (j, &v) in row.iter().enumerate() {
out[[i, j]] = v;
}
}
out
}
pub fn is_feasible_fast(&self, x: &[f32]) -> bool {
(0..self.num_constraints).all(|c| {
let base = c * self.dim;
let lb_slice = &self.lb[base..base + self.dim];
let ub_slice = &self.ub[base..base + self.dim];
let max_viol: f32 = x
.iter()
.zip(lb_slice.iter())
.zip(ub_slice.iter())
.map(|((&xi, &lbi), &ubi)| (lbi - xi).max(0.0) + (xi - ubi).max(0.0))
.sum();
max_viol == 0.0
})
}
}
#[derive(Debug, Clone)]
pub struct SolverResult {
pub solution: Array1<f32>,
pub feasible: bool,
pub iterations: usize,
pub num_violations: usize,
pub solve_time_us: u64,
}
pub struct IncrementalParallelSolver {
#[allow(dead_code)]
config: ParallelConfig,
constraints: Vec<Box<dyn FastConstraint>>,
solution: Option<Array1<f32>>,
solution_valid: bool,
}
impl IncrementalParallelSolver {
pub fn new(config: ParallelConfig) -> Self {
Self {
config,
constraints: Vec::new(),
solution: None,
solution_valid: false,
}
}
pub fn add_constraint(&mut self, constraint: Box<dyn FastConstraint>) {
if self.solution_valid {
if let Some(ref sol) = self.solution {
if !constraint.is_feasible(sol) {
self.solution_valid = false;
}
}
}
self.constraints.push(constraint);
}
pub fn remove_constraint(&mut self, idx: usize) -> bool {
if idx >= self.constraints.len() {
return false;
}
self.constraints.remove(idx);
true
}
pub fn invalidate(&mut self) {
self.solution_valid = false;
}
pub fn current_solution(&self) -> Option<&Array1<f32>> {
self.solution.as_ref()
}
pub fn num_constraints(&self) -> usize {
self.constraints.len()
}
pub fn solve(&mut self, init: Array1<f32>, max_iter: usize) -> SolverResult {
let start = Instant::now();
let start_point = if self.solution_valid {
self.solution.clone().unwrap_or_else(|| init.clone())
} else {
init.clone()
};
let actual_max_iter = if self.solution_valid {
(max_iter / 4).max(1)
} else {
max_iter
};
let result = dykstra_project(&start_point, &self.constraints, actual_max_iter);
let num_violations = self
.constraints
.iter()
.filter(|c| !c.is_feasible(&result))
.count();
let feasible = num_violations == 0;
let elapsed_us = start.elapsed().as_micros() as u64;
self.solution = Some(result.clone());
self.solution_valid = feasible;
SolverResult {
solution: result,
feasible,
iterations: actual_max_iter,
num_violations,
solve_time_us: elapsed_us,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
fn make_box() -> BoxConstraint {
BoxConstraint::new(
Array1::from(vec![0.0f32, 0.0, 0.0]),
Array1::from(vec![1.0f32, 2.0, 3.0]),
)
.expect("valid box")
}
#[test]
fn test_box_constraint_feasible() {
let bc = make_box();
let x = Array1::from(vec![0.5f32, 1.0, 2.0]);
assert!(bc.is_feasible(&x));
assert_eq!(bc.violation(&x), 0.0);
}
#[test]
fn test_box_constraint_project() {
let bc = make_box();
let x = Array1::from(vec![-1.0f32, 3.0, 5.0]); let p = bc.project(&x);
assert!((p[0] - 0.0).abs() < 1e-5, "clamped to lb");
assert!((p[1] - 2.0).abs() < 1e-5, "clamped to ub");
assert!((p[2] - 3.0).abs() < 1e-5, "clamped to ub");
assert!(bc.is_feasible(&p));
}
#[test]
fn test_l2_ball_project() {
let center = Array1::from(vec![0.0f32, 0.0]);
let ball = L2BallConstraint::new(center, 1.0).expect("valid ball");
let outside = Array1::from(vec![3.0f32, 4.0]); let p = ball.project(&outside);
let dist: f32 = p.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(dist - 1.0).abs() < 1e-4,
"projected onto ball surface, dist={dist}"
);
assert!(ball.is_feasible(&p));
let inside = Array1::from(vec![0.1f32, 0.1]);
let p2 = ball.project(&inside);
assert!((p2[0] - inside[0]).abs() < 1e-5, "inside point unchanged");
}
#[test]
fn test_hyperplane_project() {
let normal = Array1::from(vec![1.0f32, 1.0]);
let hp = HyperplaneConstraint::new(normal, 1.0).expect("valid hyperplane");
let violating = Array1::from(vec![2.0f32, 2.0]); let p = hp.project(&violating);
let ax: f32 = p[0] + p[1];
assert!(
ax <= 1.0 + 1e-5,
"projected point satisfies a^T x <= b, got {ax}"
);
assert!(hp.is_feasible(&p));
let ok = Array1::from(vec![0.3f32, 0.3]);
assert!(hp.is_feasible(&ok));
assert_eq!(hp.violation(&ok), 0.0);
}
#[test]
fn test_simplex_project() {
let simplex = SimplexConstraint::new(4);
let x = Array1::from(vec![1.0f32, 2.0, 3.0, 4.0]); let p = simplex.project(&x);
let sum: f32 = p.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-4,
"sum of projected point should be 1, got {sum}"
);
for &v in p.iter() {
assert!(v >= -1e-5, "all components should be non-negative, got {v}");
}
assert!(simplex.is_feasible(&p));
}
#[test]
fn test_parallel_checker_batch() {
let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
checker.add_constraint(Box::new(make_box()));
let mut data = vec![0.0f32; 100 * 3];
for i in 0..50usize {
data[i * 3] = 0.5;
data[i * 3 + 1] = 1.0;
data[i * 3 + 2] = 1.5;
}
for i in 50..100usize {
data[i * 3] = 5.0; data[i * 3 + 1] = 0.5;
data[i * 3 + 2] = 0.5;
}
let points = Array2::from_shape_vec((100, 3), data).expect("valid shape");
let results = checker.check_batch(&points);
assert_eq!(results.len(), 100);
let feasible_count = results.iter().filter(|&&f| f).count();
assert_eq!(
feasible_count, 50,
"expected 50 feasible points, got {feasible_count}"
);
}
#[test]
fn test_parallel_checker_violations() {
let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
checker.add_constraint(Box::new(make_box()));
checker.add_constraint(Box::new(
L2BallConstraint::new(Array1::from(vec![0.5f32, 1.0, 1.5]), 2.0).expect("valid ball"),
));
let data: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0, 3.0, 4.0];
let points = Array2::from_shape_vec((2, 3), data).expect("valid shape");
let mat = checker.violation_matrix(&points);
assert_eq!(mat.dim(), (2, 2), "expected (2, 2) violation matrix");
assert!(mat[[0, 0]] < 1e-5, "first point, box violation should be 0");
}
#[test]
fn test_parallel_checker_project_batch() {
let mut checker = ParallelFeasibilityChecker::new(ParallelConfig::default());
checker.add_constraint(Box::new(make_box()));
let data: Vec<f32> = vec![-1.0, 5.0, 10.0, -2.0, 3.0, 7.0];
let points = Array2::from_shape_vec((2, 3), data).expect("valid shape");
let projected = checker.project_batch(&points, 50);
assert_eq!(projected.dim(), (2, 3));
for i in 0..2usize {
let row: Array1<f32> = projected.slice(scirs2_core::ndarray::s![i, ..]).to_owned();
assert!(
make_box().is_feasible(&row),
"projected row {i} should be feasible"
);
}
}
#[test]
fn test_constraint_graph_independent_sets() {
let mut graph = ConstraintGraph::new(3);
graph.add_constraint(Box::new(make_box()), vec![0, 1]);
graph.add_constraint(Box::new(make_box()), vec![0, 2]); graph.add_constraint(Box::new(make_box()), vec![1, 2]);
let sets = graph.independent_sets();
for set in &sets {
for (a_idx, &a) in set.iter().enumerate() {
for &b in set.iter().skip(a_idx + 1) {
assert!(
!graph.adjacency[a].contains(&b),
"constraints {a} and {b} are adjacent but in the same independent set"
);
}
}
}
let mut seen = std::collections::HashSet::new();
for set in &sets {
for &c in set {
assert!(seen.insert(c), "constraint {c} appears in multiple sets");
}
}
assert_eq!(seen.len(), 3, "all 3 constraints must appear");
}
#[test]
fn test_simd_evaluator_batch() {
let bounds = vec![
(vec![0.0f32, 0.0, 0.0], vec![1.0f32, 2.0, 3.0]),
(vec![-1.0f32, -1.0, -1.0], vec![1.0f32, 1.0, 1.0]),
];
let evaluator = SimdConstraintEvaluator::new(bounds).expect("valid bounds");
let data = vec![0.5f32, 1.5, 2.5, 2.0, 0.0, 0.0];
let points = Array2::from_shape_vec((2, 3), data.clone()).expect("valid shape");
let batch = evaluator.evaluate_batch(&points);
for i in 0..2usize {
let row = &data[i * 3..(i + 1) * 3];
let seq = evaluator.evaluate(row);
for j in 0..evaluator.num_constraints() {
assert!(
(batch[[i, j]] - seq[j]).abs() < 1e-5,
"batch[{i},{j}]={} != seq[{j}]={}",
batch[[i, j]],
seq[j]
);
}
}
}
#[test]
fn test_simd_evaluator_fast_feasibility() {
let bounds = vec![
(vec![0.0f32, 0.0, 0.0], vec![1.0f32, 2.0, 3.0]),
(vec![-5.0f32, -5.0, -5.0], vec![5.0f32, 5.0, 5.0]),
];
let evaluator = SimdConstraintEvaluator::new(bounds).expect("valid bounds");
let feasible = vec![0.5f32, 1.0, 2.0];
assert!(evaluator.is_feasible_fast(&feasible), "point is feasible");
let infeasible = vec![2.0f32, 1.0, 2.0]; assert!(
!evaluator.is_feasible_fast(&infeasible),
"point is infeasible"
);
}
#[test]
fn test_incremental_solver_add_constraint() {
let mut solver = IncrementalParallelSolver::new(ParallelConfig::default());
solver.add_constraint(Box::new(make_box()));
let init = Array1::from(vec![5.0f32, 5.0, 5.0]);
let result = solver.solve(init, 50);
assert!(
result.feasible,
"solution should be feasible after solving with box constraint"
);
assert!(make_box().is_feasible(&result.solution));
let tight_box = BoxConstraint::new(
Array1::from(vec![0.0f32, 0.0, 0.0]),
Array1::from(vec![0.5f32, 0.5, 0.5]),
)
.expect("valid box");
solver.add_constraint(Box::new(tight_box));
let init2 = Array1::from(vec![1.0f32, 2.0, 3.0]);
let result2 = solver.solve(init2, 100);
assert!(
result2.feasible,
"solution should be feasible after adding tighter constraint"
);
assert!(result2.solution[0] <= 0.5 + 1e-4);
assert!(result2.solution[1] <= 0.5 + 1e-4);
assert!(result2.solution[2] <= 0.5 + 1e-4);
}
#[test]
fn test_incremental_solver_warmstart() {
let mut solver = IncrementalParallelSolver::new(ParallelConfig::default());
solver.add_constraint(Box::new(make_box()));
let init = Array1::from(vec![0.5f32, 1.0, 1.5]);
let cold_result = solver.solve(init.clone(), 100);
assert!(cold_result.feasible);
let warm_result = solver.solve(init, 100);
assert!(warm_result.feasible);
assert!(
warm_result.iterations <= cold_result.iterations,
"warm start should use fewer or equal iterations: warm={} cold={}",
warm_result.iterations,
cold_result.iterations
);
}
}