1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6#[repr(i32)]
7pub enum SvmType {
8 CSvc = 0,
10 NuSvc = 1,
12 OneClass = 2,
14 EpsilonSvr = 3,
16 NuSvr = 4,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25#[repr(i32)]
26pub enum KernelType {
27 Linear = 0,
29 Polynomial = 1,
31 Rbf = 2,
33 Sigmoid = 3,
35 Precomputed = 4,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct SvmNode {
46 pub index: i32,
49 pub value: f64,
51}
52
53#[derive(Debug, Clone, PartialEq)]
55pub struct SvmProblem {
56 pub labels: Vec<f64>,
58 pub instances: Vec<Vec<SvmNode>>,
60}
61
62#[derive(Debug, Clone, PartialEq)]
66pub struct SvmParameter {
67 pub svm_type: SvmType,
69 pub kernel_type: KernelType,
71 pub degree: i32,
73 pub gamma: f64,
76 pub coef0: f64,
78 pub cache_size: f64,
80 pub eps: f64,
82 pub c: f64,
84 pub weight: Vec<(i32, f64)>,
86 pub nu: f64,
88 pub p: f64,
90 pub shrinking: bool,
92 pub probability: bool,
94}
95
96impl Default for SvmParameter {
97 fn default() -> Self {
98 Self {
99 svm_type: SvmType::CSvc,
100 kernel_type: KernelType::Rbf,
101 degree: 3,
102 gamma: 0.0, coef0: 0.0,
104 cache_size: 100.0,
105 eps: 0.001,
106 c: 1.0,
107 weight: Vec::new(),
108 nu: 0.5,
109 p: 0.1,
110 shrinking: true,
111 probability: false,
112 }
113 }
114}
115
116impl SvmParameter {
117 pub fn validate(&self) -> Result<(), crate::error::SvmError> {
123 use crate::error::SvmError;
124
125 if matches!(
127 self.kernel_type,
128 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
129 ) && self.gamma < 0.0
130 {
131 return Err(SvmError::InvalidParameter("gamma < 0".into()));
132 }
133
134 if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
136 return Err(SvmError::InvalidParameter(
137 "degree of polynomial kernel < 0".into(),
138 ));
139 }
140
141 if self.cache_size <= 0.0 {
142 return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
143 }
144
145 if self.eps <= 0.0 {
146 return Err(SvmError::InvalidParameter("eps <= 0".into()));
147 }
148
149 if matches!(
151 self.svm_type,
152 SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
153 ) && self.c <= 0.0
154 {
155 return Err(SvmError::InvalidParameter("C <= 0".into()));
156 }
157
158 if matches!(
160 self.svm_type,
161 SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
162 ) && (self.nu <= 0.0 || self.nu > 1.0)
163 {
164 return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
165 }
166
167 if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
169 return Err(SvmError::InvalidParameter("p < 0".into()));
170 }
171
172 Ok(())
173 }
174}
175
176pub fn check_parameter(
180 problem: &SvmProblem,
181 param: &SvmParameter,
182) -> Result<(), crate::error::SvmError> {
183 use crate::error::SvmError;
184
185 param.validate()?;
187
188 if param.svm_type == SvmType::NuSvc {
191 let mut class_counts: Vec<(i32, usize)> = Vec::new();
192 for &y in &problem.labels {
193 let label = y as i32;
194 if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
195 entry.1 += 1;
196 } else {
197 class_counts.push((label, 1));
198 }
199 }
200
201 for (i, &(_, n1)) in class_counts.iter().enumerate() {
202 for &(_, n2) in &class_counts[i + 1..] {
203 if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
204 return Err(SvmError::InvalidParameter(
205 "specified nu is infeasible".into(),
206 ));
207 }
208 }
209 }
210 }
211
212 Ok(())
213}
214
215#[derive(Debug, Clone, PartialEq)]
219pub struct SvmModel {
220 pub param: SvmParameter,
222 pub nr_class: usize,
224 pub sv: Vec<Vec<SvmNode>>,
226 pub sv_coef: Vec<Vec<f64>>,
229 pub rho: Vec<f64>,
231 pub prob_a: Vec<f64>,
234 pub prob_b: Vec<f64>,
237 pub prob_density_marks: Vec<f64>,
239 pub sv_indices: Vec<usize>,
241 pub label: Vec<i32>,
243 pub n_sv: Vec<usize>,
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn default_params_are_valid() {
253 SvmParameter::default().validate().unwrap();
254 }
255
256 #[test]
257 fn negative_gamma_rejected() {
258 let p = SvmParameter {
259 gamma: -1.0,
260 ..Default::default()
261 };
262 assert!(p.validate().is_err());
263 }
264
265 #[test]
266 fn zero_cache_rejected() {
267 let p = SvmParameter {
268 cache_size: 0.0,
269 ..Default::default()
270 };
271 assert!(p.validate().is_err());
272 }
273
274 #[test]
275 fn zero_c_rejected() {
276 let p = SvmParameter {
277 c: 0.0,
278 ..Default::default()
279 };
280 assert!(p.validate().is_err());
281 }
282
283 #[test]
284 fn nu_out_of_range_rejected() {
285 let p = SvmParameter {
286 svm_type: SvmType::NuSvc,
287 nu: 1.5,
288 ..Default::default()
289 };
290 assert!(p.validate().is_err());
291
292 let p2 = SvmParameter {
293 svm_type: SvmType::NuSvc,
294 nu: 0.0,
295 ..Default::default()
296 };
297 assert!(p2.validate().is_err());
298 }
299
300 #[test]
301 fn negative_p_rejected_for_svr() {
302 let p = SvmParameter {
303 svm_type: SvmType::EpsilonSvr,
304 p: -0.1,
305 ..Default::default()
306 };
307 assert!(p.validate().is_err());
308 }
309
310 #[test]
311 fn negative_poly_degree_rejected() {
312 let p = SvmParameter {
313 kernel_type: KernelType::Polynomial,
314 degree: -1,
315 ..Default::default()
316 };
317 assert!(p.validate().is_err());
318 }
319
320 #[test]
321 fn nu_svc_feasibility_check() {
322 let problem = SvmProblem {
324 labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
325 instances: vec![vec![]; 6],
326 };
327 let ok_param = SvmParameter {
328 svm_type: SvmType::NuSvc,
329 nu: 0.5,
330 ..Default::default()
331 };
332 check_parameter(&problem, &ok_param).unwrap();
333
334 let borderline = SvmParameter {
336 svm_type: SvmType::NuSvc,
337 nu: 0.9,
338 ..Default::default()
339 };
340 check_parameter(&problem, &borderline).unwrap();
341 }
342
343 #[test]
344 fn nu_svc_infeasible() {
345 let problem = SvmProblem {
347 labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
348 instances: vec![vec![]; 6],
349 };
350 let param = SvmParameter {
351 svm_type: SvmType::NuSvc,
352 nu: 0.5, ..Default::default()
354 };
355 let err = check_parameter(&problem, ¶m);
356 assert!(err.is_err());
357 assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
358 }
359}