Skip to main content

kam_ml/
lib.rs

1//! Optional ML scoring via ONNX model.
2//!
3//! When an ONNX model and companion `model_meta.json` are provided,
4//! each [`VariantCall`] is scored by a gradient-boosted classifier
5//! trained on simulated training data. The resulting probability
6//! is stored in [`VariantCall::ml_prob`].
7//!
8//! Compile with `--features ml` to enable ONNX inference. Without the
9//! feature, [`MlScorer::load`] returns an error and [`MlScorer::score`]
10//! always returns `None`.
11//!
12//! # Example
13//!
14//! ```no_run
15//! use std::path::Path;
16//! use kam_ml::MlScorer;
17//!
18//! let scorer = MlScorer::load(
19//!     Path::new("model.onnx"),
20//!     Path::new("model_meta.json"),
21//! ).expect("load model");
22//! ```
23
24#[cfg(feature = "ml")]
25use ndarray::Array2;
26#[cfg(feature = "ml")]
27use ort::{
28    session::{builder::GraphOptimizationLevel, Session},
29    value::TensorRef,
30};
31
32use std::collections::HashMap;
33use std::path::Path;
34
35use kam_call::caller::VariantCall;
36#[cfg(feature = "ml")]
37use kam_call::caller::VariantType;
38
39/// Metadata loaded alongside the ONNX model.
40#[derive(Debug, serde::Deserialize)]
41pub struct ModelMeta {
42    /// Feature names in the exact order the model expects.
43    pub feature_names: Vec<String>,
44    /// Probability threshold above which a call is labelled ML_PASS.
45    pub ml_pass_threshold: f64,
46    /// Mapping from variant type string to integer encoding.
47    pub variant_class_map: HashMap<String, i32>,
48}
49
50/// Holds a loaded ONNX session and associated metadata.
51pub struct MlScorer {
52    #[cfg(feature = "ml")]
53    session: Session,
54    pub meta: ModelMeta,
55}
56
57impl MlScorer {
58    /// Load an ONNX model and its companion metadata file.
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if either file cannot be read or parsed.
63    pub fn load(model_path: &Path, meta_path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
64        let meta_bytes = std::fs::read(meta_path)?;
65        #[cfg(feature = "ml")]
66        let meta: ModelMeta = serde_json::from_slice(&meta_bytes)?;
67        #[cfg(feature = "ml")]
68        {
69            let session = Session::builder()?
70                .with_optimization_level(GraphOptimizationLevel::Level3)?
71                .with_intra_threads(1)?
72                .commit_from_file(model_path)?;
73            Ok(Self { session, meta })
74        }
75
76        #[cfg(not(feature = "ml"))]
77        {
78            let _ = (model_path, meta_bytes);
79            Err("kam-ml was built without the 'ml' feature — recompile with --features ml".into())
80        }
81    }
82
83    /// Score a single [`VariantCall`] and return the ML probability (class 1).
84    ///
85    /// Returns `None` if feature extraction or inference fails.
86    #[cfg(feature = "ml")]
87    pub fn score(&mut self, call: &VariantCall) -> Option<f32> {
88        let features = self.extract_features(call);
89        let n = self.meta.feature_names.len();
90        if features.len() != n {
91            return None;
92        }
93
94        let array = Array2::from_shape_vec((1, n), features).ok()?;
95        let tensor = TensorRef::from_array_view(array.view()).ok()?;
96        let outputs = self.session.run(ort::inputs![tensor]).ok()?;
97
98        // LightGBM ONNX output: "label" (i64) and "probabilities" (shape [N, 2]).
99        // Index 1 of the second axis is P(positive).
100        let probs = outputs
101            .get("probabilities")
102            .or_else(|| outputs.get("output_probability"))?;
103        let (_shape, data) = probs.try_extract_tensor::<f32>().ok()?;
104        // data has layout [n_samples * n_classes]; index 1 is P(class 1) for sample 0.
105        data.get(1).copied()
106    }
107
108    /// Score a single [`VariantCall`]. Returns `None` when the ml feature is not compiled in.
109    #[cfg(not(feature = "ml"))]
110    pub fn score(&mut self, _call: &VariantCall) -> Option<f32> {
111        None
112    }
113
114    /// Build the feature vector for a call in the order `meta.feature_names` specifies.
115    #[cfg(feature = "ml")]
116    fn extract_features(&self, call: &VariantCall) -> Vec<f32> {
117        let vaf = call.vaf as f32;
118        let nref = call.n_molecules_ref as f32;
119        let nalt = call.n_molecules_alt as f32;
120        let ndupalt = call.n_duplex_alt as f32;
121        let nsimalt = call.n_simplex_alt as f32;
122        let sbp = call.strand_bias_p as f32;
123        let conf = call.confidence as f32;
124        let ref_len = call.ref_sequence.len() as f32;
125        let alt_len = call.alt_sequence.len() as f32;
126
127        let duplex_frac = ndupalt / (nalt + 1e-9);
128        let has_duplex = if ndupalt > 0.0 { 1.0_f32 } else { 0.0_f32 };
129        let ci_width = (call.vaf_ci_high - call.vaf_ci_low) as f32;
130        let alt_depth = nref + nalt;
131
132        let log_nalt = (nalt + 1.0).ln();
133        let log_nref = (nref + 1.0).ln();
134        let log_alt_depth = (alt_depth + 1.0).ln();
135        let log_vaf = (vaf + 1e-6).ln();
136
137        let vaf_times_conf = vaf * conf;
138        let vaf_times_nalt = vaf * nalt;
139        let nalt_over_conf = nalt / (conf + 1e-9);
140        let ci_width_rel = ci_width / (vaf + 1e-9);
141        let snr = nalt / (nref + 1.0);
142
143        let conf_sq = conf * conf;
144        let nalt_sq = nalt * nalt;
145        let vaf_sq = vaf * vaf;
146
147        let ref_alt_len_ratio = ref_len / (alt_len + 1.0);
148        let indel_size = (ref_len - alt_len).abs();
149
150        let duplex_enrichment = ndupalt / (vaf * alt_depth + 1e-9);
151        let simplex_only_frac = nsimalt / (nalt + 1e-9);
152
153        let conf_above_99 = if conf > 0.99 { 1.0_f32 } else { 0.0_f32 };
154        let conf_above_999 = if conf > 0.999 { 1.0_f32 } else { 0.0_f32 };
155        let sbp_above_05 = if sbp > 0.05 { 1.0_f32 } else { 0.0_f32 };
156
157        let variant_class_enc = *self
158            .meta
159            .variant_class_map
160            .get(variant_type_str(call.variant_type))
161            .unwrap_or(&0) as f32;
162
163        let lookup: HashMap<&str, f32> = [
164            ("vaf", vaf),
165            ("nref", nref),
166            ("nalt", nalt),
167            ("ndupalt", ndupalt),
168            ("nsimalt", nsimalt),
169            ("sbp", sbp),
170            ("conf", conf),
171            ("ref_len", ref_len),
172            ("alt_len", alt_len),
173            ("duplex_frac", duplex_frac),
174            ("has_duplex", has_duplex),
175            ("ci_width", ci_width),
176            ("alt_depth", alt_depth),
177            ("log_nalt", log_nalt),
178            ("log_nref", log_nref),
179            ("log_alt_depth", log_alt_depth),
180            ("log_vaf", log_vaf),
181            ("vaf_times_conf", vaf_times_conf),
182            ("vaf_times_nalt", vaf_times_nalt),
183            ("nalt_over_conf", nalt_over_conf),
184            ("ci_width_rel", ci_width_rel),
185            ("snr", snr),
186            ("conf_sq", conf_sq),
187            ("nalt_sq", nalt_sq),
188            ("vaf_sq", vaf_sq),
189            ("ref_alt_len_ratio", ref_alt_len_ratio),
190            ("indel_size", indel_size),
191            ("duplex_enrichment", duplex_enrichment),
192            ("simplex_only_frac", simplex_only_frac),
193            ("conf_above_99", conf_above_99),
194            ("conf_above_999", conf_above_999),
195            ("sbp_above_05", sbp_above_05),
196            ("variant_class_enc", variant_class_enc),
197        ]
198        .into_iter()
199        .collect();
200
201        self.meta
202            .feature_names
203            .iter()
204            .map(|name| *lookup.get(name.as_str()).unwrap_or(&0.0))
205            .collect()
206    }
207}
208
209#[cfg(feature = "ml")]
210fn variant_type_str(vt: VariantType) -> &'static str {
211    match vt {
212        VariantType::Snv => "SNV",
213        VariantType::Insertion => "Insertion",
214        VariantType::Deletion => "Deletion",
215        VariantType::Mnv => "MNV",
216        VariantType::Complex => "Complex",
217        VariantType::LargeDeletion => "LargeDeletion",
218        VariantType::TandemDuplication => "TandemDuplication",
219        VariantType::Inversion => "Inversion",
220        VariantType::Fusion => "Fusion",
221        VariantType::InvDel => "InvDel",
222        VariantType::NovelInsertion => "NovelInsertion",
223    }
224}