1use crate::cli::FireOptions;
2use crate::fiber::FiberseqData;
3use crate::*;
4use anyhow;
5use derive_builder::Builder;
6use gbdt::decision_tree::{Data, DataVec};
7use gbdt::gradient_boost::GBDT;
8use itertools::Itertools;
9use ordered_float::OrderedFloat;
10use serde::Deserialize;
11use std::collections::BTreeMap;
12use std::fs;
13use tempfile::NamedTempFile;
14
15pub static FIRE_MODEL: &str = include_str!("../../models/FIRE.gbdt.json");
16pub static FIRE_CONF_JSON: &str = include_str!("../../models/FIRE.conf.json");
17
18pub fn get_model(fire_opts: &FireOptions) -> (GBDT, MapPrecisionValues) {
19 let mut remove_temp_file = false;
20 let (model_file, fdr_table_file) = match (&fire_opts.model, &fire_opts.fdr_table) {
22 (Some(model_file), Some(b)) => {
23 let fdr_table = fs::read_to_string(b).expect("Unable to read file");
24 (model_file.clone(), fdr_table)
25 }
26 _ => {
27 let temp_file = NamedTempFile::new().expect("Unable to make a temp file");
28 let (mut temp_file, path) = temp_file.keep().expect("Unable to keep temp file");
29 let temp_file_name = path
30 .as_os_str()
31 .to_str()
32 .expect("Unable to convert the path of the named temp file to an &str.");
33 temp_file
34 .write_all(FIRE_MODEL.as_bytes())
35 .expect("Unable to write file");
36 remove_temp_file = true;
38 (temp_file_name.to_string(), FIRE_CONF_JSON.to_string())
39 }
40 };
41 log::info!("Using model: {}", model_file);
42 let model =
44 GBDT::from_xgboost_dump(&model_file, "binary:logistic").expect("failed to load FIRE model");
45 if remove_temp_file {
46 fs::remove_file(model_file).expect("Unable to remove temp file");
47 }
48
49 let precision_table: PrecisionTable =
51 serde_json::from_str(&fdr_table_file).expect("Precision table JSON was not well-formatted");
52 let precision_converter = MapPrecisionValues::new(&precision_table);
53
54 (model, precision_converter)
56}
57
58fn get_mid_point(start: i64, end: i64) -> i64 {
59 (start + end) / 2
60}
61
62pub fn get_bins(mid_point: i64, bin_num: i64, bin_width: i64, max_end: i64) -> Vec<(i64, i64)> {
68 let mut bins = Vec::new();
69 for i in 0..bin_num {
70 let mut bin_start = mid_point - (bin_num / 2 - i) * bin_width - bin_width / 2;
71 let mut bin_end = bin_start + bin_width;
72 if bin_start < 0 {
73 bin_start = 0;
74 }
75 if bin_end < 0 {
76 bin_end = 0;
77 }
78 if bin_start > max_end {
79 bin_start = max_end - 1;
80 }
81 if bin_end > max_end {
82 bin_end = max_end;
83 }
84 bins.push((bin_start, bin_end));
85 }
86 bins
87}
88
89fn get_m6a_rle_data(rec: &FiberseqData, start: i64, end: i64) -> (f32, f32) {
91 let mut m6a_rles = vec![];
92 let mut max = 0;
93 let mut _max_pos = 0;
94 let mut weighted_rle = 0.0;
96 for (m6a_1, m6a_2) in rec.m6a.starts.iter().flatten().tuple_windows() {
97 if *m6a_1 < start || *m6a_1 > end || *m6a_2 < start || *m6a_2 > end {
99 continue;
100 }
101 let rle = (m6a_2 - m6a_1).abs();
103 m6a_rles.push(rle);
104 if rle > max {
106 max = rle;
107 _max_pos = ((*m6a_1 + *m6a_2) / 2 - (end + start) / 2).abs();
108 }
109 weighted_rle += (rle * rle) as f32;
111 }
112 weighted_rle /= (end - start) as f32;
113
114 if m6a_rles.is_empty() {
115 return (-1.0, -1.0);
116 }
117 let mid_length = m6a_rles.len() / 2;
118 let (_, median, _) = m6a_rles.select_nth_unstable(mid_length);
119 (weighted_rle, *median as f32)
120}
121
122const FEATS_IN_USE: [&str; 3] = [
123 "m6a_count",
124 "frac_m6a",
127 "m6a_fc",
128 ];
133#[derive(Debug, Clone, Builder)]
134pub struct FireFeatsInRange {
135 pub m6a_count: f32,
136 pub at_count: f32,
137 #[allow(unused)]
138 pub count_5mc: f32,
139 pub frac_m6a: f32,
140 pub m6a_fc: f32,
141 pub weighted_m6a_rle: f32,
142 pub median_m6a_rle: f32,
143}
144
145impl FireFeatsInRange {
146 pub fn header(tag: &str) -> String {
147 let mut out = "".to_string();
148 for col in FEATS_IN_USE.iter() {
149 out += &format!("\t{}_{}", tag, col);
150 }
151 out
152 }
153}
154
155#[derive(Debug)]
156pub struct FireFeats<'a> {
157 rec: &'a FiberseqData,
158 #[allow(unused)]
159 at_count: usize,
160 m6a_count: usize,
161 frac_m6a: f32,
162 fire_opts: &'a FireOptions,
164 seq: Vec<u8>,
165 fire_feats: Vec<(i64, i64, Vec<f32>)>,
166}
167
168impl<'a> FireFeats<'a> {
169 pub fn new(rec: &'a FiberseqData, fire_opts: &'a FireOptions) -> Self {
170 let seq_len = rec.record.seq_len();
171 let seq = rec.record.seq().as_bytes();
172
173 let mut rtn = Self {
174 rec,
175 at_count: 0,
176 m6a_count: 0,
177 frac_m6a: 0.0,
178 fire_opts,
180 seq,
181 fire_feats: vec![],
182 };
183
184 rtn.at_count = rtn.get_at_count(0, seq_len as i64);
186 rtn.m6a_count = rtn.get_m6a_count(0, seq_len as i64);
187 rtn.frac_m6a = if rtn.at_count > 0 {
188 rtn.m6a_count as f32 / rtn.at_count as f32
189 } else {
190 0.0
191 };
192
193 rtn.get_fire_features();
194 if rtn.fire_opts.ont {
195 rtn.validate_that_ont_is_single_strand();
196 }
197 rtn
198 }
199
200 fn validate_that_ont_is_single_strand(&self) {
201 let sequenced_bp = if self.rec.record.is_reverse() {
202 b'T'
203 } else {
204 b'A'
205 };
206 for m6a_st in self.rec.m6a.starts.iter().flatten() {
207 let m6a_bp = self.seq[*m6a_st as usize];
208 if m6a_bp != sequenced_bp {
209 log::warn!(
210 "m6A site at {} is not the same as the sequenced base {}",
211 m6a_st,
212 sequenced_bp as char
213 );
214 }
215 }
216 }
217
218 fn get_bp_count(&self, start: i64, end: i64, bp: u8) -> usize {
219 let subseq = &self.seq[start as usize..end as usize];
220 subseq.iter().filter(|&&b| b == bp).count()
221 }
222
223 fn get_at_count(&self, start: i64, end: i64) -> usize {
224 self.get_bp_count(start, end, b'A') + self.get_bp_count(start, end, b'T')
225 }
226
227 fn get_5mc_count(&self, start: i64, end: i64) -> usize {
228 self.rec
229 .cpg
230 .starts
231 .iter()
232 .flatten()
233 .filter(|&&pos| pos >= start && pos < end)
234 .count()
235 }
236
237 fn get_m6a_count(&self, start: i64, end: i64) -> usize {
238 let mut m6a_count = self
239 .rec
240 .m6a
241 .starts
242 .iter()
243 .flatten()
244 .filter(|&&pos| pos >= start && pos < end)
245 .count();
246
247 if self.fire_opts.ont {
249 let mut sequenced_bp = self.get_bp_count(start, end, b'A');
250 let mut un_sequenced_bp = self.get_bp_count(start, end, b'T');
251 if self.rec.record.is_reverse() {
252 std::mem::swap(&mut sequenced_bp, &mut un_sequenced_bp);
254 }
255 let m6a_frac = if sequenced_bp > 0 {
256 m6a_count as f32 / sequenced_bp as f32
257 } else {
258 0.0
259 };
260 m6a_count += (un_sequenced_bp as f32 * m6a_frac).round() as usize;
261 }
262 m6a_count
263 }
264
265 fn m6a_fc_over_expected(&self, m6a_count: usize, at_count: usize) -> f32 {
266 let expected = self.frac_m6a * at_count as f32;
269 let observed = m6a_count as f32;
270 if expected == 0.0 || observed == 0.0 {
271 return 0.0;
272 }
273 let fc = observed / expected;
274 fc.log2()
275 }
276
277 fn feats_in_range(&self, start: i64, end: i64) -> FireFeatsInRange {
278 let m6a_count = self.get_m6a_count(start, end);
279 let at_count = self.get_at_count(start, end);
280 let count_5mc = self.get_5mc_count(start, end);
281 let frac_m6a = if at_count > 0 {
282 m6a_count as f32 / at_count as f32
283 } else {
284 0.0
285 };
286 let m6a_fc = self.m6a_fc_over_expected(m6a_count, at_count);
287 let (weighted_m6a_rle, median_m6a_rle) = get_m6a_rle_data(self.rec, start, end);
288
289 FireFeatsInRange {
290 m6a_count: m6a_count as f32,
291 at_count: at_count as f32,
292 count_5mc: count_5mc as f32,
293 frac_m6a,
294 m6a_fc,
295 weighted_m6a_rle,
296 median_m6a_rle,
297 }
298 }
299
300 pub fn fire_feats_header(fire_opts: &FireOptions) -> String {
301 let mut out = "#chrom\tstart\tend\tfiber".to_string();
302 out += "\tmsp_len\tmsp_len_times_m6a_fc\tccs_passes";
303 out += "\tfiber_m6a_count\tfiber_m6a_frac";
304 out += &FireFeatsInRange::header("msp");
305 out += &FireFeatsInRange::header("best");
306 out += &FireFeatsInRange::header("worst");
307 for bin_num in 0..fire_opts.bin_num {
308 out += &FireFeatsInRange::header(&format!("bin_{}", bin_num));
309 }
310 out += "\n";
311 out
312 }
313
314 fn msp_get_fire_features(&self, start: i64, end: i64) -> Vec<f32> {
315 let msp_len = end - start;
316 if msp_len < self.fire_opts.min_msp_length_for_positive_fire_call {
318 return vec![];
319 }
320 let ccs_passes = if self.fire_opts.ont { 4.0 } else { self.rec.ec };
321
322 let mut max_m6a_count = 0;
324 let mut max_m6a_start = 0;
325 let mut max_m6a_end = 0;
326 let mut min_m6a_count = usize::MAX;
327 let mut min_m6a_start = 0;
328 let mut min_m6a_end = 0;
329 let mut centering_pos = get_mid_point(start, end);
330 for st_idx in start..end {
331 let en_idx = (st_idx + self.fire_opts.best_window_size).min(end);
332 if (end - start) < (2 * self.fire_opts.best_window_size) {
334 log::trace!("MSP window is not large enough for best and worst window analysis");
335 break;
336 }
337 let m6a_count = self.get_m6a_count(st_idx, en_idx);
338 if m6a_count > max_m6a_count {
339 max_m6a_count = m6a_count;
340 max_m6a_start = st_idx;
341 max_m6a_end = en_idx;
342 centering_pos = get_mid_point(st_idx, en_idx);
344 }
345 if m6a_count < min_m6a_count {
346 min_m6a_count = m6a_count;
347 min_m6a_start = st_idx;
348 min_m6a_end = en_idx;
349 }
350 if en_idx == end {
351 break;
352 }
353 }
354 let best_fire_feats = self.feats_in_range(max_m6a_start, max_m6a_end);
355 let worst_fire_feats = self.feats_in_range(min_m6a_start, min_m6a_end);
356
357 let msp_feats = self.feats_in_range(start, end);
358 let bins = get_bins(
359 centering_pos,
360 self.fire_opts.bin_num,
361 self.fire_opts.width_bin,
362 self.rec.record.seq_len() as i64,
363 );
364 let bin_feats = bins
365 .into_iter()
366 .map(|(start, end)| self.feats_in_range(start, end))
367 .collect::<Vec<FireFeatsInRange>>();
368 let msp_len_times_m6a_fc = msp_feats.m6a_fc * (msp_len as f32);
369 let mut rtn = vec![
370 msp_len as f32,
371 msp_len_times_m6a_fc,
372 ccs_passes,
373 self.m6a_count as f32,
374 self.frac_m6a,
375 ];
376 let feat_sets = vec![&msp_feats, &best_fire_feats, &worst_fire_feats]
377 .into_iter()
378 .chain(bin_feats.iter());
379 for feat_set in feat_sets {
380 rtn.push(feat_set.m6a_count);
381 rtn.push(feat_set.frac_m6a);
384 rtn.push(feat_set.m6a_fc);
385 }
388 rtn
389 }
390
391 pub fn get_fire_features(&mut self) {
392 let msp_data = self.rec.msp.into_iter().collect_vec();
393 self.fire_feats = msp_data
394 .into_iter()
395 .map(|(s, e, _l, _q, refs)| {
396 let (rs, re, _rl) = refs.unwrap_or((0, 0, 0));
397 (rs, re, self.msp_get_fire_features(s, e))
398 })
399 .collect();
400 }
401
402 pub fn dump_fire_feats(&self, out_buffer: &mut Box<dyn Write>) -> Result<(), anyhow::Error> {
403 for (s, e, row) in self.fire_feats.iter() {
404 if row.is_empty() {
405 continue;
406 }
407 let lead_feats = format!(
408 "{}\t{}\t{}\t{}\t",
409 self.rec.target_name,
410 s,
411 e,
412 String::from_utf8_lossy(self.rec.record.qname())
413 );
414 out_buffer.write_all(lead_feats.as_bytes())?;
415 out_buffer.write_all(row.iter().join("\t").as_bytes())?;
416 out_buffer.write_all(b"\n")?;
417 }
418 Ok(())
419 }
420
421 pub fn predict_with_xgb(
422 &self,
423 gbdt_model: &GBDT,
424 precision_converter: &MapPrecisionValues,
425 ) -> Vec<u8> {
426 let count = self.fire_feats.len();
427 if count == 0 {
428 return vec![];
429 }
430 let mut gbdt_data: DataVec = Vec::new();
432 for (_st, _en, window) in self.fire_feats.iter() {
433 if window.is_empty() {
434 continue;
435 }
436 let d = Data::new_test_data(window.to_vec(), None);
437 gbdt_data.push(d);
438 }
439 let predictions_without_short_ones = gbdt_model.predict(&gbdt_data);
440
441 let mut precisions = Vec::with_capacity(count);
443 let mut cur_pos = 0;
444 for (_st, _en, window) in self.fire_feats.iter() {
445 if window.is_empty() {
446 precisions.push(0);
447 } else {
448 let precision = precision_converter
449 .precision_from_float(predictions_without_short_ones[cur_pos]);
450 precisions.push(precision);
451 cur_pos += 1;
452 }
453 }
454 assert_eq!(cur_pos, predictions_without_short_ones.len());
456 assert_eq!(precisions.len(), count);
457 precisions
458 }
459}
460
461#[derive(Debug, Deserialize)]
462pub struct PrecisionTable {
463 pub columns: Vec<String>,
464 pub data: Vec<(f32, f32)>,
466}
467
468pub struct MapPrecisionValues {
469 pub map: BTreeMap<OrderedFloat<f32>, u8>,
470}
471
472impl MapPrecisionValues {
473 pub fn new(pt: &PrecisionTable) -> Self {
474 let mut map = BTreeMap::new();
476
477 for (mokapot_score, mokapot_q_value) in pt.data.iter() {
478 let precision = ((1.0 - mokapot_q_value) * 255.0).round() as u8;
479 map.insert(OrderedFloat(*mokapot_score), precision);
480 }
481 map.insert(
483 OrderedFloat(0.0),
484 *map.get(&OrderedFloat(0.0)).unwrap_or(&0),
485 );
486 Self { map }
487 }
488
489 pub fn precision_from_float(&self, value: f32) -> u8 {
491 let key = OrderedFloat(value);
492 let (less_key, less_val) = self
494 .map
495 .range(..key)
496 .next_back()
497 .unwrap_or((&OrderedFloat(0.0), &0));
498 let (more_key, more_val) = self
500 .map
501 .range(key..)
502 .next()
503 .unwrap_or((&OrderedFloat(1.0), &255));
504 if (more_key - key).abs() < (less_key - key).abs() {
505 *more_val
506 } else {
507 *less_val
508 }
509 }
510}