pr_ml/svm/
multi.rs

1//! Multi-class classification support vector machine.
2
3use std::collections::HashMap;
4
5use rand::seq::SliceRandom;
6
7use super::{BinarySVM, Kernel, RowVector, SVMParams};
8
9/// A fitted multi-class classification support vector machine using One-vs-One strategy.
10///
11/// # Type Parameters
12///
13/// - `D` - The dimension or number of features.
14/// - `K` - The type of the kernel function.
15/// - `C` - The type of the class label.
16///
17/// # Strategy
18///
19/// Uses One-vs-One (OvO) approach: trains a binary classifier for each pair of classes.
20/// For N classes, trains N*(N-1)/2 binary classifiers. Prediction is done by majority voting.
21#[allow(clippy::doc_markdown)]
22#[derive(Debug, Clone, PartialEq)]
23pub struct MultiClassSVM<const D: usize, K, C>
24where
25    K: Kernel<D>,
26    C: Eq + std::hash::Hash + Clone,
27{
28    /// Binary classifiers for each pair of classes.
29    /// Key is (class1, class2) where class1 < class2 (in terms of internal ordering).
30    classifiers: HashMap<(usize, usize), BinarySVM<D, K>>,
31    /// Mapping from class label to internal index.
32    class_to_index: HashMap<C, usize>,
33    /// Mapping from internal index to class label.
34    index_to_class: Vec<C>,
35}
36
37impl<const D: usize, K> SVMParams<D, K>
38where
39    K: Kernel<D>,
40{
41    /// Fits a multi-class SVM using One-vs-One strategy.
42    ///
43    /// # Arguments
44    ///
45    /// - `data` - Training data as an iterator of (`feature_vector`, `class_label`) tuples.
46    ///
47    /// # Returns
48    ///
49    /// A fitted [`MultiClassSVM`] that can predict class labels.
50    ///
51    /// # Notes
52    ///
53    /// Calling this method with large datasets may consume significant memory and time, consider using [`fit_multiclass_with_options`](Self::fit_multiclass_with_options) to limit samples per binary classifier.
54    ///
55    /// # Examples
56    ///
57    /// ```
58    /// use pr_ml::{RowVector, svm::{SVMParams, LinearKernel, MultiClassSVM}};
59    ///
60    /// let data = vec![
61    ///     (RowVector::from([1.0, 2.0]), 0),
62    ///     (RowVector::from([2.0, 3.0]), 0),
63    ///     (RowVector::from([-1.0, -2.0]), 1),
64    ///     (RowVector::from([-2.0, -3.0]), 1),
65    /// ];
66    ///
67    /// let svm: MultiClassSVM<2, LinearKernel, _> = SVMParams::new().fit_multiclass(data);
68    /// let prediction = svm.predict(&RowVector::from([1.5, 2.5]));
69    /// assert_eq!(prediction, 0);
70    /// ```
71    pub fn fit_multiclass<I, C>(self, data: I) -> MultiClassSVM<D, K, C>
72    where
73        I: IntoIterator<Item = (RowVector<D>, C)>,
74        C: Eq + std::hash::Hash + Clone + Ord,
75    {
76        self.fit_multiclass_with_options(data, |_, _, _| {}, 0)
77    }
78
79    /// Fits a multi-class SVM using One-vs-One strategy with progress callback.
80    ///
81    /// # Arguments
82    ///
83    /// - `data` - Training data as an iterator of (`feature_vector`, `class_label`) tuples.
84    /// - `callback` - Called after each binary classifier is trained with (`current`, `total`, `sample_count`).
85    ///
86    /// # Returns
87    ///
88    /// A fitted [`MultiClassSVM`] that can predict class labels.
89    ///
90    /// # Notes
91    ///
92    /// Calling this method with large datasets may consume significant memory and time, consider using [`fit_multiclass_with_options`](Self::fit_multiclass_with_options) to limit samples per binary classifier.
93    pub fn fit_multiclass_with_callback<I, C, F>(
94        self,
95        data: I,
96        callback: F,
97    ) -> MultiClassSVM<D, K, C>
98    where
99        I: IntoIterator<Item = (RowVector<D>, C)>,
100        C: Eq + std::hash::Hash + Clone + Ord,
101        F: FnMut(usize, usize, usize),
102    {
103        self.fit_multiclass_with_options(data, callback, 0)
104    }
105
106    /// Fits a multi-class SVM using One-vs-One strategy with progress callback and optional subsampling.
107    ///
108    /// # Arguments
109    ///
110    /// - `data` - Training data as an iterator of (`feature_vector`, `class_label`) tuples.
111    /// - `callback` - Called after each binary classifier is trained with (`current`, `total`, `sample_count`).
112    /// - `max_samples_per_pair` - Maximum samples to use per binary classifier (0 = use all).
113    ///
114    /// # Returns
115    ///
116    /// A fitted [`MultiClassSVM`] that can predict class labels.
117    pub fn fit_multiclass_with_options<I, C, F>(
118        self,
119        data: I,
120        mut callback: F,
121        max_samples_per_pair: usize,
122    ) -> MultiClassSVM<D, K, C>
123    where
124        I: IntoIterator<Item = (RowVector<D>, C)>,
125        C: Eq + std::hash::Hash + Clone + Ord,
126        F: FnMut(usize, usize, usize),
127    {
128        // Collect all data points
129        let data_vec: Vec<_> = data.into_iter().collect();
130
131        // Build class mappings
132        let mut unique_classes: Vec<C> = data_vec.iter().map(|(_, c)| c.clone()).collect();
133        unique_classes.sort_unstable();
134        unique_classes.dedup();
135
136        let class_to_index: HashMap<C, usize> = unique_classes
137            .iter()
138            .enumerate()
139            .map(|(i, c)| (c.clone(), i))
140            .collect();
141
142        let index_to_class = unique_classes;
143        let num_classes = index_to_class.len();
144
145        // Train binary classifiers for each pair of classes
146        let mut classifiers = HashMap::new();
147        let total_classifiers = num_classes * (num_classes - 1) / 2;
148        let mut current_classifier = 0;
149
150        for i in 0..num_classes {
151            for j in (i + 1)..num_classes {
152                // Filter data for classes i and j
153                let mut binary_data: Vec<_> = data_vec
154                    .iter()
155                    .filter_map(|(x, c)| {
156                        let class_idx = class_to_index[c];
157                        if class_idx == i {
158                            Some((*x, true)) // class i is positive
159                        } else if class_idx == j {
160                            Some((*x, false)) // class j is negative
161                        } else {
162                            None
163                        }
164                    })
165                    .collect();
166
167                // Subsample if requested
168                if max_samples_per_pair > 0 && binary_data.len() > max_samples_per_pair {
169                    let mut rng = rand::rng();
170                    binary_data.shuffle(&mut rng);
171                    binary_data.truncate(max_samples_per_pair);
172                }
173
174                // Train binary classifier
175                if !binary_data.is_empty() {
176                    let sample_count = binary_data.len();
177                    let binary_svm = self.clone().fit_binary(binary_data);
178                    classifiers.insert((i, j), binary_svm);
179                    current_classifier += 1;
180                    callback(current_classifier, total_classifiers, sample_count);
181                }
182            }
183        }
184
185        MultiClassSVM {
186            classifiers,
187            class_to_index,
188            index_to_class,
189        }
190    }
191}
192
193impl<const D: usize, K, C> MultiClassSVM<D, K, C>
194where
195    K: Kernel<D>,
196    C: Eq + std::hash::Hash + Clone,
197{
198    /// Predicts the class label for a given input vector using majority voting.
199    ///
200    /// # Arguments
201    ///
202    /// - `x` - The input feature vector.
203    ///
204    /// # Returns
205    ///
206    /// The predicted class label.
207    #[must_use]
208    pub fn predict(&self, x: &RowVector<D>) -> C {
209        let num_classes = self.index_to_class.len();
210        let mut votes = vec![0; num_classes];
211
212        // Vote using all binary classifiers
213        for (&(i, j), classifier) in &self.classifiers {
214            if classifier.predict(x) {
215                votes[i] += 1; // Vote for class i
216            } else {
217                votes[j] += 1; // Vote for class j
218            }
219        }
220
221        // Find class with maximum votes
222        let max_votes_idx = votes
223            .iter()
224            .enumerate()
225            .max_by_key(|(_, v)| *v)
226            .map_or(0, |(idx, _)| idx);
227
228        self.index_to_class[max_votes_idx].clone()
229    }
230
231    /// Predicts the class label and returns vote counts for all classes.
232    ///
233    /// # Arguments
234    ///
235    /// - `x` - The input feature vector.
236    ///
237    /// # Returns
238    ///
239    /// A tuple of (`predicted_class`, `votes_per_class`).
240    #[must_use]
241    pub fn predict_with_votes(&self, x: &RowVector<D>) -> (C, Vec<(C, usize)>) {
242        let num_classes = self.index_to_class.len();
243        let mut votes = vec![0; num_classes];
244
245        // Vote using all binary classifiers
246        for (&(i, j), classifier) in &self.classifiers {
247            if classifier.predict(x) {
248                votes[i] += 1;
249            } else {
250                votes[j] += 1;
251            }
252        }
253
254        // Find class with maximum votes
255        let max_votes_idx = votes
256            .iter()
257            .enumerate()
258            .max_by_key(|(_, v)| *v)
259            .map_or(0, |(idx, _)| idx);
260
261        let predicted_class = self.index_to_class[max_votes_idx].clone();
262
263        // Build vote summary
264        let vote_summary: Vec<_> = self
265            .index_to_class
266            .iter()
267            .enumerate()
268            .map(|(idx, class)| (class.clone(), votes[idx]))
269            .collect();
270
271        (predicted_class, vote_summary)
272    }
273
274    /// Returns the number of classes.
275    #[must_use]
276    pub const fn num_classes(&self) -> usize {
277        self.index_to_class.len()
278    }
279
280    /// Returns the number of binary classifiers.
281    #[must_use]
282    pub fn num_classifiers(&self) -> usize {
283        self.classifiers.len()
284    }
285
286    /// Returns a reference to the list of class labels.
287    #[must_use]
288    pub fn classes(&self) -> &[C] {
289        &self.index_to_class
290    }
291}