1use crate::cache::{Cache, Qfloat};
11use crate::kernel::Kernel;
12use crate::types::{SvmNode, SvmParameter};
13
14pub trait QMatrix {
20 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat];
22
23 fn get_qd(&self) -> &[f64];
25
26 fn swap_index(&mut self, i: usize, j: usize);
28}
29
30pub struct SvcQ<'a> {
36 kernel: Kernel<'a>,
37 cache: Cache,
38 y: Vec<i8>,
39 qd: Vec<f64>,
40}
41
42impl<'a> SvcQ<'a> {
43 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter, y: &[i8]) -> Self {
44 let l = x.len();
45 let kernel = Kernel::new(x, param);
46 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
47 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
48 let y = y.to_vec();
49 Self {
50 kernel,
51 cache,
52 y,
53 qd,
54 }
55 }
56}
57
58#[allow(clippy::needless_range_loop)]
59impl<'a> QMatrix for SvcQ<'a> {
60 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
61 let (data, start) = self.cache.get_data(i, len);
62 if start < len {
63 let yi = self.y[i] as f64;
64 for j in start..len {
65 let kval = self.kernel.evaluate(i, j);
66 data[j] = (yi * self.y[j] as f64 * kval) as Qfloat;
67 }
68 }
69 &data[..len]
70 }
71
72 fn get_qd(&self) -> &[f64] {
73 &self.qd
74 }
75
76 fn swap_index(&mut self, i: usize, j: usize) {
77 self.cache.swap_index(i, j);
78 self.kernel.swap_index(i, j);
79 self.y.swap(i, j);
80 self.qd.swap(i, j);
81 }
82}
83
84pub struct OneClassQ<'a> {
90 kernel: Kernel<'a>,
91 cache: Cache,
92 qd: Vec<f64>,
93}
94
95impl<'a> OneClassQ<'a> {
96 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
97 let l = x.len();
98 let kernel = Kernel::new(x, param);
99 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
100 let qd: Vec<f64> = (0..l).map(|i| kernel.evaluate(i, i)).collect();
101 Self { kernel, cache, qd }
102 }
103}
104
105#[allow(clippy::needless_range_loop)]
106impl<'a> QMatrix for OneClassQ<'a> {
107 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
108 let (data, start) = self.cache.get_data(i, len);
109 if start < len {
110 for j in start..len {
111 data[j] = self.kernel.evaluate(i, j) as Qfloat;
112 }
113 }
114 &data[..len]
115 }
116
117 fn get_qd(&self) -> &[f64] {
118 &self.qd
119 }
120
121 fn swap_index(&mut self, i: usize, j: usize) {
122 self.cache.swap_index(i, j);
123 self.kernel.swap_index(i, j);
124 self.qd.swap(i, j);
125 }
126}
127
128pub struct SvrQ<'a> {
136 kernel: Kernel<'a>,
137 cache: Cache,
138 l: usize,
140 sign: Vec<i8>,
142 index: Vec<usize>,
144 qd: Vec<f64>,
146 buffer: [Vec<Qfloat>; 2],
148 next_buffer: usize,
150}
151
152impl<'a> SvrQ<'a> {
153 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
154 let l = x.len();
155 let kernel = Kernel::new(x, param);
156 let cache = Cache::new(l, (param.cache_size * 1048576.0) as usize);
157
158 let mut sign = vec![0i8; 2 * l];
159 let mut index = vec![0usize; 2 * l];
160 let mut qd = vec![0.0f64; 2 * l];
161
162 for k in 0..l {
163 sign[k] = 1;
164 sign[k + l] = -1;
165 index[k] = k;
166 index[k + l] = k;
167 let kk = kernel.evaluate(k, k);
168 qd[k] = kk;
169 qd[k + l] = kk;
170 }
171
172 let buffer = [vec![0.0 as Qfloat; 2 * l], vec![0.0 as Qfloat; 2 * l]];
173
174 Self {
175 kernel,
176 cache,
177 l,
178 sign,
179 index,
180 qd,
181 buffer,
182 next_buffer: 0,
183 }
184 }
185}
186
187#[allow(clippy::needless_range_loop)]
188impl<'a> QMatrix for SvrQ<'a> {
189 fn get_q(&mut self, i: usize, len: usize) -> &[Qfloat] {
190 let real_i = self.index[i];
191 let l = self.l;
192
193 let (data, start) = self.cache.get_data(real_i, l);
195 if start < l {
196 for j in start..l {
197 data[j] = self.kernel.evaluate(real_i, j) as Qfloat;
198 }
199 }
200
201 let buf_idx = self.next_buffer;
203 self.next_buffer = 1 - self.next_buffer;
204 let si = self.sign[i] as f32;
205 let buf = &mut self.buffer[buf_idx];
206 for j in 0..len {
207 buf[j] = si * (self.sign[j] as f32) * data[self.index[j]];
208 }
209 &self.buffer[buf_idx][..len]
210 }
211
212 fn get_qd(&self) -> &[f64] {
213 &self.qd
214 }
215
216 fn swap_index(&mut self, i: usize, j: usize) {
217 self.sign.swap(i, j);
218 self.index.swap(i, j);
219 self.qd.swap(i, j);
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::types::{KernelType, SvmNode, SvmParameter};
227
228 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
229 pairs
230 .iter()
231 .map(|&(i, v)| SvmNode { index: i, value: v })
232 .collect()
233 }
234
235 fn default_rbf_param() -> SvmParameter {
236 SvmParameter {
237 kernel_type: KernelType::Rbf,
238 gamma: 0.5,
239 cache_size: 1.0,
240 ..Default::default()
241 }
242 }
243
244 #[test]
245 fn svc_q_diagonal_equals_one_for_rbf() {
246 let data = vec![
247 make_nodes(&[(1, 1.0), (2, 0.0)]),
248 make_nodes(&[(1, 0.0), (2, 1.0)]),
249 ];
250 let y = vec![1i8, -1i8];
251 let param = default_rbf_param();
252 let q = SvcQ::new(&data, ¶m, &y);
253 for &d in q.get_qd() {
255 assert!((d - 1.0).abs() < 1e-15);
256 }
257 }
258
259 #[test]
260 #[allow(clippy::needless_range_loop)]
261 fn svc_q_symmetry_and_sign() {
262 let data = vec![
263 make_nodes(&[(1, 1.0)]),
264 make_nodes(&[(1, 2.0)]),
265 make_nodes(&[(1, 3.0)]),
266 ];
267 let y = vec![1i8, -1i8, 1i8];
268 let param = default_rbf_param();
269 let mut q = SvcQ::new(&data, ¶m, &y);
270 let l = data.len();
271
272 let mut matrix = vec![vec![0.0f32; l]; l];
274 for i in 0..l {
275 let row = q.get_q(i, l).to_vec();
276 matrix[i][..l].copy_from_slice(&row[..l]);
277 }
278
279 for i in 0..l {
281 for j in 0..l {
282 assert!(
283 (matrix[i][j] - matrix[j][i]).abs() < 1e-6,
284 "Q[{},{}]={} != Q[{},{}]={}",
285 i,
286 j,
287 matrix[i][j],
288 j,
289 i,
290 matrix[j][i]
291 );
292 }
293 }
294
295 assert!(matrix[0][1] < 0.0);
297 assert!(matrix[0][2] > 0.0);
299 }
300
301 #[test]
302 fn one_class_q_no_sign_scaling() {
303 let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
304 let param = default_rbf_param();
305 let mut q = OneClassQ::new(&data, ¶m);
306
307 let row = q.get_q(0, 2);
308 assert!(row[0] > 0.0);
310 assert!(row[1] > 0.0);
311 assert!((row[0] - 1.0).abs() < 1e-6);
313 }
314
315 #[test]
316 fn svr_q_double_buffer() {
317 let data = vec![make_nodes(&[(1, 1.0)]), make_nodes(&[(1, 2.0)])];
318 let param = default_rbf_param();
319 let mut q = SvrQ::new(&data, ¶m);
320 let l2 = 2 * data.len(); let row0 = q.get_q(0, l2).to_vec();
324 let row1 = q.get_q(1, l2).to_vec();
325
326 assert!(row0.iter().any(|&v| v != 0.0));
330 assert!(row1.iter().any(|&v| v != 0.0));
331
332 assert!((row0[2] - (-1.0)).abs() < 1e-6, "Q[0][2] = {}", row0[2]);
335 }
336}