1#[cfg(feature = "ml")]
25use ndarray::Array2;
26#[cfg(feature = "ml")]
27use ort::{
28 session::{builder::GraphOptimizationLevel, Session},
29 value::TensorRef,
30};
31
32use std::collections::HashMap;
33use std::path::Path;
34
35use kam_call::caller::VariantCall;
36#[cfg(feature = "ml")]
37use kam_call::caller::VariantType;
38
39#[derive(Debug, serde::Deserialize)]
41pub struct ModelMeta {
42 pub feature_names: Vec<String>,
44 pub ml_pass_threshold: f64,
46 pub variant_class_map: HashMap<String, i32>,
48}
49
50pub struct MlScorer {
52 #[cfg(feature = "ml")]
53 session: Session,
54 pub meta: ModelMeta,
55}
56
57impl MlScorer {
58 pub fn load(model_path: &Path, meta_path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
64 let meta_bytes = std::fs::read(meta_path)?;
65 #[cfg(feature = "ml")]
66 let meta: ModelMeta = serde_json::from_slice(&meta_bytes)?;
67 #[cfg(feature = "ml")]
68 {
69 let session = Session::builder()?
70 .with_optimization_level(GraphOptimizationLevel::Level3)?
71 .with_intra_threads(1)?
72 .commit_from_file(model_path)?;
73 Ok(Self { session, meta })
74 }
75
76 #[cfg(not(feature = "ml"))]
77 {
78 let _ = (model_path, meta_bytes);
79 Err("kam-ml was built without the 'ml' feature — recompile with --features ml".into())
80 }
81 }
82
83 #[cfg(feature = "ml")]
87 pub fn score(&mut self, call: &VariantCall) -> Option<f32> {
88 let features = self.extract_features(call);
89 let n = self.meta.feature_names.len();
90 if features.len() != n {
91 return None;
92 }
93
94 let array = Array2::from_shape_vec((1, n), features).ok()?;
95 let tensor = TensorRef::from_array_view(array.view()).ok()?;
96 let outputs = self.session.run(ort::inputs![tensor]).ok()?;
97
98 let probs = outputs
101 .get("probabilities")
102 .or_else(|| outputs.get("output_probability"))?;
103 let (_shape, data) = probs.try_extract_tensor::<f32>().ok()?;
104 data.get(1).copied()
106 }
107
108 #[cfg(not(feature = "ml"))]
110 pub fn score(&mut self, _call: &VariantCall) -> Option<f32> {
111 None
112 }
113
114 #[cfg(feature = "ml")]
116 fn extract_features(&self, call: &VariantCall) -> Vec<f32> {
117 let vaf = call.vaf as f32;
118 let nref = call.n_molecules_ref as f32;
119 let nalt = call.n_molecules_alt as f32;
120 let ndupalt = call.n_duplex_alt as f32;
121 let nsimalt = call.n_simplex_alt as f32;
122 let sbp = call.strand_bias_p as f32;
123 let conf = call.confidence as f32;
124 let ref_len = call.ref_sequence.len() as f32;
125 let alt_len = call.alt_sequence.len() as f32;
126
127 let duplex_frac = ndupalt / (nalt + 1e-9);
128 let has_duplex = if ndupalt > 0.0 { 1.0_f32 } else { 0.0_f32 };
129 let ci_width = (call.vaf_ci_high - call.vaf_ci_low) as f32;
130 let alt_depth = nref + nalt;
131
132 let log_nalt = (nalt + 1.0).ln();
133 let log_nref = (nref + 1.0).ln();
134 let log_alt_depth = (alt_depth + 1.0).ln();
135 let log_vaf = (vaf + 1e-6).ln();
136
137 let vaf_times_conf = vaf * conf;
138 let vaf_times_nalt = vaf * nalt;
139 let nalt_over_conf = nalt / (conf + 1e-9);
140 let ci_width_rel = ci_width / (vaf + 1e-9);
141 let snr = nalt / (nref + 1.0);
142
143 let conf_sq = conf * conf;
144 let nalt_sq = nalt * nalt;
145 let vaf_sq = vaf * vaf;
146
147 let ref_alt_len_ratio = ref_len / (alt_len + 1.0);
148 let indel_size = (ref_len - alt_len).abs();
149
150 let duplex_enrichment = ndupalt / (vaf * alt_depth + 1e-9);
151 let simplex_only_frac = nsimalt / (nalt + 1e-9);
152
153 let conf_above_99 = if conf > 0.99 { 1.0_f32 } else { 0.0_f32 };
154 let conf_above_999 = if conf > 0.999 { 1.0_f32 } else { 0.0_f32 };
155 let sbp_above_05 = if sbp > 0.05 { 1.0_f32 } else { 0.0_f32 };
156
157 let variant_class_enc = *self
158 .meta
159 .variant_class_map
160 .get(variant_type_str(call.variant_type))
161 .unwrap_or(&0) as f32;
162
163 let lookup: HashMap<&str, f32> = [
164 ("vaf", vaf),
165 ("nref", nref),
166 ("nalt", nalt),
167 ("ndupalt", ndupalt),
168 ("nsimalt", nsimalt),
169 ("sbp", sbp),
170 ("conf", conf),
171 ("ref_len", ref_len),
172 ("alt_len", alt_len),
173 ("duplex_frac", duplex_frac),
174 ("has_duplex", has_duplex),
175 ("ci_width", ci_width),
176 ("alt_depth", alt_depth),
177 ("log_nalt", log_nalt),
178 ("log_nref", log_nref),
179 ("log_alt_depth", log_alt_depth),
180 ("log_vaf", log_vaf),
181 ("vaf_times_conf", vaf_times_conf),
182 ("vaf_times_nalt", vaf_times_nalt),
183 ("nalt_over_conf", nalt_over_conf),
184 ("ci_width_rel", ci_width_rel),
185 ("snr", snr),
186 ("conf_sq", conf_sq),
187 ("nalt_sq", nalt_sq),
188 ("vaf_sq", vaf_sq),
189 ("ref_alt_len_ratio", ref_alt_len_ratio),
190 ("indel_size", indel_size),
191 ("duplex_enrichment", duplex_enrichment),
192 ("simplex_only_frac", simplex_only_frac),
193 ("conf_above_99", conf_above_99),
194 ("conf_above_999", conf_above_999),
195 ("sbp_above_05", sbp_above_05),
196 ("variant_class_enc", variant_class_enc),
197 ]
198 .into_iter()
199 .collect();
200
201 self.meta
202 .feature_names
203 .iter()
204 .map(|name| *lookup.get(name.as_str()).unwrap_or(&0.0))
205 .collect()
206 }
207}
208
209#[cfg(feature = "ml")]
210fn variant_type_str(vt: VariantType) -> &'static str {
211 match vt {
212 VariantType::Snv => "SNV",
213 VariantType::Insertion => "Insertion",
214 VariantType::Deletion => "Deletion",
215 VariantType::Mnv => "MNV",
216 VariantType::Complex => "Complex",
217 VariantType::LargeDeletion => "LargeDeletion",
218 VariantType::TandemDuplication => "TandemDuplication",
219 VariantType::Inversion => "Inversion",
220 VariantType::Fusion => "Fusion",
221 VariantType::InvDel => "InvDel",
222 VariantType::NovelInsertion => "NovelInsertion",
223 }
224}