use greeners::{ModelSelection, PanelDiagnostics, SummaryStats};
use ndarray::Array1;
#[test]
fn test_model_selection_compare_models() {
let models = vec![
("Model 1", -100.0, 3, 100), ("Model 2", -95.0, 5, 100),
("Model 3", -98.0, 4, 100),
];
let comparison = ModelSelection::compare_models(models);
assert_eq!(comparison.len(), 3);
for (name, aic, bic, rank_aic, rank_bic) in &comparison {
assert!(!name.is_empty());
assert!(aic.is_finite());
assert!(bic.is_finite());
assert!(*rank_aic >= 1 && *rank_aic <= 3);
assert!(*rank_bic >= 1 && *rank_bic <= 3);
}
for i in 0..comparison.len() - 1 {
assert!(comparison[i].1 <= comparison[i + 1].1);
}
let best_model = &comparison[0];
assert_eq!(best_model.3, 1); }
#[test]
fn test_model_selection_aic_calculation() {
let loglik = -50.0;
let k = 3;
let n = 100;
let models = vec![("Test Model", loglik, k, n)];
let comparison = ModelSelection::compare_models(models);
let expected_aic = -2.0 * loglik + 2.0 * (k as f64);
assert!((comparison[0].1 - expected_aic).abs() < 1e-10);
}
#[test]
fn test_model_selection_bic_calculation() {
let loglik = -50.0;
let k = 3;
let n = 100;
let models = vec![("Test Model", loglik, k, n)];
let comparison = ModelSelection::compare_models(models);
let expected_bic = -2.0 * loglik + (k as f64) * (n as f64).ln();
assert!((comparison[0].2 - expected_bic).abs() < 1e-10);
}
#[test]
fn test_akaike_weights_basic() {
let aic_values = vec![100.0, 102.0, 105.0];
let (delta_aic, weights) = ModelSelection::akaike_weights(&aic_values);
assert_eq!(delta_aic.len(), 3);
assert!((delta_aic[0] - 0.0).abs() < 1e-10); assert!((delta_aic[1] - 2.0).abs() < 1e-10);
assert!((delta_aic[2] - 5.0).abs() < 1e-10);
assert_eq!(weights.len(), 3);
let sum_weights: f64 = weights.iter().sum();
assert!((sum_weights - 1.0).abs() < 1e-10);
for &w in &weights {
assert!(w >= 0.0 && w <= 1.0);
}
assert!(weights[0] > weights[1]);
assert!(weights[1] > weights[2]);
}
#[test]
fn test_akaike_weights_equal_models() {
let aic_values = vec![100.0, 100.0, 100.0];
let (delta_aic, weights) = ModelSelection::akaike_weights(&aic_values);
for &d in &delta_aic {
assert!(d.abs() < 1e-10);
}
for &w in &weights {
assert!((w - 1.0 / 3.0).abs() < 1e-10);
}
}
#[test]
fn test_panel_diagnostics_breusch_pagan_basic() {
let residuals = Array1::from(vec![
0.1, -0.2, 0.15, -0.1, 0.05, 0.3, -0.25, 0.2, -0.15, 0.1, -0.1, 0.2, -0.15, 0.1, -0.05, ]);
let entity_ids = vec![
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, ];
let result = PanelDiagnostics::breusch_pagan_lm(&residuals, &entity_ids);
assert!(result.is_ok());
let (lm_stat, p_value) = result.unwrap();
assert!(lm_stat >= 0.0);
assert!(p_value >= 0.0 && p_value <= 1.0);
}
#[test]
fn test_panel_diagnostics_bp_no_panel_effect() {
let residuals = Array1::from(vec![
0.05, -0.03, 0.02, -0.04, 0.01,
-0.02, 0.04, -0.01, 0.03, -0.05,
0.03, -0.05, 0.04, -0.02, 0.01,
]);
let entity_ids = vec![
0, 0, 0, 0, 0,
1, 1, 1, 1, 1,
2, 2, 2, 2, 2,
];
let (lm_stat, p_value) = PanelDiagnostics::breusch_pagan_lm(&residuals, &entity_ids).unwrap();
assert!(lm_stat >= 0.0);
}
#[test]
fn test_panel_diagnostics_f_test_basic() {
let ssr_pooled = 150.0;
let ssr_fe = 100.0;
let n = 100;
let n_entities = 10;
let k = 3;
let result = PanelDiagnostics::f_test_fixed_effects(ssr_pooled, ssr_fe, n, n_entities, k);
assert!(result.is_ok());
let (f_stat, p_value) = result.unwrap();
assert!(f_stat > 0.0);
assert!(p_value >= 0.0 && p_value <= 1.0);
assert!(p_value < 0.05);
}
#[test]
fn test_panel_diagnostics_f_test_no_effect() {
let ssr_pooled = 100.0;
let ssr_fe = 100.0;
let n = 100;
let n_entities = 10;
let k = 3;
let (f_stat, p_value) = PanelDiagnostics::f_test_fixed_effects(ssr_pooled, ssr_fe, n, n_entities, k).unwrap();
assert!(f_stat.abs() < 1e-10);
assert!(p_value > 0.95);
}
#[test]
fn test_summary_stats_basic() {
let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
let (mean, std, min, q25, median, q75, max, n) = SummaryStats::describe(&data);
assert_eq!(n, 10);
assert!((mean - 5.5).abs() < 1e-10);
assert_eq!(min, 1.0);
assert_eq!(max, 10.0);
assert!(median >= 5.0 && median <= 6.0);
assert!(std > 0.0);
assert!(min <= q25);
assert!(q25 <= median);
assert!(median <= q75);
assert!(q75 <= max);
}
#[test]
fn test_summary_stats_quartiles() {
let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let (_, _, min, q25, median, q75, max, _) = SummaryStats::describe(&data);
assert_eq!(min, 1.0);
assert_eq!(max, 5.0);
assert_eq!(median, 3.0);
assert!((q25 - 2.0).abs() < 0.5);
assert!((q75 - 4.0).abs() < 0.5);
}
#[test]
fn test_summary_stats_single_value() {
let data = Array1::from(vec![5.0]);
let (mean, std, min, q25, median, q75, max, n) = SummaryStats::describe(&data);
assert_eq!(n, 1);
assert_eq!(mean, 5.0);
assert_eq!(min, 5.0);
assert_eq!(max, 5.0);
assert_eq!(median, 5.0);
assert_eq!(std, 0.0); }
#[test]
fn test_summary_stats_constant_values() {
let data = Array1::from(vec![3.0, 3.0, 3.0, 3.0, 3.0]);
let (mean, std, min, q25, median, q75, max, _) = SummaryStats::describe(&data);
assert_eq!(mean, 3.0);
assert_eq!(std, 0.0);
assert_eq!(min, 3.0);
assert_eq!(max, 3.0);
assert_eq!(median, 3.0);
assert_eq!(q25, 3.0);
assert_eq!(q75, 3.0);
}
#[test]
fn test_model_selection_ranking() {
let models = vec![
("Best AIC", -50.0, 2, 100), ("Worst AIC", -30.0, 5, 100), ("Middle", -45.0, 3, 100), ];
let models = vec![
("Best", -52.0, 2, 100), ("Middle", -51.0, 4, 100), ("Worst", -50.0, 6, 100), ];
let comparison = ModelSelection::compare_models(models);
assert_eq!(comparison[0].0, "Best");
assert_eq!(comparison[0].3, 1);
assert_eq!(comparison[1].0, "Middle");
assert_eq!(comparison[1].3, 2);
assert_eq!(comparison[2].0, "Worst");
assert_eq!(comparison[2].3, 3); }
#[test]
fn test_panel_diagnostics_entity_mismatch() {
let residuals = Array1::from(vec![0.1, 0.2, 0.3]);
let entity_ids = vec![0, 0];
let result = PanelDiagnostics::breusch_pagan_lm(&residuals, &entity_ids);
assert!(result.is_err());
}
#[test]
fn test_panel_diagnostics_insufficient_df() {
let ssr_pooled = 100.0;
let ssr_fe = 90.0;
let n = 10;
let n_entities = 8; let k = 5;
let result = PanelDiagnostics::f_test_fixed_effects(ssr_pooled, ssr_fe, n, n_entities, k);
assert!(result.is_err());
}