Skip to main content

libsvm_rs/
qmatrix.rs

1//! Q matrix implementations for the SMO solver.
2//!
3//! The Q matrix encodes the quadratic form in the SVM dual problem:
4//! `Q[i][j] = y[i] * y[j] * K(x[i], x[j])` for classification,
5//! or just `K(i,j)` for one-class, or a signed/indexed variant for SVR.
6//!
7//! Each implementation wraps a `Kernel` and a `Cache`, providing the
8//! `QMatrix` trait that the solver consumes.
9
10use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14/// Trait for Q matrix access used by the SMO solver.
15///
16/// Takes `&mut self` because `get_q` mutates the internal cache.
17/// The solver owns `Box<dyn QMatrix>` and copies row data into its
18/// own buffers to avoid lifetime issues.
19pub trait QMatrix {
20    /// Get column `i` of the Q matrix, with at least `len` elements.
21    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23    /// Get the diagonal of Q: `QD[i] = Q[i][i]`.
24    fn get_qd(&self) -> &[f64];
25
26    /// Swap indices i and j in all internal data structures.
27    fn swap_index(&mut self, i: usize, j: usize);
28}
29
30// ─── SVC_Q ──────────────────────────────────────────────────────────
31
32/// Q matrix for C-SVC and ν-SVC classification.
33///
34/// `Q[i][j] = y[i] * y[j] * K(x[i], x[j])` stored as `Qfloat` (f32).
35pub struct SvcQ<'a> {
36    kernel: Kernel<'a>,
37    cache: Cache,
38    y: Vec<i8>,
39    qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
44        let l = x.len();
45        let kernel = Kernel::new(x, param);
46        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
47        let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
48        let y = y.to_vec();
49        Self { kernel, cache, y, qd }
50    }
51}
52
53impl<'a> QMatrix for SvcQ<'a> {
54    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
55        let (data, start) = self.cache.get_data(i, len);
56        if start < len {
57            let yi = self.y[i] as f64;
58            for j in start..len {
59                let kval = self.kernel.evaluate(i, j);
60                data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
61            }
62        }
63        &data[..len]
64    }
65
66    fn get_qd(&self) -> &[f64] {
67        &self.qd
68    }
69
70    fn swap_index(&mut self, i: usize, j: usize) {
71        self.cache.swap_index(i, j);
72        self.kernel.swap_index(i, j);
73        self.y.swap(i, j);
74        self.qd.swap(i, j);
75    }
76}
77
78// ─── ONE_CLASS_Q ────────────────────────────────────────────────────
79
80/// Q matrix for one-class SVM.
81///
82/// `Q[i][j] = K(x[i], x[j])` — no label scaling.
83pub struct OneClassQ<'a> {
84    kernel: Kernel<'a>,
85    cache: Cache,
86    qd: Vec<f64>,
87}
88
89impl<'a> OneClassQ<'a> {
90    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
91        let l = x.len();
92        let kernel = Kernel::new(x, param);
93        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
94        let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
95        Self { kernel, cache, qd }
96    }
97}
98
99impl<'a> QMatrix for OneClassQ<'a> {
100    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
101        let (data, start) = self.cache.get_data(i, len);
102        if start < len {
103            for j in start..len {
104                data[j] = self.kernel.evaluate(i, j) as Qfloat;
105            }
106        }
107        &data[..len]
108    }
109
110    fn get_qd(&self) -> &[f64] {
111        &self.qd
112    }
113
114    fn swap_index(&mut self, i: usize, j: usize) {
115        self.cache.swap_index(i, j);
116        self.kernel.swap_index(i, j);
117        self.qd.swap(i, j);
118    }
119}
120
121// ─── SVR_Q ──────────────────────────────────────────────────────────
122
123/// Q matrix for ε-SVR and ν-SVR regression.
124///
125/// The regression dual has 2l variables (α_i^+ and α_i^-).
126/// The underlying kernel cache stores only l rows of actual kernel
127/// evaluations; `get_q` reorders/signs them into a double-buffered output.
128pub struct SvrQ<'a> {
129    kernel: Kernel<'a>,
130    cache: Cache,
131    /// Number of original data points.
132    l: usize,
133    /// Sign of each of the 2l variables: +1 for first l, -1 for second l.
134    sign: Vec<i8>,
135    /// Maps each of the 2l indices to the original data index [0..l).
136    index: Vec<usize>,
137    /// Diagonal of the 2l×2l Q matrix.
138    qd: Vec<f64>,
139    /// Double buffer for returning Q rows (solver may hold two simultaneously).
140    buffer: [Vec<Qfloat>; 2],
141    /// Toggle between the two buffers.
142    next_buffer: usize,
143}
144
145impl<'a> SvrQ<'a> {
146    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
147        let l = x.len();
148        let kernel = Kernel::new(x, param);
149        let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
150
151        let mut sign = vec![0i8; 2 * l];
152        let mut index = vec![0usize; 2 * l];
153        let mut qd = vec![0.0f64; 2 * l];
154
155        for k in 0..l {
156            sign[k] = 1;
157            sign[k + l] = -1;
158            index[k] = k;
159            index[k + l] = k;
160            let kk = kernel.evaluate(k, k);
161            qd[k] = kk;
162            qd[k + l] = kk;
163        }
164
165        let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
166
167        Self {
168            kernel,
169            cache,
170            l,
171            sign,
172            index,
173            qd,
174            buffer,
175            next_buffer: 0,
176        }
177    }
178}
179
180impl<'a> QMatrix for SvrQ<'a> {
181    fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
182        let real_i = self.index[i];
183        let l = self.l;
184
185        // Fetch (or fill) the full kernel row for the real data index
186        let (data, start) = self.cache.get_data(real_i, l);
187        if start < l {
188            for j in start..l {
189                data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
190            }
191        }
192
193        // Reorder and apply signs into the output buffer
194        let buf_idx = self.next_buffer;
195        self.next_buffer = 1 - self.next_buffer;
196        let si = self.sign[i] as f32;
197        let buf = &mut self.buffer[buf_idx];
198        for j in 0..len {
199            buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
200        }
201        &self.buffer[buf_idx][..len]
202    }
203
204    fn get_qd(&self) -> &[f64] {
205        &self.qd
206    }
207
208    fn swap_index(&mut self, i: usize, j: usize) {
209        self.sign.swap(i, j);
210        self.index.swap(i, j);
211        self.qd.swap(i, j);
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::types::{KernelType, SvmNode, SvmParameter};
219
220    fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
221        pairs.iter().map(|&(i, v)| SvmNode { index: i, value: v }).collect()
222    }
223
224    fn default_rbf_param() -> SvmParameter {
225        SvmParameter {
226            kernel_type: KernelType::Rbf,
227            gamma: 0.5,
228            cache_size: 1.0,
229            ..Default::default()
230        }
231    }
232
233    #[test]
234    fn svc_q_diagonal_equals_one_for_rbf() {
235        let data = vec![
236            make_nodes(&[(1, 1.0), (2, 0.0)]),
237            make_nodes(&[(1, 0.0), (2, 1.0)]),
238        ];
239        let y = vec![1i8, -1i8];
240        let param = default_rbf_param();
241        let q = SvcQ::new(&data, &param, &y);
242        // K(x,x) = 1 for RBF, QD[i] = y[i]*y[i]*1 = 1
243        for &d in q.get_qd() {
244            assert!((d - 1.0).abs() < 1e-15);
245        }
246    }
247
248    #[test]
249    fn svc_q_symmetry_and_sign() {
250        let data = vec![
251            make_nodes(&[(1, 1.0)]),
252            make_nodes(&[(1, 2.0)]),
253            make_nodes(&[(1, 3.0)]),
254        ];
255        let y = vec![1i8, -1i8, 1i8];
256        let param = default_rbf_param();
257        let mut q = SvcQ::new(&data, &param, &y);
258        let l = data.len();
259
260        // Collect full matrix
261        let mut matrix = vec![vec![0.0f32; l]; l];
262        for i in 0..l {
263            let row = q.get_q(i, l).to_vec();
264            for j in 0..l {
265                matrix[i][j] = row[j];
266            }
267        }
268
269        // Check symmetry
270        for i in 0..l {
271            for j in 0..l {
272                assert!(
273                    (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
274                    "Q[{},{}]={} != Q[{},{}]={}",
275                    i, j, matrix[i][j], j, i, matrix[j][i]
276                );
277            }
278        }
279
280        // Check sign: Q[0][1] should be negative (y[0]*y[1] = -1)
281        assert!(matrix[0][1] < 0.0);
282        // Q[0][2] should be positive (y[0]*y[2] = +1)
283        assert!(matrix[0][2] > 0.0);
284    }
285
286    #[test]
287    fn one_class_q_no_sign_scaling() {
288        let data = vec![
289            make_nodes(&[(1, 1.0)]),
290            make_nodes(&[(1, 2.0)]),
291        ];
292        let param = default_rbf_param();
293        let mut q = OneClassQ::new(&data, &param);
294
295        let row = q.get_q(0, 2);
296        // All values should be positive (kernel values are always positive for RBF)
297        assert!(row[0] > 0.0);
298        assert!(row[1] > 0.0);
299        // Diagonal should be 1.0
300        assert!((row[0] - 1.0).abs() < 1e-6);
301    }
302
303    #[test]
304    fn svr_q_double_buffer() {
305        let data = vec![
306            make_nodes(&[(1, 1.0)]),
307            make_nodes(&[(1, 2.0)]),
308        ];
309        let param = default_rbf_param();
310        let mut q = SvrQ::new(&data, &param);
311        let l2 = 2 * data.len(); // 4
312
313        // Get two rows — they use different buffers
314        let row0 = q.get_q(0, l2).to_vec();
315        let row1 = q.get_q(1, l2).to_vec();
316
317        // Row 0: sign[0]=+1, index[0]=0 → K(0, index[j]) * sign[0] * sign[j]
318        // Row 1: sign[1]=+1, index[1]=1 → K(1, index[j]) * sign[1] * sign[j]
319        // Both should have non-zero entries
320        assert!(row0.iter().any(|&v| v != 0.0));
321        assert!(row1.iter().any(|&v| v != 0.0));
322
323        // For index 0 (sign +1) and index 2 (sign -1, real_idx 0):
324        // Q[0][2] = sign[0]*sign[2]*K(0,0) = 1*(-1)*1 = -1
325        assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
326    }
327}