use mathru::algebra::linear::{Vector, Matrix};
use mathru::statistics::distrib::{Normal, Uniform, Distribution};
use mathru::optimization::Gradient;
use tinguely::classification::LogisticRegression;
use tinguely::{SupervisedLearn};
use plotters::prelude::*;
fn generate_data(a: f64, b: f64, samples: usize, sigma: f64) -> (Matrix<f64>, Vector<f64>) {
assert!(samples > 2, "Number of samples must be greater or equal to two.");
assert!(sigma >= 0.0f64, "Noise must be non-negative.");
let uniform: Uniform<f64> = Uniform::new(0.0, 5.0);
let normal: Normal<f64> = Normal::new(0.0f64, sigma);
let mut x_1_vec: Vec<f64> = Vec::with_capacity(samples);
let mut x_2_vec: Vec<f64> = Vec::with_capacity(samples);
for _ in 0..samples
{
let x: f64 = uniform.random();
let e: f64 = normal.random();
let y: f64 = a + b * x + e;
x_1_vec.push(x);
x_2_vec.push(y);
}
let x_1: Matrix<f64> = Matrix::new(samples, 1, x_1_vec);
let mut x_2: Vector<f64> = Vector::new_column(samples, x_2_vec);
x_2 = x_2.apply(
&|y| -> f64
{
return if *y > 2.5 {
1.0
} else {
0.0
};
});
return (x_1, x_2);
}
fn main()
{
let (x, y): (Matrix<f64>, Vector<f64>) = generate_data(1.0, 3.0 / 5.0, 200, 0.05);
let optimizer: Gradient<f64> = Gradient::new(0.2, 100);
let mut model: LogisticRegression<f64> = LogisticRegression::new(optimizer);
model.train(&x, &y);
let y_hat: Vector<f64> = model.predict(&x).unwrap();
let root_area = BitMapBackend::new("./figures/logistic_regression.png", (600, 400)).into_drawing_area();
root_area.fill(&WHITE).unwrap();
let mut ctx = ChartBuilder::on(&root_area)
.margin(20)
.set_label_area_size(LabelAreaPosition::Left, 40)
.set_label_area_size(LabelAreaPosition::Bottom, 40)
.build_cartesian_2d(0.0..5.0, -0.2..1.2)
.unwrap();
ctx.configure_mesh()
.x_desc("x")
.y_desc("y")
.axis_desc_style(("sans-serif", 15).into_font())
.draw()
.unwrap();
ctx.draw_series(
x.row_into_iter().zip(y_hat.iter())
.map(|(x, y)|
{
Cross::new((*x.get(0), *y), 2, RED.filled())
}
)
).unwrap();
ctx.draw_series(
x.row_into_iter().zip(y.iter())
.map(|(x, y)|
{
Cross::new((*x.get(0), *y), 2, BLUE.filled())
}
)
).unwrap();
}