RustQuant_math/rootfinding/
brent.rs

1// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2// RustQuant: A Rust library for quantitative finance tools.
3// Copyright (C) 2024 https://github.com/avhz
4// Dual licensed under Apache 2.0 and MIT.
5// See:
6//      - LICENSE-APACHE.md
7//      - LICENSE-MIT.md
8// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9
10use crate::rootfinder::{Rootfinder, RootfinderData};
11
12/// Brent root-finding algorithm.
13pub struct Brent<F>
14where
15    F: Fn(f64) -> f64,
16{
17    function: F,
18    guess: f64,
19    data: RootfinderData,
20}
21
22impl<F> Brent<F>
23where
24    F: Fn(f64) -> f64,
25{
26    /// Create a new Brent solver.
27    pub fn new(function: F, guess: f64, data: RootfinderData) -> Self {
28        Self {
29            function,
30            guess,
31            data,
32        }
33    }
34}
35
36impl<F> Rootfinder<F> for Brent<F>
37where
38    F: Fn(f64) -> f64,
39{
40    fn value(&self, x: f64) -> f64 {
41        (self.function)(x)
42    }
43
44    fn derivative(&self, _: f64) -> f64 {
45        0.0
46    }
47
48    fn solve_impl(&mut self) -> f64 {
49        let mut min1: f64;
50        let mut min2: f64;
51
52        let mut p: f64;
53        let mut q: f64;
54        let mut r: f64;
55        let mut s: f64;
56        let mut x_acc1: f64;
57        let mut x_mid: f64;
58
59        let mut d: f64 = 0.0;
60        let mut e: f64 = 0.0;
61
62        // let mut root = self.data.x_max;
63        self.data.root = self.data.x_max;
64        let mut froot = self.data.y_max;
65
66        while self.data.iteration_count <= Self::MAX_ITERATIONS {
67            if (froot > 0.0 && self.data.y_max > 0.0) || (froot < 0.0 && self.data.y_max < 0.0) {
68                self.data.x_max = self.data.x_min;
69                self.data.y_max = self.data.y_min;
70                e = self.data.root - self.data.x_min;
71                d = e;
72            }
73
74            if self.data.y_max.abs() < froot.abs() {
75                // Adjust x's
76                self.data.x_min = self.data.root;
77                self.data.root = self.data.x_max;
78                self.data.x_max = self.data.x_min;
79
80                // Adjust f(x)'s
81                self.data.y_min = froot;
82                froot = self.data.y_max;
83                self.data.y_max = self.data.y_min;
84            }
85
86            x_acc1 = 2.0 * f64::EPSILON * self.data.root.abs() + 0.5 * self.data.accuracy;
87            x_mid = 0.5 * (self.data.x_max - self.data.root);
88
89            // if x_mid.abs() <= x_acc1 || close(froot, 0.0) {
90            if x_mid.abs() <= x_acc1 || RootfinderData::close(froot, 0.0) {
91                return self.data.root;
92            }
93
94            if e.abs() >= x_acc1 && self.data.y_min > froot.abs() {
95                s = froot / self.data.y_min;
96
97                if RootfinderData::close(self.data.x_min, self.data.x_max) {
98                    p = 2.0 * x_mid * s;
99                    q = 1.0 - s;
100                } else {
101                    q = self.data.y_min / self.data.y_max;
102                    r = froot / self.data.y_max;
103                    p = s
104                        * (2.0 * x_mid * q * (q - r)
105                            - (self.data.root - self.data.x_min) * (r - 1.0));
106                    q = (q - 1.0) * (r - 1.0) * (s - 1.0);
107                }
108
109                if p > 0.0 {
110                    q = -q;
111                }
112
113                p = p.abs();
114
115                min1 = 3.0 * x_mid * q - (x_acc1 * q).abs();
116                min2 = (e * q).abs();
117
118                let ternary = if min1 < min2 { min1 } else { min2 };
119
120                if 2.0 * p < ternary {
121                    e = d;
122                    d = p / q;
123                } else {
124                    d = x_mid;
125                    e = d;
126                }
127            } else {
128                d = x_mid;
129                e = d;
130            }
131
132            self.data.x_min = self.data.root;
133            self.data.y_min = froot;
134
135            if d.abs() > x_acc1 {
136                self.data.root += d;
137            } else {
138                self.data.root += RootfinderData::nrsign(x_acc1, x_mid); // x_acc1.abs() * x_mid.signum();
139            }
140            froot = self.value(self.data.root);
141            self.data.increment_evaluation_count();
142        }
143
144        0.0
145    }
146
147    fn solve(&mut self) -> f64 {
148        assert!(self.data.accuracy > 0., "accuracy must be positive");
149
150        self.data.accuracy = f64::max(self.data.accuracy, f64::EPSILON);
151
152        let growth_factor = 1.6;
153        let mut flipflop = -1;
154
155        self.data.root = self.guess;
156        self.data.y_max = self.value(self.data.root);
157
158        if RootfinderData::close(self.data.y_max, 0.0) {
159            return self.data.root;
160        } else if self.data.y_max > 0.0 {
161            self.data.x_min = self
162                .data
163                .enforce_bounds(self.data.root - self.data.stepsize);
164            self.data.y_min = self.value(self.data.x_min);
165            self.data.x_max = self.data.root;
166        } else {
167            self.data.x_min = self.data.root;
168            self.data.y_min = self.data.y_max;
169            self.data.x_max = self
170                .data
171                .enforce_bounds(self.data.root + self.data.stepsize);
172            self.data.y_max = self.value(self.data.x_max);
173        }
174
175        self.data.iteration_count = 2;
176
177        while self.data.iteration_count <= Self::MAX_ITERATIONS {
178            // Check if we can solve.
179            if self.data.y_min * self.data.y_max <= 0.0 {
180                if RootfinderData::close(self.data.y_min, 0.0) {
181                    return self.data.x_min;
182                }
183                if RootfinderData::close(self.data.y_max, 0.0) {
184                    return self.data.x_max;
185                }
186                self.data.root = 0.5 * (self.data.x_max + self.data.x_min);
187
188                return self.solve_impl();
189            }
190
191            // If we can't solve, adjust.
192            if self.data.y_min.abs() < self.data.y_max.abs() {
193                self.data.x_min = self.data.enforce_bounds(
194                    self.data.x_min + growth_factor * (self.data.x_min - self.data.x_max),
195                );
196                self.data.y_min = self.value(self.data.x_min);
197            } else if self.data.y_min.abs() > self.data.y_max.abs() {
198                self.data.x_max = self.data.enforce_bounds(
199                    self.data.x_max + growth_factor * (self.data.x_max - self.data.x_min),
200                );
201                self.data.y_max = self.value(self.data.x_max);
202            } else if flipflop == -1 {
203                self.data.x_min = self.data.enforce_bounds(
204                    self.data.x_min + growth_factor * (self.data.x_min - self.data.x_max),
205                );
206                self.data.y_min = self.value(self.data.x_min);
207                self.data.increment_evaluation_count();
208                flipflop = 1;
209            } else if flipflop == 1 {
210                self.data.x_max = self.data.enforce_bounds(
211                    self.data.x_max + growth_factor * (self.data.x_max - self.data.x_min),
212                );
213                self.data.y_max = self.value(self.data.x_max);
214                flipflop = -1;
215            }
216
217            self.data.increment_evaluation_count();
218        }
219
220        0.0
221    }
222}
223
224#[cfg(test)]
225mod TESTS_brent_solver {
226
227    use super::*;
228    use std::f64::consts::SQRT_2;
229
230    #[test]
231    fn test_brent_solver() {
232        // Define the objective function:
233        // f(x) = x^2 - 2
234        let f = |x: f64| x.powi(2) - 2.0;
235
236        // Create a new Brent solver with:
237        //      - Objective function: f(x) = x^2 - 2
238        //      - Initial guess: 1.0
239        //      - Root-finder data:
240        //          - Accuracy: 1e-15
241        //          - Step size: 1e-5
242        //          - Lower bound: 0.0
243        //          - Upper bound: 2.0
244        //          - Interval enforced: true
245        let data = RootfinderData::new(1e-15, 1e-5, 0.0, 2.0, true);
246        let mut solver = Brent::new(f, 1.0, data);
247        let root = solver.solve();
248        assert!((root - SQRT_2) < 1e-15);
249
250        // let n = 1_000_000;
251        // let start = std::time::Instant::now();
252        // for _ in 0..n {
253        //     solver.solve();
254        // }
255        // let duration = start.elapsed();
256        // // Takes about 1.235926167s on MacBook Air M2
257        // println!("Brent: {} solutions took: {:?}", n, duration);
258        // println!("Solution: {}", solver.data.root);
259        // println!("Expected: {}", SQRT_2);
260    }
261
262    #[test]
263    fn test_implied_volatility() {
264        use errorfunctions::RealErrorFunctions;
265
266        let phi = |x: f64| 0.5 * (1.0 + (x / SQRT_2).erf());
267
268        let d1 = |s: f64, k: f64, v: f64, r: f64, q: f64, t: f64| {
269            ((s / k).ln() + (r - q + (v * v) / 2.0) * t) / (v * t.sqrt())
270        };
271
272        let d2 =
273            |s: f64, k: f64, v: f64, r: f64, q: f64, t: f64| d1(s, k, v, r, q, t) - v * t.sqrt();
274
275        let black_scholes_call = |s: f64, k: f64, v: f64, r: f64, q: f64, t: f64| {
276            let d1 = d1(s, k, v, r, q, t);
277            let d2 = d2(s, k, v, r, q, t);
278
279            s * (-q * t).exp() * phi(d1) - k * (-r * t).exp() * phi(d2)
280        };
281
282        println!(
283            "Black-Scholes Call: {}",
284            black_scholes_call(100.0, 100.0, 0.2, 0.05, 0.0, 1.0)
285        );
286
287        let price = 10.4505835721856;
288        let expected_vol = 0.2;
289
290        // f(x) = Black-Scholes-Call(vol) - price
291        let f = |v: f64| black_scholes_call(100.0, 100.0, v, 0.05, 0.0, 1.0) - price;
292
293        let mut solver = Brent::new(f, 0.5, RootfinderData::default());
294        let root = solver.solve();
295        assert!((root - expected_vol).abs() < 1e-10, "Impl. Vol.: {}", root);
296
297        // println!("Implied Volatility: {}", root);
298        // println!("Expected: {}", expected_vol);
299        // assert!(false)
300    }
301}