use crate::prelude::*;
use std::iter::zip;
pub struct Perceptron {
eta: f64, epochs: u32,
weights: Vec<f64>,
bias: f64,
labels_as_nums: Vec<i32>,
original_labels: Vec<&'static str>,
unique_labels_as_nums: Vec<i32>,
unique_original_labels: Vec<&'static str>,
}
impl Perceptron {
pub fn new(eta: f64, epochs: u32, x_features: usize) -> Self {
let mut weights: Vec<f64> = Vec::new();
for _ in 1..=x_features {
weights.push(0.0);
}
let new: Perceptron = Perceptron {
eta: eta,
epochs: epochs,
weights: weights,
bias: 0.0,
labels_as_nums: Vec::new(),
original_labels: Vec::new(),
unique_labels_as_nums: Vec::new(),
unique_original_labels: Vec::new(),
};
new
}
pub fn fit<X, Y>(&mut self, x: &Vec<Vec<X>>, y: &Vec<Y>)
where
X: Copy + Into<f64>,
Y: Copy + Eq + Into<&'static str>,
{
let mut on_item: &'static str = y[0].into();
let mut count: i32 = 0;
self.unique_labels_as_nums.push(count);
self.unique_original_labels.push(on_item);
for i in y {
if (*i).into() == on_item {
self.labels_as_nums.push(count);
self.original_labels.push(on_item)
}
else {
count += 1;
on_item = (*i).into();
self.labels_as_nums.push(count);
self.original_labels.push(on_item);
self.unique_labels_as_nums.push(count);
self.unique_original_labels.push(on_item);
}
}
for _ in 1..=self.epochs as i32 {
for (xi, target) in zip(x, &self.labels_as_nums) {
let update: f64 = self.eta * (target - self.predict_num(xi)) as f64;
self.weights = h_vector_add(&self.weights,
&xi.h_vector_scalar_mult(update));
self.bias += update;
}
}
}
pub fn net_input<X>(&self, x_row: &Vec<X>) -> f64
where
X: Copy + Into<f64>,
{
return h_dot(&x_row, &self.weights) + self.bias;
}
pub fn predict_num<X>(&self, x_row: &Vec<X>) -> i32
where
X: Copy + Into<f64>,
{
if self.net_input(x_row) >= 0.0 {
1
} else {
0
}
}
pub fn predict<X>(&self, x_row: &Vec<X>) -> &'static str
where
X: Copy + Into<f64>,
{
if self.net_input(x_row) >= 0.0 {
return self.unique_original_labels[1];
} else {
return self.unique_original_labels[0];
}
}
pub fn predict_multiple<X>(&self, x_rows: &Vec<Vec<X>>) -> Vec<&'static str>
where
X: Copy + Into<f64>,
{
let mut answers: Vec<&'static str> = Vec::new();
for row in x_rows {
if self.net_input(row) >= 0.0 {
answers.push(self.unique_original_labels[1]);
} else {
answers.push(self.unique_original_labels[0]);
}
}
return answers;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn perceptron_learns_iris_subset() {
let x: Vec<Vec<f64>> = vec![
vec![5.1, 3.5],
vec![4.9, 3.0],
vec![4.7, 3.2],
vec![4.6, 3.1],
vec![5.0, 3.6],
vec![5.4, 3.9],
vec![4.6, 3.4],
vec![5.0, 3.4],
vec![4.4, 2.9],
vec![4.9, 3.1],
vec![5.4, 3.7],
vec![4.8, 3.4],
vec![4.8, 3.0],
vec![4.3, 3.0],
vec![5.8, 4.0],
vec![7.0, 3.2],
vec![6.4, 3.2],
vec![6.9, 3.1],
vec![5.5, 2.3],
vec![6.5, 2.8],
vec![5.7, 2.8],
vec![6.3, 3.3],
vec![4.9, 2.4],
vec![6.6, 2.9],
vec![5.2, 2.7],
vec![5.0, 2.0],
vec![5.9, 3.0],
vec![6.0, 2.2],
vec![6.1, 2.9],
vec![5.6, 2.9],
];
let y: Vec<&'static str> = vec![
"setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", ];
let mut percep = Perceptron::new(0.1, 100, 2);
percep.fit(&x, &y);
for (x, target) in zip(&x, &y) {
assert_eq!(target, &percep.predict(x));
}
assert_eq!(&percep.predict_multiple(&x), &y);
}
}