1use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14pub trait QMatrix {
20 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23 fn get_qd(&self) -> &[f64];
25
26 fn swap_index(&mut self, i: usize, j: usize);
28}
29
30pub struct SvcQ<'a> {
36 kernel: Kernel<'a>,
37 cache: Cache,
38 y: Vec<i8>,
39 qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
44 let l = x.len();
45 let kernel = Kernel::new(x, param);
46 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
47 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
48 let y = y.to_vec();
49 Self { kernel, cache, y, qd }
50 }
51}
52
53impl<'a> QMatrix for SvcQ<'a> {
54 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
55 let (data, start) = self.cache.get_data(i, len);
56 if start < len {
57 let yi = self.y[i] as f64;
58 for j in start..len {
59 let kval = self.kernel.evaluate(i, j);
60 data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
61 }
62 }
63 &data[..len]
64 }
65
66 fn get_qd(&self) -> &[f64] {
67 &self.qd
68 }
69
70 fn swap_index(&mut self, i: usize, j: usize) {
71 self.cache.swap_index(i, j);
72 self.kernel.swap_index(i, j);
73 self.y.swap(i, j);
74 self.qd.swap(i, j);
75 }
76}
77
78pub struct OneClassQ<'a> {
84 kernel: Kernel<'a>,
85 cache: Cache,
86 qd: Vec<f64>,
87}
88
89impl<'a> OneClassQ<'a> {
90 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
91 let l = x.len();
92 let kernel = Kernel::new(x, param);
93 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
94 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
95 Self { kernel, cache, qd }
96 }
97}
98
99impl<'a> QMatrix for OneClassQ<'a> {
100 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
101 let (data, start) = self.cache.get_data(i, len);
102 if start < len {
103 for j in start..len {
104 data[j] = self.kernel.evaluate(i, j) as Qfloat;
105 }
106 }
107 &data[..len]
108 }
109
110 fn get_qd(&self) -> &[f64] {
111 &self.qd
112 }
113
114 fn swap_index(&mut self, i: usize, j: usize) {
115 self.cache.swap_index(i, j);
116 self.kernel.swap_index(i, j);
117 self.qd.swap(i, j);
118 }
119}
120
121pub struct SvrQ<'a> {
129 kernel: Kernel<'a>,
130 cache: Cache,
131 l: usize,
133 sign: Vec<i8>,
135 index: Vec<usize>,
137 qd: Vec<f64>,
139 buffer: [Vec<Qfloat>; 2],
141 next_buffer: usize,
143}
144
145impl<'a> SvrQ<'a> {
146 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
147 let l = x.len();
148 let kernel = Kernel::new(x, param);
149 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
150
151 let mut sign = vec![0i8; 2 * l];
152 let mut index = vec![0usize; 2 * l];
153 let mut qd = vec![0.0f64; 2 * l];
154
155 for k in 0..l {
156 sign[k] = 1;
157 sign[k + l] = -1;
158 index[k] = k;
159 index[k + l] = k;
160 let kk = kernel.evaluate(k, k);
161 qd[k] = kk;
162 qd[k + l] = kk;
163 }
164
165 let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
166
167 Self {
168 kernel,
169 cache,
170 l,
171 sign,
172 index,
173 qd,
174 buffer,
175 next_buffer: 0,
176 }
177 }
178}
179
180impl<'a> QMatrix for SvrQ<'a> {
181 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
182 let real_i = self.index[i];
183 let l = self.l;
184
185 let (data, start) = self.cache.get_data(real_i, l);
187 if start < l {
188 for j in start..l {
189 data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
190 }
191 }
192
193 let buf_idx = self.next_buffer;
195 self.next_buffer = 1 - self.next_buffer;
196 let si = self.sign[i] as f32;
197 let buf = &mut self.buffer[buf_idx];
198 for j in 0..len {
199 buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
200 }
201 &self.buffer[buf_idx][..len]
202 }
203
204 fn get_qd(&self) -> &[f64] {
205 &self.qd
206 }
207
208 fn swap_index(&mut self, i: usize, j: usize) {
209 self.sign.swap(i, j);
210 self.index.swap(i, j);
211 self.qd.swap(i, j);
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::types::{KernelType, SvmNode, SvmParameter};
219
220 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
221 pairs.iter().map(|&(i, v)| SvmNode { index: i, value: v }).collect()
222 }
223
224 fn default_rbf_param() -> SvmParameter {
225 SvmParameter {
226 kernel_type: KernelType::Rbf,
227 gamma: 0.5,
228 cache_size: 1.0,
229 ..Default::default()
230 }
231 }
232
233 #[test]
234 fn svc_q_diagonal_equals_one_for_rbf() {
235 let data = vec![
236 make_nodes(&[(1, 1.0), (2, 0.0)]),
237 make_nodes(&[(1, 0.0), (2, 1.0)]),
238 ];
239 let y = vec![1i8, -1i8];
240 let param = default_rbf_param();
241 let q = SvcQ::new(&data, ¶m, &y);
242 for &d in q.get_qd() {
244 assert!((d - 1.0).abs() < 1e-15);
245 }
246 }
247
248 #[test]
249 fn svc_q_symmetry_and_sign() {
250 let data = vec![
251 make_nodes(&[(1, 1.0)]),
252 make_nodes(&[(1, 2.0)]),
253 make_nodes(&[(1, 3.0)]),
254 ];
255 let y = vec![1i8, -1i8, 1i8];
256 let param = default_rbf_param();
257 let mut q = SvcQ::new(&data, ¶m, &y);
258 let l = data.len();
259
260 let mut matrix = vec![vec![0.0f32; l]; l];
262 for i in 0..l {
263 let row = q.get_q(i, l).to_vec();
264 for j in 0..l {
265 matrix[i][j] = row[j];
266 }
267 }
268
269 for i in 0..l {
271 for j in 0..l {
272 assert!(
273 (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
274 "Q[{},{}]={} != Q[{},{}]={}",
275 i, j, matrix[i][j], j, i, matrix[j][i]
276 );
277 }
278 }
279
280 assert!(matrix[0][1] < 0.0);
282 assert!(matrix[0][2] > 0.0);
284 }
285
286 #[test]
287 fn one_class_q_no_sign_scaling() {
288 let data = vec![
289 make_nodes(&[(1, 1.0)]),
290 make_nodes(&[(1, 2.0)]),
291 ];
292 let param = default_rbf_param();
293 let mut q = OneClassQ::new(&data, ¶m);
294
295 let row = q.get_q(0, 2);
296 assert!(row[0] > 0.0);
298 assert!(row[1] > 0.0);
299 assert!((row[0] - 1.0).abs() < 1e-6);
301 }
302
303 #[test]
304 fn svr_q_double_buffer() {
305 let data = vec![
306 make_nodes(&[(1, 1.0)]),
307 make_nodes(&[(1, 2.0)]),
308 ];
309 let param = default_rbf_param();
310 let mut q = SvrQ::new(&data, ¶m);
311 let l2 = 2 * data.len(); let row0 = q.get_q(0, l2).to_vec();
315 let row1 = q.get_q(1, l2).to_vec();
316
317 assert!(row0.iter().any(|&v| v != 0.0));
321 assert!(row1.iter().any(|&v| v != 0.0));
322
323 assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
326 }
327}