use crate::compat::Instant;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use serde::{Deserialize, Serialize};
use crate::error::CorpFinanceError;
use crate::types::*;
use crate::CorpFinanceResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaterfallInput {
pub total_proceeds: Money,
pub total_invested: Money,
pub tiers: Vec<WaterfallTier>,
pub gp_commitment_pct: Rate,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaterfallTier {
pub name: String,
pub tier_type: WaterfallTierType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WaterfallTierType {
ReturnOfCapital,
PreferredReturn { rate: Rate },
CatchUp { gp_share: Rate },
CarriedInterest { gp_share: Rate },
Residual { gp_share: Rate },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaterfallOutput {
pub tiers: Vec<WaterfallTierResult>,
pub total_to_gp: Money,
pub total_to_lp: Money,
pub gp_pct_of_total: Rate,
pub lp_pct_of_total: Rate,
pub gp_carry: Money,
pub gp_co_invest_return: Money,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaterfallTierResult {
pub tier_name: String,
pub amount: Money,
pub to_gp: Money,
pub to_lp: Money,
pub remaining: Money,
}
pub fn calculate_waterfall(
input: &WaterfallInput,
) -> CorpFinanceResult<ComputationOutput<WaterfallOutput>> {
let start = Instant::now();
let warnings: Vec<String> = Vec::new();
if input.total_proceeds < Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "total_proceeds".into(),
reason: "Total proceeds cannot be negative".into(),
});
}
if input.total_invested <= Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "total_invested".into(),
reason: "Total invested must be positive".into(),
});
}
if input.gp_commitment_pct < Decimal::ZERO || input.gp_commitment_pct > Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "gp_commitment_pct".into(),
reason: "GP commitment percentage must be between 0 and 1".into(),
});
}
if input.tiers.is_empty() {
return Err(CorpFinanceError::InvalidInput {
field: "tiers".into(),
reason: "At least one waterfall tier is required".into(),
});
}
let gp_pct = input.gp_commitment_pct;
let mut remaining = input.total_proceeds;
let mut tier_results: Vec<WaterfallTierResult> = Vec::new();
let mut gp_co_invest_roc = Decimal::ZERO;
let mut gp_co_invest_pref = Decimal::ZERO;
let mut cumulative_lp_preferred = Decimal::ZERO;
let carry_rate = input
.tiers
.iter()
.find_map(|t| match &t.tier_type {
WaterfallTierType::CarriedInterest { gp_share } => Some(*gp_share),
WaterfallTierType::Residual { gp_share } => Some(*gp_share),
_ => None,
})
.unwrap_or(dec!(0.20));
for tier in &input.tiers {
let tier_result = match &tier.tier_type {
WaterfallTierType::ReturnOfCapital => {
let distributable = remaining.min(input.total_invested);
let to_gp = distributable * gp_pct;
let to_lp = distributable - to_gp;
remaining -= distributable;
gp_co_invest_roc = to_gp;
WaterfallTierResult {
tier_name: tier.name.clone(),
amount: distributable,
to_gp,
to_lp,
remaining,
}
}
WaterfallTierType::PreferredReturn { rate } => {
let preferred_total = input.total_invested * *rate;
let distributable = remaining.min(preferred_total);
let to_gp = distributable * gp_pct;
let to_lp = distributable - to_gp;
remaining -= distributable;
gp_co_invest_pref = to_gp;
cumulative_lp_preferred = to_lp;
WaterfallTierResult {
tier_name: tier.name.clone(),
amount: distributable,
to_gp,
to_lp,
remaining,
}
}
WaterfallTierType::CatchUp { gp_share } => {
let target_catchup = if carry_rate < Decimal::ONE {
(carry_rate / (Decimal::ONE - carry_rate)) * cumulative_lp_preferred
} else {
remaining };
let distributable = remaining.min(target_catchup).max(Decimal::ZERO);
let to_gp = distributable * *gp_share;
let to_lp = distributable - to_gp;
remaining -= distributable;
WaterfallTierResult {
tier_name: tier.name.clone(),
amount: distributable,
to_gp,
to_lp,
remaining,
}
}
WaterfallTierType::CarriedInterest { gp_share }
| WaterfallTierType::Residual { gp_share } => {
let distributable = remaining;
let to_gp = distributable * *gp_share;
let to_lp = distributable - to_gp;
remaining = Decimal::ZERO;
WaterfallTierResult {
tier_name: tier.name.clone(),
amount: distributable,
to_gp,
to_lp,
remaining,
}
}
};
tier_results.push(tier_result);
}
let total_to_gp: Money = tier_results.iter().map(|t| t.to_gp).sum();
let total_to_lp: Money = tier_results.iter().map(|t| t.to_lp).sum();
let (gp_pct_of_total, lp_pct_of_total) = if input.total_proceeds.is_zero() {
(Decimal::ZERO, Decimal::ZERO)
} else {
(
total_to_gp / input.total_proceeds,
total_to_lp / input.total_proceeds,
)
};
let gp_co_invest_return = gp_co_invest_roc + gp_co_invest_pref;
let gp_carry = total_to_gp - gp_co_invest_return;
let output = WaterfallOutput {
tiers: tier_results,
total_to_gp,
total_to_lp,
gp_pct_of_total,
lp_pct_of_total,
gp_carry,
gp_co_invest_return,
};
let elapsed = start.elapsed().as_micros() as u64;
Ok(with_metadata(
"PE Cash-Flow Waterfall (European)",
&serde_json::json!({
"total_proceeds": input.total_proceeds.to_string(),
"total_invested": input.total_invested.to_string(),
"gp_commitment_pct": input.gp_commitment_pct.to_string(),
"num_tiers": input.tiers.len(),
}),
warnings,
elapsed,
output,
))
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
fn european_waterfall(
total_proceeds: Money,
total_invested: Money,
gp_commitment_pct: Rate,
) -> WaterfallInput {
WaterfallInput {
total_proceeds,
total_invested,
tiers: vec![
WaterfallTier {
name: "Return of Capital".into(),
tier_type: WaterfallTierType::ReturnOfCapital,
},
WaterfallTier {
name: "Preferred Return".into(),
tier_type: WaterfallTierType::PreferredReturn { rate: dec!(0.08) },
},
WaterfallTier {
name: "GP Catch-Up".into(),
tier_type: WaterfallTierType::CatchUp {
gp_share: dec!(1.0),
},
},
WaterfallTier {
name: "Carried Interest".into(),
tier_type: WaterfallTierType::CarriedInterest {
gp_share: dec!(0.20),
},
},
],
gp_commitment_pct,
}
}
#[test]
fn test_basic_european_waterfall() {
let input = european_waterfall(dec!(200), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[0].amount, dec!(100));
assert_eq!(out.tiers[0].to_gp, dec!(2)); assert_eq!(out.tiers[0].to_lp, dec!(98));
assert_eq!(out.tiers[0].remaining, dec!(100));
assert_eq!(out.tiers[1].amount, dec!(8));
assert_eq!(out.tiers[1].to_gp, dec!(0.16));
assert_eq!(out.tiers[1].to_lp, dec!(7.84));
assert_eq!(out.tiers[1].remaining, dec!(92));
let expected_catchup = dec!(0.20) / dec!(0.80) * dec!(7.84);
assert_eq!(out.tiers[2].amount, expected_catchup);
assert_eq!(out.tiers[2].to_gp, expected_catchup);
assert_eq!(out.tiers[2].to_lp, dec!(0));
let remaining_after_catchup = dec!(92) - expected_catchup;
assert_eq!(out.tiers[3].amount, remaining_after_catchup);
assert_eq!(out.tiers[3].to_gp, remaining_after_catchup * dec!(0.20));
assert_eq!(
out.tiers[3].to_lp,
remaining_after_catchup - remaining_after_catchup * dec!(0.20)
);
assert_eq!(out.total_to_gp + out.total_to_lp, dec!(200));
assert!(out.gp_carry > Decimal::ZERO);
assert_eq!(out.gp_co_invest_return, dec!(2) + dec!(0.16));
assert_eq!(out.gp_carry, out.total_to_gp - out.gp_co_invest_return);
}
#[test]
fn test_return_of_capital_only() {
let input = european_waterfall(dec!(60), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[0].amount, dec!(60));
assert_eq!(out.tiers[0].to_gp, dec!(1.2)); assert_eq!(out.tiers[0].to_lp, dec!(58.8));
for tier in &out.tiers[1..] {
assert_eq!(tier.amount, Decimal::ZERO);
}
assert_eq!(out.total_to_gp + out.total_to_lp, dec!(60));
assert_eq!(out.gp_carry, Decimal::ZERO);
}
#[test]
fn test_no_carry_below_hurdle() {
let input = european_waterfall(dec!(105), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[0].amount, dec!(100));
assert_eq!(out.tiers[1].amount, dec!(5));
assert_eq!(out.tiers[2].amount, Decimal::ZERO);
assert_eq!(out.tiers[3].amount, Decimal::ZERO);
assert_eq!(out.gp_carry, Decimal::ZERO);
}
#[test]
fn test_full_catch_up() {
let input = european_waterfall(dec!(110), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[0].amount, dec!(100));
assert_eq!(out.tiers[1].amount, dec!(8));
assert_eq!(out.tiers[1].to_lp, dec!(7.84));
let expected_catchup = dec!(0.25) * dec!(7.84); assert_eq!(out.tiers[2].amount, expected_catchup);
assert_eq!(out.tiers[2].to_gp, expected_catchup);
let carry_remaining = dec!(2) - expected_catchup;
assert_eq!(out.tiers[3].amount, carry_remaining);
assert_eq!(out.total_to_gp + out.total_to_lp, dec!(110));
}
#[test]
fn test_gp_commitment_allocation() {
let input = european_waterfall(dec!(200), dec!(100), dec!(0.05));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[0].to_gp, dec!(5));
assert_eq!(out.tiers[0].to_lp, dec!(95));
assert_eq!(out.tiers[1].to_gp, dec!(0.40));
assert_eq!(out.tiers[1].to_lp, dec!(7.60));
assert_eq!(out.gp_co_invest_return, dec!(5.40));
}
#[test]
fn test_zero_proceeds() {
let input = european_waterfall(dec!(0), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
for tier in &out.tiers {
assert_eq!(tier.amount, Decimal::ZERO);
assert_eq!(tier.to_gp, Decimal::ZERO);
assert_eq!(tier.to_lp, Decimal::ZERO);
}
assert_eq!(out.total_to_gp, Decimal::ZERO);
assert_eq!(out.total_to_lp, Decimal::ZERO);
assert_eq!(out.gp_pct_of_total, Decimal::ZERO);
assert_eq!(out.lp_pct_of_total, Decimal::ZERO);
assert_eq!(out.gp_carry, Decimal::ZERO);
}
#[test]
fn test_high_return_scenario() {
let input = european_waterfall(dec!(300), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[0].amount, dec!(100));
assert_eq!(out.tiers[1].amount, dec!(8));
let catchup = dec!(0.25) * dec!(7.84);
assert_eq!(out.tiers[2].amount, catchup);
let carry_base = dec!(192) - catchup;
assert_eq!(out.tiers[3].amount, carry_base);
assert_eq!(out.tiers[3].to_gp, carry_base * dec!(0.20));
assert_eq!(out.tiers[3].to_lp, carry_base - carry_base * dec!(0.20));
assert_eq!(out.total_to_gp + out.total_to_lp, dec!(300));
assert!(out.gp_carry > dec!(35)); }
#[test]
fn test_no_preferred_return() {
let input = WaterfallInput {
total_proceeds: dec!(200),
total_invested: dec!(100),
tiers: vec![
WaterfallTier {
name: "Return of Capital".into(),
tier_type: WaterfallTierType::ReturnOfCapital,
},
WaterfallTier {
name: "Profit Split".into(),
tier_type: WaterfallTierType::CarriedInterest {
gp_share: dec!(0.20),
},
},
],
gp_commitment_pct: dec!(0.02),
};
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[0].amount, dec!(100));
assert_eq!(out.tiers[1].amount, dec!(100));
assert_eq!(out.tiers[1].to_gp, dec!(20));
assert_eq!(out.tiers[1].to_lp, dec!(80));
assert_eq!(out.gp_co_invest_return, dec!(2));
assert_eq!(out.gp_carry, out.total_to_gp - dec!(2));
assert_eq!(out.total_to_gp + out.total_to_lp, dec!(200));
}
#[test]
fn test_invalid_negative_proceeds() {
let input = european_waterfall(dec!(-50), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input);
assert!(result.is_err());
match result.unwrap_err() {
CorpFinanceError::InvalidInput { field, .. } => {
assert_eq!(field, "total_proceeds");
}
other => panic!("Expected InvalidInput, got: {other:?}"),
}
}
#[test]
fn test_invalid_zero_invested() {
let input = WaterfallInput {
total_proceeds: dec!(100),
total_invested: dec!(0),
tiers: vec![WaterfallTier {
name: "ROC".into(),
tier_type: WaterfallTierType::ReturnOfCapital,
}],
gp_commitment_pct: dec!(0.02),
};
let result = calculate_waterfall(&input);
assert!(result.is_err());
}
#[test]
fn test_invalid_gp_commitment_pct() {
let input = WaterfallInput {
total_proceeds: dec!(100),
total_invested: dec!(50),
tiers: vec![WaterfallTier {
name: "ROC".into(),
tier_type: WaterfallTierType::ReturnOfCapital,
}],
gp_commitment_pct: dec!(1.5), };
let result = calculate_waterfall(&input);
assert!(result.is_err());
}
#[test]
fn test_residual_tier() {
let input = WaterfallInput {
total_proceeds: dec!(150),
total_invested: dec!(100),
tiers: vec![
WaterfallTier {
name: "Return of Capital".into(),
tier_type: WaterfallTierType::ReturnOfCapital,
},
WaterfallTier {
name: "Residual Split".into(),
tier_type: WaterfallTierType::Residual {
gp_share: dec!(0.30),
},
},
],
gp_commitment_pct: dec!(0.01),
};
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
assert_eq!(out.tiers[1].amount, dec!(50));
assert_eq!(out.tiers[1].to_gp, dec!(15)); assert_eq!(out.tiers[1].to_lp, dec!(35)); assert_eq!(out.total_to_gp + out.total_to_lp, dec!(150));
}
#[test]
fn test_pct_of_total() {
let input = european_waterfall(dec!(200), dec!(100), dec!(0.02));
let result = calculate_waterfall(&input).unwrap();
let out = &result.result;
let sum = out.gp_pct_of_total + out.lp_pct_of_total;
assert_eq!(sum, Decimal::ONE);
}
}