1use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41
42use crate::svm::{FittedSVC, FittedSVR, Kernel, SVC, SVR};
43
44#[derive(Debug, Clone)]
58pub struct NuSVC<F, K> {
59 pub nu: F,
61 pub kernel: K,
63 pub tol: F,
65 pub max_iter: usize,
67 pub cache_size: usize,
69}
70
71impl<F: Float, K: Kernel<F>> NuSVC<F, K> {
72 #[must_use]
76 pub fn new(kernel: K) -> Self {
77 Self {
78 nu: F::from(0.5).unwrap(),
79 kernel,
80 tol: F::from(1e-3).unwrap(),
81 max_iter: 10000,
82 cache_size: 1024,
83 }
84 }
85
86 #[must_use]
88 pub fn with_nu(mut self, nu: F) -> Self {
89 self.nu = nu;
90 self
91 }
92
93 #[must_use]
95 pub fn with_tol(mut self, tol: F) -> Self {
96 self.tol = tol;
97 self
98 }
99
100 #[must_use]
102 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
103 self.max_iter = max_iter;
104 self
105 }
106
107 #[must_use]
109 pub fn with_cache_size(mut self, cache_size: usize) -> Self {
110 self.cache_size = cache_size;
111 self
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct FittedNuSVC<F, K>(FittedSVC<F, K>);
118
119impl<F: Float + Send + Sync + ScalarOperand + 'static, K: Kernel<F> + 'static>
120 Fit<Array2<F>, Array1<usize>> for NuSVC<F, K>
121{
122 type Fitted = FittedNuSVC<F, K>;
123 type Error = FerroError;
124
125 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedNuSVC<F, K>, FerroError> {
132 if self.nu <= F::zero() || self.nu > F::one() {
133 return Err(FerroError::InvalidParameter {
134 name: "nu".into(),
135 reason: "must be in (0, 1]".into(),
136 });
137 }
138
139 let n_samples = x.nrows();
140 if n_samples == 0 {
141 return Err(FerroError::InsufficientSamples {
142 required: 1,
143 actual: 0,
144 context: "NuSVC requires at least one sample".into(),
145 });
146 }
147
148 let n_f = F::from(n_samples).unwrap();
149 let c = F::one() / (self.nu * n_f);
150
151 let svc = SVC::new(self.kernel.clone())
152 .with_c(c)
153 .with_tol(self.tol)
154 .with_max_iter(self.max_iter)
155 .with_cache_size(self.cache_size);
156
157 let fitted = svc.fit(x, y)?;
158 Ok(FittedNuSVC(fitted))
159 }
160}
161
162impl<F: Float + Send + Sync + ScalarOperand + 'static, K: Kernel<F> + 'static> Predict<Array2<F>>
163 for FittedNuSVC<F, K>
164{
165 type Output = Array1<usize>;
166 type Error = FerroError;
167
168 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
175 self.0.predict(x)
176 }
177}
178
179impl<F: Float, K: Kernel<F>> FittedNuSVC<F, K> {
180 pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
188 self.0.decision_function(x)
189 }
190}
191
192#[derive(Debug, Clone)]
206pub struct NuSVR<F, K> {
207 pub nu: F,
209 pub kernel: K,
211 pub tol: F,
213 pub max_iter: usize,
215 pub cache_size: usize,
217}
218
219impl<F: Float, K: Kernel<F>> NuSVR<F, K> {
220 #[must_use]
224 pub fn new(kernel: K) -> Self {
225 Self {
226 nu: F::from(0.5).unwrap(),
227 kernel,
228 tol: F::from(1e-3).unwrap(),
229 max_iter: 10000,
230 cache_size: 1024,
231 }
232 }
233
234 #[must_use]
236 pub fn with_nu(mut self, nu: F) -> Self {
237 self.nu = nu;
238 self
239 }
240
241 #[must_use]
243 pub fn with_tol(mut self, tol: F) -> Self {
244 self.tol = tol;
245 self
246 }
247
248 #[must_use]
250 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
251 self.max_iter = max_iter;
252 self
253 }
254
255 #[must_use]
257 pub fn with_cache_size(mut self, cache_size: usize) -> Self {
258 self.cache_size = cache_size;
259 self
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct FittedNuSVR<F, K>(FittedSVR<F, K>);
266
267impl<F: Float + Send + Sync + ScalarOperand + 'static, K: Kernel<F> + 'static>
268 Fit<Array2<F>, Array1<F>> for NuSVR<F, K>
269{
270 type Fitted = FittedNuSVR<F, K>;
271 type Error = FerroError;
272
273 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedNuSVR<F, K>, FerroError> {
280 if self.nu <= F::zero() || self.nu > F::one() {
281 return Err(FerroError::InvalidParameter {
282 name: "nu".into(),
283 reason: "must be in (0, 1]".into(),
284 });
285 }
286
287 let n_samples = x.nrows();
288 if n_samples == 0 {
289 return Err(FerroError::InsufficientSamples {
290 required: 1,
291 actual: 0,
292 context: "NuSVR requires at least one sample".into(),
293 });
294 }
295
296 let n_f = F::from(n_samples).unwrap();
297 let c = F::one() / (self.nu * n_f);
298
299 let svr = SVR::new(self.kernel.clone())
300 .with_c(c)
301 .with_epsilon(F::zero())
302 .with_tol(self.tol)
303 .with_max_iter(self.max_iter)
304 .with_cache_size(self.cache_size);
305
306 let fitted = svr.fit(x, y)?;
307 Ok(FittedNuSVR(fitted))
308 }
309}
310
311impl<F: Float + Send + Sync + ScalarOperand + 'static, K: Kernel<F> + 'static> Predict<Array2<F>>
312 for FittedNuSVR<F, K>
313{
314 type Output = Array1<F>;
315 type Error = FerroError;
316
317 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
323 self.0.predict(x)
324 }
325}
326
327impl<F: Float, K: Kernel<F>> FittedNuSVR<F, K> {
328 pub fn decision_function(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
336 self.0.decision_function(x)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::svm::{LinearKernel, RbfKernel};
344 use ndarray::array;
345
346 #[test]
347 fn test_nusvc_linear_separable() {
348 let x = Array2::from_shape_vec(
349 (8, 2),
350 vec![
351 1.0, 1.0, 1.5, 1.0, 1.0, 1.5, 1.5, 1.5, 5.0, 5.0, 5.5, 5.0, 5.0, 5.5, 5.5, 5.5, ],
354 )
355 .unwrap();
356 let y = array![0usize, 0, 0, 0, 1, 1, 1, 1];
357
358 let model = NuSVC::<f64, LinearKernel>::new(LinearKernel).with_nu(0.5);
359 let fitted = model.fit(&x, &y).unwrap();
360 let preds = fitted.predict(&x).unwrap();
361
362 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
363 assert!(correct >= 6, "Expected at least 6 correct, got {correct}");
364 }
365
366 #[test]
367 fn test_nusvc_rbf() {
368 let x = Array2::from_shape_vec(
369 (8, 2),
370 vec![
371 1.0, 1.0, 1.5, 1.0, 1.0, 1.5, 1.5, 1.5, 5.0, 5.0, 5.5, 5.0, 5.0, 5.5, 5.5, 5.5,
372 ],
373 )
374 .unwrap();
375 let y = array![0usize, 0, 0, 0, 1, 1, 1, 1];
376
377 let model = NuSVC::new(RbfKernel::with_gamma(0.5)).with_nu(0.5);
378 let fitted = model.fit(&x, &y).unwrap();
379 let preds = fitted.predict(&x).unwrap();
380
381 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
382 assert!(correct >= 6, "Expected at least 6 correct, got {correct}");
383 }
384
385 #[test]
386 fn test_nusvc_decision_function() {
387 let x = Array2::from_shape_vec(
388 (6, 2),
389 vec![1.0, 1.0, 1.5, 1.0, 1.0, 1.5, 5.0, 5.0, 5.5, 5.0, 5.0, 5.5],
390 )
391 .unwrap();
392 let y = array![0usize, 0, 0, 1, 1, 1];
393
394 let model = NuSVC::<f64, LinearKernel>::new(LinearKernel).with_nu(0.5);
395 let fitted = model.fit(&x, &y).unwrap();
396 let df = fitted.decision_function(&x).unwrap();
397 assert_eq!(df.nrows(), 6);
398 }
399
400 #[test]
401 fn test_nusvc_invalid_nu_zero() {
402 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
403 let y = array![0usize, 0, 1, 1];
404
405 let model = NuSVC::<f64, LinearKernel>::new(LinearKernel).with_nu(0.0);
406 assert!(model.fit(&x, &y).is_err());
407 }
408
409 #[test]
410 fn test_nusvc_invalid_nu_above_one() {
411 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
412 let y = array![0usize, 0, 1, 1];
413
414 let model = NuSVC::<f64, LinearKernel>::new(LinearKernel).with_nu(1.5);
415 assert!(model.fit(&x, &y).is_err());
416 }
417
418 #[test]
419 fn test_nusvc_nu_one() {
420 let x = Array2::from_shape_vec(
421 (6, 2),
422 vec![1.0, 1.0, 1.5, 1.0, 1.0, 1.5, 5.0, 5.0, 5.5, 5.0, 5.0, 5.5],
423 )
424 .unwrap();
425 let y = array![0usize, 0, 0, 1, 1, 1];
426
427 let model = NuSVC::<f64, LinearKernel>::new(LinearKernel).with_nu(1.0);
428 let result = model.fit(&x, &y);
429 assert!(result.is_ok());
431 }
432
433 #[test]
434 fn test_nusvr_simple() {
435 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
436 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
437
438 let model = NuSVR::new(LinearKernel).with_nu(0.5).with_max_iter(50000);
439 let fitted = model.fit(&x, &y).unwrap();
440 let preds = fitted.predict(&x).unwrap();
441
442 for (p, &actual) in preds.iter().zip(y.iter()) {
443 assert!(
444 (*p - actual).abs() < 3.0,
445 "NuSVR prediction {p} too far from actual {actual}"
446 );
447 }
448 }
449
450 #[test]
451 fn test_nusvr_decision_function() {
452 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
453 let y = array![2.0, 4.0, 6.0, 8.0];
454
455 let model = NuSVR::new(LinearKernel).with_nu(0.5).with_max_iter(50000);
456 let fitted = model.fit(&x, &y).unwrap();
457
458 let df = fitted.decision_function(&x).unwrap();
459 let preds = fitted.predict(&x).unwrap();
460
461 for i in 0..4 {
462 assert!((df[i] - preds[i]).abs() < 1e-10);
463 }
464 }
465
466 #[test]
467 fn test_nusvr_invalid_nu() {
468 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
469 let y = array![1.0, 2.0, 3.0, 4.0];
470
471 let model = NuSVR::new(LinearKernel).with_nu(0.0);
472 assert!(model.fit(&x, &y).is_err());
473
474 let model2 = NuSVR::new(LinearKernel).with_nu(-0.5);
475 assert!(model2.fit(&x, &y).is_err());
476 }
477
478 #[test]
479 fn test_nusvc_builder_pattern() {
480 let model = NuSVC::<f64, LinearKernel>::new(LinearKernel)
481 .with_nu(0.3)
482 .with_tol(1e-4)
483 .with_max_iter(5000)
484 .with_cache_size(2048);
485
486 assert!((model.nu - 0.3).abs() < 1e-10);
487 assert!((model.tol - 1e-4).abs() < 1e-10);
488 assert_eq!(model.max_iter, 5000);
489 assert_eq!(model.cache_size, 2048);
490 }
491
492 #[test]
493 fn test_nusvr_builder_pattern() {
494 let model = NuSVR::<f64, LinearKernel>::new(LinearKernel)
495 .with_nu(0.8)
496 .with_tol(1e-5)
497 .with_max_iter(20000)
498 .with_cache_size(512);
499
500 assert!((model.nu - 0.8).abs() < 1e-10);
501 assert!((model.tol - 1e-5).abs() < 1e-10);
502 assert_eq!(model.max_iter, 20000);
503 assert_eq!(model.cache_size, 512);
504 }
505}