mathhook_core/calculus/integrals/numerical/
romberg.rs1use super::{IntegrationConfig, IntegrationResult, NumericalIntegrator};
7use crate::error::MathError;
8
9pub struct RombergIntegration {
11 max_order: usize,
12}
13
14impl RombergIntegration {
15 pub fn new(max_order: usize) -> Self {
29 Self { max_order }
30 }
31
32 fn trapezoidal_refinement<F>(&self, f: &F, a: f64, b: f64, k: usize, prev: f64) -> f64
34 where
35 F: Fn(f64) -> f64,
36 {
37 if k == 0 {
38 return 0.5 * (b - a) * (f(a) + f(b));
39 }
40
41 let n = 1_usize << (k - 1);
42 let h = (b - a) / (n as f64);
43
44 let mut sum = 0.0;
45 for i in 0..n {
46 let x = a + (2 * i + 1) as f64 * h / 2.0;
47 sum += f(x);
48 }
49
50 0.5 * prev + h * sum / 2.0
51 }
52
53 fn richardson_extrapolation(&self, tableau: &[Vec<f64>], row: usize, col: usize) -> f64 {
55 let r1 = tableau[row][col - 1];
56 let r2 = tableau[row - 1][col - 1];
57 let power = 4_f64.powi(col as i32);
58
59 (power * r1 - r2) / (power - 1.0)
60 }
61}
62
63impl NumericalIntegrator for RombergIntegration {
64 fn integrate<F>(
65 &self,
66 f: F,
67 a: f64,
68 b: f64,
69 config: &IntegrationConfig,
70 ) -> Result<IntegrationResult, MathError>
71 where
72 F: Fn(f64) -> f64,
73 {
74 if a >= b {
75 return Err(MathError::InvalidInterval { lower: a, upper: b });
76 }
77
78 let max_iterations = self.max_order.min(config.max_iterations);
79 let mut tableau = vec![vec![0.0; max_iterations]; max_iterations];
80
81 tableau[0][0] = self.trapezoidal_refinement(&f, a, b, 0, 0.0);
82
83 for i in 1..max_iterations {
84 tableau[i][0] = self.trapezoidal_refinement(&f, a, b, i, tableau[i - 1][0]);
85
86 for j in 1..=i {
87 tableau[i][j] = self.richardson_extrapolation(&tableau, i, j);
88 }
89
90 if i > 1 {
91 let error = (tableau[i][i] - tableau[i - 1][i - 1]).abs();
92 if error < config.tolerance {
93 return Ok(IntegrationResult {
94 value: tableau[i][i],
95 error_estimate: error,
96 iterations: i + 1,
97 subdivisions: 1 << i,
98 });
99 }
100 }
101 }
102
103 let final_value = tableau[max_iterations - 1][max_iterations - 1];
104 let error_estimate = if max_iterations > 1 {
105 (final_value - tableau[max_iterations - 2][max_iterations - 2]).abs()
106 } else {
107 0.0
108 };
109
110 Ok(IntegrationResult {
111 value: final_value,
112 error_estimate,
113 iterations: max_iterations,
114 subdivisions: 1 << (max_iterations - 1),
115 })
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_romberg_polynomial() {
125 let integrator = RombergIntegration::new(8);
126 let config = IntegrationConfig {
127 tolerance: 1e-12,
128 ..Default::default()
129 };
130
131 let result = integrator.integrate(|x| x * x, 0.0, 1.0, &config).unwrap();
132
133 assert!((result.value - 1.0 / 3.0).abs() < 1e-11);
134 }
135
136 #[test]
137 fn test_romberg_sine() {
138 let integrator = RombergIntegration::new(10);
139 let config = IntegrationConfig {
140 tolerance: 1e-10,
141 ..Default::default()
142 };
143
144 let result = integrator
145 .integrate(|x| x.sin(), 0.0, std::f64::consts::PI, &config)
146 .unwrap();
147
148 assert!((result.value - 2.0).abs() < 1e-9);
149 }
150
151 #[test]
152 fn test_romberg_exponential() {
153 let integrator = RombergIntegration::new(8);
154 let config = IntegrationConfig {
155 tolerance: 1e-12,
156 ..Default::default()
157 };
158
159 let result = integrator
160 .integrate(|x| x.exp(), 0.0, 1.0, &config)
161 .unwrap();
162
163 let expected = std::f64::consts::E - 1.0;
164 assert!((result.value - expected).abs() < 1e-11);
165 }
166
167 #[test]
168 fn test_romberg_convergence() {
169 let integrator = RombergIntegration::new(10);
170 let config = IntegrationConfig {
171 tolerance: 1e-10,
172 ..Default::default()
173 };
174
175 let result = integrator
176 .integrate(|x| x * x * x * x, 0.0, 1.0, &config)
177 .unwrap();
178
179 assert!((result.value - 0.2).abs() < 1e-10);
180 assert!(result.error_estimate < 1e-10);
181 }
182
183 #[test]
184 fn test_romberg_high_accuracy() {
185 let integrator = RombergIntegration::new(12);
186 let config = IntegrationConfig {
187 tolerance: 1e-14,
188 ..Default::default()
189 };
190
191 let result = integrator
192 .integrate(|x| (x * std::f64::consts::PI).cos(), 0.0, 1.0, &config)
193 .unwrap();
194
195 let expected = 1.0 / std::f64::consts::PI * (std::f64::consts::PI).sin();
196 assert!((result.value - expected).abs() < 1e-12);
197 }
198
199 #[test]
200 fn test_romberg_invalid_interval() {
201 let integrator = RombergIntegration::new(8);
202 let config = IntegrationConfig::default();
203
204 let result = integrator.integrate(|x| x, 1.0, 0.0, &config);
205 assert!(result.is_err());
206 }
207}