pr_ml/svm/
binary.rs

1//! Binary classification support vector machine.
2
3use super::{FittedSVMDataPoint, Kernel, RowVector, SVMParams};
4
5impl<const D: usize, K> SVMParams<D, K>
6where
7    K: Kernel<D>,
8{
9    /// Fits a [`BinarySVM`] to the given data.
10    #[allow(clippy::similar_names, reason = "necessary for clarity")]
11    #[allow(clippy::too_many_lines, reason = "SMO algorithm implementation")]
12    pub fn fit_binary<I>(self, data: I) -> BinarySVM<D, K>
13    where
14        I: IntoIterator<Item = (RowVector<D>, bool)>,
15    {
16        // Collect data into vectors
17        let mut data_points: Vec<_> = data
18            .into_iter()
19            .map(|(x, y)| FittedSVMDataPoint { x, y, alpha: 0.0 })
20            .collect();
21
22        let mut bias = 0.0;
23        let kernel = self.kernel;
24        let n = data_points.len();
25
26        if n == 0 {
27            return BinarySVM {
28                kernel,
29                bias,
30                data_points,
31            };
32        }
33
34        // SMO algorithm - no upfront kernel cache to save memory
35        for _ in 0..self.max_iter {
36            let mut num_changed_alphas = 0;
37
38            for i in 0..n {
39                // Calculate E_i = f(x_i) - y_i
40                let mut f_xi = bias;
41                for j in 0..n {
42                    if data_points[j].alpha > 0.0 {
43                        let y_j = if data_points[j].y { 1.0 } else { -1.0 };
44                        f_xi += data_points[j].alpha
45                            * y_j
46                            * kernel.compute(&data_points[j].x, &data_points[i].x);
47                    }
48                }
49                let y_i = if data_points[i].y { 1.0 } else { -1.0 };
50                let e_i = f_xi - y_i;
51
52                // Check KKT conditions
53                let r_i = e_i * y_i;
54                if (r_i < -self.tol && data_points[i].alpha < self.c)
55                    || (r_i > self.tol && data_points[i].alpha > 0.0)
56                {
57                    // Select j != i
58                    let mut j = i;
59                    while j == i {
60                        j = (j + 1) % n;
61                    }
62
63                    // Calculate E_j
64                    let mut f_x_j = bias;
65                    for k in 0..n {
66                        if data_points[k].alpha > 0.0 {
67                            let y_k = if data_points[k].y { 1.0 } else { -1.0 };
68                            f_x_j += data_points[k].alpha
69                                * y_k
70                                * kernel.compute(&data_points[k].x, &data_points[j].x);
71                        }
72                    }
73                    let y_j = if data_points[j].y { 1.0 } else { -1.0 };
74                    let e_j = f_x_j - y_j;
75
76                    // Save old alphas
77                    let alpha_i_old = data_points[i].alpha;
78                    let alpha_j_old = data_points[j].alpha;
79
80                    // Compute bounds L and H
81                    let (l, h) = if data_points[i].y == data_points[j].y {
82                        let l = f32::max(0.0, alpha_i_old + alpha_j_old - self.c);
83                        let h = f32::min(self.c, alpha_i_old + alpha_j_old);
84                        (l, h)
85                    } else {
86                        let l = f32::max(0.0, alpha_j_old - alpha_i_old);
87                        let h = f32::min(self.c, self.c + alpha_j_old - alpha_i_old);
88                        (l, h)
89                    };
90                    if (h - l).abs() < 1e-8 {
91                        continue;
92                    }
93
94                    // Compute eta
95                    let k_ii = kernel.compute(&data_points[i].x, &data_points[i].x);
96                    let k_jj = kernel.compute(&data_points[j].x, &data_points[j].x);
97                    let k_ij = kernel.compute(&data_points[i].x, &data_points[j].x);
98                    let eta = 2.0f32.mul_add(k_ij, -k_ii) - k_jj;
99                    if eta >= 0.0 {
100                        continue;
101                    }
102
103                    // Update alpha_j
104                    let alpha_j_new = alpha_j_old - y_j * (e_i - e_j) / eta;
105                    let alpha_j_new = f32::min(h, f32::max(l, alpha_j_new));
106                    if (alpha_j_new - alpha_j_old).abs() < 1e-5 {
107                        continue;
108                    }
109
110                    // Update alpha_i & bias
111                    let alpha_i_new = (y_i * y_j).mul_add(alpha_j_old - alpha_j_new, alpha_i_old);
112
113                    let b1 = (y_j * (alpha_j_new - alpha_j_old)).mul_add(
114                        -k_ij,
115                        (y_i * (alpha_i_new - alpha_i_old)).mul_add(-k_ii, bias - e_i),
116                    );
117
118                    let b2 = (y_j * (alpha_j_new - alpha_j_old)).mul_add(
119                        -k_jj,
120                        (y_i * (alpha_i_new - alpha_i_old)).mul_add(-k_ij, bias - e_j),
121                    );
122
123                    bias = if alpha_i_new > 0.0 && alpha_i_new < self.c {
124                        b1
125                    } else if alpha_j_new > 0.0 && alpha_j_new < self.c {
126                        b2
127                    } else {
128                        f32::midpoint(b1, b2)
129                    };
130
131                    // Update alphas
132                    data_points[i].alpha = alpha_i_new;
133                    data_points[j].alpha = alpha_j_new;
134                    num_changed_alphas += 1;
135                }
136            }
137
138            if num_changed_alphas == 0 {
139                break;
140            }
141        }
142
143        BinarySVM {
144            kernel,
145            bias,
146            data_points,
147        }
148    }
149}
150
151/// A fitted binary classification support vector machine.
152///
153/// # Type Parameters
154///
155/// - `D` - The dimension or number of features.
156/// - `K` - The type of the kernel function.
157#[derive(Debug, Clone, PartialEq)]
158pub struct BinarySVM<const D: usize, K>
159where
160    K: Kernel<D>,
161{
162    /// The kernel function.
163    kernel: K,
164    /// Bias term.
165    bias: f32,
166    /// Data points used in training.
167    data_points: Vec<FittedSVMDataPoint<D>>,
168}
169
170impl<const D: usize, K> BinarySVM<D, K>
171where
172    K: Kernel<D>,
173{
174    /// Creates a new [`SVMParams`] for fitting a [`BinarySVM`].
175    #[must_use]
176    pub fn params() -> SVMParams<D, K> {
177        SVMParams::new()
178    }
179
180    /// Predicts the class label for a given input vector.
181    pub fn predict(&self, x: &RowVector<D>) -> bool {
182        self.decision_function(x) >= 0.0
183    }
184
185    /// Computes the decision value for a given input vector.
186    ///
187    /// This is useful for getting a confidence score. A larger positive value indicates higher confidence in the positive class, while a larger negative value indicates higher confidence in the negative class.
188    pub fn decision_function(&self, x: &RowVector<D>) -> f32 {
189        // Compute decision function: f(x) = sum(alpha_i * y_i * K(x_i, x)) + bias
190        let mut decision_value = self.bias;
191
192        for point in &self.data_points {
193            if point.alpha > 0.0 {
194                let y = if point.y { 1.0 } else { -1.0 };
195                decision_value += point.alpha * y * self.kernel.compute(&point.x, x);
196            }
197        }
198
199        decision_value
200    }
201
202    /// Returns an iterator over the support vectors and their corresponding alpha values.
203    pub fn support_vectors(&self) -> impl Iterator<Item = &FittedSVMDataPoint<D>> {
204        self.data_points.iter().filter(|p| p.alpha > 0.0)
205    }
206
207    /// Returns the number of support vectors.
208    ///
209    /// Support vectors are data points with non-zero alpha values.
210    pub fn num_support_vectors(&self) -> usize {
211        self.support_vectors().count()
212    }
213
214    /// Returns the bias term of the SVM.
215    pub const fn bias(&self) -> f32 {
216        self.bias
217    }
218}