pub struct SupportVectorMachine { /* private fields */ }
Implementations§
Source§impl SupportVectorMachine
impl SupportVectorMachine
Sourcepub fn new(
class_num: usize,
feature_num: usize,
lr: f32,
bsz: Option<usize>,
lambda: f32,
) -> Self
pub fn new( class_num: usize, feature_num: usize, lr: f32, bsz: Option<usize>, lambda: f32, ) -> Self
Support Vector Machine for binary classification
only supports linear kernel, using SGD to optimize hinge loss
update = lambda * (regularization(w) or marginal width or so called structure loss) + 1/batch * Hinge Loss
- bsz: default will be len(dataset)
Examples found in repository?
examples/05_svm.rs (line 22)
6fn main() -> std::io::Result<()> {
7 let datas = vec![
8 vec![1.0, 3.0],
9 vec![2.0, 3.0],
10 vec![1.0, 2.0],
11 vec![4.0, 0.0],
12 vec![3.0, 0.0],
13 vec![3.0, -1.0],
14 vec![3.0, 0.5],
15 ];
16 let labels = vec![0usize, 0, 0, 1, 1, 1, 1];
17 let mut label_map = HashMap::new();
18 label_map.insert(0, "up left".to_string());
19 label_map.insert(1, "down right".to_string());
20 let dataset = Dataset::new(datas, labels, Some(label_map));
21
22 let mut model = SupportVectorMachine::new(2, dataset.feature_len(), 1e-1, None, 1.0);
23 model.train(dataset.clone(), 1000, SVMLoss::Hinge, true, true)?;
24
25 let (correct, acc) = evaluate(&dataset, &model);
26 println!("correct {correct}/{} acc {acc}", dataset.len());
27
28 Ok(())
29}
Sourcepub fn train(
&mut self,
dataset: Dataset<usize>,
epoch: usize,
loss: SVMLoss,
early_stop: bool,
verbose: bool,
) -> Result<()>
pub fn train( &mut self, dataset: Dataset<usize>, epoch: usize, loss: SVMLoss, early_stop: bool, verbose: bool, ) -> Result<()>
train the SVM model
- verbose: whether show the training info
Examples found in repository?
examples/05_svm.rs (line 23)
6fn main() -> std::io::Result<()> {
7 let datas = vec![
8 vec![1.0, 3.0],
9 vec![2.0, 3.0],
10 vec![1.0, 2.0],
11 vec![4.0, 0.0],
12 vec![3.0, 0.0],
13 vec![3.0, -1.0],
14 vec![3.0, 0.5],
15 ];
16 let labels = vec![0usize, 0, 0, 1, 1, 1, 1];
17 let mut label_map = HashMap::new();
18 label_map.insert(0, "up left".to_string());
19 label_map.insert(1, "down right".to_string());
20 let dataset = Dataset::new(datas, labels, Some(label_map));
21
22 let mut model = SupportVectorMachine::new(2, dataset.feature_len(), 1e-1, None, 1.0);
23 model.train(dataset.clone(), 1000, SVMLoss::Hinge, true, true)?;
24
25 let (correct, acc) = evaluate(&dataset, &model);
26 println!("correct {correct}/{} acc {acc}", dataset.len());
27
28 Ok(())
29}
Trait Implementations§
Auto Trait Implementations§
impl Freeze for SupportVectorMachine
impl RefUnwindSafe for SupportVectorMachine
impl Send for SupportVectorMachine
impl Sync for SupportVectorMachine
impl Unpin for SupportVectorMachine
impl UnwindSafe for SupportVectorMachine
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more