convex_math/interpolation/
flat_forward.rs

1//! Flat forward interpolation.
2//!
3//! Flat forward interpolation assumes constant forward rates between pillar points.
4//! This is a common choice for yield curve construction as it:
5//! - Guarantees positive forward rates (if zero rates are positive)
6//! - Produces step-function forward rate curves
7//! - Is computationally efficient
8//!
9//! # Mathematical Background
10//!
11//! Given zero rates r(t) at pillar points t_i, the forward rate f_i between
12//! t_i and t_{i+1} is:
13//!
14//! ```text
15//! f_i = (r_{i+1} * t_{i+1} - r_i * t_i) / (t_{i+1} - t_i)
16//! ```
17//!
18//! For t between t_i and t_{i+1}, the interpolated zero rate is:
19//!
20//! ```text
21//! r(t) = (r_i * t_i + f_i * (t - t_i)) / t
22//! ```
23//!
24//! This ensures the forward rate is constant (flat) within each segment.
25
26use crate::error::{MathError, MathResult};
27use crate::interpolation::Interpolator;
28
29/// Flat forward interpolation for zero rate curves.
30///
31/// Interpolates zero rates such that forward rates are constant (flat)
32/// between pillar points. This produces a step-function forward curve.
33///
34/// # Example
35///
36/// ```rust
37/// use convex_math::interpolation::{FlatForward, Interpolator};
38///
39/// // Zero rates at 1Y, 2Y, 5Y, 10Y
40/// let tenors = vec![1.0, 2.0, 5.0, 10.0];
41/// let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
42///
43/// let interp = FlatForward::new(tenors, zero_rates).unwrap();
44///
45/// // Interpolate at 3Y - forward rate is flat between 2Y and 5Y
46/// let rate_3y = interp.interpolate(3.0).unwrap();
47/// ```
48#[derive(Debug, Clone)]
49pub struct FlatForward {
50    /// Tenors (time points) in years
51    tenors: Vec<f64>,
52    /// Zero rates at each tenor
53    zero_rates: Vec<f64>,
54    /// Pre-computed forward rates for each segment
55    forward_rates: Vec<f64>,
56    /// Allow extrapolation beyond data range
57    allow_extrapolation: bool,
58}
59
60impl FlatForward {
61    /// Creates a new flat forward interpolator from zero rates.
62    ///
63    /// # Arguments
64    ///
65    /// * `tenors` - Time points in years (must be strictly increasing, > 0)
66    /// * `zero_rates` - Zero rates at each tenor (as decimals, e.g., 0.05 for 5%)
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if:
71    /// - Fewer than 2 points are provided
72    /// - Tenors and zero_rates have different lengths
73    /// - Tenors are not strictly increasing
74    /// - Any tenor is <= 0
75    pub fn new(tenors: Vec<f64>, zero_rates: Vec<f64>) -> MathResult<Self> {
76        if tenors.len() < 2 {
77            return Err(MathError::insufficient_data(2, tenors.len()));
78        }
79        if tenors.len() != zero_rates.len() {
80            return Err(MathError::invalid_input(format!(
81                "tenors and zero_rates must have same length: {} vs {}",
82                tenors.len(),
83                zero_rates.len()
84            )));
85        }
86
87        // Check that tenors are positive and strictly increasing
88        if tenors[0] <= 0.0 {
89            return Err(MathError::invalid_input(
90                "First tenor must be positive for flat forward interpolation",
91            ));
92        }
93        for i in 1..tenors.len() {
94            if tenors[i] <= tenors[i - 1] {
95                return Err(MathError::invalid_input(
96                    "Tenors must be strictly increasing",
97                ));
98            }
99        }
100
101        // Pre-compute forward rates for each segment
102        let forward_rates = Self::compute_forward_rates(&tenors, &zero_rates);
103
104        Ok(Self {
105            tenors,
106            zero_rates,
107            forward_rates,
108            allow_extrapolation: false,
109        })
110    }
111
112    /// Creates a flat forward interpolator with an initial point at t=0.
113    ///
114    /// This variant allows interpolation from t=0 by assuming the first
115    /// zero rate extends back to the origin.
116    ///
117    /// # Arguments
118    ///
119    /// * `tenors` - Time points in years (must be strictly increasing, >= 0)
120    /// * `zero_rates` - Zero rates at each tenor
121    pub fn with_origin(mut tenors: Vec<f64>, mut zero_rates: Vec<f64>) -> MathResult<Self> {
122        if tenors.is_empty() {
123            return Err(MathError::insufficient_data(1, 0));
124        }
125
126        // If first tenor is not 0, prepend origin point
127        if tenors[0] > 0.0 {
128            // Use first zero rate for the origin (flat from 0 to first pillar)
129            tenors.insert(0, 0.0);
130            zero_rates.insert(0, zero_rates[0]);
131        }
132
133        // Now call the standard constructor with tenors starting at 0
134        // We need to handle t=0 specially
135        if tenors.len() < 2 {
136            return Err(MathError::insufficient_data(2, tenors.len()));
137        }
138        if tenors.len() != zero_rates.len() {
139            return Err(MathError::invalid_input(format!(
140                "tenors and zero_rates must have same length: {} vs {}",
141                tenors.len(),
142                zero_rates.len()
143            )));
144        }
145
146        for i in 1..tenors.len() {
147            if tenors[i] <= tenors[i - 1] {
148                return Err(MathError::invalid_input(
149                    "Tenors must be strictly increasing",
150                ));
151            }
152        }
153
154        let forward_rates = Self::compute_forward_rates(&tenors, &zero_rates);
155
156        Ok(Self {
157            tenors,
158            zero_rates,
159            forward_rates,
160            allow_extrapolation: false,
161        })
162    }
163
164    /// Enables extrapolation beyond the data range.
165    ///
166    /// When extrapolating:
167    /// - Below first tenor: uses first forward rate
168    /// - Above last tenor: uses last forward rate (flat forward extension)
169    #[must_use]
170    pub fn with_extrapolation(mut self) -> Self {
171        self.allow_extrapolation = true;
172        self
173    }
174
175    /// Computes forward rates for each segment.
176    ///
177    /// Forward rate from t_i to t_{i+1}:
178    /// f_i = (r_{i+1} * t_{i+1} - r_i * t_i) / (t_{i+1} - t_i)
179    fn compute_forward_rates(tenors: &[f64], zero_rates: &[f64]) -> Vec<f64> {
180        let n = tenors.len();
181        let mut forwards = Vec::with_capacity(n);
182
183        for i in 0..n - 1 {
184            let t0 = tenors[i];
185            let t1 = tenors[i + 1];
186            let r0 = zero_rates[i];
187            let r1 = zero_rates[i + 1];
188
189            // Handle t0 = 0 case
190            let fwd = if t0 == 0.0 {
191                // Forward from 0 to t1 is just r1 (since r0 * 0 = 0)
192                r1
193            } else {
194                (r1 * t1 - r0 * t0) / (t1 - t0)
195            };
196
197            forwards.push(fwd);
198        }
199
200        // Last segment uses flat forward extension
201        if !forwards.is_empty() {
202            forwards.push(*forwards.last().unwrap());
203        } else {
204            forwards.push(zero_rates[0]);
205        }
206
207        forwards
208    }
209
210    /// Finds the segment index for a given tenor.
211    ///
212    /// Returns i such that tenors[i] <= t < tenors[i+1]
213    fn find_segment(&self, t: f64) -> usize {
214        match self
215            .tenors
216            .binary_search_by(|probe| probe.partial_cmp(&t).unwrap_or(std::cmp::Ordering::Equal))
217        {
218            Ok(i) => i.min(self.tenors.len() - 2),
219            Err(i) => (i.saturating_sub(1)).min(self.tenors.len() - 2),
220        }
221    }
222
223    /// Returns the forward rate at tenor t.
224    ///
225    /// Since forward rates are flat between pillars, this returns the
226    /// constant forward rate for the segment containing t.
227    pub fn forward_rate(&self, t: f64) -> MathResult<f64> {
228        if !self.allow_extrapolation && (t < self.tenors[0] || t > *self.tenors.last().unwrap()) {
229            return Err(MathError::ExtrapolationNotAllowed {
230                x: t,
231                min: self.tenors[0],
232                max: *self.tenors.last().unwrap(),
233            });
234        }
235
236        let i = self.find_segment(t);
237        Ok(self.forward_rates[i])
238    }
239
240    /// Returns the tenors.
241    pub fn tenors(&self) -> &[f64] {
242        &self.tenors
243    }
244
245    /// Returns the zero rates.
246    pub fn zero_rates(&self) -> &[f64] {
247        &self.zero_rates
248    }
249
250    /// Returns the pre-computed forward rates.
251    pub fn forward_rates_vec(&self) -> &[f64] {
252        &self.forward_rates
253    }
254}
255
256impl Interpolator for FlatForward {
257    fn interpolate(&self, t: f64) -> MathResult<f64> {
258        let min_t = self.tenors[0];
259        let max_t = *self.tenors.last().unwrap();
260
261        // Check bounds
262        if !self.allow_extrapolation && (t < min_t || t > max_t) {
263            return Err(MathError::ExtrapolationNotAllowed {
264                x: t,
265                min: min_t,
266                max: max_t,
267            });
268        }
269
270        // Handle t <= 0 edge case
271        if t <= 0.0 {
272            // Return first zero rate for t=0 or negative (if extrapolating)
273            return Ok(self.zero_rates[0]);
274        }
275
276        // Handle exact pillar hits
277        if let Some(idx) = self.tenors.iter().position(|&x| (x - t).abs() < 1e-12) {
278            return Ok(self.zero_rates[idx]);
279        }
280
281        // Handle extrapolation below first pillar
282        if t < min_t {
283            // Use first forward rate to extrapolate backward
284            // r(t) = r_0 * t_0 / t + f_0 * (t - 0) / t = (r_0 * t_0 + f_0 * t) / t
285            // But simpler: just use first zero rate (flat backward)
286            return Ok(self.zero_rates[0]);
287        }
288
289        // Handle extrapolation above last pillar
290        if t > max_t {
291            // Flat forward extension: use last forward rate
292            let n = self.tenors.len();
293            let t_n = self.tenors[n - 1];
294            let r_n = self.zero_rates[n - 1];
295            let f_n = self.forward_rates[n - 1];
296
297            // r(t) = (r_n * t_n + f_n * (t - t_n)) / t
298            return Ok((r_n * t_n + f_n * (t - t_n)) / t);
299        }
300
301        // Normal interpolation between pillars
302        let i = self.find_segment(t);
303        let t_i = self.tenors[i];
304        let r_i = self.zero_rates[i];
305        let f_i = self.forward_rates[i];
306
307        // r(t) = (r_i * t_i + f_i * (t - t_i)) / t
308        Ok((r_i * t_i + f_i * (t - t_i)) / t)
309    }
310
311    fn derivative(&self, t: f64) -> MathResult<f64> {
312        let min_t = self.tenors[0];
313        let max_t = *self.tenors.last().unwrap();
314
315        // Check bounds
316        if !self.allow_extrapolation && (t < min_t || t > max_t) {
317            return Err(MathError::ExtrapolationNotAllowed {
318                x: t,
319                min: min_t,
320                max: max_t,
321            });
322        }
323
324        if t <= 0.0 {
325            return Ok(0.0); // Derivative at origin
326        }
327
328        // Find segment
329        let i = if t > max_t {
330            self.tenors.len() - 2
331        } else if t < min_t {
332            0
333        } else {
334            self.find_segment(t)
335        };
336
337        let t_i = self.tenors[i];
338        let r_i = self.zero_rates[i];
339        let f_i = self.forward_rates[i];
340
341        // r(t) = (r_i * t_i + f_i * (t - t_i)) / t
342        //      = r_i * t_i / t + f_i * (t - t_i) / t
343        //      = r_i * t_i / t + f_i - f_i * t_i / t
344        //      = (r_i * t_i - f_i * t_i) / t + f_i
345        //      = (r_i - f_i) * t_i / t + f_i
346        //
347        // dr/dt = -(r_i - f_i) * t_i / t^2
348        //       = (f_i - r_i) * t_i / t^2
349
350        Ok((f_i - r_i) * t_i / (t * t))
351    }
352
353    fn allows_extrapolation(&self) -> bool {
354        self.allow_extrapolation
355    }
356
357    fn min_x(&self) -> f64 {
358        self.tenors[0]
359    }
360
361    fn max_x(&self) -> f64 {
362        *self.tenors.last().unwrap()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use approx::assert_relative_eq;
370
371    #[test]
372    fn test_flat_forward_through_pillars() {
373        let tenors = vec![1.0, 2.0, 5.0, 10.0];
374        let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
375
376        let interp = FlatForward::new(tenors.clone(), zero_rates.clone()).unwrap();
377
378        // Should pass through all pillar points
379        for (t, r) in tenors.iter().zip(zero_rates.iter()) {
380            assert_relative_eq!(interp.interpolate(*t).unwrap(), *r, epsilon = 1e-10);
381        }
382    }
383
384    #[test]
385    fn test_flat_forward_rates() {
386        let tenors = vec![1.0, 2.0, 3.0];
387        let zero_rates = vec![0.02, 0.03, 0.04];
388
389        let interp = FlatForward::new(tenors, zero_rates).unwrap();
390
391        // Forward rate from 1Y to 2Y:
392        // f = (0.03 * 2 - 0.02 * 1) / (2 - 1) = (0.06 - 0.02) / 1 = 0.04
393        assert_relative_eq!(interp.forward_rate(1.5).unwrap(), 0.04, epsilon = 1e-10);
394
395        // Forward rate from 2Y to 3Y:
396        // f = (0.04 * 3 - 0.03 * 2) / (3 - 2) = (0.12 - 0.06) / 1 = 0.06
397        assert_relative_eq!(interp.forward_rate(2.5).unwrap(), 0.06, epsilon = 1e-10);
398    }
399
400    #[test]
401    fn test_interpolation_between_pillars() {
402        let tenors = vec![1.0, 2.0];
403        let zero_rates = vec![0.02, 0.04];
404
405        let interp = FlatForward::new(tenors, zero_rates).unwrap();
406
407        // Forward rate = (0.04 * 2 - 0.02 * 1) / 1 = 0.06
408        // At t = 1.5:
409        // r(1.5) = (0.02 * 1 + 0.06 * 0.5) / 1.5 = (0.02 + 0.03) / 1.5 = 0.05 / 1.5 = 0.0333...
410        let r_mid = interp.interpolate(1.5).unwrap();
411        assert_relative_eq!(r_mid, 0.05 / 1.5, epsilon = 1e-10);
412    }
413
414    #[test]
415    fn test_forward_rate_consistency() {
416        // Verify that interpolated zero rates produce the expected forwards
417        let tenors = vec![1.0, 2.0, 5.0, 10.0];
418        let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
419
420        let interp = FlatForward::new(tenors.clone(), zero_rates.clone()).unwrap();
421
422        // Check that forward rates are indeed constant within segments
423        let f_segment_1 = interp.forward_rate(1.5).unwrap();
424        assert_relative_eq!(
425            interp.forward_rate(1.1).unwrap(),
426            f_segment_1,
427            epsilon = 1e-10
428        );
429        assert_relative_eq!(
430            interp.forward_rate(1.9).unwrap(),
431            f_segment_1,
432            epsilon = 1e-10
433        );
434
435        // Verify forward rate calculation from zero rates
436        // f(1Y, 2Y) = (r_2Y * 2 - r_1Y * 1) / (2 - 1)
437        let expected_f = (0.025 * 2.0 - 0.02 * 1.0) / (2.0 - 1.0);
438        assert_relative_eq!(f_segment_1, expected_f, epsilon = 1e-10);
439    }
440
441    #[test]
442    fn test_positive_forward_rates() {
443        // Upward sloping curve should have positive forwards
444        let tenors = vec![1.0, 2.0, 5.0, 10.0];
445        let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
446
447        let interp = FlatForward::new(tenors, zero_rates).unwrap();
448
449        // All forwards should be positive
450        for &f in interp.forward_rates_vec() {
451            assert!(f > 0.0, "Forward rate {} should be positive", f);
452        }
453    }
454
455    #[test]
456    fn test_derivative_numerical() {
457        let tenors = vec![1.0, 2.0, 5.0, 10.0];
458        let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
459
460        let interp = FlatForward::new(tenors, zero_rates)
461            .unwrap()
462            .with_extrapolation();
463
464        // Test derivative at several points
465        for t in [1.5, 2.5, 4.0, 7.0] {
466            let h = 1e-6;
467            let r_plus = interp.interpolate(t + h).unwrap();
468            let r_minus = interp.interpolate(t - h).unwrap();
469            let numerical = (r_plus - r_minus) / (2.0 * h);
470            let analytical = interp.derivative(t).unwrap();
471
472            assert_relative_eq!(analytical, numerical, epsilon = 1e-5);
473        }
474    }
475
476    #[test]
477    fn test_extrapolation() {
478        let tenors = vec![1.0, 2.0, 5.0];
479        let zero_rates = vec![0.02, 0.025, 0.03];
480
481        let interp = FlatForward::new(tenors, zero_rates)
482            .unwrap()
483            .with_extrapolation();
484
485        // Should extrapolate beyond range
486        assert!(interp.interpolate(0.5).is_ok());
487        assert!(interp.interpolate(7.0).is_ok());
488
489        // Extrapolation below first pillar uses first zero rate
490        assert_relative_eq!(interp.interpolate(0.5).unwrap(), 0.02, epsilon = 1e-10);
491    }
492
493    #[test]
494    fn test_no_extrapolation() {
495        let tenors = vec![1.0, 2.0, 5.0];
496        let zero_rates = vec![0.02, 0.025, 0.03];
497
498        let interp = FlatForward::new(tenors, zero_rates).unwrap();
499
500        // Should fail outside range
501        assert!(interp.interpolate(0.5).is_err());
502        assert!(interp.interpolate(7.0).is_err());
503    }
504
505    #[test]
506    fn test_with_origin() {
507        let tenors = vec![1.0, 2.0, 5.0];
508        let zero_rates = vec![0.02, 0.025, 0.03];
509
510        let interp = FlatForward::with_origin(tenors, zero_rates).unwrap();
511
512        // Should allow interpolation from t=0
513        assert!(interp.interpolate(0.0).is_ok());
514        assert!(interp.interpolate(0.5).is_ok());
515    }
516
517    #[test]
518    fn test_insufficient_points() {
519        let tenors = vec![1.0];
520        let zero_rates = vec![0.02];
521
522        assert!(FlatForward::new(tenors, zero_rates).is_err());
523    }
524
525    #[test]
526    fn test_mismatched_lengths() {
527        let tenors = vec![1.0, 2.0, 3.0];
528        let zero_rates = vec![0.02, 0.025];
529
530        assert!(FlatForward::new(tenors, zero_rates).is_err());
531    }
532
533    #[test]
534    fn test_non_positive_tenor() {
535        let tenors = vec![0.0, 1.0, 2.0];
536        let zero_rates = vec![0.02, 0.025, 0.03];
537
538        // Should fail because first tenor is 0 (use with_origin instead)
539        assert!(FlatForward::new(tenors, zero_rates).is_err());
540    }
541
542    #[test]
543    fn test_flat_curve() {
544        // Flat zero curve should have zero forward curve
545        let tenors = vec![1.0, 2.0, 5.0, 10.0];
546        let zero_rates = vec![0.03, 0.03, 0.03, 0.03];
547
548        let interp = FlatForward::new(tenors.clone(), zero_rates).unwrap();
549
550        // Forward rates should all be 0.03 (same as zero rate for flat curve)
551        for &f in interp.forward_rates_vec() {
552            assert_relative_eq!(f, 0.03, epsilon = 1e-10);
553        }
554
555        // Interpolated values should all be 0.03
556        for t in [1.0, 1.5, 2.5, 4.0, 7.0, 10.0] {
557            assert_relative_eq!(interp.interpolate(t).unwrap(), 0.03, epsilon = 1e-10);
558        }
559    }
560
561    #[test]
562    fn test_inverted_curve() {
563        // Inverted curve (downward sloping)
564        let tenors = vec![1.0, 2.0, 5.0, 10.0];
565        let zero_rates = vec![0.05, 0.04, 0.03, 0.025];
566
567        let interp = FlatForward::new(tenors, zero_rates).unwrap();
568
569        // Forward rates can be negative for inverted curve
570        // Just verify it doesn't crash and produces reasonable values
571        assert!(interp.interpolate(1.5).is_ok());
572        assert!(interp.interpolate(3.0).is_ok());
573    }
574}