use crate::derivatives::pv_prime_r;
use crate::tvm::{npv, xnpv};
use rust_decimal::prelude::*;
use rust_decimal_macros::*;
pub fn irr(
cash_flows: &[Decimal],
guess: Option<Decimal>,
tolerance: Option<Decimal>,
) -> Result<Decimal, (Decimal, Decimal)> {
const MAX_ITER: u8 = 20;
let tolerance = tolerance.unwrap_or(dec!(1e-5));
let mut rate = guess.unwrap_or(dec!(0.1));
for _ in 0..MAX_ITER {
let npv_value = npv(rate, cash_flows);
if npv_value.abs() < tolerance {
return Ok(rate);
}
let drate: Decimal = cash_flows
.iter()
.enumerate()
.map(|(i, &cf)| pv_prime_r(rate, i.into(), cf))
.sum();
if drate.is_zero() {
return Err((rate, npv_value));
}
rate -= npv_value / drate;
}
Err((rate, npv(rate, cash_flows)))
}
pub fn xirr(
flow_table: &[(Decimal, i32)],
guess: Option<Decimal>,
tolerance: Option<Decimal>,
) -> Result<Decimal, (Decimal, Decimal)> {
let tolerance = tolerance.unwrap_or(dec!(1e-5));
const MAX_ITER: u8 = 20;
let init_date = flow_table.first().unwrap().1;
let mut rate = guess.unwrap_or(dec!(0.1));
for _ in 0..MAX_ITER {
let npv_value = xnpv(rate, &flow_table);
if npv_value.abs() < tolerance {
return Ok(rate);
}
let drate: Decimal = flow_table
.iter()
.map(|&(cf, date)| pv_prime_r(rate, Decimal::from_i32(date - init_date).unwrap() / dec!(365), cf))
.sum();
if drate.is_zero() {
return Err((rate, npv_value));
}
rate -= npv_value / drate;
}
Err((rate, xnpv(rate, &flow_table)))
}
#[cfg(test)]
mod tests {
#[cfg(not(feature = "std"))]
extern crate std;
use super::*;
#[cfg(not(feature = "std"))]
use std::prelude::v1::*;
#[cfg(not(feature = "std"))]
use std::{assert, vec};
#[test]
fn test_irr() {
let cash_flows = vec![dec!(-100), dec!(50), dec!(40), dec!(30), dec!(1000)];
let result = irr(&cash_flows, None, Some(dec!(1e-20)));
if let Err((rate, npv)) = result {
assert!(
(npv).abs() < dec!(1e-20),
"Failed to converge at 1e-20 precision. Last rate: {}, NPV: {}",
rate,
npv
);
} else {
assert!(true);
}
}
#[test]
fn test_xirr() {
let flow_table = vec![
(dec!(-100), 0),
(dec!(50), 359),
(dec!(40), 400),
(dec!(30), 1000),
(dec!(20), 2000),
];
let xirr = xirr(&flow_table, None, Some(dec!(1e-20)));
if let Err((rate, npv)) = xirr {
assert!(
(npv).abs() < dec!(1e-20),
"Failed to converge at 1e-20 precision. Last rate: {}, NPV: {}",
rate,
npv
);
} else {
let expected = dec!(0.20084);
assert!(
(xirr.unwrap() - expected).abs() < dec!(1e-5),
"Failed on case: {}. Expected: {}, Result: {}",
"Cash flows of -100, 50, 40, 30, 20",
expected,
xirr.unwrap()
);
}
}
}