pr_ml/svm/
mod.rs

1//! Support vector machine module.
2
3mod binary;
4mod kernel;
5mod multi;
6
7use super::RowVector;
8pub use binary::BinarySVM;
9pub use kernel::{Kernel, LinearKernel, PolynomialKernel, GaussianKernel};
10pub use multi::MultiClassSVM;
11
12/// A fitted data point for SVM.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct FittedSVMDataPoint<const D: usize> {
15    /// Feature vector.
16    pub x: RowVector<D>,
17    /// Class label (true for positive class, false for negative class).
18    pub y: bool,
19    /// Lagrange multiplier (alpha) for this data point.
20    pub alpha: f32,
21}
22
23/// Parameters for fitting a [`BinarySVM`] or [`MultiClassSVM`].
24///
25/// # Type Parameters
26///
27/// - `D` - The dimension or number of features.
28/// - `K` - The type of the kernel function.
29///
30/// # Examples
31///
32/// ```
33/// use pr_ml::svm::{BinarySVM, SVMParams, LinearKernel};
34///
35/// let params = SVMParams::new()
36///     .c(0.5)
37///     .tol(1e-4)
38///     .max_iter(500)
39///     .kernel(LinearKernel);
40/// let svm: BinarySVM<3, LinearKernel> = params.fit_binary([]);
41/// ```
42#[derive(Debug, Clone, PartialEq)]
43pub struct SVMParams<const D: usize, K>
44where
45    K: Kernel<D>,
46{
47    /// Regularization parameter.
48    c: f32,
49    /// Tolerance for stopping criterion.
50    tol: f32,
51    /// Maximum number of iterations.
52    max_iter: usize,
53    /// Kernel function.
54    kernel: K,
55}
56
57impl<const D: usize, K> SVMParams<D, K>
58where
59    K: Kernel<D>,
60{
61    /// Creates a new [`SVMParams`] with default values.
62    #[must_use]
63    pub fn new() -> Self {
64        Self {
65            c: 1.0,
66            tol: 1e-3,
67            max_iter: 1000,
68            kernel: K::default(),
69        }
70    }
71
72    /// Creates a new [`SVMParams`] with the specified kernel.
73    pub const fn new_with_kernel(kernel: K) -> Self {
74        Self {
75            c: 1.0,
76            tol: 1e-3,
77            max_iter: 1000,
78            kernel,
79        }
80    }
81
82    /// Sets the regularization parameter.
83    #[must_use]
84    pub const fn c(mut self, c: f32) -> Self {
85        self.c = c;
86        self
87    }
88
89    /// Sets the tolerance for the stopping criterion.
90    #[must_use]
91    pub const fn tol(mut self, tol: f32) -> Self {
92        self.tol = tol;
93        self
94    }
95
96    /// Sets the maximum number of iterations.
97    #[must_use]
98    pub const fn max_iter(mut self, max_iter: usize) -> Self {
99        self.max_iter = max_iter;
100        self
101    }
102
103    /// Sets the kernel function.
104    #[must_use]
105    pub fn kernel(mut self, kernel: K) -> Self {
106        self.kernel = kernel;
107        self
108    }
109}
110
111impl<const D: usize, K> Default for SVMParams<D, K>
112where
113    K: Kernel<D>,
114{
115    fn default() -> Self {
116        Self::new()
117    }
118}