Skip to main content

fibertools_rs/utils/
fire.rs

1use crate::cli::FireOptions;
2use crate::fiber::FiberseqData;
3use crate::*;
4use anyhow;
5use derive_builder::Builder;
6use gbdt::decision_tree::{Data, DataVec};
7use gbdt::gradient_boost::GBDT;
8use itertools::Itertools;
9use ordered_float::OrderedFloat;
10use serde::Deserialize;
11use std::collections::BTreeMap;
12use std::fs;
13use tempfile::NamedTempFile;
14
15pub static FIRE_MODEL: &str = include_str!("../../models/FIRE.gbdt.json");
16pub static FIRE_CONF_JSON: &str = include_str!("../../models/FIRE.conf.json");
17
18pub fn get_model(fire_opts: &FireOptions) -> (GBDT, MapPrecisionValues) {
19    let mut remove_temp_file = false;
20    // load defaults or passed in options
21    let (model_file, fdr_table_file) = match (&fire_opts.model, &fire_opts.fdr_table) {
22        (Some(model_file), Some(b)) => {
23            let fdr_table = fs::read_to_string(b).expect("Unable to read file");
24            (model_file.clone(), fdr_table)
25        }
26        _ => {
27            let temp_file = NamedTempFile::new().expect("Unable to make a temp file");
28            let (mut temp_file, path) = temp_file.keep().expect("Unable to keep temp file");
29            let temp_file_name = path
30                .as_os_str()
31                .to_str()
32                .expect("Unable to convert the path of the named temp file to an &str.");
33            temp_file
34                .write_all(FIRE_MODEL.as_bytes())
35                .expect("Unable to write file");
36            //fs::write(temp, FIRE_MODEL).expect("Unable to write file");
37            remove_temp_file = true;
38            (temp_file_name.to_string(), FIRE_CONF_JSON.to_string())
39        }
40    };
41    log::info!("Using model: {model_file}");
42    // load model
43    let model =
44        GBDT::from_xgboost_dump(&model_file, "binary:logistic").expect("failed to load FIRE model");
45    if remove_temp_file {
46        fs::remove_file(model_file).expect("Unable to remove temp file");
47    }
48
49    // load precision table
50    let precision_table: PrecisionTable =
51        serde_json::from_str(&fdr_table_file).expect("Precision table JSON was not well-formatted");
52    let precision_converter = MapPrecisionValues::new(&precision_table);
53
54    // return
55    (model, precision_converter)
56}
57
58fn get_mid_point(start: i64, end: i64) -> i64 {
59    (start + end) / 2
60}
61
62/// ```
63/// use fibertools_rs::utils::fire::get_bins;
64/// let bins = get_bins(50, 5, 20, 200);
65/// assert_eq!(bins, vec![(0, 20), (20, 40), (40, 60), (60, 80), (80, 100)]);
66/// ```
67pub fn get_bins(mid_point: i64, bin_num: i64, bin_width: i64, max_end: i64) -> Vec<(i64, i64)> {
68    let mut bins = Vec::new();
69    for i in 0..bin_num {
70        let mut bin_start = mid_point - (bin_num / 2 - i) * bin_width - bin_width / 2;
71        let mut bin_end = bin_start + bin_width;
72        if bin_start < 0 {
73            bin_start = 0;
74        }
75        if bin_end < 0 {
76            bin_end = 0;
77        }
78        if bin_start > max_end {
79            bin_start = max_end - 1;
80        }
81        if bin_end > max_end {
82            bin_end = max_end;
83        }
84        bins.push((bin_start, bin_end));
85    }
86    bins
87}
88
89/// get the maximum and median rle of m6a in a window
90fn get_m6a_rle_data(rec: &FiberseqData, start: i64, end: i64) -> (f32, f32) {
91    let mut m6a_rles = vec![];
92    let mut max = 0;
93    let mut _max_pos = 0;
94    // if you are a position, on average you will be in an rle length of weighted_rle
95    let mut weighted_rle = 0.0;
96    for (m6a_1, m6a_2) in rec.m6a.starts().iter().tuple_windows() {
97        // we only want m6a in the window
98        if *m6a_1 < start || *m6a_1 > end || *m6a_2 < start || *m6a_2 > end {
99            continue;
100        }
101        // distance between m6a sites
102        let rle = (m6a_2 - m6a_1).abs();
103        m6a_rles.push(rle);
104        // update max
105        if rle > max {
106            max = rle;
107            _max_pos = ((*m6a_1 + *m6a_2) / 2 - (end + start) / 2).abs();
108        }
109        // update weighted rle
110        weighted_rle += (rle * rle) as f32;
111    }
112    weighted_rle /= (end - start) as f32;
113
114    if m6a_rles.is_empty() {
115        return (-1.0, -1.0);
116    }
117    let mid_length = m6a_rles.len() / 2;
118    let (_, median, _) = m6a_rles.select_nth_unstable(mid_length);
119    (weighted_rle, *median as f32)
120}
121
122const FEATS_IN_USE: [&str; 3] = [
123    "m6a_count",
124    //"at_count",
125    //"count_5mc",
126    "frac_m6a",
127    "m6a_fc",
128    //"max_m6a_rle",
129    //"max_m6a_rle_pos",
130    // "weighted_m6a_rle",
131    //"median_m6a_rle",
132];
133#[derive(Debug, Clone, Builder)]
134pub struct FireFeatsInRange {
135    pub m6a_count: f32,
136    pub at_count: f32,
137    #[allow(unused)]
138    pub count_5mc: f32,
139    pub frac_m6a: f32,
140    pub m6a_fc: f32,
141    pub weighted_m6a_rle: f32,
142    pub median_m6a_rle: f32,
143}
144
145impl FireFeatsInRange {
146    pub fn header(tag: &str) -> String {
147        let mut out = "".to_string();
148        for col in FEATS_IN_USE.iter() {
149            out += &format!("\t{tag}_{col}");
150        }
151        out
152    }
153}
154
155#[derive(Debug)]
156pub struct FireFeats<'a> {
157    rec: &'a FiberseqData,
158    #[allow(unused)]
159    at_count: usize,
160    m6a_count: usize,
161    frac_m6a: f32,
162    //frac_m6a_in_msps: f32,
163    fire_opts: &'a FireOptions,
164    seq: Vec<u8>,
165    fire_feats: Vec<(i64, i64, Vec<f32>)>,
166}
167
168impl<'a> FireFeats<'a> {
169    pub fn new(rec: &'a FiberseqData, fire_opts: &'a FireOptions) -> Self {
170        let seq_len = rec.record.seq_len();
171        let seq = rec.record.seq().as_bytes();
172
173        let mut rtn = Self {
174            rec,
175            at_count: 0,
176            m6a_count: 0,
177            frac_m6a: 0.0,
178            //frac_m6a_in_msps,
179            fire_opts,
180            seq,
181            fire_feats: vec![],
182        };
183
184        // add in the m6a and AT counts
185        rtn.at_count = rtn.get_at_count(0, seq_len as i64);
186        rtn.m6a_count = rtn.get_m6a_count(0, seq_len as i64);
187        rtn.frac_m6a = if rtn.at_count > 0 {
188            rtn.m6a_count as f32 / rtn.at_count as f32
189        } else {
190            0.0
191        };
192
193        rtn.get_fire_features();
194        if rtn.fire_opts.ont {
195            rtn.validate_that_ont_is_single_strand();
196        }
197        rtn
198    }
199
200    fn validate_that_ont_is_single_strand(&self) {
201        let sequenced_bp = if self.rec.record.is_reverse() {
202            b'T'
203        } else {
204            b'A'
205        };
206        for m6a_st in self.rec.m6a.starts().iter() {
207            let m6a_bp = self.seq[*m6a_st as usize];
208            if m6a_bp != sequenced_bp {
209                log::warn!(
210                    "m6A site at {} is not the same as the sequenced base {}",
211                    m6a_st,
212                    sequenced_bp as char
213                );
214            }
215        }
216    }
217
218    fn get_bp_count(&self, start: i64, end: i64, bp: u8) -> usize {
219        let subseq = &self.seq[start as usize..end as usize];
220        subseq.iter().filter(|&&b| b == bp).count()
221    }
222
223    fn get_at_count(&self, start: i64, end: i64) -> usize {
224        self.get_bp_count(start, end, b'A') + self.get_bp_count(start, end, b'T')
225    }
226
227    fn get_5mc_count(&self, start: i64, end: i64) -> usize {
228        self.rec
229            .cpg
230            .starts()
231            .iter()
232            .filter(|&&pos| pos >= start && pos < end)
233            .count()
234    }
235
236    fn get_m6a_count(&self, start: i64, end: i64) -> usize {
237        let mut m6a_count = self
238            .rec
239            .m6a
240            .starts()
241            .iter()
242            .filter(|&&pos| pos >= start && pos < end)
243            .count();
244
245        // estimate what the count would be if we sequenced the other strand
246        if self.fire_opts.ont {
247            let mut sequenced_bp = self.get_bp_count(start, end, b'A');
248            let mut un_sequenced_bp = self.get_bp_count(start, end, b'T');
249            if self.rec.record.is_reverse() {
250                // swap the counts
251                std::mem::swap(&mut sequenced_bp, &mut un_sequenced_bp);
252            }
253            let m6a_frac = if sequenced_bp > 0 {
254                m6a_count as f32 / sequenced_bp as f32
255            } else {
256                0.0
257            };
258            m6a_count += (un_sequenced_bp as f32 * m6a_frac).round() as usize;
259        }
260        m6a_count
261    }
262
263    fn m6a_fc_over_expected(&self, m6a_count: usize, at_count: usize) -> f32 {
264        //let expected = self.frac_m6a_in_msps * at_count as f32;
265        // ^ this didn't work well
266        let expected = self.frac_m6a * at_count as f32;
267        let observed = m6a_count as f32;
268        if expected == 0.0 || observed == 0.0 {
269            return 0.0;
270        }
271        let fc = observed / expected;
272        fc.log2()
273    }
274
275    fn feats_in_range(&self, start: i64, end: i64) -> FireFeatsInRange {
276        let m6a_count = self.get_m6a_count(start, end);
277        let at_count = self.get_at_count(start, end);
278        let count_5mc = self.get_5mc_count(start, end);
279        let frac_m6a = if at_count > 0 {
280            m6a_count as f32 / at_count as f32
281        } else {
282            0.0
283        };
284        let m6a_fc = self.m6a_fc_over_expected(m6a_count, at_count);
285        let (weighted_m6a_rle, median_m6a_rle) = get_m6a_rle_data(self.rec, start, end);
286
287        FireFeatsInRange {
288            m6a_count: m6a_count as f32,
289            at_count: at_count as f32,
290            count_5mc: count_5mc as f32,
291            frac_m6a,
292            m6a_fc,
293            weighted_m6a_rle,
294            median_m6a_rle,
295        }
296    }
297
298    pub fn fire_feats_header(fire_opts: &FireOptions) -> String {
299        let mut out = "#chrom\tstart\tend\tfiber".to_string();
300        out += "\tmsp_len\tmsp_len_times_m6a_fc\tccs_passes";
301        out += "\tfiber_m6a_count\tfiber_m6a_frac";
302        out += &FireFeatsInRange::header("msp");
303        out += &FireFeatsInRange::header("best");
304        out += &FireFeatsInRange::header("worst");
305        for bin_num in 0..fire_opts.bin_num {
306            out += &FireFeatsInRange::header(&format!("bin_{bin_num}"));
307        }
308        out += "\n";
309        out
310    }
311
312    fn msp_get_fire_features(&self, start: i64, end: i64) -> Vec<f32> {
313        let msp_len = end - start;
314        // skip predicting (or outputting) on short windows
315        if msp_len < self.fire_opts.min_msp_length_for_positive_fire_call {
316            return vec![];
317        }
318        let ccs_passes = if self.fire_opts.ont { 4.0 } else { self.rec.ec };
319
320        // find the 100bp window within the range with the most m6a
321        let mut max_m6a_count = 0;
322        let mut max_m6a_start = 0;
323        let mut max_m6a_end = 0;
324        let mut min_m6a_count = usize::MAX;
325        let mut min_m6a_start = 0;
326        let mut min_m6a_end = 0;
327        let mut centering_pos = get_mid_point(start, end);
328        for st_idx in start..end {
329            let en_idx = (st_idx + self.fire_opts.best_window_size).min(end);
330            // this analysis is only interesting if we have a larger msp that could have two ~ distinct windows. Thus we need to check that we have a window larger than 2X the best window size
331            if (end - start) < (2 * self.fire_opts.best_window_size) {
332                log::trace!("MSP window is not large enough for best and worst window analysis");
333                break;
334            }
335            let m6a_count = self.get_m6a_count(st_idx, en_idx);
336            if m6a_count > max_m6a_count {
337                max_m6a_count = m6a_count;
338                max_m6a_start = st_idx;
339                max_m6a_end = en_idx;
340                // center my bins around the highest density m6A region instead of the middle of the MSP
341                centering_pos = get_mid_point(st_idx, en_idx);
342            }
343            if m6a_count < min_m6a_count {
344                min_m6a_count = m6a_count;
345                min_m6a_start = st_idx;
346                min_m6a_end = en_idx;
347            }
348            if en_idx == end {
349                break;
350            }
351        }
352        let best_fire_feats = self.feats_in_range(max_m6a_start, max_m6a_end);
353        let worst_fire_feats = self.feats_in_range(min_m6a_start, min_m6a_end);
354
355        let msp_feats = self.feats_in_range(start, end);
356        let bins = get_bins(
357            centering_pos,
358            self.fire_opts.bin_num,
359            self.fire_opts.width_bin,
360            self.rec.record.seq_len() as i64,
361        );
362        let bin_feats = bins
363            .into_iter()
364            .map(|(start, end)| self.feats_in_range(start, end))
365            .collect::<Vec<FireFeatsInRange>>();
366        let msp_len_times_m6a_fc = msp_feats.m6a_fc * (msp_len as f32);
367        let mut rtn = vec![
368            msp_len as f32,
369            msp_len_times_m6a_fc,
370            ccs_passes,
371            self.m6a_count as f32,
372            self.frac_m6a,
373        ];
374        let feat_sets = vec![&msp_feats, &best_fire_feats, &worst_fire_feats]
375            .into_iter()
376            .chain(bin_feats.iter());
377        for feat_set in feat_sets {
378            rtn.push(feat_set.m6a_count);
379            //rtn.push(feat_set.at_count);
380            //rtn.push(feat_set.count_5mc);
381            rtn.push(feat_set.frac_m6a);
382            rtn.push(feat_set.m6a_fc);
383            //rtn.push(feat_set.weighted_m6a_rle);
384            //rtn.push(feat_set.median_m6a_rle);
385        }
386        rtn
387    }
388
389    pub fn get_fire_features(&mut self) {
390        let msp_data = self.rec.msp.into_iter().collect_vec();
391        self.fire_feats = msp_data
392            .into_iter()
393            .map(|annotation| {
394                let s = annotation.start;
395                let e = annotation.end;
396                let (rs, re, _rl) = match (
397                    annotation.reference_start,
398                    annotation.reference_end,
399                    annotation.reference_length,
400                ) {
401                    (Some(rs), Some(re), Some(rl)) => (rs, re, rl),
402                    _ => (0, 0, 0),
403                };
404                (rs, re, self.msp_get_fire_features(s, e))
405            })
406            .collect();
407    }
408
409    pub fn dump_fire_feats(&self, out_buffer: &mut Box<dyn Write>) -> Result<(), anyhow::Error> {
410        for (s, e, row) in self.fire_feats.iter() {
411            if row.is_empty() {
412                continue;
413            }
414            let lead_feats = format!(
415                "{}\t{}\t{}\t{}\t",
416                self.rec.target_name,
417                s,
418                e,
419                String::from_utf8_lossy(self.rec.record.qname())
420            );
421            out_buffer.write_all(lead_feats.as_bytes())?;
422            out_buffer.write_all(row.iter().join("\t").as_bytes())?;
423            out_buffer.write_all(b"\n")?;
424        }
425        Ok(())
426    }
427
428    pub fn predict_with_xgb(
429        &self,
430        gbdt_model: &GBDT,
431        precision_converter: &MapPrecisionValues,
432    ) -> Vec<u8> {
433        let count = self.fire_feats.len();
434        if count == 0 {
435            return vec![];
436        }
437        // predict on windows of sufficient length
438        let mut gbdt_data: DataVec = Vec::new();
439        for (_st, _en, window) in self.fire_feats.iter() {
440            if window.is_empty() {
441                continue;
442            }
443            let d = Data::new_test_data(window.to_vec(), None);
444            gbdt_data.push(d);
445        }
446        let predictions_without_short_ones = gbdt_model.predict(&gbdt_data);
447
448        // convert predictions to precision values, restoring empty windows
449        let mut precisions = Vec::with_capacity(count);
450        let mut cur_pos = 0;
451        for (_st, _en, window) in self.fire_feats.iter() {
452            if window.is_empty() {
453                precisions.push(0);
454            } else {
455                let precision = precision_converter
456                    .precision_from_float(predictions_without_short_ones[cur_pos]);
457                precisions.push(precision);
458                cur_pos += 1;
459            }
460        }
461        // check outputs
462        assert_eq!(cur_pos, predictions_without_short_ones.len());
463        assert_eq!(precisions.len(), count);
464        precisions
465    }
466}
467
468#[derive(Debug, Deserialize)]
469pub struct PrecisionTable {
470    pub columns: Vec<String>,
471    /// vec of (mokapot score, mokapot q-value)
472    pub data: Vec<(f32, f32)>,
473}
474
475pub struct MapPrecisionValues {
476    pub map: BTreeMap<OrderedFloat<f32>, u8>,
477}
478
479impl MapPrecisionValues {
480    pub fn new(pt: &PrecisionTable) -> Self {
481        // set up a precision table
482        let mut map = BTreeMap::new();
483
484        for (mokapot_score, mokapot_q_value) in pt.data.iter() {
485            let precision = ((1.0 - mokapot_q_value) * 255.0).round() as u8;
486            map.insert(OrderedFloat(*mokapot_score), precision);
487        }
488        // if we dont have a zero value insert one
489        map.insert(
490            OrderedFloat(0.0),
491            *map.get(&OrderedFloat(0.0)).unwrap_or(&0),
492        );
493        Self { map }
494    }
495
496    /// function to find closest value in a btree based on precision
497    pub fn precision_from_float(&self, value: f32) -> u8 {
498        let key = OrderedFloat(value);
499        // maximum in map less than key
500        let (less_key, less_val) = self
501            .map
502            .range(..key)
503            .next_back()
504            .unwrap_or((&OrderedFloat(0.0), &0));
505        // minimum in map greater than or equal to key
506        let (more_key, more_val) = self
507            .map
508            .range(key..)
509            .next()
510            .unwrap_or((&OrderedFloat(1.0), &255));
511        if (more_key - key).abs() < (less_key - key).abs() {
512            *more_val
513        } else {
514            *less_val
515        }
516    }
517}