pr_ml/svm/mod.rs
1//! Support vector machine module.
2
3mod binary;
4mod kernel;
5mod multi;
6
7use super::RowVector;
8pub use binary::BinarySVM;
9pub use kernel::{Kernel, LinearKernel, PolynomialKernel, GaussianKernel};
10pub use multi::MultiClassSVM;
11
12/// A fitted data point for SVM.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct FittedSVMDataPoint<const D: usize> {
15 /// Feature vector.
16 pub x: RowVector<D>,
17 /// Class label (true for positive class, false for negative class).
18 pub y: bool,
19 /// Lagrange multiplier (alpha) for this data point.
20 pub alpha: f32,
21}
22
23/// Parameters for fitting a [`BinarySVM`] or [`MultiClassSVM`].
24///
25/// # Type Parameters
26///
27/// - `D` - The dimension or number of features.
28/// - `K` - The type of the kernel function.
29///
30/// # Examples
31///
32/// ```
33/// use pr_ml::svm::{BinarySVM, SVMParams, LinearKernel};
34///
35/// let params = SVMParams::new()
36/// .c(0.5)
37/// .tol(1e-4)
38/// .max_iter(500)
39/// .kernel(LinearKernel);
40/// let svm: BinarySVM<3, LinearKernel> = params.fit_binary([]);
41/// ```
42#[derive(Debug, Clone, PartialEq)]
43pub struct SVMParams<const D: usize, K>
44where
45 K: Kernel<D>,
46{
47 /// Regularization parameter.
48 c: f32,
49 /// Tolerance for stopping criterion.
50 tol: f32,
51 /// Maximum number of iterations.
52 max_iter: usize,
53 /// Kernel function.
54 kernel: K,
55}
56
57impl<const D: usize, K> SVMParams<D, K>
58where
59 K: Kernel<D>,
60{
61 /// Creates a new [`SVMParams`] with default values.
62 #[must_use]
63 pub fn new() -> Self {
64 Self {
65 c: 1.0,
66 tol: 1e-3,
67 max_iter: 1000,
68 kernel: K::default(),
69 }
70 }
71
72 /// Creates a new [`SVMParams`] with the specified kernel.
73 pub const fn new_with_kernel(kernel: K) -> Self {
74 Self {
75 c: 1.0,
76 tol: 1e-3,
77 max_iter: 1000,
78 kernel,
79 }
80 }
81
82 /// Sets the regularization parameter.
83 #[must_use]
84 pub const fn c(mut self, c: f32) -> Self {
85 self.c = c;
86 self
87 }
88
89 /// Sets the tolerance for the stopping criterion.
90 #[must_use]
91 pub const fn tol(mut self, tol: f32) -> Self {
92 self.tol = tol;
93 self
94 }
95
96 /// Sets the maximum number of iterations.
97 #[must_use]
98 pub const fn max_iter(mut self, max_iter: usize) -> Self {
99 self.max_iter = max_iter;
100 self
101 }
102
103 /// Sets the kernel function.
104 #[must_use]
105 pub fn kernel(mut self, kernel: K) -> Self {
106 self.kernel = kernel;
107 self
108 }
109}
110
111impl<const D: usize, K> Default for SVMParams<D, K>
112where
113 K: Kernel<D>,
114{
115 fn default() -> Self {
116 Self::new()
117 }
118}