rusvm/smo/
ws.rs

1use super::subproblem::{compute_step, Subproblem};
2use crate::kernel::Kernel;
3use crate::problem::DualProblem;
4use crate::status::Status;
5
6fn find_mvp_signed(
7    problem: &dyn DualProblem,
8    status: &mut Status,
9    active_set: &Vec<usize>,
10    sign: f64,
11) -> (f64, f64, usize, usize) {
12    let mut g_min = f64::INFINITY;
13    let mut g_max = f64::NEG_INFINITY;
14    let mut idx_i: usize = 0;
15    let mut idx_j: usize = 0;
16    for (idx, &i) in active_set.iter().enumerate() {
17        let g_i = status.ka[i] + problem.d_dloss(i, status.a[i]);
18        status.g[i] = g_i;
19        if problem.sign(i) * sign >= 0.0 {
20            if status.a[i] > problem.lb(i) && g_i > g_max {
21                idx_i = idx;
22                g_max = g_i;
23            }
24            if status.a[i] < problem.ub(i) && g_i < g_min {
25                idx_j = idx;
26                g_min = g_i;
27            }
28        }
29    }
30    (g_max - g_min, g_max + g_min, idx_i, idx_j)
31}
32
33pub fn find_mvp(
34    problem: &dyn DualProblem,
35    status: &mut Status,
36    active_set: &Vec<usize>,
37) -> (usize, usize) {
38    let (dij, idx_i, idx_j) = if status.asum == problem.max_asum() {
39        let (dij_p, sij_p, idx_i_p, idx_j_p) = find_mvp_signed(problem, status, active_set, 1.0);
40        let (dij_n, sij_n, idx_i_n, idx_j_n) = find_mvp_signed(problem, status, active_set, -1.0);
41        status.b = -0.25 * (sij_p + sij_n);
42        status.c = 0.25 * (sij_n - sij_p);
43        if dij_p >= dij_n {
44            (dij_p, idx_i_p, idx_j_p)
45        } else {
46            (dij_n, idx_i_n, idx_j_n)
47        }
48    } else {
49        let (dij, sij, idx_i, idx_j) = find_mvp_signed(problem, status, active_set, 0.0);
50        status.b = -0.5 * sij;
51        status.opt_status.violation = dij;
52        (dij, idx_i, idx_j)
53    };
54    status.opt_status.violation = dij;
55    (idx_i, idx_j)
56}
57
58pub fn find_ws2(
59    problem: &dyn DualProblem,
60    kernel: &mut dyn Kernel,
61    idx_i0: usize,
62    idx_j1: usize,
63    status: &Status,
64    active_set: &Vec<usize>,
65    sign: f64,
66) -> (usize, usize) {
67    let i0 = active_set[idx_i0];
68    let j1 = active_set[idx_j1];
69    let gi0 = status.g[i0];
70    let gj1 = status.g[j1];
71    let mut max_d0 = 0.0;
72    let mut max_d1 = 0.0;
73    let mut idx_j0 = idx_j1;
74    let mut idx_i1 = idx_i0;
75
76    let diags: Vec<f64> = active_set.iter().map(|&i| kernel.diag(i)).collect();
77    kernel.use_rows([i0, j1].as_slice(), &active_set, &mut |kij: Vec<&[f64]>| {
78        let ki0 = kij[0];
79        let kj1 = kij[1];
80        let ki0i0 = ki0[idx_i0];
81        let kj1j1 = kj1[idx_j1];
82        let max_ti0 = status.a[i0] - problem.lb(i0);
83        let max_tj1 = problem.ub(j1) - status.a[j1];
84
85        for (idx_r, &r) in active_set.iter().enumerate() {
86            if sign * problem.sign(r) < 0.0 {
87                continue;
88            }
89            let gr = status.g[r];
90            let krr = diags[idx_r];
91
92            let pi0r = gi0 - gr;
93            let d_upr = problem.ub(r) - status.a[r];
94            if d_upr > 0.0 && pi0r > 0.0 {
95                let step = compute_step(
96                    problem,
97                    Subproblem {
98                        ij: (i0, r),
99                        max_t: f64::min(max_ti0, d_upr),
100                        q0: ki0i0 + krr - 2.0 * ki0[idx_r],
101                        p0: status.ka[i0] - status.ka[r],
102                    },
103                    status,
104                );
105                if step.dvalue > max_d0 {
106                    idx_j0 = idx_r;
107                    max_d0 = step.dvalue;
108                }
109            }
110
111            let prj1 = gr - gj1;
112            let d_dnr = status.a[r] - problem.lb(r);
113            if d_dnr > 0.0 && prj1 > 0.0 {
114                let step = compute_step(
115                    problem,
116                    Subproblem {
117                        ij: (r, j1),
118                        max_t: f64::min(max_tj1, d_dnr),
119                        q0: kj1j1 + krr - 2.0 * kj1[idx_r],
120                        p0: status.ka[r] - status.ka[j1],
121                    },
122                    status,
123                );
124                if step.dvalue > max_d1 {
125                    idx_i1 = idx_r;
126                    max_d1 = step.dvalue;
127                }
128            }
129        }
130    });
131    if max_d0 > max_d1 {
132        (idx_i0, idx_j0)
133    } else {
134        (idx_i1, idx_j1)
135    }
136}