Skip to main content

libsvm_rs/
qmatrix.rs

1//! Q matrix implementations for the SMO solver.
2//!
3//! The Q matrix encodes the quadratic form in the SVM dual problem:
4//! `Q[i][j] = y[i] * y[j] * K(x[i], x[j])` for classification,
5//! or just `K(i,j)` for one-class, or a signed/indexed variant for SVR.
6//!
7//! Each implementation wraps a `Kernel` and a `Cache`, providing the
8//! `QMatrix` trait that the solver consumes.
9
10use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14/// Trait for Q matrix access used by the SMO solver.
15///
16/// Takes `&mut self` because `get_q` mutates the internal cache.
17/// The solver owns `Box<dyn QMatrix>` and copies row data into its
18/// own buffers to avoid lifetime issues.
19pub trait QMatrix {
20    /// Get column `i` of the Q matrix, with at least `len` elements.
21    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23    /// Get the diagonal of Q: `QD[i] = Q[i][i]`.
24    fn get_qd(&self) -> &[f64];
25
26    /// Swap indices i and j in all internal data structures.
27    fn swap_index(&mut self, i: usize, j: usize);
28}
29
30// ─── SVC_Q ──────────────────────────────────────────────────────────
31
32/// Q matrix for C-SVC and ν-SVC classification.
33///
34/// `Q[i][j] = y[i] * y[j] * K(x[i], x[j])` stored as `Qfloat` (f32).
35pub struct SvcQ<'a> {
36    kernel: Kernel<'a>,
37    cache: Cache,
38    y: Vec<i8>,
39    qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
44        let l = x.len();
45        let kernel = Kernel::new(x, param);
46        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
47        let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
48        let y = y.to_vec();
49        Self { kernel, cache, y, qd }
50    }
51}
52
53#[allow(clippy::needless_range_loop)]
54impl<'a> QMatrix for SvcQ<'a> {
55    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
56        let (data, start) = self.cache.get_data(i, len);
57        if start < len {
58            let yi = self.y[i] as f64;
59            for j in start..len {
60                let kval = self.kernel.evaluate(i, j);
61                data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
62            }
63        }
64        &data[..len]
65    }
66
67    fn get_qd(&self) -> &[f64] {
68        &self.qd
69    }
70
71    fn swap_index(&mut self, i: usize, j: usize) {
72        self.cache.swap_index(i, j);
73        self.kernel.swap_index(i, j);
74        self.y.swap(i, j);
75        self.qd.swap(i, j);
76    }
77}
78
79// ─── ONE_CLASS_Q ────────────────────────────────────────────────────
80
81/// Q matrix for one-class SVM.
82///
83/// `Q[i][j] = K(x[i], x[j])` — no label scaling.
84pub struct OneClassQ<'a> {
85    kernel: Kernel<'a>,
86    cache: Cache,
87    qd: Vec<f64>,
88}
89
90impl<'a> OneClassQ<'a> {
91    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
92        let l = x.len();
93        let kernel = Kernel::new(x, param);
94        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
95        let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
96        Self { kernel, cache, qd }
97    }
98}
99
100#[allow(clippy::needless_range_loop)]
101impl<'a> QMatrix for OneClassQ<'a> {
102    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
103        let (data, start) = self.cache.get_data(i, len);
104        if start < len {
105            for j in start..len {
106                data[j] = self.kernel.evaluate(i, j) as Qfloat;
107            }
108        }
109        &data[..len]
110    }
111
112    fn get_qd(&self) -> &[f64] {
113        &self.qd
114    }
115
116    fn swap_index(&mut self, i: usize, j: usize) {
117        self.cache.swap_index(i, j);
118        self.kernel.swap_index(i, j);
119        self.qd.swap(i, j);
120    }
121}
122
123// ─── SVR_Q ──────────────────────────────────────────────────────────
124
125/// Q matrix for ε-SVR and ν-SVR regression.
126///
127/// The regression dual has 2l variables (α_i^+ and α_i^-).
128/// The underlying kernel cache stores only l rows of actual kernel
129/// evaluations; `get_q` reorders/signs them into a double-buffered output.
130pub struct SvrQ<'a> {
131    kernel: Kernel<'a>,
132    cache: Cache,
133    /// Number of original data points.
134    l: usize,
135    /// Sign of each of the 2l variables: +1 for first l, -1 for second l.
136    sign: Vec<i8>,
137    /// Maps each of the 2l indices to the original data index [0..l).
138    index: Vec<usize>,
139    /// Diagonal of the 2l×2l Q matrix.
140    qd: Vec<f64>,
141    /// Double buffer for returning Q rows (solver may hold two simultaneously).
142    buffer: [Vec<Qfloat>; 2],
143    /// Toggle between the two buffers.
144    next_buffer: usize,
145}
146
147impl<'a> SvrQ<'a> {
148    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
149        let l = x.len();
150        let kernel = Kernel::new(x, param);
151        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
152
153        let mut sign = vec![0i8; 2 * l];
154        let mut index = vec![0usize; 2 * l];
155        let mut qd = vec![0.0f64; 2 * l];
156
157        for k in 0..l {
158            sign[k] = 1;
159            sign[k + l] = -1;
160            index[k] = k;
161            index[k + l] = k;
162            let kk = kernel.evaluate(k, k);
163            qd[k] = kk;
164            qd[k + l] = kk;
165        }
166
167        let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
168
169        Self {
170            kernel,
171            cache,
172            l,
173            sign,
174            index,
175            qd,
176            buffer,
177            next_buffer: 0,
178        }
179    }
180}
181
182#[allow(clippy::needless_range_loop)]
183impl<'a> QMatrix for SvrQ<'a> {
184    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
185        let real_i = self.index[i];
186        let l = self.l;
187
188        // Fetch (or fill) the full kernel row for the real data index
189        let (data, start) = self.cache.get_data(real_i, l);
190        if start < l {
191            for j in start..l {
192                data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
193            }
194        }
195
196        // Reorder and apply signs into the output buffer
197        let buf_idx = self.next_buffer;
198        self.next_buffer = 1 - self.next_buffer;
199        let si = self.sign[i] as f32;
200        let buf = &mut self.buffer[buf_idx];
201        for j in 0..len {
202            buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
203        }
204        &self.buffer[buf_idx][..len]
205    }
206
207    fn get_qd(&self) -> &[f64] {
208        &self.qd
209    }
210
211    fn swap_index(&mut self, i: usize, j: usize) {
212        self.sign.swap(i, j);
213        self.index.swap(i, j);
214        self.qd.swap(i, j);
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::types::{KernelType, SvmNode, SvmParameter};
222
223    fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
224        pairs.iter().map(|&(i, v)| SvmNode { index: i, value: v }).collect()
225    }
226
227    fn default_rbf_param() -> SvmParameter {
228        SvmParameter {
229            kernel_type: KernelType::Rbf,
230            gamma: 0.5,
231            cache_size: 1.0,
232            ..Default::default()
233        }
234    }
235
236    #[test]
237    fn svc_q_diagonal_equals_one_for_rbf() {
238        let data = vec![
239            make_nodes(&[(1, 1.0), (2, 0.0)]),
240            make_nodes(&[(1, 0.0), (2, 1.0)]),
241        ];
242        let y = vec![1i8, -1i8];
243        let param = default_rbf_param();
244        let q = SvcQ::new(&data, &param, &y);
245        // K(x,x) = 1 for RBF, QD[i] = y[i]*y[i]*1 = 1
246        for &d in q.get_qd() {
247            assert!((d - 1.0).abs() < 1e-15);
248        }
249    }
250
251    #[test]
252    #[allow(clippy::needless_range_loop)]
253    fn svc_q_symmetry_and_sign() {
254        let data = vec![
255            make_nodes(&[(1, 1.0)]),
256            make_nodes(&[(1, 2.0)]),
257            make_nodes(&[(1, 3.0)]),
258        ];
259        let y = vec![1i8, -1i8, 1i8];
260        let param = default_rbf_param();
261        let mut q = SvcQ::new(&data, &param, &y);
262        let l = data.len();
263
264        // Collect full matrix
265        let mut matrix = vec![vec![0.0f32; l]; l];
266        for i in 0..l {
267            let row = q.get_q(i, l).to_vec();
268            matrix[i][..l].copy_from_slice(&row[..l]);
269        }
270
271        // Check symmetry
272        for i in 0..l {
273            for j in 0..l {
274                assert!(
275                    (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
276                    "Q[{},{}]={} != Q[{},{}]={}",
277                    i, j, matrix[i][j], j, i, matrix[j][i]
278                );
279            }
280        }
281
282        // Check sign: Q[0][1] should be negative (y[0]*y[1] = -1)
283        assert!(matrix[0][1] < 0.0);
284        // Q[0][2] should be positive (y[0]*y[2] = +1)
285        assert!(matrix[0][2] > 0.0);
286    }
287
288    #[test]
289    fn one_class_q_no_sign_scaling() {
290        let data = vec![
291            make_nodes(&[(1, 1.0)]),
292            make_nodes(&[(1, 2.0)]),
293        ];
294        let param = default_rbf_param();
295        let mut q = OneClassQ::new(&data, &param);
296
297        let row = q.get_q(0, 2);
298        // All values should be positive (kernel values are always positive for RBF)
299        assert!(row[0] > 0.0);
300        assert!(row[1] > 0.0);
301        // Diagonal should be 1.0
302        assert!((row[0] - 1.0).abs() < 1e-6);
303    }
304
305    #[test]
306    fn svr_q_double_buffer() {
307        let data = vec![
308            make_nodes(&[(1, 1.0)]),
309            make_nodes(&[(1, 2.0)]),
310        ];
311        let param = default_rbf_param();
312        let mut q = SvrQ::new(&data, &param);
313        let l2 = 2 * data.len(); // 4
314
315        // Get two rows — they use different buffers
316        let row0 = q.get_q(0, l2).to_vec();
317        let row1 = q.get_q(1, l2).to_vec();
318
319        // Row 0: sign[0]=+1, index[0]=0 → K(0, index[j]) * sign[0] * sign[j]
320        // Row 1: sign[1]=+1, index[1]=1 → K(1, index[j]) * sign[1] * sign[j]
321        // Both should have non-zero entries
322        assert!(row0.iter().any(|&v| v != 0.0));
323        assert!(row1.iter().any(|&v| v != 0.0));
324
325        // For index 0 (sign +1) and index 2 (sign -1, real_idx 0):
326        // Q[0][2] = sign[0]*sign[2]*K(0,0) = 1*(-1)*1 = -1
327        assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
328    }
329}