1use super::subproblem::{compute_step, Subproblem};
2use crate::kernel::Kernel;
3use crate::problem::DualProblem;
4use crate::status::Status;
5
6fn find_mvp_signed(
7 problem: &dyn DualProblem,
8 status: &mut Status,
9 active_set: &Vec<usize>,
10 sign: f64,
11) -> (f64, f64, usize, usize) {
12 let mut g_min = f64::INFINITY;
13 let mut g_max = f64::NEG_INFINITY;
14 let mut idx_i: usize = 0;
15 let mut idx_j: usize = 0;
16 for (idx, &i) in active_set.iter().enumerate() {
17 let g_i = status.ka[i] + problem.d_dloss(i, status.a[i]);
18 status.g[i] = g_i;
19 if problem.sign(i) * sign >= 0.0 {
20 if status.a[i] > problem.lb(i) && g_i > g_max {
21 idx_i = idx;
22 g_max = g_i;
23 }
24 if status.a[i] < problem.ub(i) && g_i < g_min {
25 idx_j = idx;
26 g_min = g_i;
27 }
28 }
29 }
30 (g_max - g_min, g_max + g_min, idx_i, idx_j)
31}
32
33pub fn find_mvp(
34 problem: &dyn DualProblem,
35 status: &mut Status,
36 active_set: &Vec<usize>,
37) -> (usize, usize) {
38 let (dij, idx_i, idx_j) = if status.asum == problem.max_asum() {
39 let (dij_p, sij_p, idx_i_p, idx_j_p) = find_mvp_signed(problem, status, active_set, 1.0);
40 let (dij_n, sij_n, idx_i_n, idx_j_n) = find_mvp_signed(problem, status, active_set, -1.0);
41 status.b = -0.25 * (sij_p + sij_n);
42 status.c = 0.25 * (sij_n - sij_p);
43 if dij_p >= dij_n {
44 (dij_p, idx_i_p, idx_j_p)
45 } else {
46 (dij_n, idx_i_n, idx_j_n)
47 }
48 } else {
49 let (dij, sij, idx_i, idx_j) = find_mvp_signed(problem, status, active_set, 0.0);
50 status.b = -0.5 * sij;
51 status.opt_status.violation = dij;
52 (dij, idx_i, idx_j)
53 };
54 status.opt_status.violation = dij;
55 (idx_i, idx_j)
56}
57
58pub fn find_ws2(
59 problem: &dyn DualProblem,
60 kernel: &mut dyn Kernel,
61 idx_i0: usize,
62 idx_j1: usize,
63 status: &Status,
64 active_set: &Vec<usize>,
65 sign: f64,
66) -> (usize, usize) {
67 let i0 = active_set[idx_i0];
68 let j1 = active_set[idx_j1];
69 let gi0 = status.g[i0];
70 let gj1 = status.g[j1];
71 let mut max_d0 = 0.0;
72 let mut max_d1 = 0.0;
73 let mut idx_j0 = idx_j1;
74 let mut idx_i1 = idx_i0;
75
76 let diags: Vec<f64> = active_set.iter().map(|&i| kernel.diag(i)).collect();
77 kernel.use_rows([i0, j1].as_slice(), &active_set, &mut |kij: Vec<&[f64]>| {
78 let ki0 = kij[0];
79 let kj1 = kij[1];
80 let ki0i0 = ki0[idx_i0];
81 let kj1j1 = kj1[idx_j1];
82 let max_ti0 = status.a[i0] - problem.lb(i0);
83 let max_tj1 = problem.ub(j1) - status.a[j1];
84
85 for (idx_r, &r) in active_set.iter().enumerate() {
86 if sign * problem.sign(r) < 0.0 {
87 continue;
88 }
89 let gr = status.g[r];
90 let krr = diags[idx_r];
91
92 let pi0r = gi0 - gr;
93 let d_upr = problem.ub(r) - status.a[r];
94 if d_upr > 0.0 && pi0r > 0.0 {
95 let step = compute_step(
96 problem,
97 Subproblem {
98 ij: (i0, r),
99 max_t: f64::min(max_ti0, d_upr),
100 q0: ki0i0 + krr - 2.0 * ki0[idx_r],
101 p0: status.ka[i0] - status.ka[r],
102 },
103 status,
104 );
105 if step.dvalue > max_d0 {
106 idx_j0 = idx_r;
107 max_d0 = step.dvalue;
108 }
109 }
110
111 let prj1 = gr - gj1;
112 let d_dnr = status.a[r] - problem.lb(r);
113 if d_dnr > 0.0 && prj1 > 0.0 {
114 let step = compute_step(
115 problem,
116 Subproblem {
117 ij: (r, j1),
118 max_t: f64::min(max_tj1, d_dnr),
119 q0: kj1j1 + krr - 2.0 * kj1[idx_r],
120 p0: status.ka[r] - status.ka[j1],
121 },
122 status,
123 );
124 if step.dvalue > max_d1 {
125 idx_i1 = idx_r;
126 max_d1 = step.dvalue;
127 }
128 }
129 }
130 });
131 if max_d0 > max_d1 {
132 (idx_i0, idx_j0)
133 } else {
134 (idx_i1, idx_j1)
135 }
136}