use std::path::PathBuf;
use faer::Mat;
use serde_json::Value;
fn fixtures_dir() -> PathBuf {
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.pop(); p.pop(); p.push("tests");
p.push("fixtures");
p
}
fn load_fixture(name: &str) -> Value {
let path = fixtures_dir().join(format!("{name}.json"));
let data = std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("failed to read fixture {}: {e}", path.display()));
serde_json::from_str(&data).unwrap_or_else(|e| panic!("invalid JSON in {name}.json: {e}"))
}
fn json_to_mat(val: &Value) -> Mat<f64> {
let rows: Vec<Vec<f64>> = val
.as_array()
.unwrap()
.iter()
.map(|row| {
row.as_array()
.unwrap()
.iter()
.map(|v| v.as_f64().unwrap())
.collect()
})
.collect();
let nrows = rows.len();
let ncols = rows[0].len();
Mat::from_fn(nrows, ncols, |i, j| rows[i][j])
}
fn json_to_vec(val: &Value) -> Vec<f64> {
val.as_array()
.unwrap()
.iter()
.map(|v| v.as_f64().unwrap())
.collect()
}
fn assert_mat_close(a: &Mat<f64>, b: &Mat<f64>, tol: f64, msg: &str) {
assert_eq!(a.nrows(), b.nrows(), "{msg}: row count mismatch");
assert_eq!(a.ncols(), b.ncols(), "{msg}: col count mismatch");
for i in 0..a.nrows() {
for j in 0..a.ncols() {
let diff = (a[(i, j)] - b[(i, j)]).abs();
assert!(
diff < tol,
"{msg}: [{i},{j}] Rust={} R={} diff={diff} (tol={tol})",
a[(i, j)],
b[(i, j)]
);
}
}
}
fn assert_vec_close(a: &[f64], b: &[f64], tol: f64, msg: &str) {
assert_eq!(
a.len(),
b.len(),
"{msg}: length mismatch {} vs {}",
a.len(),
b.len()
);
for (i, (&av, &bv)) in a.iter().zip(b.iter()).enumerate() {
let diff = (av - bv).abs();
assert!(
diff < tol,
"{msg}: [{i}] Rust={av} R={bv} diff={diff} (tol={tol})"
);
}
}
#[test]
fn test_near_pd_matches_r() {
let fix = load_fixture("near_pd");
let input = json_to_mat(&fix["input"]);
let expected = json_to_mat(&fix["expected"]);
let result = gsem_matrix::near_pd::nearest_pd(&input, true, 100, 1e-8).unwrap();
assert_mat_close(&result, &expected, 1e-7, "nearPD");
}
#[test]
fn test_vech_3x3_matches_r() {
let fix = load_fixture("vech_3x3");
let input = json_to_mat(&fix["input"]);
let expected_vech = json_to_vec(&fix["vech"]);
let expected_rev = json_to_mat(&fix["reverse"]);
let v = gsem_matrix::vech::vech(&input).unwrap();
assert_vec_close(&v, &expected_vech, 1e-15, "vech 3x3");
let rev = gsem_matrix::vech::vech_reverse(&v, 3).unwrap();
assert_mat_close(&rev, &expected_rev, 1e-15, "vech_reverse 3x3");
}
#[test]
fn test_vech_4x4_matches_r() {
let fix = load_fixture("vech_4x4");
let input = json_to_mat(&fix["input"]);
let expected_vech = json_to_vec(&fix["vech"]);
let expected_rev = json_to_mat(&fix["reverse"]);
let v = gsem_matrix::vech::vech(&input).unwrap();
assert_vec_close(&v, &expected_vech, 1e-15, "vech 4x4");
let rev = gsem_matrix::vech::vech_reverse(&v, 4).unwrap();
assert_mat_close(&rev, &expected_rev, 1e-15, "vech_reverse 4x4");
}
#[test]
fn test_cov_to_cor_matches_r() {
let fix = load_fixture("cov_to_cor");
let input = json_to_mat(&fix["input"]);
let expected = json_to_mat(&fix["expected"]);
let result = gsem_matrix::smooth::cov_to_cor(&input);
assert_mat_close(&result, &expected, 1e-12, "cov_to_cor");
}
fn load_v_snp_inputs(fix: &Value) -> (Vec<f64>, Mat<f64>, f64, usize) {
let se = json_to_vec(&fix["se_snp"]);
let i_ld = json_to_mat(&fix["i_ld"]);
let var_snp = fix["var_snp"].as_f64().unwrap();
let k = fix["k"].as_u64().unwrap() as usize;
(se, i_ld, var_snp, k)
}
#[test]
fn test_v_snp_standard_matches_r() {
let fix = load_fixture("v_snp_standard");
let (se, i_ld, var_snp, k) = load_v_snp_inputs(&fix);
let expected = json_to_mat(&fix["expected"]);
let result = gsem::gwas::gc_correction::build_v_snp(
&se,
&i_ld,
var_snp,
gsem::gwas::gc_correction::GcMode::Standard,
k,
);
assert_mat_close(&result, &expected, 1e-12, "V_SNP standard");
}
#[test]
fn test_v_snp_conservative_matches_r() {
let fix = load_fixture("v_snp_conservative");
let (se, i_ld, var_snp, k) = load_v_snp_inputs(&fix);
let expected = json_to_mat(&fix["expected"]);
let result = gsem::gwas::gc_correction::build_v_snp(
&se,
&i_ld,
var_snp,
gsem::gwas::gc_correction::GcMode::Conservative,
k,
);
assert_mat_close(&result, &expected, 1e-12, "V_SNP conservative");
}
#[test]
fn test_v_snp_none_matches_r() {
let fix = load_fixture("v_snp_none");
let (se, i_ld, var_snp, k) = load_v_snp_inputs(&fix);
let expected = json_to_mat(&fix["expected"]);
let result = gsem::gwas::gc_correction::build_v_snp(
&se,
&i_ld,
var_snp,
gsem::gwas::gc_correction::GcMode::None,
k,
);
assert_mat_close(&result, &expected, 1e-12, "V_SNP none");
}
#[test]
fn test_s_full_matches_r() {
let fix = load_fixture("s_full");
let s_ld = json_to_mat(&fix["s_ld"]);
let beta_snp = json_to_vec(&fix["beta_snp"]);
let var_snp = fix["var_snp"].as_f64().unwrap();
let expected = json_to_mat(&fix["expected"]);
let k = s_ld.nrows();
let result = gsem::gwas::add_snps::build_s_full(&s_ld, &beta_snp, var_snp, k);
assert_mat_close(&result, &expected, 1e-12, "S_Full");
}
#[test]
fn test_v_full_matches_r() {
let fix = load_fixture("v_full");
let v_ld = json_to_mat(&fix["v_ld"]);
let var_snp_se2 = fix["var_snp_se2"].as_f64().unwrap();
let k = fix["k"].as_u64().unwrap() as usize;
let expected = json_to_mat(&fix["expected"]);
let v_snp_fix = load_fixture("v_snp_standard");
let se = json_to_vec(&v_snp_fix["se_snp"]);
let i_ld = json_to_mat(&v_snp_fix["i_ld"]);
let var_snp = v_snp_fix["var_snp"].as_f64().unwrap();
let result = gsem::gwas::add_snps::build_v_full(
&v_ld,
&se,
var_snp,
var_snp_se2,
&i_ld,
gsem::gwas::gc_correction::GcMode::Standard,
k,
);
assert_mat_close(&result, &expected, 1e-10, "V_Full");
}
#[test]
fn test_z_pre_matches_r() {
let fix = load_fixture("z_pre");
let beta = json_to_vec(&fix["beta"]);
let se = json_to_vec(&fix["se"]);
let i_ld = json_to_mat(&fix["i_ld"]);
let k = beta.len();
let z_std = gsem::gwas::gc_correction::gc_adjusted_z(
&beta,
&se,
&i_ld,
gsem::gwas::gc_correction::GcMode::Standard,
k,
);
let z_con = gsem::gwas::gc_correction::gc_adjusted_z(
&beta,
&se,
&i_ld,
gsem::gwas::gc_correction::GcMode::Conservative,
k,
);
let z_none = gsem::gwas::gc_correction::gc_adjusted_z(
&beta,
&se,
&i_ld,
gsem::gwas::gc_correction::GcMode::None,
k,
);
let expected_std = json_to_vec(&fix["standard"]);
let expected_con = json_to_vec(&fix["conservative"]);
let expected_none = json_to_vec(&fix["none"]);
assert_vec_close(&z_std, &expected_std, 1e-10, "Z_pre standard");
assert_vec_close(&z_con, &expected_con, 1e-10, "Z_pre conservative");
assert_vec_close(&z_none, &expected_none, 1e-10, "Z_pre none");
}
#[test]
fn test_sem_estimates_match_r() {
let fix = load_fixture("sem_1factor");
let s = json_to_mat(&fix["s"]);
let v_diag = json_to_vec(&fix["v_diag"]);
let r_estimates: Vec<(String, String, String, f64)> = fix["estimates"]
.as_array()
.unwrap()
.iter()
.map(|e| {
(
e["lhs"].as_str().unwrap().to_string(),
e["op"].as_str().unwrap().to_string(),
e["rhs"].as_str().unwrap().to_string(),
e["est"].as_f64().unwrap(),
)
})
.collect();
let model_str = "F1 =~ NA*V1 + V2 + V3\nF1 ~~ 1*F1\nV1 ~~ V1\nV2 ~~ V2\nV3 ~~ V3";
let pt = gsem_sem::syntax::parse_model(model_str, false).unwrap();
let obs_names: Vec<String> = vec!["V1", "V2", "V3"]
.into_iter()
.map(String::from)
.collect();
let mut model = gsem_sem::model::Model::from_partable(&pt, &obs_names);
let fit = gsem_sem::estimator::fit_dwls(&mut model, &s, &v_diag, 1000, None);
assert!(fit.converged, "SEM should converge");
let free_rows: Vec<_> = pt.rows.iter().filter(|r| r.free > 0).collect();
assert_eq!(
free_rows.len(),
fit.params.len(),
"Number of free rows must match number of fitted params"
);
for (i, row) in free_rows.iter().enumerate() {
let est = fit.params[i];
if let Some(r_est) = r_estimates
.iter()
.find(|(l, o, r, _)| *l == row.lhs && *o == row.op.to_string() && *r == row.rhs)
{
let diff = (est - r_est.3).abs();
assert!(
diff < 0.05,
"SEM param {} {} {}: Rust={est:.6} R={:.6} diff={diff:.6}",
row.lhs,
row.op,
row.rhs,
r_est.3
);
} else {
panic!(
"Free param {} {} {} not found in R reference estimates",
row.lhs, row.op, row.rhs
);
}
}
let r_sandwich_se = json_to_vec(&fix["sandwich_se"]);
let kstar = 3 * 4 / 2;
let v = json_to_mat(&fix["v"]);
let w = faer::Mat::from_fn(kstar, kstar, |i, j| {
if i == j && v_diag[i] > 1e-30 {
1.0 / v_diag[i]
} else {
0.0
}
});
let (se_vec, _ohtt) = gsem_sem::sandwich::sandwich_se(&mut model, &w, &v);
assert_eq!(
se_vec.len(),
r_sandwich_se.len(),
"Sandwich SE count mismatch: Rust={} R={}",
se_vec.len(),
r_sandwich_se.len()
);
for (i, (&rust_se, &r_se)) in se_vec.iter().zip(r_sandwich_se.iter()).enumerate() {
let diff = (rust_se - r_se).abs();
assert!(
diff < 0.01,
"1-factor sandwich SE[{i}]: Rust={rust_se:.6} R={r_se:.6} diff={diff:.6}"
);
}
let r_fit = &fix["fit_indices"];
let r_chisq = r_fit["chisq"].as_f64().unwrap();
let r_df = r_fit["df"].as_f64().unwrap() as usize;
let r_srmr = r_fit["srmr"].as_f64().unwrap();
let sigma_hat = model.implied_cov();
let n_free = model.n_free();
let df = kstar.saturating_sub(n_free);
assert_eq!(df, r_df, "1-factor df mismatch");
let fit_stats = gsem_sem::fit_indices::compute_fit(&s, &sigma_hat, &v, df, n_free, None, None);
let chisq_diff = (fit_stats.chisq - r_chisq).abs();
assert!(
chisq_diff < 1e-4,
"1-factor chisq: Rust={:.6} R={r_chisq:.6} diff={chisq_diff:.6}",
fit_stats.chisq
);
let srmr_diff = (fit_stats.srmr - r_srmr).abs();
assert!(
srmr_diff < 1e-6,
"1-factor SRMR: Rust={:.10} R={r_srmr:.10} diff={srmr_diff:.10}",
fit_stats.srmr
);
assert_mat_close(&sigma_hat, &s, 0.05, "SEM implied cov ≈ S");
assert!(
fit.objective < 0.1,
"SEM objective should be small: {}",
fit.objective
);
}
#[test]
fn test_sem_2factor_all_params_match_r() {
let fix = load_fixture("sem_2factor");
let s = json_to_mat(&fix["s"]);
let v_diag = json_to_vec(&fix["v_diag"]);
let r_estimates: Vec<(String, String, String, f64)> = fix["estimates"]
.as_array()
.unwrap()
.iter()
.map(|e| {
(
e["lhs"].as_str().unwrap().to_string(),
e["op"].as_str().unwrap().to_string(),
e["rhs"].as_str().unwrap().to_string(),
e["est"].as_f64().unwrap(),
)
})
.collect();
let model_str = "F1 =~ NA*V1 + V2\nF2 =~ NA*V3 + V4\n\
F1 ~~ 1*F1\nF2 ~~ 1*F2\nF1 ~~ F2\n\
V1 ~~ V1\nV2 ~~ V2\nV3 ~~ V3\nV4 ~~ V4";
let pt = gsem_sem::syntax::parse_model(model_str, false).unwrap();
let obs_names: Vec<String> = vec!["V1", "V2", "V3", "V4"]
.into_iter()
.map(String::from)
.collect();
let mut model = gsem_sem::model::Model::from_partable(&pt, &obs_names);
let fit = gsem_sem::estimator::fit_dwls(&mut model, &s, &v_diag, 1000, None);
assert!(fit.converged, "2-factor SEM should converge");
let n_fixed = pt.rows.iter().filter(|r| r.free == 0).count();
assert!(
n_fixed >= 2,
"Partable should have at least 2 fixed rows (F1~~1*F1, F2~~1*F2), got {n_fixed}"
);
let free_rows: Vec<_> = pt.rows.iter().filter(|r| r.free > 0).collect();
assert_eq!(
free_rows.len(),
fit.params.len(),
"Number of free rows ({}) must match number of fitted params ({})",
free_rows.len(),
fit.params.len()
);
for (i, row) in free_rows.iter().enumerate() {
let est = fit.params[i];
if let Some(r_est) = r_estimates
.iter()
.find(|(l, o, r, _)| *l == row.lhs && *o == row.op.to_string() && *r == row.rhs)
{
let diff = (est - r_est.3).abs();
assert!(
diff < 0.05,
"2-factor param {} {} {}: Rust={est:.6} R={:.6} diff={diff:.6}",
row.lhs,
row.op,
row.rhs,
r_est.3
);
} else {
panic!(
"Free param {} {} {} not found in R reference estimates",
row.lhs, row.op, row.rhs
);
}
}
let r_sandwich_se = json_to_vec(&fix["sandwich_se"]);
let kstar = 4 * 5 / 2; let v = json_to_mat(&fix["v"]);
let w = faer::Mat::from_fn(kstar, kstar, |i, j| {
if i == j && v_diag[i] > 1e-30 {
1.0 / v_diag[i]
} else {
0.0
}
});
let (se_vec, _ohtt) = gsem_sem::sandwich::sandwich_se(&mut model, &w, &v);
assert_eq!(
se_vec.len(),
r_sandwich_se.len(),
"2-factor sandwich SE count mismatch: Rust={} R={}",
se_vec.len(),
r_sandwich_se.len()
);
for (i, (&rust_se, &r_se)) in se_vec.iter().zip(r_sandwich_se.iter()).enumerate() {
let diff = (rust_se - r_se).abs();
assert!(
diff < 0.01,
"2-factor sandwich SE[{i}]: Rust={rust_se:.6} R={r_se:.6} diff={diff:.6}"
);
}
let r_fit = &fix["fit_indices"];
let r_chisq = r_fit["chisq"].as_f64().unwrap();
let r_df = r_fit["df"].as_f64().unwrap() as usize;
let r_srmr = r_fit["srmr"].as_f64().unwrap();
let sigma_hat = model.implied_cov();
let n_free = model.n_free();
let df = kstar.saturating_sub(n_free);
assert_eq!(df, r_df, "2-factor df mismatch");
let fit_stats = gsem_sem::fit_indices::compute_fit(&s, &sigma_hat, &v, df, n_free, None, None);
let chisq_diff = (fit_stats.chisq - r_chisq).abs();
assert!(
chisq_diff < 1e-4,
"2-factor chisq: Rust={:.6} R={r_chisq:.6} diff={chisq_diff:.6}",
fit_stats.chisq
);
let srmr_diff = (fit_stats.srmr - r_srmr).abs();
assert!(
srmr_diff < 1e-6,
"2-factor SRMR: Rust={:.10} R={r_srmr:.10} diff={srmr_diff:.10}",
fit_stats.srmr
);
}
#[test]
fn test_v_reorder_matches_r() {
let fix = load_fixture("reorder");
let v = json_to_mat(&fix["v"]);
let user_order: Vec<String> = fix["user_order"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap().to_string())
.collect();
let model_order: Vec<String> = fix["model_order"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap().to_string())
.collect();
let expected = json_to_mat(&fix["v_reordered"]);
let result = gsem_sem::reorder::reorder_v(&v, &user_order, &model_order).unwrap();
assert_mat_close(&result, &expected, 1e-12, "V reorder");
}
#[test]
fn test_commonfactor_matches_r() {
let fix = load_fixture("commonfactor");
let s = json_to_mat(&fix["s"]);
let v = json_to_mat(&fix["v"]);
let r_params: Vec<(String, String, String, f64)> = fix["parameters"]
.as_array()
.unwrap()
.iter()
.map(|e| {
(
e["lhs"].as_str().unwrap().to_string(),
e["op"].as_str().unwrap().to_string(),
e["rhs"].as_str().unwrap().to_string(),
e["est"].as_f64().unwrap(),
)
})
.collect();
let r_sandwich_se = json_to_vec(&fix["sandwich_se"]);
let r_implied = json_to_mat(&fix["implied_cov"]);
let result =
gsem_sem::commonfactor::run_commonfactor(&s, &v, gsem_sem::EstimationMethod::Dwls).unwrap();
assert_eq!(
result.parameters.len(),
r_params.len(),
"parameter count mismatch"
);
for (rust_p, r_p) in result.parameters.iter().zip(r_params.iter()) {
assert_eq!(rust_p.lhs, r_p.0, "lhs mismatch");
assert_eq!(rust_p.rhs, r_p.2, "rhs mismatch");
let diff = (rust_p.est - r_p.3).abs();
assert!(
diff < 0.02,
"commonfactor param {}.{}.{}: Rust={:.6} R={:.6} diff={diff:.6}",
rust_p.lhs,
rust_p.op,
rust_p.rhs,
rust_p.est,
r_p.3
);
}
for (i, (rust_p, &r_se)) in result
.parameters
.iter()
.zip(r_sandwich_se.iter())
.enumerate()
{
let diff = (rust_p.se - r_se).abs();
assert!(
diff < 0.01,
"commonfactor SE[{i}]: Rust={:.6} R={r_se:.6} diff={diff:.6}",
rust_p.se
);
}
assert_mat_close(
&result.implied_cov,
&r_implied,
1e-6,
"commonfactor implied cov",
);
let r_chisq = fix["chisq"].as_f64().unwrap();
let r_df = fix["df"].as_f64().unwrap() as usize;
let r_cfi = fix["cfi"].as_f64().unwrap();
let r_srmr = fix["srmr"].as_f64().unwrap();
assert_eq!(result.fit.df, r_df, "commonfactor df mismatch");
let chisq_diff = (result.fit.chisq - r_chisq).abs();
assert!(
chisq_diff < 1e-4,
"commonfactor chisq: Rust={:.6} R={r_chisq:.6} diff={chisq_diff:.6}",
result.fit.chisq
);
let cfi_diff = (result.fit.cfi - r_cfi).abs();
assert!(
cfi_diff < 0.01,
"commonfactor CFI: Rust={:.6} R={r_cfi:.6} diff={cfi_diff:.6}",
result.fit.cfi
);
let srmr_diff = (result.fit.srmr - r_srmr).abs();
assert!(
srmr_diff < 1e-6,
"commonfactor SRMR: Rust={:.10} R={r_srmr:.10} diff={srmr_diff:.10}",
result.fit.srmr
);
}
#[allow(clippy::type_complexity)]
fn load_gwas_inputs() -> (
Mat<f64>,
Mat<f64>,
Mat<f64>,
Vec<String>,
Vec<Vec<f64>>,
Vec<Vec<f64>>,
Vec<f64>,
Vec<String>,
Value,
) {
let fix = load_fixture("gwas_per_snp");
let s = json_to_mat(&fix["s"]);
let v = json_to_mat(&fix["v"]);
let i_mat = json_to_mat(&fix["i_mat"]);
let trait_names: Vec<String> = fix["trait_names"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap().to_string())
.collect();
let snps_arr = fix["snps"].as_array().unwrap();
let mut beta_snp: Vec<Vec<f64>> = Vec::new();
let mut se_snp: Vec<Vec<f64>> = Vec::new();
let mut var_snp: Vec<f64> = Vec::new();
let mut snp_ids: Vec<String> = Vec::new();
for snp in snps_arr {
snp_ids.push(snp["SNP"].as_str().unwrap().to_string());
let maf = snp["MAF"].as_f64().unwrap();
var_snp.push(2.0 * maf * (1.0 - maf));
beta_snp.push(json_to_vec(&snp["beta"]));
se_snp.push(json_to_vec(&snp["se"]));
}
(
s,
v,
i_mat,
trait_names,
beta_snp,
se_snp,
var_snp,
snp_ids,
fix,
)
}
#[test]
fn test_gwas_baseline_commonfactor_match_r() {
let fix = load_fixture("gwas_per_snp");
let s = json_to_mat(&fix["s"]);
let v = json_to_mat(&fix["v"]);
let result =
gsem_sem::commonfactor::run_commonfactor(&s, &v, gsem_sem::EstimationMethod::Dwls).unwrap();
eprintln!("\n==== Rust commonfactor baseline ====");
for p in &result.parameters {
eprintln!(" {} {} {} = {:.6}", p.lhs, p.op, p.rhs, p.est);
}
eprintln!("chisq = {:.6}", result.fit.chisq);
eprintln!("objective ≈ {:.6e}", result.fit.chisq);
let anx_loading = result
.parameters
.iter()
.find(|p| p.lhs == "F1" && p.rhs == "V1")
.map(|p| p.est)
.unwrap_or(f64::NAN);
eprintln!("F1 =~ V1 (ANX) = {anx_loading}");
assert!(
anx_loading.abs() > 0.01,
"Baseline loading magnitude too small"
);
}
#[test]
fn test_commonfactor_gwas_per_snp_match_r() {
let (s, v, i_mat, trait_names, beta_snp, se_snp, var_snp, snp_ids, fix) = load_gwas_inputs();
let cfg = gsem::gwas::common_factor::CommonFactorGwasConfig {
identification: gsem::gwas::common_factor::Identification::FixedVariance,
..Default::default()
};
let beta_refs: Vec<&[f64]> = beta_snp.iter().map(Vec::as_slice).collect();
let se_refs: Vec<&[f64]> = se_snp.iter().map(Vec::as_slice).collect();
let rust_results = gsem::gwas::common_factor::run_common_factor_gwas(
&trait_names,
&s,
&v,
&i_mat,
&beta_refs,
&se_refs,
&var_snp,
&cfg,
None,
);
assert_eq!(
rust_results.len(),
snp_ids.len(),
"SNP count mismatch: Rust={}, expected={}",
rust_results.len(),
snp_ids.len()
);
let r_cf = fix["user_gwas"].as_array().unwrap();
assert_eq!(r_cf.len(), snp_ids.len(), "R result count mismatch");
let mut n_compared = 0;
for (idx, rust_res) in rust_results.iter().enumerate() {
let r_row = &r_cf[idx];
let r_snp = r_row["SNP"].as_str().unwrap();
assert_eq!(
snp_ids[idx], r_snp,
"SNP order mismatch at {idx}: {} vs {r_snp}",
snp_ids[idx]
);
let r_est = r_row["est"].as_f64().unwrap();
let snp_param = rust_res
.params
.iter()
.find(|p| p.op == gsem_sem::syntax::Op::Regression && p.lhs == "F1" && p.rhs == "SNP")
.unwrap_or_else(|| {
panic!(
"No F1~SNP parameter in Rust result for SNP {}",
snp_ids[idx]
)
});
if !rust_res.converged {
continue;
}
n_compared += 1;
let est_diff = (snp_param.est.abs() - r_est.abs()).abs();
assert!(
est_diff < 0.002,
"commonfactorGWAS |est| for {}: Rust={:.6} R={r_est:.6} diff={est_diff:.6}",
r_snp,
snp_param.est
);
assert!(
snp_param.se > 0.0 && snp_param.se.is_finite(),
"commonfactorGWAS SE for {r_snp} should be finite and positive: {}",
snp_param.se
);
}
assert!(
n_compared >= 15,
"Too few SNPs converged ({n_compared}/20); expected most to converge"
);
}
#[test]
fn test_commonfactor_gwas_marker_indicator_matches_r() {
let (s, v, i_mat, trait_names, beta_snp, se_snp, var_snp, snp_ids, fix) = load_gwas_inputs();
let cfg = gsem::gwas::common_factor::CommonFactorGwasConfig {
identification: gsem::gwas::common_factor::Identification::MarkerIndicator,
..Default::default()
};
let beta_refs: Vec<&[f64]> = beta_snp.iter().map(Vec::as_slice).collect();
let se_refs: Vec<&[f64]> = se_snp.iter().map(Vec::as_slice).collect();
let rust_results = gsem::gwas::common_factor::run_common_factor_gwas(
&trait_names,
&s,
&v,
&i_mat,
&beta_refs,
&se_refs,
&var_snp,
&cfg,
None,
);
let r_cf = fix["commonfactor_gwas"].as_array().unwrap();
let mut n_compared = 0;
let mut n_est_match = 0;
for (idx, rust_res) in rust_results.iter().enumerate() {
if !rust_res.converged {
continue;
}
let r_row = &r_cf[idx];
let r_snp = r_row["SNP"].as_str().unwrap();
assert_eq!(snp_ids[idx], r_snp, "SNP order mismatch");
let r_est = r_row["est"].as_f64().unwrap();
let snp_param = rust_res
.params
.iter()
.find(|p| p.op == gsem_sem::syntax::Op::Regression && p.lhs == "F1" && p.rhs == "SNP")
.unwrap_or_else(|| panic!("No F1~SNP param for SNP {r_snp}"));
n_compared += 1;
let est_diff = (snp_param.est - r_est).abs();
if est_diff < 0.01 {
n_est_match += 1;
}
}
assert!(
n_compared >= 15,
"MarkerIndicator mode: too few SNPs converged ({n_compared}/20)"
);
assert!(
n_est_match >= 1,
"MarkerIndicator mode: at least 1 SNP should match R's signed est \
within 0.01 — got {n_est_match}/{n_compared}"
);
}
#[test]
fn test_user_gwas_per_snp_match_r() {
let (s, v, i_mat, trait_names, beta_snp, se_snp, var_snp, snp_ids, fix) = load_gwas_inputs();
let model_str = format!(
"F1 =~ NA*{} + {} + {}\nF1 ~ SNP\nF1 ~~ 1*F1",
trait_names[0], trait_names[1], trait_names[2]
);
let pt = gsem_sem::syntax::parse_model(&model_str, false).unwrap();
let cfg = gsem::gwas::user_gwas::UserGwasConfig {
model: pt,
estimation: gsem_sem::EstimationMethod::Dwls,
gc: gsem::gwas::gc_correction::GcMode::Standard,
max_iter: 500,
smooth_check: false,
snp_se: None,
variant_label: gsem::gwas::user_gwas::VariantLabel::Snp,
q_snp: false,
fix_measurement: true,
num_threads: None,
};
let beta_refs: Vec<&[f64]> = beta_snp.iter().map(Vec::as_slice).collect();
let se_refs: Vec<&[f64]> = se_snp.iter().map(Vec::as_slice).collect();
let rust_results = gsem::gwas::user_gwas::run_user_gwas(
&cfg, &s, &v, &i_mat, &beta_refs, &se_refs, &var_snp, None,
);
assert_eq!(rust_results.len(), snp_ids.len(), "SNP count mismatch");
let r_user = fix["user_gwas"].as_array().unwrap();
let mut n_compared = 0;
for (idx, rust_res) in rust_results.iter().enumerate() {
if !rust_res.converged {
continue;
}
let r_row = &r_user[idx];
let r_snp = r_row["SNP"].as_str().unwrap();
assert_eq!(snp_ids[idx], r_snp, "SNP order mismatch at {idx}");
let r_est = r_row["est"].as_f64().unwrap();
let snp_param = rust_res
.params
.iter()
.find(|p| p.op == gsem_sem::syntax::Op::Regression && p.lhs == "F1" && p.rhs == "SNP")
.unwrap_or_else(|| panic!("No F1~SNP parameter for SNP {r_snp}"));
let est_diff = (snp_param.est.abs() - r_est.abs()).abs();
assert!(
est_diff < 0.01,
"userGWAS |est| for {r_snp}: Rust={:.6} R={r_est:.6} diff={est_diff:.6}",
snp_param.est
);
if let Some(r_chisq) = r_row["chisq"].as_f64() {
let chisq_diff = (rust_res.chisq - r_chisq).abs();
assert!(
chisq_diff < 0.5,
"userGWAS chisq for {r_snp}: Rust={:.4} R={r_chisq:.4} diff={chisq_diff:.4}",
rust_res.chisq
);
}
n_compared += 1;
}
assert!(n_compared >= 15, "Too few SNPs converged ({n_compared}/20)");
}