use crate::eval::coercion::to_number;
use crate::eval::functions::check_arity;
use crate::types::{ErrorKind, Value};
pub fn irr_fn(args: &[Value]) -> Value {
if let Some(err) = check_arity(args, 1, 256) {
return err;
}
let (cfs, guess) = match collect_cashflows_with_guess(args) {
Ok(pair) => pair,
Err(e) => return e,
};
if cfs.len() < 2 {
return Value::Error(ErrorKind::NA);
}
let has_positive = cfs.iter().any(|&n| n > 0.0);
let has_negative = cfs.iter().any(|&n| n < 0.0);
if !has_positive || !has_negative {
return Value::Error(ErrorKind::Num);
}
if let Some(rate) = irr_newton(&cfs, guess) {
return Value::Number(rate);
}
match irr_brent(&cfs) {
Some(rate) => Value::Number(rate),
None => Value::Error(ErrorKind::Num),
}
}
fn irr_newton(cfs: &[f64], guess: f64) -> Option<f64> {
let mut rate = guess;
for _ in 0..100 {
let (npv, dnpv) = npv_and_derivative(cfs, rate);
if !npv.is_finite() || !dnpv.is_finite() || dnpv == 0.0 {
return None;
}
let new_rate = rate - npv / dnpv;
if new_rate <= -1.0 || !new_rate.is_finite() {
return None; }
if (new_rate - rate).abs() < 1e-7 {
return Some(new_rate);
}
rate = new_rate;
}
None
}
fn irr_brent(cfs: &[f64]) -> Option<f64> {
let npv = |r: f64| -> f64 {
cfs.iter()
.enumerate()
.fold(0.0, |acc, (t, &cf)| acc + cf / (1.0 + r).powf(t as f64))
};
let candidates: &[f64] = &[
-0.999, -0.99, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2,
-0.15, -0.1, -0.05, 0.0, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5, 1.0,
2.0, 5.0, 10.0, 50.0, 100.0,
];
let mut prev_r = candidates[0];
let mut prev_f = npv(prev_r);
for &r in &candidates[1..] {
let f_r = npv(r);
if prev_f * f_r <= 0.0 {
return brent_root(npv, prev_r, r, 1e-10);
}
prev_r = r;
prev_f = f_r;
}
None
}
fn brent_root<F: Fn(f64) -> f64>(f: F, mut a: f64, mut b: f64, tol: f64) -> Option<f64> {
let mut fa = f(a);
let mut fb = f(b);
if !fa.is_finite() || !fb.is_finite() {
return None;
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut mflag = true;
let mut d = 0.0_f64;
for _ in 0..200 {
if fb.abs() < tol || (b - a).abs() < tol {
return Some(b);
}
let s = if fa != fc && fb != fc {
a * fb * fc / ((fa - fb) * (fa - fc))
+ b * fa * fc / ((fb - fa) * (fb - fc))
+ c * fa * fb / ((fc - fa) * (fc - fb))
} else {
b - fb * (b - a) / (fb - fa)
};
let mid = (a + b) / 2.0;
let use_bisect = !(((3.0 * a + b) / 4.0 < s && s < b)
|| (b < s && s < (3.0 * a + b) / 4.0))
|| (mflag && (s - b).abs() >= (b - c).abs() / 2.0)
|| (!mflag && (s - b).abs() >= (c - d).abs() / 2.0)
|| (mflag && (b - c).abs() < tol)
|| (!mflag && (c - d).abs() < tol);
let s = if use_bisect { mid } else { s };
mflag = use_bisect;
let fs = f(s);
d = c;
c = b;
fc = fb;
if fa * fs < 0.0 {
b = s;
fb = fs;
} else {
a = s;
fa = fs;
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
}
Some(b)
}
fn collect_cashflows_with_guess(args: &[Value]) -> Result<(Vec<f64>, f64), Value> {
if let Value::Array(items) = &args[0] {
let cfs = flatten_values(items.clone())?;
let guess = if args.len() > 1 {
to_number(args[1].clone())?
} else {
0.1
};
return Ok((cfs, guess));
}
let mut cfs = Vec::with_capacity(args.len());
for arg in args {
match arg {
Value::Bool(_) => {} _ => cfs.push(to_number(arg.clone())?),
}
}
Ok((cfs, 0.1))
}
fn flatten_values(items: Vec<Value>) -> Result<Vec<f64>, Value> {
let mut out = Vec::new();
for v in items {
match v {
Value::Array(inner) => {
let sub = flatten_values(inner)?;
out.extend(sub);
}
Value::Bool(_) | Value::Text(_) => {} other => out.push(to_number(other)?),
}
}
Ok(out)
}
fn npv_and_derivative(cfs: &[f64], rate: f64) -> (f64, f64) {
let mut npv = 0.0;
let mut dnpv = 0.0;
for (i, &cf) in cfs.iter().enumerate() {
let t = i as f64;
let denom = (1.0 + rate).powf(t);
npv += cf / denom;
dnpv -= t * cf / ((1.0 + rate).powf(t + 1.0));
}
(npv, dnpv)
}
#[cfg(test)]
mod tests;