use faer::Mat;
use gsem_sem::EstimationMethod;
use gsem_sem::syntax;
use super::gc_correction::GcMode;
use super::user_gwas::{self, SnpResult, UserGwasConfig};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Identification {
#[default]
FixedVariance,
MarkerIndicator,
}
impl Identification {
pub fn from_str_lossy(s: &str) -> Self {
match s.to_ascii_lowercase().as_str() {
"marker" | "marker_indicator" | "marker-indicator" | "mi" => Self::MarkerIndicator,
_ => Self::FixedVariance,
}
}
}
pub struct CommonFactorGwasConfig {
pub estimation: EstimationMethod,
pub gc: GcMode,
pub snp_se: Option<f64>,
pub smooth_check: bool,
pub identification: Identification,
pub fix_measurement: bool,
pub num_threads: Option<usize>,
}
impl Default for CommonFactorGwasConfig {
fn default() -> Self {
Self {
estimation: EstimationMethod::Dwls,
gc: GcMode::Standard,
snp_se: None,
smooth_check: false,
identification: Identification::default(),
fix_measurement: true,
num_threads: None,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn run_common_factor_gwas(
trait_names: &[String],
s_ld: &Mat<f64>,
v_ld: &Mat<f64>,
i_ld: &Mat<f64>,
beta_snp: &[&[f64]],
se_snp: &[&[f64]],
var_snp: &[f64],
cfg: &CommonFactorGwasConfig,
on_snp_done: Option<&(dyn Fn() + Sync)>,
) -> Vec<SnpResult> {
let model_str = match cfg.identification {
Identification::FixedVariance => {
let loading = std::iter::once(format!("NA*{}", trait_names[0]))
.chain(trait_names[1..].iter().cloned())
.collect::<Vec<_>>()
.join(" + ");
format!("F1 =~ {loading}\nF1 ~ SNP\nF1 ~~ 1*F1")
}
Identification::MarkerIndicator => {
let loading = trait_names.join(" + ");
format!("F1 =~ {loading}\nF1 ~ SNP")
}
};
let model =
syntax::parse_model(&model_str, false).expect("auto-generated model syntax is invalid");
let config = UserGwasConfig {
model,
estimation: cfg.estimation,
gc: cfg.gc,
max_iter: 500,
smooth_check: cfg.smooth_check,
snp_se: cfg.snp_se,
variant_label: user_gwas::VariantLabel::Snp,
q_snp: false,
fix_measurement: cfg.fix_measurement,
num_threads: cfg.num_threads,
};
user_gwas::run_user_gwas(
&config,
s_ld,
v_ld,
i_ld,
beta_snp,
se_snp,
var_snp,
on_snp_done,
)
}