1use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41
42use crate::svm::Kernel;
43
44#[derive(Debug, Clone)]
58pub struct OneClassSVM<F, K> {
59 pub nu: F,
62 pub kernel: K,
64 pub tol: F,
66 pub max_iter: usize,
68 pub cache_size: usize,
70}
71
72impl<F: Float, K: Kernel<F>> OneClassSVM<F, K> {
73 #[must_use]
77 pub fn new(kernel: K) -> Self {
78 Self {
79 nu: F::from(0.5).unwrap(),
80 kernel,
81 tol: F::from(1e-3).unwrap(),
82 max_iter: 10000,
83 cache_size: 1024,
84 }
85 }
86
87 #[must_use]
89 pub fn with_nu(mut self, nu: F) -> Self {
90 self.nu = nu;
91 self
92 }
93
94 #[must_use]
96 pub fn with_tol(mut self, tol: F) -> Self {
97 self.tol = tol;
98 self
99 }
100
101 #[must_use]
103 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
104 self.max_iter = max_iter;
105 self
106 }
107
108 #[must_use]
110 pub fn with_cache_size(mut self, cache_size: usize) -> Self {
111 self.cache_size = cache_size;
112 self
113 }
114}
115
116#[derive(Debug, Clone)]
122pub struct FittedOneClassSVM<F, K> {
123 kernel: K,
125 support_vectors: Vec<Vec<F>>,
127 dual_coefs: Vec<F>,
129 rho: F,
131}
132
133impl<F: Float + Send + Sync + ScalarOperand + 'static, K: Kernel<F> + 'static>
134 Fit<Array2<F>, ()> for OneClassSVM<F, K>
135{
136 type Fitted = FittedOneClassSVM<F, K>;
137 type Error = FerroError;
138
139 fn fit(
146 &self,
147 x: &Array2<F>,
148 _y: &(),
149 ) -> Result<FittedOneClassSVM<F, K>, FerroError> {
150 if self.nu <= F::zero() || self.nu > F::one() {
151 return Err(FerroError::InvalidParameter {
152 name: "nu".into(),
153 reason: "must be in (0, 1]".into(),
154 });
155 }
156
157 let n_samples = x.nrows();
158 let n_features = x.ncols();
159
160 if n_samples == 0 {
161 return Err(FerroError::InsufficientSamples {
162 required: 1,
163 actual: 0,
164 context: "OneClassSVM requires at least one sample".into(),
165 });
166 }
167
168 let c = F::one() / (F::from(n_samples).unwrap() * self.nu);
176 let data: Vec<Vec<F>> = (0..n_samples).map(|i| x.row(i).to_vec()).collect();
177
178 let init_alpha = F::one() / F::from(n_samples).unwrap();
180 let mut alphas = vec![init_alpha.min(c); n_samples];
181
182 let alpha_sum: F = alphas.iter().copied().fold(F::zero(), |a, b| a + b);
184 if alpha_sum < F::one() {
185 let remaining = F::one() - alpha_sum;
187 let per_sample = remaining / F::from(n_samples).unwrap();
188 for alpha in &mut alphas {
189 *alpha = (*alpha + per_sample).min(c);
190 }
191 }
192
193 let eps = F::from(1e-12).unwrap_or(F::epsilon());
195 let two = F::one() + F::one();
196
197 let mut grad = vec![F::zero(); n_samples];
198 for i in 0..n_samples {
199 for j in 0..n_samples {
200 grad[i] = grad[i] + alphas[j] * self.kernel.compute(&data[i], &data[j]);
201 }
202 }
203
204 for _iter in 0..self.max_iter {
206 let mut i_best = None;
209 let mut i_max_grad = F::neg_infinity();
210 let mut j_best = None;
211 let mut j_min_grad = F::infinity();
212
213 for k in 0..n_samples {
214 if alphas[k] > eps && grad[k] > i_max_grad {
215 i_max_grad = grad[k];
216 i_best = Some(k);
217 }
218 if alphas[k] < c - eps && grad[k] < j_min_grad {
219 j_min_grad = grad[k];
220 j_best = Some(k);
221 }
222 }
223
224 if i_best.is_none() || j_best.is_none() || i_max_grad - j_min_grad < self.tol {
225 break;
226 }
227
228 let i = i_best.unwrap();
229 let j = j_best.unwrap();
230
231 if i == j {
232 break;
233 }
234
235 let kii = self.kernel.compute(&data[i], &data[i]);
236 let kjj = self.kernel.compute(&data[j], &data[j]);
237 let kij = self.kernel.compute(&data[i], &data[j]);
238 let eta = kii + kjj - two * kij;
239
240 if eta <= eps {
241 continue;
242 }
243
244 let delta = (grad[i] - grad[j]) / eta;
246 let delta = delta.min(alphas[i]).min(c - alphas[j]);
247
248 if delta.abs() < eps {
249 continue;
250 }
251
252 alphas[i] = alphas[i] - delta;
253 alphas[j] = alphas[j] + delta;
254
255 for k in 0..n_samples {
257 let kki = self.kernel.compute(&data[k], &data[i]);
258 let kkj = self.kernel.compute(&data[k], &data[j]);
259 grad[k] = grad[k] - delta * kki + delta * kkj;
260 }
261 }
262
263 let mut rho_sum = F::zero();
266 let mut rho_count = 0usize;
267
268 for i in 0..n_samples {
269 if alphas[i] > eps && alphas[i] < c - eps {
270 rho_sum = rho_sum + grad[i];
271 rho_count += 1;
272 }
273 }
274
275 let rho = if rho_count > 0 {
276 rho_sum / F::from(rho_count).unwrap()
277 } else {
278 let sv_grads: Vec<F> = (0..n_samples)
280 .filter(|&i| alphas[i] > eps)
281 .map(|i| grad[i])
282 .collect();
283
284 if sv_grads.is_empty() {
285 F::zero()
286 } else {
287 let min_g = sv_grads.iter().fold(F::infinity(), |a, &b| a.min(b));
288 let max_g = sv_grads.iter().fold(F::neg_infinity(), |a, &b| a.max(b));
289 (min_g + max_g) / two
290 }
291 };
292
293 let mut support_vectors = Vec::new();
295 let mut dual_coefs = Vec::new();
296
297 for (i, &alpha) in alphas.iter().enumerate() {
298 if alpha > eps {
299 support_vectors.push(data[i].clone());
300 dual_coefs.push(alpha);
301 }
302 }
303
304 if support_vectors.is_empty() {
306 let weight = F::one() / F::from(n_samples).unwrap();
307 for row in &data {
308 support_vectors.push(row.clone());
309 dual_coefs.push(weight);
310 }
311 }
312
313 let _ = n_features; Ok(FittedOneClassSVM {
316 kernel: self.kernel.clone(),
317 support_vectors,
318 dual_coefs,
319 rho,
320 })
321 }
322}
323
324impl<F: Float + Send + Sync + ScalarOperand + 'static, K: Kernel<F> + 'static>
325 FittedOneClassSVM<F, K>
326{
327 fn decision_value(&self, x: &[F]) -> F {
331 let mut val = F::zero();
332 for (sv, &coef) in self.support_vectors.iter().zip(self.dual_coefs.iter()) {
333 val = val + coef * self.kernel.compute(sv, x);
334 }
335 val - self.rho
336 }
337
338 pub fn decision_function(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
347 let n_samples = x.nrows();
348 let mut result = Array1::<F>::zeros(n_samples);
349 for s in 0..n_samples {
350 let xi: Vec<F> = x.row(s).to_vec();
351 result[s] = self.decision_value(&xi);
352 }
353 Ok(result)
354 }
355}
356
357impl<F: Float + Send + Sync + ScalarOperand + 'static, K: Kernel<F> + 'static> Predict<Array2<F>>
358 for FittedOneClassSVM<F, K>
359{
360 type Output = Array1<isize>;
361 type Error = FerroError;
362
363 fn predict(&self, x: &Array2<F>) -> Result<Array1<isize>, FerroError> {
369 let n_samples = x.nrows();
370 let mut predictions = Array1::<isize>::zeros(n_samples);
371
372 for s in 0..n_samples {
373 let xi: Vec<F> = x.row(s).to_vec();
374 let val = self.decision_value(&xi);
375 predictions[s] = if val >= F::zero() { 1 } else { -1 };
376 }
377
378 Ok(predictions)
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::svm::{LinearKernel, RbfKernel};
386 use ndarray::Array2;
387
388 fn make_cluster_data() -> Array2<f64> {
389 Array2::from_shape_vec(
390 (8, 2),
391 vec![
392 1.0, 1.0, 1.1, 1.0, 1.0, 1.1, 1.1, 1.1,
393 0.9, 0.9, 1.0, 0.9, 0.9, 1.0, 1.05, 1.05,
394 ],
395 )
396 .unwrap()
397 }
398
399 #[test]
400 fn test_one_class_svm_fit() {
401 let x = make_cluster_data();
402 let model = OneClassSVM::<f64, RbfKernel<f64>>::new(RbfKernel::with_gamma(10.0));
403 let result = model.fit(&x, &());
404 assert!(result.is_ok());
405 }
406
407 #[test]
408 fn test_one_class_svm_inliers() {
409 let x = make_cluster_data();
410 let model = OneClassSVM::new(RbfKernel::with_gamma(10.0)).with_nu(0.1);
411 let fitted = model.fit(&x, &()).unwrap();
412 let preds = fitted.predict(&x).unwrap();
413
414 let inliers: usize = preds.iter().filter(|&&p| p == 1).count();
416 assert!(
417 inliers >= 6,
418 "Expected at least 6 inliers, got {inliers}"
419 );
420 }
421
422 #[test]
423 fn test_one_class_svm_outlier_detection() {
424 let x_train = Array2::from_shape_vec(
425 (8, 2),
426 vec![
427 0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 0.1, 0.1,
428 -0.1, 0.0, 0.0, -0.1, 0.05, 0.05, -0.05, -0.05,
429 ],
430 )
431 .unwrap();
432
433 let model = OneClassSVM::new(RbfKernel::with_gamma(10.0)).with_nu(0.1);
434 let fitted = model.fit(&x_train, &()).unwrap();
435
436 let x_outlier = Array2::from_shape_vec((1, 2), vec![100.0, 100.0]).unwrap();
438 let preds = fitted.predict(&x_outlier).unwrap();
439 assert_eq!(preds[0], -1, "Far-away point should be an outlier");
440 }
441
442 #[test]
443 fn test_one_class_svm_decision_function() {
444 let x = make_cluster_data();
445 let model = OneClassSVM::new(RbfKernel::with_gamma(10.0)).with_nu(0.1);
446 let fitted = model.fit(&x, &()).unwrap();
447
448 let df = fitted.decision_function(&x).unwrap();
449 assert_eq!(df.len(), 8);
450
451 let positive: usize = df.iter().filter(|&&v| v >= 0.0).count();
453 assert!(positive >= 6, "Expected at least 6 positive df, got {positive}");
454 }
455
456 #[test]
457 fn test_one_class_svm_invalid_nu() {
458 let x = Array2::from_shape_vec((4, 2), vec![1.0; 8]).unwrap();
459
460 let model = OneClassSVM::new(RbfKernel::<f64>::new()).with_nu(0.0);
461 assert!(model.fit(&x, &()).is_err());
462
463 let model2 = OneClassSVM::new(RbfKernel::<f64>::new()).with_nu(1.5);
464 assert!(model2.fit(&x, &()).is_err());
465 }
466
467 #[test]
468 fn test_one_class_svm_empty_data() {
469 let x = Array2::<f64>::zeros((0, 2));
470 let model = OneClassSVM::new(RbfKernel::<f64>::new());
471 assert!(model.fit(&x, &()).is_err());
472 }
473
474 #[test]
475 fn test_one_class_svm_builder_pattern() {
476 let model = OneClassSVM::<f64, LinearKernel>::new(LinearKernel)
477 .with_nu(0.3)
478 .with_tol(1e-4)
479 .with_max_iter(5000)
480 .with_cache_size(2048);
481
482 assert!((model.nu - 0.3).abs() < 1e-10);
483 assert!((model.tol - 1e-4).abs() < 1e-10);
484 assert_eq!(model.max_iter, 5000);
485 assert_eq!(model.cache_size, 2048);
486 }
487
488 #[test]
489 fn test_one_class_svm_linear_kernel() {
490 let x = make_cluster_data();
491 let model = OneClassSVM::new(LinearKernel).with_nu(0.5);
492 let fitted = model.fit(&x, &()).unwrap();
493 let preds = fitted.predict(&x).unwrap();
494 assert_eq!(preds.len(), 8);
495 }
496
497 #[test]
498 fn test_one_class_svm_single_sample() {
499 let x = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
500 let model = OneClassSVM::new(RbfKernel::with_gamma(1.0)).with_nu(0.5);
501 let result = model.fit(&x, &());
502 assert!(result.is_ok());
503 }
504}