#![allow(non_snake_case)]
use crate::xml::MortXML;
use polars::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AssumptionEnum {
UDD,
CFM,
HPB,
}
#[derive(Debug, Clone)]
pub struct MortTableConfig {
pub xml: MortXML,
pub radix: Option<i32>,
pub pct: Option<f64>,
pub int_rate: Option<f64>,
pub assumption: Option<AssumptionEnum>,
}
impl MortTableConfig {
pub fn gen_mort_table(&self, detail_level: i32) -> PolarsResult<DataFrame> {
let tables_count = self.xml.tables.len();
if tables_count < 1 {
return Err(PolarsError::ComputeError(
"No tables found in MortXML".into(),
));
}
if tables_count > 1 {
return Err(PolarsError::ComputeError(
"MortXML must contain exactly one table".into(),
));
}
if detail_level > 1 && self.int_rate.is_none() {
return Err(PolarsError::ComputeError(
"Interest rate is required for detail level 2.".into(),
));
}
match detail_level {
1 => gen_demographic_movement_level_1(self.clone()),
2 => {
let df = gen_demographic_movement_level_1(self.clone())?;
gen_commutation_level_2(df, self.int_rate.unwrap())
}
3 => {
let df = gen_demographic_movement_level_1(self.clone())?;
let df = gen_commutation_level_2(df, self.int_rate.unwrap())?;
gen_commutation_level_3(df)
}
4 => {
let df = gen_demographic_movement_level_1(self.clone())?;
let df = gen_commutation_level_2(df, self.int_rate.unwrap())?;
let df = gen_commutation_level_3(df)?;
gen_commutation_level_4(df)
}
_ => Err(PolarsError::ComputeError(
"Invalid detail level specified. Valid levels are 1-4.".into(),
)),
}
}
}
fn gen_demographic_movement_level_1(config: MortTableConfig) -> PolarsResult<DataFrame> {
let df = &config.xml.tables[0].values;
if df.get_column_names().contains(&&"lx".into()) {
_gen_demographic_movement_life_table_content(config)
} else {
_gen_demographic_movement_other_content(config)
}
}
fn _gen_demographic_movement_life_table_content(
config: MortTableConfig,
) -> PolarsResult<DataFrame> {
let df = config.xml.tables[0].values.clone();
let pct = config.pct.unwrap_or(1.0);
let age: Vec<u32> = df
.column("age")?
.u32()?
.into_iter()
.map(|v| v.unwrap())
.collect();
let lx: Vec<f64> = df
.column("lx")?
.f64()?
.into_iter()
.map(|v| v.unwrap())
.collect();
let mut qx: Vec<f64> = Vec::with_capacity(age.len());
let mut px: Vec<f64> = Vec::with_capacity(age.len());
let mut dx: Vec<f64> = Vec::with_capacity(age.len());
for i in 0..age.len() - 1 {
let dx_value = (lx[i] - lx[i + 1]).round();
dx.push(dx_value);
let qx_value = dx_value / lx[i] * pct;
qx.push(qx_value);
px.push(1.0 - qx_value);
}
qx.push(1.0);
px.push(0.0);
let result = DataFrame::new(vec![
Series::new("age".into(), age).into_column(),
Series::new("qx".into(), qx).into_column(),
Series::new("px".into(), px).into_column(),
Series::new("lx".into(), lx).into_column(),
Series::new("dx".into(), dx).into_column(),
])?;
Ok(result)
}
fn _gen_demographic_movement_other_content(config: MortTableConfig) -> PolarsResult<DataFrame> {
let df = config.xml.tables[0].values.clone();
let pct = config.pct.unwrap_or(1.0);
let radix = config.radix;
let age: Vec<u32> = df
.column("age")?
.u32()?
.into_iter()
.map(|v| v.unwrap())
.collect();
let qx: Vec<f64> = df
.column("qx")?
.f64()?
.into_iter()
.map(|v| v.unwrap())
.collect();
let mut qx_new: Vec<f64> = Vec::with_capacity(age.len());
let mut px: Vec<f64> = Vec::with_capacity(age.len());
let mut lx: Vec<f64> = Vec::with_capacity(age.len());
let mut dx: Vec<f64> = Vec::with_capacity(age.len());
if let Some(radix_val) = radix {
lx.push(radix_val as f64);
} else {
lx.push(100_000.0);
}
for i in 0..age.len() {
let qx_val = qx[i] * pct; qx_new.push(qx_val);
px.push(1.0 - qx_val);
if i > 0 {
let lx_value = lx[i - 1] - dx[i - 1];
lx.push(lx_value);
}
let dx_value = lx[i] * qx_val;
dx.push(dx_value);
}
let result = DataFrame::new(vec![
Series::new("age".into(), age).into_column(),
Series::new("qx".into(), qx).into_column(),
Series::new("px".into(), px).into_column(),
Series::new("lx".into(), lx).into_column(),
Series::new("dx".into(), dx).into_column(),
])?;
Ok(result)
}
fn gen_commutation_level_2(
df: DataFrame,
int_rate: f64, ) -> PolarsResult<DataFrame> {
let age = df
.column("age")?
.u32()?
.into_iter()
.map(|v| v.unwrap())
.collect::<Vec<u32>>();
let lx = df
.column("lx")?
.f64()?
.into_iter()
.map(|v| v.unwrap())
.collect::<Vec<f64>>();
let dx = df
.column("dx")?
.f64()?
.into_iter()
.map(|v| v.unwrap())
.collect::<Vec<f64>>();
let mut Dx: Vec<f64> = Vec::with_capacity(age.len());
let mut Cx: Vec<f64> = Vec::with_capacity(age.len());
for i in 0..age.len() {
let age_f64 = age[i] as f64;
let cx_value = dx[i] / (1.0 + int_rate).powf(age_f64 + 1.0);
Cx.push(cx_value);
let dx_value = lx[i] / (1.0 + int_rate).powf(age_f64);
Dx.push(dx_value);
}
let new_df = DataFrame::new(vec![
Series::new("Cx".into(), Cx).into_column(),
Series::new("Dx".into(), Dx).into_column(),
])?;
let result = df.hstack(new_df.get_columns())?;
Ok(result)
}
fn gen_commutation_level_3(df: DataFrame) -> PolarsResult<DataFrame> {
let cx = df.column("Cx")?.f64()?.to_vec();
let dx = df.column("Dx")?.f64()?.to_vec();
let mut Nx: Vec<f64> = Vec::with_capacity(cx.len());
let mut Mx: Vec<f64> = Vec::with_capacity(cx.len());
let mut Px: Vec<f64> = Vec::with_capacity(cx.len());
for i in 0..cx.len() {
let nx_value: f64 = dx[i..].iter().filter_map(|&v| v).sum();
Nx.push(nx_value);
let mx_value: f64 = cx[i..].iter().filter_map(|&v| v).sum();
Mx.push(mx_value);
let px_value = if nx_value > 0.0 {
mx_value / nx_value
} else {
0.0
};
Px.push(px_value);
}
let new_df = DataFrame::new(vec![
Series::new("Nx".into(), Nx).into_column(),
Series::new("Mx".into(), Mx).into_column(),
Series::new("Px".into(), Px).into_column(),
])?;
let result = df.hstack(new_df.get_columns())?;
Ok(result)
}
fn gen_commutation_level_4(df: DataFrame) -> PolarsResult<DataFrame> {
let mx = df.column("Mx")?.f64()?.to_vec();
let nx = df.column("Nx")?.f64()?.to_vec();
let mut Rx: Vec<f64> = Vec::with_capacity(mx.len());
let mut Sx: Vec<f64> = Vec::with_capacity(mx.len());
for i in 0..mx.len() {
let rx_value: f64 = mx[i..].iter().filter_map(|&v| v).sum();
Rx.push(rx_value);
let sx_value: f64 = nx[i..].iter().filter_map(|&v| v).sum();
Sx.push(sx_value);
}
let new_df = DataFrame::new(vec![
Series::new("Rx".into(), Rx).into_column(),
Series::new("Sx".into(), Sx).into_column(),
])?;
let result = df.hstack(new_df.get_columns())?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::xml::MortXML;
#[test]
fn test_basic_mortality_table_generation() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let config = MortTableConfig {
xml,
radix: Some(100_000),
pct: Some(1.0),
int_rate: None,
assumption: None,
};
let result = config
.gen_mort_table(1)
.expect("Failed to generate mortality table");
assert!(result.height() > 0, "Result DataFrame should not be empty");
assert_eq!(result.width(), 5, "Basic table should have 5 columns");
let expected_columns = vec!["age", "qx", "px", "lx", "dx"];
let actual_columns = result.get_column_names();
assert_eq!(
actual_columns, expected_columns,
"Column names don't match expected"
);
assert!(
result.column("age").unwrap().dtype().is_integer(),
"Age should be integer"
);
assert!(
result.column("qx").unwrap().dtype().is_float(),
"qx should be float"
);
assert!(
result.column("lx").unwrap().dtype().is_float(),
"lx should be float"
);
assert!(
result.column("dx").unwrap().dtype().is_float(),
"dx should be float"
);
println!("✓ Basic mortality table generated successfully");
println!(
"Table dimensions: {} rows × {} columns",
result.height(),
result.width()
);
}
#[test]
#[ignore]
fn test_mortality_table_with_commutation() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let config = MortTableConfig {
xml,
radix: Some(100_000),
pct: Some(1.0),
int_rate: Some(0.03), assumption: Some(AssumptionEnum::UDD),
};
let result = config
.gen_mort_table(2)
.expect("Failed to generate commutation table");
assert!(result.height() > 0, "Result DataFrame should not be empty");
assert_eq!(
result.width(),
10,
"Commutation table should have 10 columns (age, qx, px, lx, dx, Cx, Dx, Mx, Nx, Px)"
);
let expected_columns = vec!["age", "qx", "px", "lx", "dx", "Cx", "Dx", "Mx", "Nx", "Px"];
let actual_columns = result.get_column_names();
assert_eq!(
actual_columns, expected_columns,
"Commutation column names don't match"
);
assert!(
result.column("Cx").unwrap().dtype().is_float(),
"Cx should be float"
);
assert!(
result.column("Dx").unwrap().dtype().is_float(),
"Dx should be float"
);
assert!(
result.column("Mx").unwrap().dtype().is_float(),
"Mx should be float"
);
assert!(
result.column("Nx").unwrap().dtype().is_float(),
"Nx should be float"
);
assert!(
result.column("Px").unwrap().dtype().is_float(),
"Px should be float"
);
assert!(
result.column("Rx").unwrap().dtype().is_float(),
"Rx should be float"
);
assert!(
result.column("Sx").unwrap().dtype().is_float(),
"Sx should be float"
);
println!("✓ Commutation table generated successfully");
println!(
"Table with interest rate: {} rows × {} columns",
result.height(),
result.width()
);
}
#[test]
#[ignore]
fn test_percentage_adjustment() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let config_50 = MortTableConfig {
xml: xml.clone(),
radix: Some(100_000),
pct: Some(0.5),
int_rate: None,
assumption: None,
};
let config_100 = MortTableConfig {
xml: xml.clone(),
radix: Some(100_000),
pct: Some(1.0),
int_rate: None,
assumption: None,
};
let config_150 = MortTableConfig {
xml,
radix: Some(100_000),
pct: Some(1.5),
int_rate: None,
assumption: None,
};
let table_50 = config_50.gen_mort_table(1).expect("Failed with 50% rates");
let table_100 = config_100
.gen_mort_table(1)
.expect("Failed with 100% rates");
let table_150 = config_150
.gen_mort_table(1)
.expect("Failed with 150% rates");
let qx_50 = table_50
.column("qx")
.unwrap()
.f64()
.unwrap()
.get(5)
.unwrap();
let qx_100 = table_100
.column("qx")
.unwrap()
.f64()
.unwrap()
.get(5)
.unwrap();
let qx_150 = table_150
.column("qx")
.unwrap()
.f64()
.unwrap()
.get(5)
.unwrap();
assert!(
(qx_50 * 2.0 - qx_100).abs() < 1e-10,
"50% should be half of 100%"
);
assert!(
(qx_150 / 1.5 - qx_100).abs() < 1e-10,
"150% should be 1.5 times 100%"
);
let lx_30_50 = table_50
.column("lx")
.unwrap()
.f64()
.unwrap()
.get(30)
.unwrap_or(0.0);
let lx_30_100 = table_100
.column("lx")
.unwrap()
.f64()
.unwrap()
.get(30)
.unwrap_or(0.0);
let lx_30_150 = table_150
.column("lx")
.unwrap()
.f64()
.unwrap()
.get(30)
.unwrap_or(0.0);
assert!(
lx_30_50 > lx_30_100,
"Lower mortality should result in higher survival"
);
assert!(
lx_30_100 > lx_30_150,
"Higher mortality should result in lower survival"
);
println!("✓ Percentage adjustment working correctly");
println!("qx at index 5: 50%={qx_50:.6}, 100%={qx_100:.6}, 150%={qx_150:.6}");
}
#[test]
fn test_actuarial_relationships() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let config = MortTableConfig {
xml,
radix: Some(100_000),
pct: Some(1.0),
int_rate: Some(0.04), assumption: Some(AssumptionEnum::CFM),
};
let result = config.gen_mort_table(3).expect("Failed to generate table");
let lx = result.column("lx").unwrap().f64().unwrap();
let dx = result.column("dx").unwrap().f64().unwrap();
let qx = result.column("qx").unwrap().f64().unwrap();
for i in 0..std::cmp::min(10, result.height()) {
let lx_val = lx.get(i).unwrap();
let dx_val = dx.get(i).unwrap();
let qx_val = qx.get(i).unwrap();
let expected_dx = lx_val * qx_val;
assert!(
(dx_val - expected_dx).abs() < 1.0,
"dx calculation incorrect at index {i}: expected {expected_dx}, got {dx_val}"
);
assert!(
(0.0..=1.0).contains(&qx_val),
"qx should be a probability at index {i}"
);
if i > 0 {
let prev_lx = lx.get(i - 1).unwrap();
assert!(
lx_val <= prev_lx,
"lx should be non-increasing at index {i}"
);
}
}
let _Dx = result.column("Dx").unwrap().f64().unwrap();
let Nx = result.column("Nx").unwrap().f64().unwrap();
let _Cx = result.column("Cx").unwrap().f64().unwrap();
let _Mx = result.column("Mx").unwrap().f64().unwrap();
for i in 1..std::cmp::min(10, result.height()) {
let nx_curr = Nx.get(i).unwrap();
let nx_prev = Nx.get(i - 1).unwrap();
assert!(nx_curr < nx_prev, "Nx should be decreasing at index {i}");
}
println!("✓ Actuarial relationships verified");
}
#[test]
fn test_different_radix_values() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let radix_values = vec![100_000, 1_000_000, 10_000_000];
for &radix in &radix_values {
let config = MortTableConfig {
xml: xml.clone(),
radix: Some(radix),
pct: Some(1.0),
int_rate: None,
assumption: None,
};
let result = config
.gen_mort_table(1)
.unwrap_or_else(|_| panic!("Failed with radix {radix}"));
let first_lx = result.column("lx").unwrap().f64().unwrap().get(0).unwrap();
assert_eq!(
first_lx, radix as f64,
"First lx should equal radix for {radix}"
);
println!("✓ Radix {radix} working correctly");
}
}
#[test]
fn test_error_handling() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let config = MortTableConfig {
xml,
radix: Some(100_000),
pct: Some(1.0),
int_rate: None,
assumption: None,
};
let result = config.gen_mort_table(1);
assert!(result.is_ok(), "Valid config should succeed");
println!("✓ Error handling tests completed");
}
#[test]
fn test_comprehensive_table_validation() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let config = MortTableConfig {
xml,
radix: Some(100_000),
pct: Some(0.75), int_rate: Some(0.035), assumption: Some(AssumptionEnum::HPB),
};
let result = config
.gen_mort_table(2)
.expect("Failed to generate comprehensive table");
println!("\n=== COMPREHENSIVE TABLE VALIDATION ===");
println!(
"Table dimensions: {} rows × {} columns",
result.height(),
result.width()
);
println!("Configuration: 75% mortality, 3.5% interest, HPB assumption");
if result.height() >= 5 {
println!("\nFirst 5 rows:");
println!("{}", result.head(Some(5)));
}
if result.height() >= 5 {
println!("\nLast 5 rows:");
println!("{}", result.tail(Some(5)));
}
let lx_col = result.column("lx").unwrap().f64().unwrap();
let dx_col = result.column("dx").unwrap().f64().unwrap();
let qx_col = result.column("qx").unwrap().f64().unwrap();
assert_eq!(
lx_col.get(0).unwrap(),
100_000.0,
"Should start with 100,000 lives"
);
for i in 0..result.height() {
let qx = qx_col.get(i).unwrap();
assert!(
(0.0..=1.0).contains(&qx),
"Mortality rate out of bounds at row {i}: {qx}"
);
}
for i in 0..result.height() {
let lx = lx_col.get(i).unwrap();
let dx = dx_col.get(i).unwrap();
assert!(dx <= lx, "Deaths exceed lives at row {i}: dx={dx}, lx={lx}");
}
if let Ok(dx_comm) = result.column("Dx") {
let dx_values = dx_comm.f64().unwrap();
for i in 0..std::cmp::min(10, result.height()) {
let dx_val = dx_values.get(i).unwrap();
assert!(dx_val > 0.0, "Dx should be positive at row {i}: {dx_val}");
}
}
println!("✓ All comprehensive validations passed");
println!("✓ Table generation working correctly with all features");
}
#[test]
#[ignore]
fn test_check_doctest_xml_ids() {
let xml_28001 = MortXML::from_url_id(28001).expect("Failed to load XML 28001");
println!("XML 28001 - Number of tables: {}", xml_28001.tables.len());
println!(
"XML 28001 - Table name: {}",
xml_28001.content_classification.table_name
);
println!(
"XML 28001 - Content type: {}",
xml_28001.content_classification.content_type
);
let xml_1705 = MortXML::from_url_id(1705).expect("Failed to load XML 1705");
println!("XML 1705 - Number of tables: {}", xml_1705.tables.len());
println!(
"XML 1705 - Table name: {}",
xml_1705.content_classification.table_name
);
println!(
"XML 1705 - Content type: {}",
xml_1705.content_classification.content_type
);
assert_eq!(
xml_28001.tables.len(),
1,
"XML 28001 should have exactly 1 table"
);
assert_eq!(
xml_1705.tables.len(),
1,
"XML 1705 should have exactly 1 table"
);
}
#[test]
fn test_mathematical_precision() {
let xml = MortXML::from_url_id(1704).expect("Failed to load XML");
let config = MortTableConfig {
xml,
radix: Some(1_000_000), pct: Some(1.0),
int_rate: Some(0.03),
assumption: Some(AssumptionEnum::UDD),
};
let result = config
.gen_mort_table(2)
.expect("Failed to generate high precision table");
let lx = result.column("lx").unwrap().f64().unwrap();
let dx = result.column("dx").unwrap().f64().unwrap();
let _qx = result.column("qx").unwrap().f64().unwrap();
for i in 0..std::cmp::min(result.height() - 1, 50) {
let lx_curr = lx.get(i).unwrap();
let dx_curr = dx.get(i).unwrap();
let lx_next = lx.get(i + 1).unwrap();
let expected_lx_next = lx_curr - dx_curr;
assert!(
(lx_next - expected_lx_next).abs() < 1e-6,
"Life table relationship violated at age {}: l(x+1)={}, lx-dx={}",
i,
lx_next,
expected_lx_next
);
}
println!("✓ Mathematical precision verified with high-precision calculations");
}
}