use crate::error::{MathError, MathResult};
use crate::interpolation::Interpolator;
#[derive(Debug, Clone)]
pub struct FlatForward {
tenors: Vec<f64>,
zero_rates: Vec<f64>,
forward_rates: Vec<f64>,
allow_extrapolation: bool,
}
impl FlatForward {
pub fn new(tenors: Vec<f64>, zero_rates: Vec<f64>) -> MathResult<Self> {
if tenors.len() < 2 {
return Err(MathError::insufficient_data(2, tenors.len()));
}
if tenors.len() != zero_rates.len() {
return Err(MathError::invalid_input(format!(
"tenors and zero_rates must have same length: {} vs {}",
tenors.len(),
zero_rates.len()
)));
}
if tenors[0] <= 0.0 {
return Err(MathError::invalid_input(
"First tenor must be positive for flat forward interpolation",
));
}
for i in 1..tenors.len() {
if tenors[i] <= tenors[i - 1] {
return Err(MathError::invalid_input(
"Tenors must be strictly increasing",
));
}
}
let forward_rates = Self::compute_forward_rates(&tenors, &zero_rates);
Ok(Self {
tenors,
zero_rates,
forward_rates,
allow_extrapolation: false,
})
}
pub fn with_origin(mut tenors: Vec<f64>, mut zero_rates: Vec<f64>) -> MathResult<Self> {
if tenors.is_empty() {
return Err(MathError::insufficient_data(1, 0));
}
if tenors[0] > 0.0 {
tenors.insert(0, 0.0);
zero_rates.insert(0, zero_rates[0]);
}
if tenors.len() < 2 {
return Err(MathError::insufficient_data(2, tenors.len()));
}
if tenors.len() != zero_rates.len() {
return Err(MathError::invalid_input(format!(
"tenors and zero_rates must have same length: {} vs {}",
tenors.len(),
zero_rates.len()
)));
}
for i in 1..tenors.len() {
if tenors[i] <= tenors[i - 1] {
return Err(MathError::invalid_input(
"Tenors must be strictly increasing",
));
}
}
let forward_rates = Self::compute_forward_rates(&tenors, &zero_rates);
Ok(Self {
tenors,
zero_rates,
forward_rates,
allow_extrapolation: false,
})
}
#[must_use]
pub fn with_extrapolation(mut self) -> Self {
self.allow_extrapolation = true;
self
}
fn compute_forward_rates(tenors: &[f64], zero_rates: &[f64]) -> Vec<f64> {
let n = tenors.len();
let mut forwards = Vec::with_capacity(n);
for i in 0..n - 1 {
let t0 = tenors[i];
let t1 = tenors[i + 1];
let r0 = zero_rates[i];
let r1 = zero_rates[i + 1];
let fwd = if t0 == 0.0 {
r1
} else {
(r1 * t1 - r0 * t0) / (t1 - t0)
};
forwards.push(fwd);
}
if !forwards.is_empty() {
forwards.push(*forwards.last().unwrap());
} else {
forwards.push(zero_rates[0]);
}
forwards
}
fn find_segment(&self, t: f64) -> usize {
match self
.tenors
.binary_search_by(|probe| probe.partial_cmp(&t).unwrap_or(std::cmp::Ordering::Equal))
{
Ok(i) => i.min(self.tenors.len() - 2),
Err(i) => (i.saturating_sub(1)).min(self.tenors.len() - 2),
}
}
pub fn forward_rate(&self, t: f64) -> MathResult<f64> {
if !self.allow_extrapolation && (t < self.tenors[0] || t > *self.tenors.last().unwrap()) {
return Err(MathError::ExtrapolationNotAllowed {
x: t,
min: self.tenors[0],
max: *self.tenors.last().unwrap(),
});
}
let i = self.find_segment(t);
Ok(self.forward_rates[i])
}
pub fn tenors(&self) -> &[f64] {
&self.tenors
}
pub fn zero_rates(&self) -> &[f64] {
&self.zero_rates
}
pub fn forward_rates_vec(&self) -> &[f64] {
&self.forward_rates
}
}
impl Interpolator for FlatForward {
fn interpolate(&self, t: f64) -> MathResult<f64> {
let min_t = self.tenors[0];
let max_t = *self.tenors.last().unwrap();
if !self.allow_extrapolation && (t < min_t || t > max_t) {
return Err(MathError::ExtrapolationNotAllowed {
x: t,
min: min_t,
max: max_t,
});
}
if t <= 0.0 {
return Ok(self.zero_rates[0]);
}
if let Some(idx) = self.tenors.iter().position(|&x| (x - t).abs() < 1e-12) {
return Ok(self.zero_rates[idx]);
}
if t < min_t {
return Ok(self.zero_rates[0]);
}
if t > max_t {
let n = self.tenors.len();
let t_n = self.tenors[n - 1];
let r_n = self.zero_rates[n - 1];
let f_n = self.forward_rates[n - 1];
return Ok((r_n * t_n + f_n * (t - t_n)) / t);
}
let i = self.find_segment(t);
let t_i = self.tenors[i];
let r_i = self.zero_rates[i];
let f_i = self.forward_rates[i];
Ok((r_i * t_i + f_i * (t - t_i)) / t)
}
fn derivative(&self, t: f64) -> MathResult<f64> {
let min_t = self.tenors[0];
let max_t = *self.tenors.last().unwrap();
if !self.allow_extrapolation && (t < min_t || t > max_t) {
return Err(MathError::ExtrapolationNotAllowed {
x: t,
min: min_t,
max: max_t,
});
}
if t <= 0.0 {
return Ok(0.0); }
let i = if t > max_t {
self.tenors.len() - 2
} else if t < min_t {
0
} else {
self.find_segment(t)
};
let t_i = self.tenors[i];
let r_i = self.zero_rates[i];
let f_i = self.forward_rates[i];
Ok((f_i - r_i) * t_i / (t * t))
}
fn allows_extrapolation(&self) -> bool {
self.allow_extrapolation
}
fn min_x(&self) -> f64 {
self.tenors[0]
}
fn max_x(&self) -> f64 {
*self.tenors.last().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_flat_forward_through_pillars() {
let tenors = vec![1.0, 2.0, 5.0, 10.0];
let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
let interp = FlatForward::new(tenors.clone(), zero_rates.clone()).unwrap();
for (t, r) in tenors.iter().zip(zero_rates.iter()) {
assert_relative_eq!(interp.interpolate(*t).unwrap(), *r, epsilon = 1e-10);
}
}
#[test]
fn test_flat_forward_rates() {
let tenors = vec![1.0, 2.0, 3.0];
let zero_rates = vec![0.02, 0.03, 0.04];
let interp = FlatForward::new(tenors, zero_rates).unwrap();
assert_relative_eq!(interp.forward_rate(1.5).unwrap(), 0.04, epsilon = 1e-10);
assert_relative_eq!(interp.forward_rate(2.5).unwrap(), 0.06, epsilon = 1e-10);
}
#[test]
fn test_interpolation_between_pillars() {
let tenors = vec![1.0, 2.0];
let zero_rates = vec![0.02, 0.04];
let interp = FlatForward::new(tenors, zero_rates).unwrap();
let r_mid = interp.interpolate(1.5).unwrap();
assert_relative_eq!(r_mid, 0.05 / 1.5, epsilon = 1e-10);
}
#[test]
fn test_forward_rate_consistency() {
let tenors = vec![1.0, 2.0, 5.0, 10.0];
let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
let interp = FlatForward::new(tenors.clone(), zero_rates.clone()).unwrap();
let f_segment_1 = interp.forward_rate(1.5).unwrap();
assert_relative_eq!(
interp.forward_rate(1.1).unwrap(),
f_segment_1,
epsilon = 1e-10
);
assert_relative_eq!(
interp.forward_rate(1.9).unwrap(),
f_segment_1,
epsilon = 1e-10
);
let expected_f = (0.025 * 2.0 - 0.02 * 1.0) / (2.0 - 1.0);
assert_relative_eq!(f_segment_1, expected_f, epsilon = 1e-10);
}
#[test]
fn test_positive_forward_rates() {
let tenors = vec![1.0, 2.0, 5.0, 10.0];
let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
let interp = FlatForward::new(tenors, zero_rates).unwrap();
for &f in interp.forward_rates_vec() {
assert!(f > 0.0, "Forward rate {} should be positive", f);
}
}
#[test]
fn test_derivative_numerical() {
let tenors = vec![1.0, 2.0, 5.0, 10.0];
let zero_rates = vec![0.02, 0.025, 0.03, 0.035];
let interp = FlatForward::new(tenors, zero_rates)
.unwrap()
.with_extrapolation();
for t in [1.5, 2.5, 4.0, 7.0] {
let h = 1e-6;
let r_plus = interp.interpolate(t + h).unwrap();
let r_minus = interp.interpolate(t - h).unwrap();
let numerical = (r_plus - r_minus) / (2.0 * h);
let analytical = interp.derivative(t).unwrap();
assert_relative_eq!(analytical, numerical, epsilon = 1e-5);
}
}
#[test]
fn test_extrapolation() {
let tenors = vec![1.0, 2.0, 5.0];
let zero_rates = vec![0.02, 0.025, 0.03];
let interp = FlatForward::new(tenors, zero_rates)
.unwrap()
.with_extrapolation();
assert!(interp.interpolate(0.5).is_ok());
assert!(interp.interpolate(7.0).is_ok());
assert_relative_eq!(interp.interpolate(0.5).unwrap(), 0.02, epsilon = 1e-10);
}
#[test]
fn test_no_extrapolation() {
let tenors = vec![1.0, 2.0, 5.0];
let zero_rates = vec![0.02, 0.025, 0.03];
let interp = FlatForward::new(tenors, zero_rates).unwrap();
assert!(interp.interpolate(0.5).is_err());
assert!(interp.interpolate(7.0).is_err());
}
#[test]
fn test_with_origin() {
let tenors = vec![1.0, 2.0, 5.0];
let zero_rates = vec![0.02, 0.025, 0.03];
let interp = FlatForward::with_origin(tenors, zero_rates).unwrap();
assert!(interp.interpolate(0.0).is_ok());
assert!(interp.interpolate(0.5).is_ok());
}
#[test]
fn test_insufficient_points() {
let tenors = vec![1.0];
let zero_rates = vec![0.02];
assert!(FlatForward::new(tenors, zero_rates).is_err());
}
#[test]
fn test_mismatched_lengths() {
let tenors = vec![1.0, 2.0, 3.0];
let zero_rates = vec![0.02, 0.025];
assert!(FlatForward::new(tenors, zero_rates).is_err());
}
#[test]
fn test_non_positive_tenor() {
let tenors = vec![0.0, 1.0, 2.0];
let zero_rates = vec![0.02, 0.025, 0.03];
assert!(FlatForward::new(tenors, zero_rates).is_err());
}
#[test]
fn test_flat_curve() {
let tenors = vec![1.0, 2.0, 5.0, 10.0];
let zero_rates = vec![0.03, 0.03, 0.03, 0.03];
let interp = FlatForward::new(tenors.clone(), zero_rates).unwrap();
for &f in interp.forward_rates_vec() {
assert_relative_eq!(f, 0.03, epsilon = 1e-10);
}
for t in [1.0, 1.5, 2.5, 4.0, 7.0, 10.0] {
assert_relative_eq!(interp.interpolate(t).unwrap(), 0.03, epsilon = 1e-10);
}
}
#[test]
fn test_inverted_curve() {
let tenors = vec![1.0, 2.0, 5.0, 10.0];
let zero_rates = vec![0.05, 0.04, 0.03, 0.025];
let interp = FlatForward::new(tenors, zero_rates).unwrap();
assert!(interp.interpolate(1.5).is_ok());
assert!(interp.interpolate(3.0).is_ok());
}
}