1use crate::ViolationComputable;
10use scirs2_core::ndarray::Array1;
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone)]
15pub struct HierarchicalConstraint<C: ViolationComputable> {
16 pub constraint: C,
18 pub priority: u32,
20 pub weight: f32,
22}
23
24impl<C: ViolationComputable> HierarchicalConstraint<C> {
25 pub fn new(constraint: C, priority: u32, weight: f32) -> Self {
27 Self {
28 constraint,
29 priority,
30 weight,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct HierarchicalRelaxation<C: ViolationComputable> {
38 constraints: Vec<HierarchicalConstraint<C>>,
40 max_iterations_per_level: usize,
42 tolerance: f32,
44}
45
46impl<C: ViolationComputable + Clone> HierarchicalRelaxation<C> {
47 pub fn new(max_iterations_per_level: usize, tolerance: f32) -> Self {
49 Self {
50 constraints: Vec::new(),
51 max_iterations_per_level,
52 tolerance,
53 }
54 }
55
56 pub fn add_constraint(&mut self, constraint: C, priority: u32, weight: f32) {
58 self.constraints
59 .push(HierarchicalConstraint::new(constraint, priority, weight));
60 self.constraints.sort_by(|a, b| b.priority.cmp(&a.priority));
62 }
63
64 pub fn solve(&self, initial: &Array1<f32>) -> ApproximateSolution {
66 let start_time = Instant::now();
67 let mut current = initial.clone();
68 let mut satisfied_constraints = 0;
69 let mut total_violation = 0.0;
70
71 let max_priority = self
73 .constraints
74 .iter()
75 .map(|c| c.priority)
76 .max()
77 .unwrap_or(0);
78
79 for priority_level in (0..=max_priority).rev() {
81 let level_constraints: Vec<_> = self
82 .constraints
83 .iter()
84 .filter(|c| c.priority == priority_level)
85 .collect();
86
87 if level_constraints.is_empty() {
88 continue;
89 }
90
91 for _ in 0..self.max_iterations_per_level {
93 let mut improved = false;
94
95 for hc in &level_constraints {
96 let current_slice = current.as_slice().unwrap_or(&[]);
97 if !hc.constraint.check(current_slice) {
98 let violation = hc.constraint.violation(current_slice);
100 if violation > self.tolerance {
101 for i in 0..current.len() {
103 let epsilon = 0.001;
104 let mut perturbed = current.clone();
105 perturbed[i] += epsilon;
106
107 let viol_plus =
108 hc.constraint.violation(perturbed.as_slice().unwrap_or(&[]));
109 let grad = (viol_plus - violation) / epsilon;
110
111 if grad.abs() > 1e-6 {
112 current[i] -= 0.01 * hc.weight * grad.signum();
113 improved = true;
114 }
115 }
116 }
117 }
118 }
119
120 if !improved {
121 break;
122 }
123 }
124
125 for hc in &level_constraints {
127 let current_slice = current.as_slice().unwrap_or(&[]);
128 if hc.constraint.check(current_slice) {
129 satisfied_constraints += 1;
130 } else {
131 total_violation += hc.constraint.violation(current_slice);
132 }
133 }
134 }
135
136 ApproximateSolution {
137 solution: current,
138 satisfied_constraints,
139 total_constraints: self.constraints.len(),
140 total_violation,
141 computation_time: start_time.elapsed(),
142 optimality_gap: None,
143 }
144 }
145
146 pub fn num_constraints(&self) -> usize {
148 self.constraints.len()
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct BoundedErrorSolver<C: ViolationComputable> {
155 constraints: Vec<C>,
157 error_bound: f32,
159 step_size: f32,
161 max_iterations: usize,
163}
164
165impl<C: ViolationComputable + Clone> BoundedErrorSolver<C> {
166 pub fn new(error_bound: f32, step_size: f32, max_iterations: usize) -> Self {
168 Self {
169 constraints: Vec::new(),
170 error_bound,
171 step_size,
172 max_iterations,
173 }
174 }
175
176 pub fn add_constraint(&mut self, constraint: C) {
178 self.constraints.push(constraint);
179 }
180
181 pub fn solve(&self, initial: &Array1<f32>) -> ApproximateSolution {
183 let start_time = Instant::now();
184 let mut current = initial.clone();
185 let mut total_violation = 0.0;
186
187 for iter in 0..self.max_iterations {
188 let mut max_violation: f32 = 0.0;
189 let mut any_violation = false;
190
191 for constraint in &self.constraints {
193 let current_slice = current.as_slice().unwrap_or(&[]);
194 let violation = constraint.violation(current_slice);
195
196 if violation > self.error_bound {
197 any_violation = true;
198 max_violation = max_violation.max(violation);
199
200 for i in 0..current.len() {
202 let epsilon = 0.001;
203 let mut perturbed = current.clone();
204 perturbed[i] += epsilon;
205
206 let viol_plus = constraint.violation(perturbed.as_slice().unwrap_or(&[]));
207 let grad = (viol_plus - violation) / epsilon;
208
209 current[i] -= self.step_size * grad;
210 }
211 }
212 }
213
214 if !any_violation || max_violation <= self.error_bound {
216 break;
217 }
218
219 if iter % 10 == 0 && iter > 0 {
221 }
223 }
224
225 let mut satisfied = 0;
227 for constraint in &self.constraints {
228 let current_slice = current.as_slice().unwrap_or(&[]);
229 let violation = constraint.violation(current_slice);
230 total_violation += violation;
231 if violation <= self.error_bound {
232 satisfied += 1;
233 }
234 }
235
236 ApproximateSolution {
237 solution: current,
238 satisfied_constraints: satisfied,
239 total_constraints: self.constraints.len(),
240 total_violation,
241 computation_time: start_time.elapsed(),
242 optimality_gap: Some(self.error_bound),
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct AnytimeSolver<C: ViolationComputable> {
250 constraints: Vec<C>,
252 initial_solution: Array1<f32>,
254 best_solution: Option<Array1<f32>>,
256 best_violation: f32,
258 iterations: usize,
260 step_size: f32,
262}
263
264impl<C: ViolationComputable + Clone> AnytimeSolver<C> {
265 pub fn new(constraints: Vec<C>, initial_solution: Array1<f32>, step_size: f32) -> Self {
267 Self {
268 constraints,
269 initial_solution,
270 best_solution: None,
271 best_violation: f32::INFINITY,
272 iterations: 0,
273 step_size,
274 }
275 }
276
277 pub fn solve_for_duration(&mut self, duration: Duration) -> ApproximateSolution {
279 let start_time = Instant::now();
280 let mut current = self.initial_solution.clone();
281
282 while start_time.elapsed() < duration {
283 self.iterations += 1;
284
285 let mut total_violation = 0.0;
287 for constraint in &self.constraints {
288 let current_slice = current.as_slice().unwrap_or(&[]);
289 total_violation += constraint.violation(current_slice).max(0.0);
290 }
291
292 if total_violation < self.best_violation {
294 self.best_violation = total_violation;
295 self.best_solution = Some(current.clone());
296 }
297
298 for constraint in &self.constraints {
300 let current_slice = current.as_slice().unwrap_or(&[]);
301 let violation = constraint.violation(current_slice);
302
303 if violation > 0.0 {
304 for i in 0..current.len() {
305 let epsilon = 0.001;
306 let mut perturbed = current.clone();
307 perturbed[i] += epsilon;
308
309 let viol_plus = constraint.violation(perturbed.as_slice().unwrap_or(&[]));
310 let grad = (viol_plus - violation) / epsilon;
311
312 current[i] -= self.step_size * grad;
313 }
314 }
315 }
316
317 if self.iterations.is_multiple_of(100) {
319 self.step_size *= 0.99; }
321 }
322
323 let solution = self
324 .best_solution
325 .clone()
326 .unwrap_or_else(|| self.initial_solution.clone());
327
328 let mut satisfied = 0;
330 for constraint in &self.constraints {
331 let sol_slice = solution.as_slice().unwrap_or(&[]);
332 if constraint.check(sol_slice) {
333 satisfied += 1;
334 }
335 }
336
337 ApproximateSolution {
338 solution,
339 satisfied_constraints: satisfied,
340 total_constraints: self.constraints.len(),
341 total_violation: self.best_violation,
342 computation_time: start_time.elapsed(),
343 optimality_gap: None,
344 }
345 }
346
347 pub fn solve_for_iterations(&mut self, num_iterations: usize) -> ApproximateSolution {
349 let start_time = Instant::now();
350 let mut current = self.initial_solution.clone();
351
352 for _ in 0..num_iterations {
353 self.iterations += 1;
354
355 let mut total_violation = 0.0;
357 for constraint in &self.constraints {
358 let current_slice = current.as_slice().unwrap_or(&[]);
359 total_violation += constraint.violation(current_slice).max(0.0);
360 }
361
362 if total_violation < self.best_violation {
364 self.best_violation = total_violation;
365 self.best_solution = Some(current.clone());
366 }
367
368 for constraint in &self.constraints {
370 let current_slice = current.as_slice().unwrap_or(&[]);
371 let violation = constraint.violation(current_slice);
372
373 if violation > 0.0 {
374 for i in 0..current.len() {
375 let epsilon = 0.001;
376 let mut perturbed = current.clone();
377 perturbed[i] += epsilon;
378
379 let viol_plus = constraint.violation(perturbed.as_slice().unwrap_or(&[]));
380 let grad = (viol_plus - violation) / epsilon;
381
382 current[i] -= self.step_size * grad;
383 }
384 }
385 }
386 }
387
388 let solution = self
389 .best_solution
390 .clone()
391 .unwrap_or_else(|| self.initial_solution.clone());
392
393 let mut satisfied = 0;
394 for constraint in &self.constraints {
395 let sol_slice = solution.as_slice().unwrap_or(&[]);
396 if constraint.check(sol_slice) {
397 satisfied += 1;
398 }
399 }
400
401 ApproximateSolution {
402 solution,
403 satisfied_constraints: satisfied,
404 total_constraints: self.constraints.len(),
405 total_violation: self.best_violation,
406 computation_time: start_time.elapsed(),
407 optimality_gap: None,
408 }
409 }
410
411 pub fn best_solution(&self) -> Option<&Array1<f32>> {
413 self.best_solution.as_ref()
414 }
415
416 pub fn iterations(&self) -> usize {
418 self.iterations
419 }
420}
421
422#[derive(Debug, Clone)]
424pub struct ApproximateSolution {
425 pub solution: Array1<f32>,
427 pub satisfied_constraints: usize,
429 pub total_constraints: usize,
431 pub total_violation: f32,
433 pub computation_time: Duration,
435 pub optimality_gap: Option<f32>,
437}
438
439impl ApproximateSolution {
440 pub fn satisfaction_ratio(&self) -> f32 {
442 if self.total_constraints == 0 {
443 1.0
444 } else {
445 self.satisfied_constraints as f32 / self.total_constraints as f32
446 }
447 }
448
449 pub fn is_feasible(&self) -> bool {
451 self.satisfied_constraints == self.total_constraints
452 }
453
454 pub fn average_violation(&self) -> f32 {
456 if self.total_constraints == 0 {
457 0.0
458 } else {
459 self.total_violation / self.total_constraints as f32
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::LinearConstraint;
468
469 #[test]
470 fn test_hierarchical_relaxation() {
471 let mut solver = HierarchicalRelaxation::new(200, 0.01);
472
473 solver.add_constraint(LinearConstraint::less_eq(vec![1.0], 5.0), 2, 2.0);
475 solver.add_constraint(LinearConstraint::less_eq(vec![1.0], 3.0), 1, 1.0);
477
478 let initial = Array1::from_vec(vec![10.0]);
479 let result = solver.solve(&initial);
480
481 assert!(result.solution[0] < 10.0);
483 assert_eq!(solver.num_constraints(), 2);
484 assert!(result.total_violation < 5.0);
486 }
487
488 #[test]
489 fn test_bounded_error_solver() {
490 let mut solver = BoundedErrorSolver::new(0.5, 0.1, 100);
491
492 solver.add_constraint(LinearConstraint::less_eq(vec![1.0], 5.0));
493 solver.add_constraint(LinearConstraint::greater_eq(vec![1.0], 2.0));
494
495 let initial = Array1::from_vec(vec![10.0]);
496 let result = solver.solve(&initial);
497
498 assert!(result.average_violation() <= 0.5);
499 assert!(result.satisfaction_ratio() > 0.0);
500 }
501
502 #[test]
503 fn test_anytime_solver_iterations() {
504 let constraints = vec![
505 LinearConstraint::less_eq(vec![1.0], 5.0),
506 LinearConstraint::greater_eq(vec![1.0], 0.0),
507 ];
508
509 let initial = Array1::from_vec(vec![10.0]);
510 let mut solver = AnytimeSolver::new(constraints, initial, 0.1);
511
512 let result = solver.solve_for_iterations(100);
513
514 assert!(result.solution[0] >= 0.0);
515 assert!(result.solution[0] <= 6.0); assert_eq!(solver.iterations(), 100);
517 }
518
519 #[test]
520 fn test_anytime_solver_duration() {
521 let constraints = vec![LinearConstraint::less_eq(vec![1.0, 1.0], 10.0)];
522
523 let initial = Array1::from_vec(vec![15.0, 15.0]);
524 let mut solver = AnytimeSolver::new(constraints, initial, 0.1);
525
526 let result = solver.solve_for_duration(Duration::from_millis(10));
527
528 assert!(result.computation_time >= Duration::from_millis(10));
529 assert!(result.total_violation < 100.0); }
531
532 #[test]
533 fn test_approximate_solution_metrics() {
534 let solution = ApproximateSolution {
535 solution: Array1::from_vec(vec![5.0]),
536 satisfied_constraints: 3,
537 total_constraints: 5,
538 total_violation: 2.5,
539 computation_time: Duration::from_millis(100),
540 optimality_gap: Some(0.1),
541 };
542
543 assert_eq!(solution.satisfaction_ratio(), 0.6);
544 assert!(!solution.is_feasible());
545 assert_eq!(solution.average_violation(), 0.5);
546 }
547}