pr_ml/svm/multi.rs
1//! Multi-class classification support vector machine.
2
3use std::collections::HashMap;
4
5use rand::seq::SliceRandom;
6
7use super::{BinarySVM, Kernel, RowVector, SVMParams};
8
9/// A fitted multi-class classification support vector machine using One-vs-One strategy.
10///
11/// # Type Parameters
12///
13/// - `D` - The dimension or number of features.
14/// - `K` - The type of the kernel function.
15/// - `C` - The type of the class label.
16///
17/// # Strategy
18///
19/// Uses One-vs-One (OvO) approach: trains a binary classifier for each pair of classes.
20/// For N classes, trains N*(N-1)/2 binary classifiers. Prediction is done by majority voting.
21#[allow(clippy::doc_markdown)]
22#[derive(Debug, Clone, PartialEq)]
23pub struct MultiClassSVM<const D: usize, K, C>
24where
25 K: Kernel<D>,
26 C: Eq + std::hash::Hash + Clone,
27{
28 /// Binary classifiers for each pair of classes.
29 /// Key is (class1, class2) where class1 < class2 (in terms of internal ordering).
30 classifiers: HashMap<(usize, usize), BinarySVM<D, K>>,
31 /// Mapping from class label to internal index.
32 class_to_index: HashMap<C, usize>,
33 /// Mapping from internal index to class label.
34 index_to_class: Vec<C>,
35}
36
37impl<const D: usize, K> SVMParams<D, K>
38where
39 K: Kernel<D>,
40{
41 /// Fits a multi-class SVM using One-vs-One strategy.
42 ///
43 /// # Arguments
44 ///
45 /// - `data` - Training data as an iterator of (`feature_vector`, `class_label`) tuples.
46 ///
47 /// # Returns
48 ///
49 /// A fitted [`MultiClassSVM`] that can predict class labels.
50 ///
51 /// # Notes
52 ///
53 /// Calling this method with large datasets may consume significant memory and time, consider using [`fit_multiclass_with_options`](Self::fit_multiclass_with_options) to limit samples per binary classifier.
54 ///
55 /// # Examples
56 ///
57 /// ```
58 /// use pr_ml::{RowVector, svm::{SVMParams, LinearKernel, MultiClassSVM}};
59 ///
60 /// let data = vec![
61 /// (RowVector::from([1.0, 2.0]), 0),
62 /// (RowVector::from([2.0, 3.0]), 0),
63 /// (RowVector::from([-1.0, -2.0]), 1),
64 /// (RowVector::from([-2.0, -3.0]), 1),
65 /// ];
66 ///
67 /// let svm: MultiClassSVM<2, LinearKernel, _> = SVMParams::new().fit_multiclass(data);
68 /// let prediction = svm.predict(&RowVector::from([1.5, 2.5]));
69 /// assert_eq!(prediction, 0);
70 /// ```
71 pub fn fit_multiclass<I, C>(self, data: I) -> MultiClassSVM<D, K, C>
72 where
73 I: IntoIterator<Item = (RowVector<D>, C)>,
74 C: Eq + std::hash::Hash + Clone + Ord,
75 {
76 self.fit_multiclass_with_options(data, |_, _, _| {}, 0)
77 }
78
79 /// Fits a multi-class SVM using One-vs-One strategy with progress callback.
80 ///
81 /// # Arguments
82 ///
83 /// - `data` - Training data as an iterator of (`feature_vector`, `class_label`) tuples.
84 /// - `callback` - Called after each binary classifier is trained with (`current`, `total`, `sample_count`).
85 ///
86 /// # Returns
87 ///
88 /// A fitted [`MultiClassSVM`] that can predict class labels.
89 ///
90 /// # Notes
91 ///
92 /// Calling this method with large datasets may consume significant memory and time, consider using [`fit_multiclass_with_options`](Self::fit_multiclass_with_options) to limit samples per binary classifier.
93 pub fn fit_multiclass_with_callback<I, C, F>(
94 self,
95 data: I,
96 callback: F,
97 ) -> MultiClassSVM<D, K, C>
98 where
99 I: IntoIterator<Item = (RowVector<D>, C)>,
100 C: Eq + std::hash::Hash + Clone + Ord,
101 F: FnMut(usize, usize, usize),
102 {
103 self.fit_multiclass_with_options(data, callback, 0)
104 }
105
106 /// Fits a multi-class SVM using One-vs-One strategy with progress callback and optional subsampling.
107 ///
108 /// # Arguments
109 ///
110 /// - `data` - Training data as an iterator of (`feature_vector`, `class_label`) tuples.
111 /// - `callback` - Called after each binary classifier is trained with (`current`, `total`, `sample_count`).
112 /// - `max_samples_per_pair` - Maximum samples to use per binary classifier (0 = use all).
113 ///
114 /// # Returns
115 ///
116 /// A fitted [`MultiClassSVM`] that can predict class labels.
117 pub fn fit_multiclass_with_options<I, C, F>(
118 self,
119 data: I,
120 mut callback: F,
121 max_samples_per_pair: usize,
122 ) -> MultiClassSVM<D, K, C>
123 where
124 I: IntoIterator<Item = (RowVector<D>, C)>,
125 C: Eq + std::hash::Hash + Clone + Ord,
126 F: FnMut(usize, usize, usize),
127 {
128 // Collect all data points
129 let data_vec: Vec<_> = data.into_iter().collect();
130
131 // Build class mappings
132 let mut unique_classes: Vec<C> = data_vec.iter().map(|(_, c)| c.clone()).collect();
133 unique_classes.sort_unstable();
134 unique_classes.dedup();
135
136 let class_to_index: HashMap<C, usize> = unique_classes
137 .iter()
138 .enumerate()
139 .map(|(i, c)| (c.clone(), i))
140 .collect();
141
142 let index_to_class = unique_classes;
143 let num_classes = index_to_class.len();
144
145 // Train binary classifiers for each pair of classes
146 let mut classifiers = HashMap::new();
147 let total_classifiers = num_classes * (num_classes - 1) / 2;
148 let mut current_classifier = 0;
149
150 for i in 0..num_classes {
151 for j in (i + 1)..num_classes {
152 // Filter data for classes i and j
153 let mut binary_data: Vec<_> = data_vec
154 .iter()
155 .filter_map(|(x, c)| {
156 let class_idx = class_to_index[c];
157 if class_idx == i {
158 Some((*x, true)) // class i is positive
159 } else if class_idx == j {
160 Some((*x, false)) // class j is negative
161 } else {
162 None
163 }
164 })
165 .collect();
166
167 // Subsample if requested
168 if max_samples_per_pair > 0 && binary_data.len() > max_samples_per_pair {
169 let mut rng = rand::rng();
170 binary_data.shuffle(&mut rng);
171 binary_data.truncate(max_samples_per_pair);
172 }
173
174 // Train binary classifier
175 if !binary_data.is_empty() {
176 let sample_count = binary_data.len();
177 let binary_svm = self.clone().fit_binary(binary_data);
178 classifiers.insert((i, j), binary_svm);
179 current_classifier += 1;
180 callback(current_classifier, total_classifiers, sample_count);
181 }
182 }
183 }
184
185 MultiClassSVM {
186 classifiers,
187 class_to_index,
188 index_to_class,
189 }
190 }
191}
192
193impl<const D: usize, K, C> MultiClassSVM<D, K, C>
194where
195 K: Kernel<D>,
196 C: Eq + std::hash::Hash + Clone,
197{
198 /// Predicts the class label for a given input vector using majority voting.
199 ///
200 /// # Arguments
201 ///
202 /// - `x` - The input feature vector.
203 ///
204 /// # Returns
205 ///
206 /// The predicted class label.
207 #[must_use]
208 pub fn predict(&self, x: &RowVector<D>) -> C {
209 let num_classes = self.index_to_class.len();
210 let mut votes = vec![0; num_classes];
211
212 // Vote using all binary classifiers
213 for (&(i, j), classifier) in &self.classifiers {
214 if classifier.predict(x) {
215 votes[i] += 1; // Vote for class i
216 } else {
217 votes[j] += 1; // Vote for class j
218 }
219 }
220
221 // Find class with maximum votes
222 let max_votes_idx = votes
223 .iter()
224 .enumerate()
225 .max_by_key(|(_, v)| *v)
226 .map_or(0, |(idx, _)| idx);
227
228 self.index_to_class[max_votes_idx].clone()
229 }
230
231 /// Predicts the class label and returns vote counts for all classes.
232 ///
233 /// # Arguments
234 ///
235 /// - `x` - The input feature vector.
236 ///
237 /// # Returns
238 ///
239 /// A tuple of (`predicted_class`, `votes_per_class`).
240 #[must_use]
241 pub fn predict_with_votes(&self, x: &RowVector<D>) -> (C, Vec<(C, usize)>) {
242 let num_classes = self.index_to_class.len();
243 let mut votes = vec![0; num_classes];
244
245 // Vote using all binary classifiers
246 for (&(i, j), classifier) in &self.classifiers {
247 if classifier.predict(x) {
248 votes[i] += 1;
249 } else {
250 votes[j] += 1;
251 }
252 }
253
254 // Find class with maximum votes
255 let max_votes_idx = votes
256 .iter()
257 .enumerate()
258 .max_by_key(|(_, v)| *v)
259 .map_or(0, |(idx, _)| idx);
260
261 let predicted_class = self.index_to_class[max_votes_idx].clone();
262
263 // Build vote summary
264 let vote_summary: Vec<_> = self
265 .index_to_class
266 .iter()
267 .enumerate()
268 .map(|(idx, class)| (class.clone(), votes[idx]))
269 .collect();
270
271 (predicted_class, vote_summary)
272 }
273
274 /// Returns the number of classes.
275 #[must_use]
276 pub const fn num_classes(&self) -> usize {
277 self.index_to_class.len()
278 }
279
280 /// Returns the number of binary classifiers.
281 #[must_use]
282 pub fn num_classifiers(&self) -> usize {
283 self.classifiers.len()
284 }
285
286 /// Returns a reference to the list of class labels.
287 #[must_use]
288 pub fn classes(&self) -> &[C] {
289 &self.index_to_class
290 }
291}