use faer::Mat;
use gsem_sem::EstimationMethod;
use gsem_sem::estimator;
use gsem_sem::model::Model;
use gsem_sem::sandwich;
use gsem_sem::syntax;
#[derive(Debug, Clone)]
pub struct MultiSnpConfig {
pub model: syntax::ParTable,
pub estimation: EstimationMethod,
pub max_iter: usize,
pub snp_var_se: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct MultiSnpResult {
pub params: Vec<super::user_gwas::SnpParamResult>,
pub chisq: f64,
pub chisq_df: usize,
pub converged: bool,
}
#[allow(clippy::too_many_arguments)]
pub fn run_multi_snp(
config: &MultiSnpConfig,
s_ld: &Mat<f64>,
v_ld: &Mat<f64>,
beta: &[&[f64]], se: &[&[f64]], var_snp: &[f64], ld_matrix: &Mat<f64>, snp_names: &[String],
) -> MultiSnpResult {
let k = s_ld.nrows();
let n_snps = snp_names.len();
let total = n_snps + k;
let mut s_full = Mat::zeros(total, total);
for i in 0..n_snps {
s_full[(i, i)] = var_snp[i];
for j in (i + 1)..n_snps {
let cov = ld_matrix[(i, j)] * (var_snp[i] * var_snp[j]).sqrt();
s_full[(i, j)] = cov;
s_full[(j, i)] = cov;
}
}
for snp_i in 0..n_snps {
for t in 0..k {
let cov = beta[snp_i][t] * var_snp[snp_i];
s_full[(snp_i, n_snps + t)] = cov;
s_full[(n_snps + t, snp_i)] = cov;
}
}
for i in 0..k {
for j in 0..k {
s_full[(n_snps + i, n_snps + j)] = s_ld[(i, j)];
}
}
let kstar_full = total * (total + 1) / 2;
let mut v_full = Mat::zeros(kstar_full, kstar_full);
let kstar_trait = k * (k + 1) / 2;
let trait_vech_offset = kstar_full - kstar_trait;
for i in 0..kstar_trait {
for j in 0..kstar_trait {
if trait_vech_offset + i < kstar_full && trait_vech_offset + j < kstar_full {
v_full[(trait_vech_offset + i, trait_vech_offset + j)] = v_ld[(i, j)];
}
}
}
let snp_var_se = config.snp_var_se.unwrap_or(0.0005_f64).powi(2);
for i in 0..n_snps {
let idx = vech_index(i, i, total);
if idx < kstar_full {
v_full[(idx, idx)] = snp_var_se;
}
}
for (snp_i, var_snp_i) in var_snp.iter().enumerate().take(n_snps) {
for t in 0..k {
let row_idx = n_snps + t;
let col_idx = snp_i;
let (r, c) = if row_idx >= col_idx {
(row_idx, col_idx)
} else {
(col_idx, row_idx)
};
let idx = vech_index(r, c, total);
if idx < kstar_full {
let se_val = se
.get(snp_i)
.and_then(|s| s.get(t))
.copied()
.unwrap_or(f64::NAN);
v_full[(idx, idx)] = (se_val * var_snp_i).powi(2);
}
}
}
for i in 0..n_snps {
for j in (i + 1)..n_snps {
let idx = vech_index(j, i, total);
if idx < kstar_full {
v_full[(idx, idx)] = snp_var_se;
}
}
}
let mut obs_names: Vec<String> = snp_names.to_vec();
obs_names.extend((0..k).map(|i| format!("V{}", i + 1)));
let pt = &config.model;
let mut model = Model::from_partable(pt, &obs_names);
let v_diag: Vec<f64> = (0..kstar_full).map(|i| v_full[(i, i)]).collect();
let fit = match config.estimation {
EstimationMethod::Ml => estimator::fit_ml(&mut model, &s_full, config.max_iter, None),
EstimationMethod::Dwls => {
estimator::fit_dwls(&mut model, &s_full, &v_diag, config.max_iter, None)
}
};
let w_diag = Mat::from_fn(kstar_full, kstar_full, |i, j| {
if i == j && v_diag[i] > 1e-30 {
1.0 / v_diag[i]
} else {
0.0
}
});
let (se_vec, _) = sandwich::sandwich_se(&mut model, &w_diag, &v_full);
use statrs::distribution::ContinuousCDF;
let params: Vec<super::user_gwas::SnpParamResult> = pt
.rows
.iter()
.enumerate()
.filter(|(_, row)| row.free > 0)
.enumerate()
.map(|(free_idx, (_, row))| {
let est = fit.params.get(free_idx).copied().unwrap_or(f64::NAN);
let se_val = se_vec.get(free_idx).copied().unwrap_or(f64::NAN);
let z = est / se_val;
let p = if z.is_finite() {
2.0 * statrs::distribution::Normal::standard().cdf(-z.abs())
} else {
f64::NAN
};
super::user_gwas::SnpParamResult {
lhs: row.lhs.clone(),
op: row.op,
rhs: row.rhs.clone(),
est,
se: se_val,
z_stat: z,
p_value: p,
}
})
.collect();
MultiSnpResult {
params,
chisq: fit.objective,
chisq_df: model.df(),
converged: fit.converged,
}
}
fn vech_index(row: usize, col: usize, p: usize) -> usize {
debug_assert!(row >= col, "vech_index requires row >= col");
let col_offset = if col == 0 {
0
} else {
col * p - col * (col - 1) / 2
};
col_offset + (row - col)
}
pub fn read_ld_matrix(path: &std::path::Path) -> anyhow::Result<(Mat<f64>, Option<Vec<String>>)> {
use std::io::BufRead;
let file = std::fs::File::open(path)
.map_err(|e| anyhow::anyhow!("failed to open LD matrix {}: {e}", path.display()))?;
let reader = std::io::BufReader::new(file);
let mut lines: Vec<String> = Vec::new();
for line in reader.lines() {
let line = line?;
let trimmed = line.trim();
if !trimmed.is_empty() {
lines.push(trimmed.to_string());
}
}
if lines.is_empty() {
anyhow::bail!("LD matrix file is empty");
}
let first_fields: Vec<&str> = lines[0].split('\t').collect();
let has_header = first_fields[0].parse::<f64>().is_err();
let (header, data_lines) = if has_header {
let names: Vec<String> = first_fields.iter().map(|s| s.to_string()).collect();
(Some(names), &lines[1..])
} else {
(None, &lines[..])
};
let n = data_lines.len();
let mut mat = Mat::zeros(n, n);
for (i, line) in data_lines.iter().enumerate() {
let fields: Vec<&str> = line.split('\t').collect();
if fields.len() != n {
anyhow::bail!(
"LD matrix row {} has {} columns, expected {n}",
i + 1,
fields.len()
);
}
for (j, field) in fields.iter().enumerate() {
mat[(i, j)] = field
.parse::<f64>()
.map_err(|e| anyhow::anyhow!("LD matrix [{i},{j}]: {e}"))?;
}
}
Ok((mat, header))
}
#[cfg(test)]
mod tests {
use super::*;
use faer::Mat;
use std::io::Write;
#[test]
fn test_vech_index_known_values() {
assert_eq!(vech_index(0, 0, 3), 0);
assert_eq!(vech_index(1, 0, 3), 1);
assert_eq!(vech_index(2, 0, 3), 2);
assert_eq!(vech_index(1, 1, 3), 3);
assert_eq!(vech_index(2, 1, 3), 4);
assert_eq!(vech_index(2, 2, 3), 5);
}
#[test]
fn test_vech_index_4x4() {
assert_eq!(vech_index(0, 0, 4), 0);
assert_eq!(vech_index(3, 0, 4), 3);
assert_eq!(vech_index(1, 1, 4), 4);
assert_eq!(vech_index(3, 3, 4), 9);
}
#[test]
fn test_read_ld_matrix_no_header() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("ld.txt");
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "1.0\t0.3").unwrap();
writeln!(f, "0.3\t1.0").unwrap();
drop(f);
let (mat, header) = read_ld_matrix(&path).unwrap();
assert!(header.is_none());
assert_eq!(mat.nrows(), 2);
assert_eq!(mat.ncols(), 2);
assert!((mat[(0, 0)] - 1.0).abs() < 1e-10);
assert!((mat[(0, 1)] - 0.3).abs() < 1e-10);
assert!((mat[(1, 0)] - 0.3).abs() < 1e-10);
}
#[test]
fn test_read_ld_matrix_with_header() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("ld_h.txt");
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "SNP1\tSNP2").unwrap();
writeln!(f, "1.0\t0.5").unwrap();
writeln!(f, "0.5\t1.0").unwrap();
drop(f);
let (mat, header) = read_ld_matrix(&path).unwrap();
assert!(header.is_some());
let names = header.unwrap();
assert_eq!(names, vec!["SNP1", "SNP2"]);
assert_eq!(mat.nrows(), 2);
assert!((mat[(0, 1)] - 0.5).abs() < 1e-10);
}
#[test]
fn test_read_ld_matrix_empty_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("empty.txt");
std::fs::File::create(&path).unwrap();
assert!(read_ld_matrix(&path).is_err());
}
#[test]
fn test_run_multi_snp_basic() {
let s_ld = faer::mat![[0.5, 0.2], [0.2, 0.4]];
let v_ld = Mat::from_fn(3, 3, |i, j| if i == j { 0.001 } else { 0.0 });
let beta = vec![vec![0.1, 0.05], vec![0.08, 0.12]];
let se = vec![vec![0.02, 0.02], vec![0.02, 0.02]];
let var_snp = vec![0.3, 0.25];
let ld_matrix = faer::mat![[1.0, 0.3], [0.3, 1.0]];
let snp_names = vec!["SNP1".to_string(), "SNP2".to_string()];
let config = MultiSnpConfig {
model: syntax::parse_model("F1 =~ NA*V1 + V2\nF1 ~~ 1*F1\nV1 ~~ V1\nV2 ~~ V2\nF1 ~ SNP1 + SNP2\nSNP1 ~~ SNP1\nSNP2 ~~ SNP2", false).unwrap(),
estimation: EstimationMethod::Dwls,
max_iter: 500,
snp_var_se: None,
};
let beta_refs: Vec<&[f64]> = beta.iter().map(Vec::as_slice).collect();
let se_refs: Vec<&[f64]> = se.iter().map(Vec::as_slice).collect();
let result = run_multi_snp(
&config, &s_ld, &v_ld, &beta_refs, &se_refs, &var_snp, &ld_matrix, &snp_names,
);
assert!(!result.params.is_empty(), "should have parameter estimates");
assert!(
result.chisq.is_finite() || result.chisq.is_nan(),
"chisq should be a number"
);
}
}