1use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14pub trait QMatrix {
20 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23 fn get_qd(&self) -> &[f64];
25
26 fn swap_index(&mut self, i: usize, j: usize);
28}
29
30pub struct SvcQ<'a> {
36 kernel: Kernel<'a>,
37 cache: Cache,
38 y: Vec<i8>,
39 qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
44 let l = x.len();
45 let kernel = Kernel::new(x, param);
46 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
52 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
53 let y = y.to_vec();
54 Self {
55 kernel,
56 cache,
57 y,
58 qd,
59 }
60 }
61}
62
63#[allow(clippy::needless_range_loop)]
64impl<'a> QMatrix for SvcQ<'a> {
65 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
66 let (data, start) = self.cache.get_data(i, len);
67 if start < len {
68 let yi = self.y[i] as f64;
69 for j in start..len {
70 let kval = self.kernel.evaluate(i, j);
71 data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
72 }
73 }
74 &data[..len]
75 }
76
77 fn get_qd(&self) -> &[f64] {
78 &self.qd
79 }
80
81 fn swap_index(&mut self, i: usize, j: usize) {
82 self.cache.swap_index(i, j);
83 self.kernel.swap_index(i, j);
84 self.y.swap(i, j);
85 self.qd.swap(i, j);
86 }
87}
88
89pub struct OneClassQ<'a> {
95 kernel: Kernel<'a>,
96 cache: Cache,
97 qd: Vec<f64>,
98}
99
100impl<'a> OneClassQ<'a> {
101 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
102 let l = x.len();
103 let kernel = Kernel::new(x, param);
104 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
106 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
107 Self { kernel, cache, qd }
108 }
109}
110
111#[allow(clippy::needless_range_loop)]
112impl<'a> QMatrix for OneClassQ<'a> {
113 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
114 let (data, start) = self.cache.get_data(i, len);
115 if start < len {
116 for j in start..len {
117 data[j] = self.kernel.evaluate(i, j) as Qfloat;
118 }
119 }
120 &data[..len]
121 }
122
123 fn get_qd(&self) -> &[f64] {
124 &self.qd
125 }
126
127 fn swap_index(&mut self, i: usize, j: usize) {
128 self.cache.swap_index(i, j);
129 self.kernel.swap_index(i, j);
130 self.qd.swap(i, j);
131 }
132}
133
134pub struct SvrQ<'a> {
142 kernel: Kernel<'a>,
143 cache: Cache,
144 l: usize,
146 sign: Vec<i8>,
148 index: Vec<usize>,
150 qd: Vec<f64>,
152 buffer: [Vec<Qfloat>; 2],
154 next_buffer: usize,
156}
157
158impl<'a> SvrQ<'a> {
159 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
160 let l = x.len();
161 let kernel = Kernel::new(x, param);
162 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
164
165 let mut sign = vec![0i8; 2 * l];
166 let mut index = vec![0usize; 2 * l];
167 let mut qd = vec![0.0f64; 2 * l];
168
169 for k in 0..l {
170 sign[k] = 1;
171 sign[k + l] = -1;
172 index[k] = k;
173 index[k + l] = k;
174 let kk = kernel.evaluate(k, k);
175 qd[k] = kk;
176 qd[k + l] = kk;
177 }
178
179 let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
180
181 Self {
182 kernel,
183 cache,
184 l,
185 sign,
186 index,
187 qd,
188 buffer,
189 next_buffer: 0,
190 }
191 }
192}
193
194#[allow(clippy::needless_range_loop)]
195impl<'a> QMatrix for SvrQ<'a> {
196 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
197 let real_i = self.index[i];
198 let l = self.l;
199
200 let (data, start) = self.cache.get_data(real_i, l);
202 if start < l {
203 for j in start..l {
204 data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
205 }
206 }
207
208 let buf_idx = self.next_buffer;
210 self.next_buffer = 1 - self.next_buffer;
211 let si = self.sign[i] as f32;
212 let buf = &mut self.buffer[buf_idx];
213 for j in 0..len {
214 buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
215 }
216 &self.buffer[buf_idx][..len]
217 }
218
219 fn get_qd(&self) -> &[f64] {
220 &self.qd
221 }
222
223 fn swap_index(&mut self, i: usize, j: usize) {
224 self.sign.swap(i, j);
225 self.index.swap(i, j);
226 self.qd.swap(i, j);
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::types::{KernelType, SvmNode, SvmParameter};
234
235 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
236 pairs
237 .iter()
238 .map(|&(i, v)| SvmNode { index: i, value: v })
239 .collect()
240 }
241
242 fn default_rbf_param() -> SvmParameter {
243 SvmParameter {
244 kernel_type: KernelType::Rbf,
245 gamma: 0.5,
246 cache_size: 1.0,
247 ..Default::default()
248 }
249 }
250
251 #[test]
252 fn svc_q_diagonal_equals_one_for_rbf() {
253 let data = vec![
254 make_nodes(&[(1, 1.0), (2, 0.0)]),
255 make_nodes(&[(1, 0.0), (2, 1.0)]),
256 ];
257 let y = vec![1i8, -1i8];
258 let param = default_rbf_param();
259 let q = SvcQ::new(&data, ¶m, &y);
260 for &d in q.get_qd() {
262 assert!((d - 1.0).abs() < 1e-15);
263 }
264 }
265
266 #[test]
267 #[allow(clippy::needless_range_loop)]
268 fn svc_q_symmetry_and_sign() {
269 let data = vec![
270 make_nodes(&[(1, 1.0)]),
271 make_nodes(&[(1, 2.0)]),
272 make_nodes(&[(1, 3.0)]),
273 ];
274 let y = vec![1i8, -1i8, 1i8];
275 let param = default_rbf_param();
276 let mut q = SvcQ::new(&data, ¶m, &y);
277 let l = data.len();
278
279 let mut matrix = vec![vec![0.0f32; l]; l];
281 for i in 0..l {
282 let row = q.get_q(i, l).to_vec();
283 matrix[i][..l].copy_from_slice(&row[..l]);
284 }
285
286 for i in 0..l {
288 for j in 0..l {
289 assert!(
290 (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
291 "Q[{},{}]={} != Q[{},{}]={}",
292 i,
293 j,
294 matrix[i][j],
295 j,
296 i,
297 matrix[j][i]
298 );
299 }
300 }
301
302 assert!(matrix[0][1] < 0.0);
304 assert!(matrix[0][2] > 0.0);
306 }
307
308 #[test]
309 fn one_class_q_no_sign_scaling() {
310 let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
311 let param = default_rbf_param();
312 let mut q = OneClassQ::new(&data, ¶m);
313
314 let row = q.get_q(0, 2);
315 assert!(row[0] > 0.0);
317 assert!(row[1] > 0.0);
318 assert!((row[0] - 1.0).abs() < 1e-6);
320 }
321
322 #[test]
323 fn svr_q_double_buffer() {
324 let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
325 let param = default_rbf_param();
326 let mut q = SvrQ::new(&data, ¶m);
327 let l2 = 2 * data.len(); let row0 = q.get_q(0, l2).to_vec();
331 let row1 = q.get_q(1, l2).to_vec();
332
333 assert!(row0.iter().any(|&v| v != 0.0));
337 assert!(row1.iter().any(|&v| v != 0.0));
338
339 assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
342 }
343}