#![cfg(feature = "machine_learning")]
use ndarray::{arr1, arr2};
use rustyml::KernelType;
use rustyml::machine_learning::svc::SVC;
#[test]
fn test_svc_constructor() {
let svc = SVC::new(KernelType::Linear, 1.0, 0.001, 100).unwrap();
assert_eq!(svc.get_regularization_parameter(), 1.0);
assert_eq!(svc.get_tolerance(), 0.001);
assert_eq!(svc.get_max_iterations(), 100);
match svc.get_kernel() {
KernelType::Linear => (), _ => panic!("Expected linear kernel"),
}
}
#[test]
fn test_svc_default() {
let svc = SVC::default();
match svc.get_kernel() {
KernelType::RBF { gamma } => {
assert!((gamma - 0.1).abs() < 1e-10);
}
_ => panic!("Expected RBF kernel with gamma=1.0"),
}
assert_eq!(svc.get_regularization_parameter(), 1.0);
assert_eq!(svc.get_tolerance(), 0.001);
assert_eq!(svc.get_max_iterations(), 1000);
assert_eq!(svc.get_epsilon(), 1e-8);
}
#[test]
fn test_getters_before_fit() {
let svc = SVC::default();
assert!(svc.get_alphas().is_none());
assert!(svc.get_support_vectors().is_none());
assert!(svc.get_support_vector_labels().is_none());
assert!(svc.get_bias().is_none());
}
#[test]
fn test_fit_and_predict_linear() {
let x = arr2(&[
[1.0, 2.0],
[2.0, 3.0],
[3.0, 3.0],
[2.0, 1.0],
[-1.0, -2.0],
[-2.0, -3.0],
[-3.0, -2.0],
[-2.0, -1.0],
]);
let y = arr1(&[1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0]);
let mut svc = SVC::new(KernelType::Linear, 10.0, 0.001, 10000).unwrap();
let fit_result = svc.fit(&x.view(), &y.view());
assert!(fit_result.is_ok());
assert!(svc.get_alphas().is_some());
assert!(svc.get_support_vectors().is_some());
assert!(svc.get_support_vector_labels().is_some());
assert!(svc.get_bias().is_some());
let predictions = svc.predict(&x.view()).unwrap();
let mut correct_count = 0;
for (pre_val, act_val) in predictions.iter().zip(y.iter()) {
if pre_val == act_val {
correct_count += 1;
}
}
assert!(correct_count >= 4);
}
#[test]
fn test_fit_and_predict_rbf() {
let x = arr2(&[
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
[0.5, 0.5],
[1.5, 1.5],
[-0.5, -0.5],
[-1.0, -1.0],
]);
let y = arr1(&[-1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0]);
let mut svc = SVC::new(KernelType::RBF { gamma: 10.0 }, 10.0, 0.001, 10000).unwrap();
let fit_result = svc.fit(&x.view(), &y.view());
assert!(fit_result.is_ok());
let predictions = svc.predict(&x.view()).unwrap();
let mut correct_count = 0;
for (pre_val, act_val) in predictions.iter().zip(y.iter()) {
if pre_val == act_val {
correct_count += 1;
}
}
assert!(correct_count >= 5);
}
#[test]
fn test_error_handling() {
let mut svc = SVC::default();
let x = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let y = arr1(&[1.0, -1.0, 1.0]);
let result = svc.fit(&x.view(), &y.view());
assert!(result.is_err());
let test_x = arr2(&[[1.0, 2.0]]);
let predict_result = svc.predict(&test_x.view());
assert!(predict_result.is_err());
let decision_result = svc.decision_function(&test_x.view());
assert!(decision_result.is_err());
}
#[test]
fn test_different_kernels() {
let poly_svc = SVC::new(
KernelType::Poly {
degree: 3,
gamma: 0.1,
coef0: 0.0,
},
1.0,
0.001,
100,
)
.unwrap();
let sigmoid_svc = SVC::new(
KernelType::Sigmoid {
gamma: 0.1,
coef0: 0.0,
},
1.0,
0.001,
100,
)
.unwrap();
let cosine_svc = SVC::new(KernelType::Cosine, 1.0, 0.001, 100).unwrap();
match poly_svc.get_kernel() {
KernelType::Poly {
degree,
gamma,
coef0,
} => {
assert_eq!(degree, 3);
assert!((gamma - 0.1).abs() < 1e-10);
assert!((coef0 - 0.0).abs() < 1e-10);
}
_ => panic!("Expected polynomial kernel"),
}
match sigmoid_svc.get_kernel() {
KernelType::Sigmoid { gamma, coef0 } => {
assert!((gamma - 0.1).abs() < 1e-10);
assert!((coef0 - 0.0).abs() < 1e-10);
}
_ => panic!("Expected sigmoid kernel"),
}
match cosine_svc.get_kernel() {
KernelType::Cosine => (),
_ => panic!("Expected cosine kernel"),
}
}
#[test]
fn test_cosine_kernel_constructor() {
let svc = SVC::new(KernelType::Cosine, 1.0, 0.001, 100).unwrap();
match svc.get_kernel() {
KernelType::Cosine => (),
_ => panic!("Expected cosine kernel"),
}
}
#[test]
fn test_decision_function() {
let x = arr2(&[[1.0, 2.0], [2.0, 3.0], [-1.0, -2.0], [-2.0, -3.0]]);
let y = arr1(&[1.0, 1.0, -1.0, -1.0]);
let mut svc = SVC::new(KernelType::Linear, 1.0, 0.001, 100).unwrap();
svc.fit(&x.view(), &y.view()).unwrap();
let test_point = arr2(&[[0.0, 0.0]]);
let decisions = svc.decision_function(&test_point.view()).unwrap();
assert_eq!(decisions.len(), 1);
}