Skip to main content

libsvm_rs/
qmatrix.rs

1//! Q matrix implementations for the SMO solver.
2//!
3//! The Q matrix encodes the quadratic form in the SVM dual problem:
4//! `Q[i][j] = y[i] * y[j] * K(x[i], x[j])` for classification,
5//! or just `K(i,j)` for one-class, or a signed/indexed variant for SVR.
6//!
7//! Each implementation wraps a `Kernel` and a `Cache`, providing the
8//! `QMatrix` trait that the solver consumes.
9
10use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14/// Trait for Q matrix access used by the SMO solver.
15///
16/// Takes `&mut self` because `get_q` mutates the internal cache.
17/// The solver owns `Box<dyn QMatrix>` and copies row data into its
18/// own buffers to avoid lifetime issues.
19pub trait QMatrix {
20    /// Get column `i` of the Q matrix, with at least `len` elements.
21    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23    /// Get the diagonal of Q: `QD[i] = Q[i][i]`.
24    fn get_qd(&self) -> &[f64];
25
26    /// Swap indices i and j in all internal data structures.
27    fn swap_index(&mut self, i: usize, j: usize);
28}
29
30// ─── SVC_Q ──────────────────────────────────────────────────────────
31
32/// Q matrix for C-SVC and ν-SVC classification.
33///
34/// `Q[i][j] = y[i] * y[j] * K(x[i], x[j])` stored as `Qfloat` (f32).
35pub struct SvcQ<'a> {
36    kernel: Kernel<'a>,
37    cache: Cache,
38    y: Vec<i8>,
39    qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
44        let l = x.len();
45        let kernel = Kernel::new(x, param);
46        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
47        let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
48        let y = y.to_vec();
49        Self {
50            kernel,
51            cache,
52            y,
53            qd,
54        }
55    }
56}
57
58#[allow(clippy::needless_range_loop)]
59impl<'a> QMatrix for SvcQ<'a> {
60    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
61        let (data, start) = self.cache.get_data(i, len);
62        if start < len {
63            let yi = self.y[i] as f64;
64            for j in start..len {
65                let kval = self.kernel.evaluate(i, j);
66                data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
67            }
68        }
69        &data[..len]
70    }
71
72    fn get_qd(&self) -> &[f64] {
73        &self.qd
74    }
75
76    fn swap_index(&mut self, i: usize, j: usize) {
77        self.cache.swap_index(i, j);
78        self.kernel.swap_index(i, j);
79        self.y.swap(i, j);
80        self.qd.swap(i, j);
81    }
82}
83
84// ─── ONE_CLASS_Q ────────────────────────────────────────────────────
85
86/// Q matrix for one-class SVM.
87///
88/// `Q[i][j] = K(x[i], x[j])` — no label scaling.
89pub struct OneClassQ<'a> {
90    kernel: Kernel<'a>,
91    cache: Cache,
92    qd: Vec<f64>,
93}
94
95impl<'a> OneClassQ<'a> {
96    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
97        let l = x.len();
98        let kernel = Kernel::new(x, param);
99        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
100        let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
101        Self { kernel, cache, qd }
102    }
103}
104
105#[allow(clippy::needless_range_loop)]
106impl<'a> QMatrix for OneClassQ<'a> {
107    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
108        let (data, start) = self.cache.get_data(i, len);
109        if start < len {
110            for j in start..len {
111                data[j] = self.kernel.evaluate(i, j) as Qfloat;
112            }
113        }
114        &data[..len]
115    }
116
117    fn get_qd(&self) -> &[f64] {
118        &self.qd
119    }
120
121    fn swap_index(&mut self, i: usize, j: usize) {
122        self.cache.swap_index(i, j);
123        self.kernel.swap_index(i, j);
124        self.qd.swap(i, j);
125    }
126}
127
128// ─── SVR_Q ──────────────────────────────────────────────────────────
129
130/// Q matrix for ε-SVR and ν-SVR regression.
131///
132/// The regression dual has 2l variables (α_i^+ and α_i^-).
133/// The underlying kernel cache stores only l rows of actual kernel
134/// evaluations; `get_q` reorders/signs them into a double-buffered output.
135pub struct SvrQ<'a> {
136    kernel: Kernel<'a>,
137    cache: Cache,
138    /// Number of original data points.
139    l: usize,
140    /// Sign of each of the 2l variables: +1 for first l, -1 for second l.
141    sign: Vec<i8>,
142    /// Maps each of the 2l indices to the original data index [0..l).
143    index: Vec<usize>,
144    /// Diagonal of the 2l×2l Q matrix.
145    qd: Vec<f64>,
146    /// Double buffer for returning Q rows (solver may hold two simultaneously).
147    buffer: [Vec<Qfloat>; 2],
148    /// Toggle between the two buffers.
149    next_buffer: usize,
150}
151
152impl<'a> SvrQ<'a> {
153    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
154        let l = x.len();
155        let kernel = Kernel::new(x, param);
156        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
157
158        let mut sign = vec![0i8; 2 * l];
159        let mut index = vec![0usize; 2 * l];
160        let mut qd = vec![0.0f64; 2 * l];
161
162        for k in 0..l {
163            sign[k] = 1;
164            sign[k + l] = -1;
165            index[k] = k;
166            index[k + l] = k;
167            let kk = kernel.evaluate(k, k);
168            qd[k] = kk;
169            qd[k + l] = kk;
170        }
171
172        let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
173
174        Self {
175            kernel,
176            cache,
177            l,
178            sign,
179            index,
180            qd,
181            buffer,
182            next_buffer: 0,
183        }
184    }
185}
186
187#[allow(clippy::needless_range_loop)]
188impl<'a> QMatrix for SvrQ<'a> {
189    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
190        let real_i = self.index[i];
191        let l = self.l;
192
193        // Fetch (or fill) the full kernel row for the real data index
194        let (data, start) = self.cache.get_data(real_i, l);
195        if start < l {
196            for j in start..l {
197                data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
198            }
199        }
200
201        // Reorder and apply signs into the output buffer
202        let buf_idx = self.next_buffer;
203        self.next_buffer = 1 - self.next_buffer;
204        let si = self.sign[i] as f32;
205        let buf = &mut self.buffer[buf_idx];
206        for j in 0..len {
207            buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
208        }
209        &self.buffer[buf_idx][..len]
210    }
211
212    fn get_qd(&self) -> &[f64] {
213        &self.qd
214    }
215
216    fn swap_index(&mut self, i: usize, j: usize) {
217        self.sign.swap(i, j);
218        self.index.swap(i, j);
219        self.qd.swap(i, j);
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::types::{KernelType, SvmNode, SvmParameter};
227
228    fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
229        pairs
230            .iter()
231            .map(|&(i, v)| SvmNode { index: i, value: v })
232            .collect()
233    }
234
235    fn default_rbf_param() -> SvmParameter {
236        SvmParameter {
237            kernel_type: KernelType::Rbf,
238            gamma: 0.5,
239            cache_size: 1.0,
240            ..Default::default()
241        }
242    }
243
244    #[test]
245    fn svc_q_diagonal_equals_one_for_rbf() {
246        let data = vec![
247            make_nodes(&[(1, 1.0), (2, 0.0)]),
248            make_nodes(&[(1, 0.0), (2, 1.0)]),
249        ];
250        let y = vec![1i8, -1i8];
251        let param = default_rbf_param();
252        let q = SvcQ::new(&data, &param, &y);
253        // K(x,x) = 1 for RBF, QD[i] = y[i]*y[i]*1 = 1
254        for &d in q.get_qd() {
255            assert!((d - 1.0).abs() < 1e-15);
256        }
257    }
258
259    #[test]
260    #[allow(clippy::needless_range_loop)]
261    fn svc_q_symmetry_and_sign() {
262        let data = vec![
263            make_nodes(&[(1, 1.0)]),
264            make_nodes(&[(1, 2.0)]),
265            make_nodes(&[(1, 3.0)]),
266        ];
267        let y = vec![1i8, -1i8, 1i8];
268        let param = default_rbf_param();
269        let mut q = SvcQ::new(&data, &param, &y);
270        let l = data.len();
271
272        // Collect full matrix
273        let mut matrix = vec![vec![0.0f32; l]; l];
274        for i in 0..l {
275            let row = q.get_q(i, l).to_vec();
276            matrix[i][..l].copy_from_slice(&row[..l]);
277        }
278
279        // Check symmetry
280        for i in 0..l {
281            for j in 0..l {
282                assert!(
283                    (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
284                    "Q[{},{}]={} != Q[{},{}]={}",
285                    i,
286                    j,
287                    matrix[i][j],
288                    j,
289                    i,
290                    matrix[j][i]
291                );
292            }
293        }
294
295        // Check sign: Q[0][1] should be negative (y[0]*y[1] = -1)
296        assert!(matrix[0][1] < 0.0);
297        // Q[0][2] should be positive (y[0]*y[2] = +1)
298        assert!(matrix[0][2] > 0.0);
299    }
300
301    #[test]
302    fn one_class_q_no_sign_scaling() {
303        let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
304        let param = default_rbf_param();
305        let mut q = OneClassQ::new(&data, &param);
306
307        let row = q.get_q(0, 2);
308        // All values should be positive (kernel values are always positive for RBF)
309        assert!(row[0] > 0.0);
310        assert!(row[1] > 0.0);
311        // Diagonal should be 1.0
312        assert!((row[0] - 1.0).abs() < 1e-6);
313    }
314
315    #[test]
316    fn svr_q_double_buffer() {
317        let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
318        let param = default_rbf_param();
319        let mut q = SvrQ::new(&data, &param);
320        let l2 = 2 * data.len(); // 4
321
322        // Get two rows — they use different buffers
323        let row0 = q.get_q(0, l2).to_vec();
324        let row1 = q.get_q(1, l2).to_vec();
325
326        // Row 0: sign[0]=+1, index[0]=0 → K(0, index[j]) * sign[0] * sign[j]
327        // Row 1: sign[1]=+1, index[1]=1 → K(1, index[j]) * sign[1] * sign[j]
328        // Both should have non-zero entries
329        assert!(row0.iter().any(|&v| v != 0.0));
330        assert!(row1.iter().any(|&v| v != 0.0));
331
332        // For index 0 (sign +1) and index 2 (sign -1, real_idx 0):
333        // Q[0][2] = sign[0]*sign[2]*K(0,0) = 1*(-1)*1 = -1
334        assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
335    }
336}