1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
//! Super fast online learning using perceptron
//! It allows you to train your model as you go, and use an ensemble of models
//! to make accurate predictions.
//!
//! # Background
//! [Preceptron](https://en.wikipedia.org/wiki/Perceptron) is a
//! relatively simple and fast machine learning algorithm. It has
//! * no hyper parameters (no need to tune),
//! * good generalization properties using ensemble and
//! * overall good classification accuracy.
//!
//!
//! # Usage
//!
//! ```rust
//!    use perceptron::*;
//!
//!    fn and_works() {
//!        let examples = vec!((vec!(-1.0, 1.0), false), (vec!(-1.0, -1.0), false), (vec!(1.0, -1.0), false), (vec!(1.0, 1.0), true));
//!        let perceptron = (1..100).fold(Perceptron::new(2), 
//!				|pepoch, _epoch| examples.iter().fold(pepoch, 
//!				|pexample, example| pexample.train(example.0.clone(), example.1).unwrap()));
//!
//!        println!("{:?}", perceptron);
//!
//!        assert_eq!(perceptron.predict(examples[0].0.clone()), examples[0].1);
//!        assert_eq!(perceptron.predict(examples[1].0.clone()), examples[1].1);
//!        assert_eq!(perceptron.predict(examples[2].0.clone()), examples[2].1);
//!        assert_eq!(perceptron.predict(examples[3].0.clone()), examples[3].1);
//!    }
//! ```
//!

use serde::{Serialize, Deserialize};

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Perceptron {
    weights: Vec<f32>,
    s_weights: Vec<f32>,
    n_examples: u32,
}

impl Perceptron {
    pub fn new(n_feats: usize) -> Perceptron {
        Perceptron{ weights: vec![0.0; n_feats + 1], s_weights: vec![0.0; n_feats + 1], n_examples: 0 }
    }

    pub fn train(&self, example: Vec<f32>, label: bool) -> Option<Perceptron>
    {
        if example.len() + 1 == self.weights.len() {
            let pexample: Vec<f32> = example.into_iter().chain(vec!(1.0).into_iter()).collect();
            let pred: f32 = self.weights.iter().zip(pexample.iter()).map(|(w, e)| w * e).sum();
            let flabel = if label {1.0} else {-1.0};

            if flabel * pred > 0.0 {
                return Some(Perceptron{weights: self.weights.clone(), s_weights: self.s_weights.clone(), n_examples: self.n_examples+1})
            }

            let update_vec: Vec<f32> = pexample.iter().map(|x| flabel * x).collect();
            return Some(Perceptron{weights: self.weights.iter().zip(update_vec.iter()).map(|(a,b)| a+b).collect(),
                                   s_weights: self.s_weights.iter().zip(update_vec.iter()).map(|(a,b)| (a+(self.n_examples as f32)*b).clone()).collect(),
                                   n_examples: self.n_examples + 1})
        }
        None
    }

    pub fn predict(&self, feats: Vec<f32>) -> bool {
        let pfeats: Vec<f32> = feats.into_iter().chain(vec!(1.0).into_iter()).collect();
        let temp: Vec<f32> = self.s_weights.iter().map(|w| (w / (self.n_examples as f32))).collect();
        let weights: Vec<f32> = self.weights.iter().zip(temp.iter()).map(|(a,b)| a-b).collect();
        let pred: f32 = weights.iter().zip(pfeats.iter()).map(|(w, e)| w * e).sum();
        pred >= 0.0
    }
}


#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn and_works() {
        let examples = vec!((vec!(-1.0, 1.0), false), (vec!(-1.0, -1.0), false), (vec!(1.0, -1.0), false), (vec!(1.0, 1.0), true));
        let perceptron = (1..100).fold(Perceptron::new(2), 
				|pepoch, _epoch| examples.iter().fold(pepoch, 
				|pexample, example| pexample.train(example.0.clone(), example.1).unwrap()));

        println!("{:?}", perceptron);

        assert_eq!(perceptron.predict(examples[0].0.clone()), examples[0].1);
        assert_eq!(perceptron.predict(examples[1].0.clone()), examples[1].1);
        assert_eq!(perceptron.predict(examples[2].0.clone()), examples[2].1);
        assert_eq!(perceptron.predict(examples[3].0.clone()), examples[3].1);
    }

    #[test]
    fn or_works() {
        let examples = vec!((vec!(-1.0, 1.0), true), (vec!(-1.0, -1.0), false), (vec!(1.0, -1.0), true), (vec!(1.0, 1.0), true));
        let perceptron = (1..100).fold(Perceptron::new(2), |pepoch, _epoch|
                                examples.iter().fold(pepoch, |pexample, example| pexample.train(example.0.clone(), example.1).unwrap()));

        println!("{:?}", perceptron);

        assert_eq!(perceptron.predict(examples[0].0.clone()), examples[0].1);
        assert_eq!(perceptron.predict(examples[1].0.clone()), examples[1].1);
        assert_eq!(perceptron.predict(examples[2].0.clone()), examples[2].1);
        assert_eq!(perceptron.predict(examples[3].0.clone()), examples[3].1);
    }
}