use crate::error::SvmError;
use crate::types::{KernelType, SvmParameter, SvmType};
#[derive(Debug, Clone)]
pub struct SvmParameterBuilder {
param: SvmParameter,
}
impl Default for SvmParameterBuilder {
fn default() -> Self {
Self::new()
}
}
impl SvmParameterBuilder {
pub fn new() -> Self {
Self {
param: SvmParameter::default(),
}
}
pub fn svm_type(mut self, svm_type: SvmType) -> Self {
self.param.svm_type = svm_type;
self
}
pub fn kernel_type(mut self, kernel_type: KernelType) -> Self {
self.param.kernel_type = kernel_type;
self
}
pub fn degree(mut self, degree: i32) -> Self {
self.param.degree = degree;
self
}
pub fn gamma(mut self, gamma: f64) -> Self {
self.param.gamma = gamma;
self
}
pub fn coef0(mut self, coef0: f64) -> Self {
self.param.coef0 = coef0;
self
}
pub fn c(mut self, c: f64) -> Self {
self.param.c = c;
self
}
pub fn nu(mut self, nu: f64) -> Self {
self.param.nu = nu;
self
}
pub fn p(mut self, p: f64) -> Self {
self.param.p = p;
self
}
pub fn cache_size(mut self, cache_size: f64) -> Self {
self.param.cache_size = cache_size;
self
}
pub fn eps(mut self, eps: f64) -> Self {
self.param.eps = eps;
self
}
pub fn shrinking(mut self, shrinking: bool) -> Self {
self.param.shrinking = shrinking;
self
}
pub fn probability(mut self, probability: bool) -> Self {
self.param.probability = probability;
self
}
pub fn weight(mut self, label: i32, weight: f64) -> Self {
self.param.weight.push((label, weight));
self
}
pub fn weights(mut self, weights: Vec<(i32, f64)>) -> Self {
self.param.weight = weights;
self
}
pub fn build(self) -> Result<SvmParameter, SvmError> {
self.param.validate()?;
Ok(self.param)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_method_build_equals_parameter_default() {
assert_eq!(
SvmParameterBuilder::new().build().unwrap(),
SvmParameter::default()
);
}
#[test]
#[allow(clippy::field_reassign_with_default)]
fn happy_path_equals_field_assignment() {
let built = SvmParameterBuilder::new()
.svm_type(SvmType::EpsilonSvr)
.kernel_type(KernelType::Sigmoid)
.degree(2)
.gamma(0.25)
.coef0(1.5)
.c(2.0)
.nu(0.25)
.p(0.2)
.cache_size(256.0)
.eps(0.0001)
.shrinking(false)
.probability(true)
.weight(1, 3.0)
.weight(-1, 0.5)
.build()
.unwrap();
let mut assigned = SvmParameter::default();
assigned.svm_type = SvmType::EpsilonSvr;
assigned.kernel_type = KernelType::Sigmoid;
assigned.degree = 2;
assigned.gamma = 0.25;
assigned.coef0 = 1.5;
assigned.c = 2.0;
assigned.nu = 0.25;
assigned.p = 0.2;
assigned.cache_size = 256.0;
assigned.eps = 0.0001;
assigned.shrinking = false;
assigned.probability = true;
assigned.weight = vec![(1, 3.0), (-1, 0.5)];
assert_eq!(built, assigned);
}
#[test]
fn weights_replaces_weight_list() {
let built = SvmParameterBuilder::new()
.weight(1, 2.0)
.weights(vec![(3, 4.0), (5, 6.0)])
.build()
.unwrap();
assert_eq!(built.weight, vec![(3, 4.0), (5, 6.0)]);
}
#[test]
fn negative_gamma_rejected_by_build() {
assert!(matches!(
SvmParameterBuilder::new().gamma(-1.0).build(),
Err(SvmError::InvalidParameter(_))
));
}
#[test]
fn non_positive_eps_rejected_by_build() {
assert!(matches!(
SvmParameterBuilder::new().eps(0.0).build(),
Err(SvmError::InvalidParameter(_))
));
}
#[test]
fn non_positive_cache_size_rejected_by_build() {
assert!(matches!(
SvmParameterBuilder::new().cache_size(0.0).build(),
Err(SvmError::InvalidParameter(_))
));
}
#[test]
fn negative_polynomial_degree_rejected_by_build() {
assert!(matches!(
SvmParameterBuilder::new()
.kernel_type(KernelType::Polynomial)
.degree(-1)
.build(),
Err(SvmError::InvalidParameter(_))
));
}
}