least_squares/
least_squares.rs1use trellis_runner::{
2 CancellationGuard, GenerateBuilder, MaxIterationPolicy, Procedure, Progress,
3 ProgressDiagnostics, StagnationPolicy, UserState,
4};
5
6#[derive(Clone)]
7pub struct LinearRegressionProblem {
8 pub data: Vec<(f64, f64)>,
9}
10
11#[derive(Clone, Debug)]
12pub struct LSState {
13 a: f64,
14 b: f64,
15 loss: f64,
16}
17
18impl Default for LSState {
19 fn default() -> Self {
20 Self {
21 a: 0.0,
22 b: 0.0,
23 loss: f64::INFINITY,
24 }
25 }
26}
27
28impl UserState for LSState {
29 type Float = f64;
30
31 fn progress(&self) -> Progress<Self::Float> {
32 Progress::Report {
33 measure: self.loss,
34 diagnostics: ProgressDiagnostics {
35 gradient_norm: Some((self.a.powi(2) + self.b.powi(2)).sqrt()),
36 step_size: Some(0.01),
37 ..Default::default()
38 },
39 }
40 }
41}
42
43pub struct LeastSquares;
45
46impl Procedure<LinearRegressionProblem> for LeastSquares {
47 type State = LSState;
48 type Output = (f64, f64);
49
50 const NAME: &'static str = "Least Squares Optimisation";
51
52 fn step(
53 &self,
54 problem: &mut LinearRegressionProblem,
55 state: &mut Self::State,
56 _guard: CancellationGuard<'_>,
57 ) {
58 let lr = 0.01;
59 let mut da = 0.0;
60 let mut db = 0.0;
61 let mut loss = 0.0;
62
63 for (x, y) in &problem.data {
64 let pred = state.a * *x + state.b;
65 let err = pred - *y;
66
67 loss += err * err;
68 da += err * *x;
69 db += err;
70 }
71
72 state.a -= lr * da;
73 state.b -= lr * db;
74 state.loss = loss;
75 }
76
77 fn finalise(&self, _: &mut LinearRegressionProblem, state: &Self::State) -> Self::Output {
78 (state.a, state.b)
79 }
80}
81
82fn main() {
83 let problem = LinearRegressionProblem {
84 data: vec![(1.0, 2.0), (2.0, 4.0), (3.0, 6.0)],
85 };
86
87 let result = LeastSquares
88 .build_for(problem)
89 .with_initial_state(LSState::default())
90 .and_policy(MaxIterationPolicy::new(3000))
91 .and_policy(StagnationPolicy::new(10))
92 .finalise()
93 .run();
94
95 println!("fit: {:?}", result);
96}