1use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14pub trait QMatrix {
20 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23 fn get_qd(&self) -> &[f64];
25
26 fn swap_index(&mut self, i: usize, j: usize);
28}
29
30pub struct SvcQ<'a> {
36 kernel: Kernel<'a>,
37 cache: Cache,
38 y: Vec<i8>,
39 qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
45 let l = x.len();
46 let kernel = Kernel::new(x, param);
47 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
53 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
54 let y = y.to_vec();
55 Self {
56 kernel,
57 cache,
58 y,
59 qd,
60 }
61 }
62}
63
64#[allow(clippy::needless_range_loop)]
65impl<'a> QMatrix for SvcQ<'a> {
66 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
67 let (data, start) = self.cache.get_data(i, len);
68 if start < len {
69 let yi = self.y[i] as f64;
70 for j in start..len {
71 let kval = self.kernel.evaluate(i, j);
72 data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
73 }
74 }
75 &data[..len]
76 }
77
78 fn get_qd(&self) -> &[f64] {
79 &self.qd
80 }
81
82 fn swap_index(&mut self, i: usize, j: usize) {
83 self.cache.swap_index(i, j);
84 self.kernel.swap_index(i, j);
85 self.y.swap(i, j);
86 self.qd.swap(i, j);
87 }
88}
89
90pub struct OneClassQ<'a> {
96 kernel: Kernel<'a>,
97 cache: Cache,
98 qd: Vec<f64>,
99}
100
101impl<'a> OneClassQ<'a> {
102 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
104 let l = x.len();
105 let kernel = Kernel::new(x, param);
106 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
108 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
109 Self { kernel, cache, qd }
110 }
111}
112
113#[allow(clippy::needless_range_loop)]
114impl<'a> QMatrix for OneClassQ<'a> {
115 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
116 let (data, start) = self.cache.get_data(i, len);
117 if start < len {
118 for j in start..len {
119 data[j] = self.kernel.evaluate(i, j) as Qfloat;
120 }
121 }
122 &data[..len]
123 }
124
125 fn get_qd(&self) -> &[f64] {
126 &self.qd
127 }
128
129 fn swap_index(&mut self, i: usize, j: usize) {
130 self.cache.swap_index(i, j);
131 self.kernel.swap_index(i, j);
132 self.qd.swap(i, j);
133 }
134}
135
136pub struct SvrQ<'a> {
144 kernel: Kernel<'a>,
145 cache: Cache,
146 l: usize,
148 sign: Vec<i8>,
150 index: Vec<usize>,
152 qd: Vec<f64>,
154 buffer: [Vec<Qfloat>; 2],
156 next_buffer: usize,
158}
159
160impl<'a> SvrQ<'a> {
161 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
163 let l = x.len();
164 let kernel = Kernel::new(x, param);
165 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
167
168 let mut sign = vec![0i8; 2 * l];
169 let mut index = vec![0usize; 2 * l];
170 let mut qd = vec![0.0f64; 2 * l];
171
172 for k in 0..l {
173 sign[k] = 1;
174 sign[k + l] = -1;
175 index[k] = k;
176 index[k + l] = k;
177 let kk = kernel.evaluate(k, k);
178 qd[k] = kk;
179 qd[k + l] = kk;
180 }
181
182 let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
183
184 Self {
185 kernel,
186 cache,
187 l,
188 sign,
189 index,
190 qd,
191 buffer,
192 next_buffer: 0,
193 }
194 }
195}
196
197#[allow(clippy::needless_range_loop)]
198impl<'a> QMatrix for SvrQ<'a> {
199 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
200 let real_i = self.index[i];
201 let l = self.l;
202
203 let (data, start) = self.cache.get_data(real_i, l);
205 if start < l {
206 for j in start..l {
207 data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
208 }
209 }
210
211 let buf_idx = self.next_buffer;
213 self.next_buffer = 1 - self.next_buffer;
214 let si = self.sign[i] as f32;
215 let buf = &mut self.buffer[buf_idx];
216 for j in 0..len {
217 buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
218 }
219 &self.buffer[buf_idx][..len]
220 }
221
222 fn get_qd(&self) -> &[f64] {
223 &self.qd
224 }
225
226 fn swap_index(&mut self, i: usize, j: usize) {
227 self.sign.swap(i, j);
228 self.index.swap(i, j);
229 self.qd.swap(i, j);
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::types::{KernelType, SvmNode, SvmParameter};
237
238 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
239 pairs
240 .iter()
241 .map(|&(i, v)| SvmNode { index: i, value: v })
242 .collect()
243 }
244
245 fn default_rbf_param() -> SvmParameter {
246 SvmParameter {
247 kernel_type: KernelType::Rbf,
248 gamma: 0.5,
249 cache_size: 1.0,
250 ..Default::default()
251 }
252 }
253
254 #[test]
255 fn svc_q_diagonal_equals_one_for_rbf() {
256 let data = vec![
257 make_nodes(&[(1, 1.0), (2, 0.0)]),
258 make_nodes(&[(1, 0.0), (2, 1.0)]),
259 ];
260 let y = vec![1i8, -1i8];
261 let param = default_rbf_param();
262 let q = SvcQ::new(&data, ¶m, &y);
263 for &d in q.get_qd() {
265 assert!((d - 1.0).abs() < 1e-15);
266 }
267 }
268
269 #[test]
270 #[allow(clippy::needless_range_loop)]
271 fn svc_q_symmetry_and_sign() {
272 let data = vec![
273 make_nodes(&[(1, 1.0)]),
274 make_nodes(&[(1, 2.0)]),
275 make_nodes(&[(1, 3.0)]),
276 ];
277 let y = vec![1i8, -1i8, 1i8];
278 let param = default_rbf_param();
279 let mut q = SvcQ::new(&data, ¶m, &y);
280 let l = data.len();
281
282 let mut matrix = vec![vec![0.0f32; l]; l];
284 for i in 0..l {
285 let row = q.get_q(i, l).to_vec();
286 matrix[i][..l].copy_from_slice(&row[..l]);
287 }
288
289 for i in 0..l {
291 for j in 0..l {
292 assert!(
293 (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
294 "Q[{},{}]={} != Q[{},{}]={}",
295 i,
296 j,
297 matrix[i][j],
298 j,
299 i,
300 matrix[j][i]
301 );
302 }
303 }
304
305 assert!(matrix[0][1] < 0.0);
307 assert!(matrix[0][2] > 0.0);
309 }
310
311 #[test]
312 fn one_class_q_no_sign_scaling() {
313 let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
314 let param = default_rbf_param();
315 let mut q = OneClassQ::new(&data, ¶m);
316
317 let row = q.get_q(0, 2);
318 assert!(row[0] > 0.0);
320 assert!(row[1] > 0.0);
321 assert!((row[0] - 1.0).abs() < 1e-6);
323 }
324
325 #[test]
326 fn svr_q_double_buffer() {
327 let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
328 let param = default_rbf_param();
329 let mut q = SvrQ::new(&data, ¶m);
330 let l2 = 2 * data.len(); let row0 = q.get_q(0, l2).to_vec();
334 let row1 = q.get_q(1, l2).to_vec();
335
336 assert!(row0.iter().any(|&v| v != 0.0));
340 assert!(row1.iter().any(|&v| v != 0.0));
341
342 assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
345 }
346}