1use super::{FittedSVMDataPoint, Kernel, RowVector, SVMParams};
4
5impl<const D: usize, K> SVMParams<D, K>
6where
7 K: Kernel<D>,
8{
9 #[allow(clippy::similar_names, reason = "necessary for clarity")]
11 #[allow(clippy::too_many_lines, reason = "SMO algorithm implementation")]
12 pub fn fit_binary<I>(self, data: I) -> BinarySVM<D, K>
13 where
14 I: IntoIterator<Item = (RowVector<D>, bool)>,
15 {
16 let mut data_points: Vec<_> = data
18 .into_iter()
19 .map(|(x, y)| FittedSVMDataPoint { x, y, alpha: 0.0 })
20 .collect();
21
22 let mut bias = 0.0;
23 let kernel = self.kernel;
24 let n = data_points.len();
25
26 if n == 0 {
27 return BinarySVM {
28 kernel,
29 bias,
30 data_points,
31 };
32 }
33
34 for _ in 0..self.max_iter {
36 let mut num_changed_alphas = 0;
37
38 for i in 0..n {
39 let mut f_xi = bias;
41 for j in 0..n {
42 if data_points[j].alpha > 0.0 {
43 let y_j = if data_points[j].y { 1.0 } else { -1.0 };
44 f_xi += data_points[j].alpha
45 * y_j
46 * kernel.compute(&data_points[j].x, &data_points[i].x);
47 }
48 }
49 let y_i = if data_points[i].y { 1.0 } else { -1.0 };
50 let e_i = f_xi - y_i;
51
52 let r_i = e_i * y_i;
54 if (r_i < -self.tol && data_points[i].alpha < self.c)
55 || (r_i > self.tol && data_points[i].alpha > 0.0)
56 {
57 let mut j = i;
59 while j == i {
60 j = (j + 1) % n;
61 }
62
63 let mut f_x_j = bias;
65 for k in 0..n {
66 if data_points[k].alpha > 0.0 {
67 let y_k = if data_points[k].y { 1.0 } else { -1.0 };
68 f_x_j += data_points[k].alpha
69 * y_k
70 * kernel.compute(&data_points[k].x, &data_points[j].x);
71 }
72 }
73 let y_j = if data_points[j].y { 1.0 } else { -1.0 };
74 let e_j = f_x_j - y_j;
75
76 let alpha_i_old = data_points[i].alpha;
78 let alpha_j_old = data_points[j].alpha;
79
80 let (l, h) = if data_points[i].y == data_points[j].y {
82 let l = f32::max(0.0, alpha_i_old + alpha_j_old - self.c);
83 let h = f32::min(self.c, alpha_i_old + alpha_j_old);
84 (l, h)
85 } else {
86 let l = f32::max(0.0, alpha_j_old - alpha_i_old);
87 let h = f32::min(self.c, self.c + alpha_j_old - alpha_i_old);
88 (l, h)
89 };
90 if (h - l).abs() < 1e-8 {
91 continue;
92 }
93
94 let k_ii = kernel.compute(&data_points[i].x, &data_points[i].x);
96 let k_jj = kernel.compute(&data_points[j].x, &data_points[j].x);
97 let k_ij = kernel.compute(&data_points[i].x, &data_points[j].x);
98 let eta = 2.0f32.mul_add(k_ij, -k_ii) - k_jj;
99 if eta >= 0.0 {
100 continue;
101 }
102
103 let alpha_j_new = alpha_j_old - y_j * (e_i - e_j) / eta;
105 let alpha_j_new = f32::min(h, f32::max(l, alpha_j_new));
106 if (alpha_j_new - alpha_j_old).abs() < 1e-5 {
107 continue;
108 }
109
110 let alpha_i_new = (y_i * y_j).mul_add(alpha_j_old - alpha_j_new, alpha_i_old);
112
113 let b1 = (y_j * (alpha_j_new - alpha_j_old)).mul_add(
114 -k_ij,
115 (y_i * (alpha_i_new - alpha_i_old)).mul_add(-k_ii, bias - e_i),
116 );
117
118 let b2 = (y_j * (alpha_j_new - alpha_j_old)).mul_add(
119 -k_jj,
120 (y_i * (alpha_i_new - alpha_i_old)).mul_add(-k_ij, bias - e_j),
121 );
122
123 bias = if alpha_i_new > 0.0 && alpha_i_new < self.c {
124 b1
125 } else if alpha_j_new > 0.0 && alpha_j_new < self.c {
126 b2
127 } else {
128 f32::midpoint(b1, b2)
129 };
130
131 data_points[i].alpha = alpha_i_new;
133 data_points[j].alpha = alpha_j_new;
134 num_changed_alphas += 1;
135 }
136 }
137
138 if num_changed_alphas == 0 {
139 break;
140 }
141 }
142
143 BinarySVM {
144 kernel,
145 bias,
146 data_points,
147 }
148 }
149}
150
151#[derive(Debug, Clone, PartialEq)]
158pub struct BinarySVM<const D: usize, K>
159where
160 K: Kernel<D>,
161{
162 kernel: K,
164 bias: f32,
166 data_points: Vec<FittedSVMDataPoint<D>>,
168}
169
170impl<const D: usize, K> BinarySVM<D, K>
171where
172 K: Kernel<D>,
173{
174 #[must_use]
176 pub fn params() -> SVMParams<D, K> {
177 SVMParams::new()
178 }
179
180 pub fn predict(&self, x: &RowVector<D>) -> bool {
182 self.decision_function(x) >= 0.0
183 }
184
185 pub fn decision_function(&self, x: &RowVector<D>) -> f32 {
189 let mut decision_value = self.bias;
191
192 for point in &self.data_points {
193 if point.alpha > 0.0 {
194 let y = if point.y { 1.0 } else { -1.0 };
195 decision_value += point.alpha * y * self.kernel.compute(&point.x, x);
196 }
197 }
198
199 decision_value
200 }
201
202 pub fn support_vectors(&self) -> impl Iterator<Item = &FittedSVMDataPoint<D>> {
204 self.data_points.iter().filter(|p| p.alpha > 0.0)
205 }
206
207 pub fn num_support_vectors(&self) -> usize {
211 self.support_vectors().count()
212 }
213
214 pub const fn bias(&self) -> f32 {
216 self.bias
217 }
218}