Skip to main content

integration/
integration.rs

1use trellis_runner::{
2    CancellationGuard, GenerateBuilder, MaxIterationPolicy, Procedure, Progress, StagnationPolicy,
3    UserState,
4};
5
6#[derive(Clone)]
7pub struct QuadratureProblem {
8    pub a: f64,
9    pub b: f64,
10    pub f: fn(f64) -> f64,
11}
12
13#[derive(Default, Clone, Debug)]
14pub struct TrapezoidalState {
15    n: usize,
16    estimate: f64,
17}
18
19impl UserState for TrapezoidalState {
20    type Float = f64;
21
22    fn progress(&self) -> Progress<Self::Float> {
23        Progress::Measure(self.estimate)
24    }
25}
26
27/// Integrate f(x) = x^2 over [0, 1] using trapezoidal refinement
28pub struct TrapezoidalIntegration;
29
30impl Procedure<QuadratureProblem> for TrapezoidalIntegration {
31    type State = TrapezoidalState;
32
33    type Output = f64;
34
35    const NAME: &'static str = "Trapezoidal Integrator";
36
37    fn step(
38        &self,
39        problem: &mut QuadratureProblem,
40        state: &mut Self::State,
41        _guard: CancellationGuard<'_>,
42    ) {
43        state.n += 1;
44
45        let h = (problem.b - problem.a) / state.n as f64;
46
47        let mut sum = 0.0;
48
49        for i in 0..state.n {
50            let x0 = problem.a + i as f64 * h;
51            let x1 = problem.a + (i + 1) as f64 * h;
52
53            let f0 = (problem.f)(x0);
54            let f1 = (problem.f)(x1);
55
56            sum += 0.5 * (f0 + f1) * h;
57        }
58
59        state.estimate = sum;
60    }
61
62    fn finalise(&self, _: &mut QuadratureProblem, state: &Self::State) -> Self::Output {
63        state.estimate
64    }
65}
66
67fn main() {
68    let problem = QuadratureProblem {
69        a: 0.0,
70        b: 1.0,
71        f: |x| x * x,
72    };
73    let result = TrapezoidalIntegration
74        .build_for(problem)
75        .with_initial_state(TrapezoidalState::default())
76        .and_policy(MaxIterationPolicy::new(3000))
77        .and_policy(StagnationPolicy::new(10))
78        .finalise()
79        .run();
80
81    println!("{result:?}");
82}