1use crate::error::SvmError;
24use crate::types::{KernelType, SvmParameter, SvmType};
25
26#[derive(Debug, Clone)]
33pub struct SvmParameterBuilder {
34 param: SvmParameter,
35}
36
37impl Default for SvmParameterBuilder {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl SvmParameterBuilder {
44 pub fn new() -> Self {
46 Self {
47 param: SvmParameter::default(),
48 }
49 }
50
51 pub fn svm_type(mut self, svm_type: SvmType) -> Self {
53 self.param.svm_type = svm_type;
54 self
55 }
56
57 pub fn kernel_type(mut self, kernel_type: KernelType) -> Self {
59 self.param.kernel_type = kernel_type;
60 self
61 }
62
63 pub fn degree(mut self, degree: i32) -> Self {
65 self.param.degree = degree;
66 self
67 }
68
69 pub fn gamma(mut self, gamma: f64) -> Self {
72 self.param.gamma = gamma;
73 self
74 }
75
76 pub fn coef0(mut self, coef0: f64) -> Self {
78 self.param.coef0 = coef0;
79 self
80 }
81
82 pub fn c(mut self, c: f64) -> Self {
84 self.param.c = c;
85 self
86 }
87
88 pub fn nu(mut self, nu: f64) -> Self {
90 self.param.nu = nu;
91 self
92 }
93
94 pub fn p(mut self, p: f64) -> Self {
96 self.param.p = p;
97 self
98 }
99
100 pub fn cache_size(mut self, cache_size: f64) -> Self {
102 self.param.cache_size = cache_size;
103 self
104 }
105
106 pub fn eps(mut self, eps: f64) -> Self {
108 self.param.eps = eps;
109 self
110 }
111
112 pub fn shrinking(mut self, shrinking: bool) -> Self {
114 self.param.shrinking = shrinking;
115 self
116 }
117
118 pub fn probability(mut self, probability: bool) -> Self {
120 self.param.probability = probability;
121 self
122 }
123
124 pub fn weight(mut self, label: i32, weight: f64) -> Self {
126 self.param.weight.push((label, weight));
127 self
128 }
129
130 pub fn weights(mut self, weights: Vec<(i32, f64)>) -> Self {
132 self.param.weight = weights;
133 self
134 }
135
136 pub fn build(self) -> Result<SvmParameter, SvmError> {
142 self.param.validate()?;
143 Ok(self.param)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn no_method_build_equals_parameter_default() {
153 assert_eq!(
154 SvmParameterBuilder::new().build().unwrap(),
155 SvmParameter::default()
156 );
157 }
158
159 #[test]
160 #[allow(clippy::field_reassign_with_default)]
161 fn happy_path_equals_field_assignment() {
162 let built = SvmParameterBuilder::new()
163 .svm_type(SvmType::EpsilonSvr)
164 .kernel_type(KernelType::Sigmoid)
165 .degree(2)
166 .gamma(0.25)
167 .coef0(1.5)
168 .c(2.0)
169 .nu(0.25)
170 .p(0.2)
171 .cache_size(256.0)
172 .eps(0.0001)
173 .shrinking(false)
174 .probability(true)
175 .weight(1, 3.0)
176 .weight(-1, 0.5)
177 .build()
178 .unwrap();
179
180 let mut assigned = SvmParameter::default();
181 assigned.svm_type = SvmType::EpsilonSvr;
182 assigned.kernel_type = KernelType::Sigmoid;
183 assigned.degree = 2;
184 assigned.gamma = 0.25;
185 assigned.coef0 = 1.5;
186 assigned.c = 2.0;
187 assigned.nu = 0.25;
188 assigned.p = 0.2;
189 assigned.cache_size = 256.0;
190 assigned.eps = 0.0001;
191 assigned.shrinking = false;
192 assigned.probability = true;
193 assigned.weight = vec![(1, 3.0), (-1, 0.5)];
194
195 assert_eq!(built, assigned);
196 }
197
198 #[test]
199 fn weights_replaces_weight_list() {
200 let built = SvmParameterBuilder::new()
201 .weight(1, 2.0)
202 .weights(vec![(3, 4.0), (5, 6.0)])
203 .build()
204 .unwrap();
205
206 assert_eq!(built.weight, vec![(3, 4.0), (5, 6.0)]);
207 }
208
209 #[test]
210 fn negative_gamma_rejected_by_build() {
211 assert!(matches!(
212 SvmParameterBuilder::new().gamma(-1.0).build(),
213 Err(SvmError::InvalidParameter(_))
214 ));
215 }
216
217 #[test]
218 fn non_positive_eps_rejected_by_build() {
219 assert!(matches!(
220 SvmParameterBuilder::new().eps(0.0).build(),
221 Err(SvmError::InvalidParameter(_))
222 ));
223 }
224
225 #[test]
226 fn non_positive_cache_size_rejected_by_build() {
227 assert!(matches!(
228 SvmParameterBuilder::new().cache_size(0.0).build(),
229 Err(SvmError::InvalidParameter(_))
230 ));
231 }
232
233 #[test]
234 fn negative_polynomial_degree_rejected_by_build() {
235 assert!(matches!(
236 SvmParameterBuilder::new()
237 .kernel_type(KernelType::Polynomial)
238 .degree(-1)
239 .build(),
240 Err(SvmError::InvalidParameter(_))
241 ));
242 }
243}