1use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14pub trait QMatrix {
20 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23 fn get_qd(&self) -> &[f64];
25
26 fn swap_index(&mut self, i: usize, j: usize);
28}
29
30pub struct SvcQ<'a> {
36 kernel: Kernel<'a>,
37 cache: Cache,
38 y: Vec<i8>,
39 qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
44 let l = x.len();
45 let kernel = Kernel::new(x, param);
46 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
47 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
48 let y = y.to_vec();
49 Self { kernel, cache, y, qd }
50 }
51}
52
53#[allow(clippy::needless_range_loop)]
54impl<'a> QMatrix for SvcQ<'a> {
55 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
56 let (data, start) = self.cache.get_data(i, len);
57 if start < len {
58 let yi = self.y[i] as f64;
59 for j in start..len {
60 let kval = self.kernel.evaluate(i, j);
61 data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
62 }
63 }
64 &data[..len]
65 }
66
67 fn get_qd(&self) -> &[f64] {
68 &self.qd
69 }
70
71 fn swap_index(&mut self, i: usize, j: usize) {
72 self.cache.swap_index(i, j);
73 self.kernel.swap_index(i, j);
74 self.y.swap(i, j);
75 self.qd.swap(i, j);
76 }
77}
78
79pub struct OneClassQ<'a> {
85 kernel: Kernel<'a>,
86 cache: Cache,
87 qd: Vec<f64>,
88}
89
90impl<'a> OneClassQ<'a> {
91 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
92 let l = x.len();
93 let kernel = Kernel::new(x, param);
94 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
95 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
96 Self { kernel, cache, qd }
97 }
98}
99
100#[allow(clippy::needless_range_loop)]
101impl<'a> QMatrix for OneClassQ<'a> {
102 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
103 let (data, start) = self.cache.get_data(i, len);
104 if start < len {
105 for j in start..len {
106 data[j] = self.kernel.evaluate(i, j) as Qfloat;
107 }
108 }
109 &data[..len]
110 }
111
112 fn get_qd(&self) -> &[f64] {
113 &self.qd
114 }
115
116 fn swap_index(&mut self, i: usize, j: usize) {
117 self.cache.swap_index(i, j);
118 self.kernel.swap_index(i, j);
119 self.qd.swap(i, j);
120 }
121}
122
123pub struct SvrQ<'a> {
131 kernel: Kernel<'a>,
132 cache: Cache,
133 l: usize,
135 sign: Vec<i8>,
137 index: Vec<usize>,
139 qd: Vec<f64>,
141 buffer: [Vec<Qfloat>; 2],
143 next_buffer: usize,
145}
146
147impl<'a> SvrQ<'a> {
148 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
149 let l = x.len();
150 let kernel = Kernel::new(x, param);
151 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
152
153 let mut sign = vec![0i8; 2 * l];
154 let mut index = vec![0usize; 2 * l];
155 let mut qd = vec![0.0f64; 2 * l];
156
157 for k in 0..l {
158 sign[k] = 1;
159 sign[k + l] = -1;
160 index[k] = k;
161 index[k + l] = k;
162 let kk = kernel.evaluate(k, k);
163 qd[k] = kk;
164 qd[k + l] = kk;
165 }
166
167 let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
168
169 Self {
170 kernel,
171 cache,
172 l,
173 sign,
174 index,
175 qd,
176 buffer,
177 next_buffer: 0,
178 }
179 }
180}
181
182#[allow(clippy::needless_range_loop)]
183impl<'a> QMatrix for SvrQ<'a> {
184 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
185 let real_i = self.index[i];
186 let l = self.l;
187
188 let (data, start) = self.cache.get_data(real_i, l);
190 if start < l {
191 for j in start..l {
192 data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
193 }
194 }
195
196 let buf_idx = self.next_buffer;
198 self.next_buffer = 1 - self.next_buffer;
199 let si = self.sign[i] as f32;
200 let buf = &mut self.buffer[buf_idx];
201 for j in 0..len {
202 buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
203 }
204 &self.buffer[buf_idx][..len]
205 }
206
207 fn get_qd(&self) -> &[f64] {
208 &self.qd
209 }
210
211 fn swap_index(&mut self, i: usize, j: usize) {
212 self.sign.swap(i, j);
213 self.index.swap(i, j);
214 self.qd.swap(i, j);
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::types::{KernelType, SvmNode, SvmParameter};
222
223 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
224 pairs.iter().map(|&(i, v)| SvmNode { index: i, value: v }).collect()
225 }
226
227 fn default_rbf_param() -> SvmParameter {
228 SvmParameter {
229 kernel_type: KernelType::Rbf,
230 gamma: 0.5,
231 cache_size: 1.0,
232 ..Default::default()
233 }
234 }
235
236 #[test]
237 fn svc_q_diagonal_equals_one_for_rbf() {
238 let data = vec![
239 make_nodes(&[(1, 1.0), (2, 0.0)]),
240 make_nodes(&[(1, 0.0), (2, 1.0)]),
241 ];
242 let y = vec![1i8, -1i8];
243 let param = default_rbf_param();
244 let q = SvcQ::new(&data, ¶m, &y);
245 for &d in q.get_qd() {
247 assert!((d - 1.0).abs() < 1e-15);
248 }
249 }
250
251 #[test]
252 #[allow(clippy::needless_range_loop)]
253 fn svc_q_symmetry_and_sign() {
254 let data = vec![
255 make_nodes(&[(1, 1.0)]),
256 make_nodes(&[(1, 2.0)]),
257 make_nodes(&[(1, 3.0)]),
258 ];
259 let y = vec![1i8, -1i8, 1i8];
260 let param = default_rbf_param();
261 let mut q = SvcQ::new(&data, ¶m, &y);
262 let l = data.len();
263
264 let mut matrix = vec![vec![0.0f32; l]; l];
266 for i in 0..l {
267 let row = q.get_q(i, l).to_vec();
268 matrix[i][..l].copy_from_slice(&row[..l]);
269 }
270
271 for i in 0..l {
273 for j in 0..l {
274 assert!(
275 (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
276 "Q[{},{}]={} != Q[{},{}]={}",
277 i, j, matrix[i][j], j, i, matrix[j][i]
278 );
279 }
280 }
281
282 assert!(matrix[0][1] < 0.0);
284 assert!(matrix[0][2] > 0.0);
286 }
287
288 #[test]
289 fn one_class_q_no_sign_scaling() {
290 let data = vec![
291 make_nodes(&[(1, 1.0)]),
292 make_nodes(&[(1, 2.0)]),
293 ];
294 let param = default_rbf_param();
295 let mut q = OneClassQ::new(&data, ¶m);
296
297 let row = q.get_q(0, 2);
298 assert!(row[0] > 0.0);
300 assert!(row[1] > 0.0);
301 assert!((row[0] - 1.0).abs() < 1e-6);
303 }
304
305 #[test]
306 fn svr_q_double_buffer() {
307 let data = vec![
308 make_nodes(&[(1, 1.0)]),
309 make_nodes(&[(1, 2.0)]),
310 ];
311 let param = default_rbf_param();
312 let mut q = SvrQ::new(&data, ¶m);
313 let l2 = 2 * data.len(); let row0 = q.get_q(0, l2).to_vec();
317 let row1 = q.get_q(1, l2).to_vec();
318
319 assert!(row0.iter().any(|&v| v != 0.0));
323 assert!(row1.iter().any(|&v| v != 0.0));
324
325 assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
328 }
329}