use crate::compat::Instant;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use serde::{Deserialize, Serialize};
use crate::error::CorpFinanceError;
use crate::types::{with_metadata, ComputationOutput, Money, Rate};
use crate::CorpFinanceResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PrepaymentModel {
Cpr(Rate),
Psa(Decimal),
Smm(Rate),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DefaultModel {
Cdr(Rate),
Sda(Decimal),
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AbsMbsInput {
pub pool_balance: Money,
pub weighted_avg_coupon: Rate,
pub weighted_avg_maturity_months: u32,
pub weighted_avg_age_months: u32,
pub num_loans: u32,
pub prepayment_model: PrepaymentModel,
pub default_model: DefaultModel,
pub loss_severity: Rate,
pub recovery_lag_months: u32,
pub servicing_fee_rate: Rate,
pub projection_months: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AbsPeriod {
pub month: u32,
pub beginning_balance: Money,
pub scheduled_principal: Money,
pub scheduled_interest: Money,
pub prepayment: Money,
pub defaults: Money,
pub loss: Money,
pub recovery: Money,
pub servicing_fee: Money,
pub total_principal: Money,
pub total_cashflow: Money,
pub ending_balance: Money,
pub smm: Rate,
pub cpr: Rate,
pub mdr: Rate,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AbsSummary {
pub total_principal_collected: Money,
pub total_interest_collected: Money,
pub total_prepayments: Money,
pub total_defaults: Money,
pub total_losses: Money,
pub total_recoveries: Money,
pub total_servicing_fees: Money,
pub weighted_average_life_years: Decimal,
pub pool_factor_at_end: Rate,
pub cumulative_loss_rate: Rate,
pub total_cashflows: Money,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AbsMbsOutput {
pub periods: Vec<AbsPeriod>,
pub summary: AbsSummary,
}
const PSA_BASE_CPR_30: Decimal = dec!(0.06);
const SDA_PEAK_CDR_30: Decimal = dec!(0.006);
const SDA_FLOOR_CDR: Decimal = dec!(0.0003);
const BALANCE_EPSILON: Decimal = dec!(0.01);
pub fn model_abs_cashflows(
input: &AbsMbsInput,
) -> CorpFinanceResult<ComputationOutput<AbsMbsOutput>> {
let start = Instant::now();
let mut warnings: Vec<String> = Vec::new();
validate_input(input)?;
let wac_monthly = input.weighted_avg_coupon / dec!(12);
let mut balance = input.pool_balance;
let mut remaining_months = input.weighted_avg_maturity_months;
let mut periods: Vec<AbsPeriod> = Vec::with_capacity(input.projection_months as usize);
let mut defaults_history: Vec<Money> = Vec::with_capacity(input.projection_months as usize);
let mut total_principal_collected = Decimal::ZERO;
let mut total_interest_collected = Decimal::ZERO;
let mut total_prepayments = Decimal::ZERO;
let mut total_defaults = Decimal::ZERO;
let mut total_losses = Decimal::ZERO;
let mut total_recoveries = Decimal::ZERO;
let mut total_servicing_fees = Decimal::ZERO;
let mut total_cashflows = Decimal::ZERO;
let mut wal_numerator = Decimal::ZERO;
for month_idx in 0..input.projection_months {
let month = month_idx + 1;
let age = input.weighted_avg_age_months + month;
if balance < BALANCE_EPSILON || remaining_months == 0 {
let period = zero_period(month, balance);
defaults_history.push(Decimal::ZERO);
periods.push(period);
continue;
}
let beginning_balance = balance;
let cpr_annual = compute_cpr(age, &input.prepayment_model);
let smm = cpr_to_smm(cpr_annual);
let cdr_annual = compute_cdr(age, &input.default_model);
let mdr = cdr_to_mdr(cdr_annual);
let scheduled_interest = beginning_balance * wac_monthly;
let scheduled_payment = if wac_monthly > Decimal::ZERO {
let denom =
Decimal::ONE - iterative_pow_recip(Decimal::ONE + wac_monthly, remaining_months);
if denom > Decimal::ZERO {
beginning_balance * wac_monthly / denom
} else {
beginning_balance
}
} else {
beginning_balance / Decimal::from(remaining_months)
};
let mut scheduled_principal = scheduled_payment - scheduled_interest;
if scheduled_principal > beginning_balance {
scheduled_principal = beginning_balance;
}
if scheduled_principal < Decimal::ZERO {
scheduled_principal = Decimal::ZERO;
}
let prepay_base = beginning_balance - scheduled_principal;
let prepayment = if prepay_base > Decimal::ZERO {
prepay_base * smm
} else {
Decimal::ZERO
};
let default_base = beginning_balance - scheduled_principal - prepayment;
let defaults = if default_base > Decimal::ZERO {
default_base * mdr
} else {
Decimal::ZERO
};
let loss = defaults * input.loss_severity;
let recovery = if input.recovery_lag_months > 0 && month > input.recovery_lag_months {
let lag_idx = (month - input.recovery_lag_months - 1) as usize;
if lag_idx < defaults_history.len() {
defaults_history[lag_idx] * (Decimal::ONE - input.loss_severity)
} else {
Decimal::ZERO
}
} else if input.recovery_lag_months == 0 {
defaults * (Decimal::ONE - input.loss_severity)
} else {
Decimal::ZERO
};
defaults_history.push(defaults);
let servicing_fee = beginning_balance * input.servicing_fee_rate / dec!(12);
let total_principal = scheduled_principal + prepayment;
let total_cashflow = scheduled_interest + total_principal - servicing_fee + recovery;
let mut ending_balance = beginning_balance - scheduled_principal - prepayment - defaults;
if ending_balance < Decimal::ZERO {
if ending_balance.abs() < BALANCE_EPSILON {
ending_balance = Decimal::ZERO;
} else {
warnings.push(format!(
"Month {}: ending balance went negative ({}) — clamped to zero",
month, ending_balance
));
ending_balance = Decimal::ZERO;
}
}
total_principal_collected += total_principal;
total_interest_collected += scheduled_interest;
total_prepayments += prepayment;
total_defaults += defaults;
total_losses += loss;
total_recoveries += recovery;
total_servicing_fees += servicing_fee;
total_cashflows += total_cashflow;
wal_numerator += Decimal::from(month) * total_principal / dec!(12);
periods.push(AbsPeriod {
month,
beginning_balance,
scheduled_principal,
scheduled_interest,
prepayment,
defaults,
loss,
recovery,
servicing_fee,
total_principal,
total_cashflow,
ending_balance,
smm,
cpr: cpr_annual,
mdr,
});
balance = ending_balance;
remaining_months = remaining_months.saturating_sub(1);
}
let weighted_average_life_years = if total_principal_collected > Decimal::ZERO {
wal_numerator / total_principal_collected
} else {
Decimal::ZERO
};
let pool_factor_at_end = if input.pool_balance > Decimal::ZERO {
balance / input.pool_balance
} else {
Decimal::ZERO
};
let cumulative_loss_rate = if input.pool_balance > Decimal::ZERO {
total_losses / input.pool_balance
} else {
Decimal::ZERO
};
let summary = AbsSummary {
total_principal_collected,
total_interest_collected,
total_prepayments,
total_defaults,
total_losses,
total_recoveries,
total_servicing_fees,
weighted_average_life_years,
pool_factor_at_end,
cumulative_loss_rate,
total_cashflows,
};
let output = AbsMbsOutput { periods, summary };
let elapsed = start.elapsed().as_micros() as u64;
Ok(with_metadata(
"ABS/MBS Cash Flow Model — Amortisation with prepayment/default/recovery projections",
input,
warnings,
elapsed,
output,
))
}
fn validate_input(input: &AbsMbsInput) -> CorpFinanceResult<()> {
if input.pool_balance <= Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "pool_balance".into(),
reason: "Pool balance must be positive".into(),
});
}
if input.weighted_avg_coupon < Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "weighted_avg_coupon".into(),
reason: "WAC cannot be negative".into(),
});
}
if input.weighted_avg_maturity_months == 0 {
return Err(CorpFinanceError::InvalidInput {
field: "weighted_avg_maturity_months".into(),
reason: "WAM must be greater than zero".into(),
});
}
if input.loss_severity < Decimal::ZERO || input.loss_severity > Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "loss_severity".into(),
reason: "Loss severity must be between 0 and 1".into(),
});
}
if input.servicing_fee_rate < Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "servicing_fee_rate".into(),
reason: "Servicing fee rate cannot be negative".into(),
});
}
if input.projection_months == 0 {
return Err(CorpFinanceError::InvalidInput {
field: "projection_months".into(),
reason: "Projection months must be greater than zero".into(),
});
}
match &input.prepayment_model {
PrepaymentModel::Cpr(rate) => {
if *rate < Decimal::ZERO || *rate > Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "prepayment_model.Cpr".into(),
reason: "CPR must be between 0 and 1".into(),
});
}
}
PrepaymentModel::Psa(speed) => {
if *speed < Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "prepayment_model.Psa".into(),
reason: "PSA speed must be non-negative".into(),
});
}
}
PrepaymentModel::Smm(rate) => {
if *rate < Decimal::ZERO || *rate > Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "prepayment_model.Smm".into(),
reason: "SMM must be between 0 and 1".into(),
});
}
}
}
match &input.default_model {
DefaultModel::Cdr(rate) => {
if *rate < Decimal::ZERO || *rate > Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "default_model.Cdr".into(),
reason: "CDR must be between 0 and 1".into(),
});
}
}
DefaultModel::Sda(speed) => {
if *speed < Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "default_model.Sda".into(),
reason: "SDA speed must be non-negative".into(),
});
}
}
DefaultModel::None => {}
}
Ok(())
}
fn compute_cpr(age: u32, model: &PrepaymentModel) -> Rate {
match model {
PrepaymentModel::Cpr(cpr) => *cpr,
PrepaymentModel::Psa(speed) => {
let base_cpr = if age <= 30 {
PSA_BASE_CPR_30 * Decimal::from(age) / dec!(30)
} else {
PSA_BASE_CPR_30
};
base_cpr * *speed / dec!(100)
}
PrepaymentModel::Smm(smm) => {
smm_to_cpr(*smm)
}
}
}
fn cpr_to_smm(cpr: Rate) -> Rate {
if cpr <= Decimal::ZERO {
return Decimal::ZERO;
}
if cpr >= Decimal::ONE {
return Decimal::ONE;
}
let base = Decimal::ONE - cpr;
Decimal::ONE - nth_root(base, 12)
}
fn smm_to_cpr(smm: Rate) -> Rate {
if smm <= Decimal::ZERO {
return Decimal::ZERO;
}
if smm >= Decimal::ONE {
return Decimal::ONE;
}
Decimal::ONE - iterative_pow(Decimal::ONE - smm, 12)
}
fn compute_cdr(age: u32, model: &DefaultModel) -> Rate {
match model {
DefaultModel::Cdr(cdr) => *cdr,
DefaultModel::Sda(speed) => {
let base_cdr = if age <= 30 {
SDA_PEAK_CDR_30 * Decimal::from(age) / dec!(30)
} else if age <= 60 {
SDA_PEAK_CDR_30
} else if age <= 120 {
let months_into_decline = Decimal::from(age - 60);
let decline_range = SDA_PEAK_CDR_30 - SDA_FLOOR_CDR;
SDA_PEAK_CDR_30 - decline_range * months_into_decline / dec!(60)
} else {
SDA_FLOOR_CDR
};
base_cdr * *speed / dec!(100)
}
DefaultModel::None => Decimal::ZERO,
}
}
fn cdr_to_mdr(cdr: Rate) -> Rate {
if cdr <= Decimal::ZERO {
return Decimal::ZERO;
}
if cdr >= Decimal::ONE {
return Decimal::ONE;
}
let base = Decimal::ONE - cdr;
Decimal::ONE - nth_root(base, 12)
}
fn iterative_pow(base: Decimal, n: u32) -> Decimal {
let mut result = Decimal::ONE;
for _ in 0..n {
result *= base;
}
result
}
fn iterative_pow_recip(base: Decimal, n: u32) -> Decimal {
let pow = iterative_pow(base, n);
if pow.is_zero() {
Decimal::ZERO
} else {
Decimal::ONE / pow
}
}
fn nth_root(x: Decimal, n: u32) -> Decimal {
if x == Decimal::ONE {
return Decimal::ONE;
}
if x == Decimal::ZERO {
return Decimal::ZERO;
}
if n == 0 {
return Decimal::ONE;
}
if n == 1 {
return x;
}
let n_dec = Decimal::from(n);
let n_minus_1 = n - 1;
let mut guess = Decimal::ONE;
for _ in 0..40 {
let g_n_minus_1 = iterative_pow(guess, n_minus_1);
let g_n = g_n_minus_1 * guess;
if g_n_minus_1.is_zero() {
break;
}
let delta = (g_n - x) / (n_dec * g_n_minus_1);
guess -= delta;
if delta.abs() < dec!(0.0000000000001) {
break;
}
}
guess
}
fn zero_period(month: u32, balance: Money) -> AbsPeriod {
AbsPeriod {
month,
beginning_balance: balance,
scheduled_principal: Decimal::ZERO,
scheduled_interest: Decimal::ZERO,
prepayment: Decimal::ZERO,
defaults: Decimal::ZERO,
loss: Decimal::ZERO,
recovery: Decimal::ZERO,
servicing_fee: Decimal::ZERO,
total_principal: Decimal::ZERO,
total_cashflow: Decimal::ZERO,
ending_balance: balance,
smm: Decimal::ZERO,
cpr: Decimal::ZERO,
mdr: Decimal::ZERO,
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
const TOL: Decimal = dec!(0.01);
fn standard_input() -> AbsMbsInput {
AbsMbsInput {
pool_balance: dec!(1_000_000),
weighted_avg_coupon: dec!(0.06),
weighted_avg_maturity_months: 360,
weighted_avg_age_months: 0,
num_loans: 1000,
prepayment_model: PrepaymentModel::Cpr(dec!(0.0)),
default_model: DefaultModel::None,
loss_severity: dec!(0.40),
recovery_lag_months: 6,
servicing_fee_rate: dec!(0.0025),
projection_months: 360,
}
}
fn assert_close(actual: Decimal, expected: Decimal, tol: Decimal, msg: &str) {
let diff = (actual - expected).abs();
assert!(
diff <= tol,
"{}: expected ~{}, got {} (diff = {})",
msg,
expected,
actual,
diff
);
}
#[test]
fn test_basic_amortisation_no_prepay_no_defaults() {
let input = standard_input();
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_eq!(out.periods[0].beginning_balance, dec!(1_000_000));
let last = out.periods.last().unwrap();
assert!(
last.ending_balance < dec!(1.0),
"Final balance should be near zero, got {}",
last.ending_balance
);
assert_close(
out.summary.total_principal_collected,
dec!(1_000_000),
dec!(1.0),
"Total principal collected should equal pool balance",
);
assert_eq!(out.summary.total_prepayments, Decimal::ZERO);
assert_eq!(out.summary.total_defaults, Decimal::ZERO);
}
#[test]
fn test_cpr_prepayment_model() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Cpr(dec!(0.06));
input.projection_months = 120;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(out.summary.total_prepayments > Decimal::ZERO);
let first_smm = out.periods[0].smm;
for p in &out.periods {
if p.beginning_balance > BALANCE_EPSILON {
assert_close(
p.smm,
first_smm,
dec!(0.000001),
"SMM should be constant for CPR",
);
}
}
assert_close(
out.periods[0].cpr,
dec!(0.06),
dec!(0.0001),
"CPR should be 6%",
);
}
#[test]
fn test_cpr_various_speeds() {
for cpr_val in [dec!(0.02), dec!(0.10), dec!(0.20)] {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Cpr(cpr_val);
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.summary.total_prepayments > Decimal::ZERO,
"CPR {} should produce prepayments",
cpr_val
);
assert_close(
out.periods[0].cpr,
cpr_val,
dec!(0.0001),
&format!("CPR should be {}", cpr_val),
);
}
}
#[test]
fn test_psa_100_model() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Psa(dec!(100));
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_close(
out.periods[0].cpr,
dec!(0.002),
dec!(0.0001),
"PSA 100 month 1 CPR",
);
assert_close(
out.periods[29].cpr,
dec!(0.06),
dec!(0.0001),
"PSA 100 month 30 CPR",
);
assert_close(
out.periods[39].cpr,
dec!(0.06),
dec!(0.0001),
"PSA 100 month 40 CPR",
);
}
#[test]
fn test_psa_200_model() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Psa(dec!(200));
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_close(
out.periods[29].cpr,
dec!(0.12),
dec!(0.0001),
"PSA 200 month 30 CPR",
);
let mut input_100 = standard_input();
input_100.prepayment_model = PrepaymentModel::Psa(dec!(100));
input_100.projection_months = 60;
let result_100 = model_abs_cashflows(&input_100).unwrap();
assert!(
out.summary.total_prepayments > result_100.result.summary.total_prepayments,
"PSA 200 should have more prepayments than PSA 100"
);
}
#[test]
fn test_psa_50_model() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Psa(dec!(50));
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_close(
out.periods[29].cpr,
dec!(0.03),
dec!(0.0001),
"PSA 50 month 30 CPR",
);
}
#[test]
fn test_sda_default_model() {
let mut input = standard_input();
input.default_model = DefaultModel::Sda(dec!(100));
input.projection_months = 120;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.summary.total_defaults > Decimal::ZERO,
"SDA 100 should produce defaults"
);
assert!(
out.summary.total_losses > Decimal::ZERO,
"SDA 100 should produce losses"
);
if out.summary.total_defaults > Decimal::ZERO {
let avg_severity = out.summary.total_losses / out.summary.total_defaults;
assert_close(
avg_severity,
dec!(0.40),
dec!(0.01),
"Average loss severity should equal input loss severity",
);
}
}
#[test]
fn test_cdr_constant_default_model() {
let mut input = standard_input();
input.default_model = DefaultModel::Cdr(dec!(0.02));
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.summary.total_defaults > Decimal::ZERO,
"CDR 2% should produce defaults"
);
let first_mdr = out.periods[0].mdr;
for p in &out.periods {
if p.beginning_balance > BALANCE_EPSILON {
assert_close(
p.mdr,
first_mdr,
dec!(0.000001),
"MDR should be constant for CDR",
);
}
}
}
#[test]
fn test_combined_prepayment_and_default() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Psa(dec!(150));
input.default_model = DefaultModel::Cdr(dec!(0.03));
input.projection_months = 120;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(out.summary.total_prepayments > Decimal::ZERO);
assert!(out.summary.total_defaults > Decimal::ZERO);
assert!(out.summary.total_losses > Decimal::ZERO);
for window in out.periods.windows(2) {
assert!(
window[1].beginning_balance <= window[0].beginning_balance + TOL,
"Balance should decrease: month {} = {}, month {} = {}",
window[0].month,
window[0].beginning_balance,
window[1].month,
window[1].beginning_balance,
);
}
}
#[test]
fn test_wal_calculation() {
let input = standard_input();
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.summary.weighted_average_life_years > dec!(10),
"WAL should be > 10 years, got {}",
out.summary.weighted_average_life_years
);
assert!(
out.summary.weighted_average_life_years < dec!(20),
"WAL should be < 20 years, got {}",
out.summary.weighted_average_life_years
);
let mut input_fast = standard_input();
input_fast.prepayment_model = PrepaymentModel::Cpr(dec!(0.15));
let result_fast = model_abs_cashflows(&input_fast).unwrap();
assert!(
result_fast.result.summary.weighted_average_life_years
< out.summary.weighted_average_life_years,
"Higher prepayment should produce shorter WAL"
);
}
#[test]
fn test_pool_factor() {
let input = standard_input();
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.summary.pool_factor_at_end < dec!(0.001),
"Pool factor at end should be near zero, got {}",
out.summary.pool_factor_at_end
);
}
#[test]
fn test_pool_factor_partial_projection() {
let mut input = standard_input();
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.summary.pool_factor_at_end > dec!(0.5),
"Pool factor after 60 months should be > 50%, got {}",
out.summary.pool_factor_at_end
);
}
#[test]
fn test_recovery_lag() {
let mut input = standard_input();
input.default_model = DefaultModel::Cdr(dec!(0.05));
input.recovery_lag_months = 6;
input.projection_months = 24;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
for i in 0..6 {
assert_eq!(
out.periods[i].recovery,
Decimal::ZERO,
"Month {} should have zero recovery (lag = 6)",
i + 1
);
}
assert!(
out.periods[6].recovery > Decimal::ZERO,
"Month 7 should have non-zero recovery"
);
}
#[test]
fn test_recovery_lag_zero() {
let mut input = standard_input();
input.default_model = DefaultModel::Cdr(dec!(0.05));
input.recovery_lag_months = 0;
input.projection_months = 12;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.periods[0].recovery > Decimal::ZERO,
"Month 1 should have immediate recovery with lag=0"
);
for p in &out.periods {
if p.defaults > Decimal::ZERO {
let expected_recovery = p.defaults * (Decimal::ONE - dec!(0.40));
assert_close(
p.recovery,
expected_recovery,
dec!(0.01),
&format!("Month {} recovery", p.month),
);
}
}
}
#[test]
fn test_servicing_fee() {
let input = standard_input();
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
let expected_fee = dec!(1_000_000) * dec!(0.0025) / dec!(12);
assert_close(
out.periods[0].servicing_fee,
expected_fee,
dec!(0.01),
"First month servicing fee",
);
assert!(
out.summary.total_servicing_fees > Decimal::ZERO,
"Total servicing fees should be positive"
);
}
#[test]
fn test_zero_prepayment() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Cpr(dec!(0.0));
input.projection_months = 360;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_eq!(out.summary.total_prepayments, Decimal::ZERO);
for p in &out.periods {
assert_eq!(p.prepayment, Decimal::ZERO);
assert_eq!(p.smm, Decimal::ZERO);
}
}
#[test]
fn test_very_high_prepayment() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Cpr(dec!(0.95));
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
let last_period = out.periods.last().unwrap();
assert!(
last_period.ending_balance < dec!(100),
"Very high CPR should exhaust pool quickly, got {}",
last_period.ending_balance
);
assert!(
out.summary.weighted_average_life_years < dec!(2.0),
"WAL with 95% CPR should be very short, got {}",
out.summary.weighted_average_life_years
);
}
#[test]
fn test_high_default_rate() {
let mut input = standard_input();
input.default_model = DefaultModel::Cdr(dec!(0.20));
input.projection_months = 60;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert!(
out.summary.total_defaults > Decimal::ZERO,
"20% CDR should produce significant defaults"
);
assert!(
out.summary.cumulative_loss_rate > dec!(0.01),
"Cumulative loss rate should be significant"
);
}
#[test]
fn test_balance_never_negative() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Cpr(dec!(0.30));
input.default_model = DefaultModel::Cdr(dec!(0.10));
input.projection_months = 120;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
for p in &out.periods {
assert!(
p.ending_balance >= Decimal::ZERO,
"Month {}: ending balance should not be negative, got {}",
p.month,
p.ending_balance
);
assert!(
p.beginning_balance >= Decimal::ZERO,
"Month {}: beginning balance should not be negative, got {}",
p.month,
p.beginning_balance
);
}
}
#[test]
fn test_principal_plus_defaults_equals_starting_balance() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Psa(dec!(150));
input.default_model = DefaultModel::Cdr(dec!(0.03));
input.projection_months = 360;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
let total_reduction = out.summary.total_principal_collected
+ out.summary.total_defaults
+ out.periods.last().unwrap().ending_balance;
assert_close(
total_reduction,
dec!(1_000_000),
dec!(1.0),
"Principal + defaults + remaining balance should equal pool balance",
);
}
#[test]
fn test_smm_prepayment_model() {
let mut input = standard_input();
input.prepayment_model = PrepaymentModel::Smm(dec!(0.005));
input.projection_months = 24;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_close(
out.periods[0].smm,
dec!(0.005),
dec!(0.0001),
"SMM should be 0.005",
);
assert!(out.summary.total_prepayments > Decimal::ZERO);
}
#[test]
fn test_cumulative_loss_rate() {
let mut input = standard_input();
input.default_model = DefaultModel::Cdr(dec!(0.04));
input.projection_months = 120;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
let expected = out.summary.total_losses / dec!(1_000_000);
assert_close(
out.summary.cumulative_loss_rate,
expected,
dec!(0.0001),
"Cumulative loss rate",
);
}
#[test]
fn test_interest_first_month() {
let input = standard_input();
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
let expected_interest = dec!(1_000_000) * dec!(0.06) / dec!(12);
assert_close(
out.periods[0].scheduled_interest,
expected_interest,
dec!(0.01),
"First month scheduled interest",
);
}
#[test]
fn test_total_cashflow_composition() {
let mut input = standard_input();
input.default_model = DefaultModel::Cdr(dec!(0.02));
input.recovery_lag_months = 0;
input.prepayment_model = PrepaymentModel::Cpr(dec!(0.05));
input.projection_months = 24;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
for p in &out.periods {
let expected_cf =
p.scheduled_interest + p.total_principal - p.servicing_fee + p.recovery;
assert_close(
p.total_cashflow,
expected_cf,
dec!(0.01),
&format!("Month {} total cashflow composition", p.month),
);
}
}
#[test]
fn test_validation_negative_pool_balance() {
let mut input = standard_input();
input.pool_balance = dec!(-100);
let result = model_abs_cashflows(&input);
assert!(result.is_err());
match result.unwrap_err() {
CorpFinanceError::InvalidInput { field, .. } => {
assert_eq!(field, "pool_balance");
}
other => panic!("Expected InvalidInput, got {:?}", other),
}
}
#[test]
fn test_validation_loss_severity_out_of_range() {
let mut input = standard_input();
input.loss_severity = dec!(1.5);
let result = model_abs_cashflows(&input);
assert!(result.is_err());
match result.unwrap_err() {
CorpFinanceError::InvalidInput { field, .. } => {
assert_eq!(field, "loss_severity");
}
other => panic!("Expected InvalidInput, got {:?}", other),
}
}
#[test]
fn test_metadata_populated() {
let input = standard_input();
let result = model_abs_cashflows(&input).unwrap();
assert!(!result.methodology.is_empty());
assert!(result.methodology.contains("ABS/MBS"));
assert_eq!(result.metadata.precision, "rust_decimal_128bit");
}
#[test]
fn test_correct_number_of_periods() {
let mut input = standard_input();
input.projection_months = 120;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_eq!(out.periods.len(), 120);
assert_eq!(out.periods[0].month, 1);
assert_eq!(out.periods[119].month, 120);
}
#[test]
fn test_sda_curve_shape() {
let cdr_1 = compute_cdr(1, &DefaultModel::Sda(dec!(100)));
assert_close(cdr_1, dec!(0.0002), dec!(0.00001), "SDA age 1 CDR");
let cdr_30 = compute_cdr(30, &DefaultModel::Sda(dec!(100)));
assert_close(cdr_30, dec!(0.006), dec!(0.00001), "SDA age 30 CDR");
let cdr_45 = compute_cdr(45, &DefaultModel::Sda(dec!(100)));
assert_close(cdr_45, dec!(0.006), dec!(0.00001), "SDA age 45 CDR");
let cdr_90 = compute_cdr(90, &DefaultModel::Sda(dec!(100)));
assert_close(cdr_90, dec!(0.00315), dec!(0.0001), "SDA age 90 CDR");
let cdr_120 = compute_cdr(120, &DefaultModel::Sda(dec!(100)));
assert_close(cdr_120, dec!(0.0003), dec!(0.00001), "SDA age 120 CDR");
let cdr_150 = compute_cdr(150, &DefaultModel::Sda(dec!(100)));
assert_close(cdr_150, dec!(0.0003), dec!(0.00001), "SDA age 150 CDR");
}
#[test]
fn test_wala_offset() {
let mut input = standard_input();
input.weighted_avg_age_months = 24;
input.prepayment_model = PrepaymentModel::Psa(dec!(100));
input.projection_months = 36;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_close(
out.periods[0].cpr,
dec!(0.05),
dec!(0.001),
"PSA CPR at age 25",
);
assert_close(
out.periods[5].cpr,
dec!(0.06),
dec!(0.001),
"PSA CPR at age 30",
);
}
#[test]
fn test_zero_coupon_pool() {
let mut input = standard_input();
input.weighted_avg_coupon = dec!(0.0);
input.projection_months = 360;
let result = model_abs_cashflows(&input).unwrap();
let out = &result.result;
assert_eq!(out.summary.total_interest_collected, Decimal::ZERO);
for p in &out.periods {
assert_eq!(p.scheduled_interest, Decimal::ZERO);
}
assert_close(
out.summary.total_principal_collected,
dec!(1_000_000),
dec!(1.0),
"Zero coupon pool should still amortise",
);
}
#[test]
fn test_validation_zero_projection_months() {
let mut input = standard_input();
input.projection_months = 0;
let result = model_abs_cashflows(&input);
assert!(result.is_err());
}
#[test]
fn test_nth_root_precision() {
let root = nth_root(dec!(0.94), 12);
let reconstructed = iterative_pow(root, 12);
assert_close(
reconstructed,
dec!(0.94),
dec!(0.000001),
"12th root of 0.94 reconstruction",
);
let root2 = nth_root(dec!(0.98), 12);
let reconstructed2 = iterative_pow(root2, 12);
assert_close(
reconstructed2,
dec!(0.98),
dec!(0.000001),
"12th root of 0.98 reconstruction",
);
}
#[test]
fn test_cpr_smm_round_trip() {
for cpr in [dec!(0.02), dec!(0.06), dec!(0.10), dec!(0.20)] {
let smm = cpr_to_smm(cpr);
let cpr_back = smm_to_cpr(smm);
assert_close(
cpr_back,
cpr,
dec!(0.0001),
&format!("CPR-SMM round trip for CPR={}", cpr),
);
}
}
}