fibertools_rs/utils/
fire.rs

1use crate::cli::FireOptions;
2use crate::fiber::FiberseqData;
3use crate::*;
4use anyhow;
5use derive_builder::Builder;
6use gbdt::decision_tree::{Data, DataVec};
7use gbdt::gradient_boost::GBDT;
8use itertools::Itertools;
9use ordered_float::OrderedFloat;
10use serde::Deserialize;
11use std::collections::BTreeMap;
12use std::fs;
13use tempfile::NamedTempFile;
14
15pub static FIRE_MODEL: &str = include_str!("../../models/FIRE.gbdt.json");
16pub static FIRE_CONF_JSON: &str = include_str!("../../models/FIRE.conf.json");
17
18pub fn get_model(fire_opts: &FireOptions) -> (GBDT, MapPrecisionValues) {
19    let mut remove_temp_file = false;
20    // load defaults or passed in options
21    let (model_file, fdr_table_file) = match (&fire_opts.model, &fire_opts.fdr_table) {
22        (Some(model_file), Some(b)) => {
23            let fdr_table = fs::read_to_string(b).expect("Unable to read file");
24            (model_file.clone(), fdr_table)
25        }
26        _ => {
27            let temp_file = NamedTempFile::new().expect("Unable to make a temp file");
28            let (mut temp_file, path) = temp_file.keep().expect("Unable to keep temp file");
29            let temp_file_name = path
30                .as_os_str()
31                .to_str()
32                .expect("Unable to convert the path of the named temp file to an &str.");
33            temp_file
34                .write_all(FIRE_MODEL.as_bytes())
35                .expect("Unable to write file");
36            //fs::write(temp, FIRE_MODEL).expect("Unable to write file");
37            remove_temp_file = true;
38            (temp_file_name.to_string(), FIRE_CONF_JSON.to_string())
39        }
40    };
41    log::info!("Using model: {}", model_file);
42    // load model
43    let model =
44        GBDT::from_xgboost_dump(&model_file, "binary:logistic").expect("failed to load FIRE model");
45    if remove_temp_file {
46        fs::remove_file(model_file).expect("Unable to remove temp file");
47    }
48
49    // load precision table
50    let precision_table: PrecisionTable =
51        serde_json::from_str(&fdr_table_file).expect("Precision table JSON was not well-formatted");
52    let precision_converter = MapPrecisionValues::new(&precision_table);
53
54    // return
55    (model, precision_converter)
56}
57
58fn get_mid_point(start: i64, end: i64) -> i64 {
59    (start + end) / 2
60}
61
62/// ```
63/// use fibertools_rs::utils::fire::get_bins;
64/// let bins = get_bins(50, 5, 20, 200);
65/// assert_eq!(bins, vec![(0, 20), (20, 40), (40, 60), (60, 80), (80, 100)]);
66/// ```
67pub fn get_bins(mid_point: i64, bin_num: i64, bin_width: i64, max_end: i64) -> Vec<(i64, i64)> {
68    let mut bins = Vec::new();
69    for i in 0..bin_num {
70        let mut bin_start = mid_point - (bin_num / 2 - i) * bin_width - bin_width / 2;
71        let mut bin_end = bin_start + bin_width;
72        if bin_start < 0 {
73            bin_start = 0;
74        }
75        if bin_end < 0 {
76            bin_end = 0;
77        }
78        if bin_start > max_end {
79            bin_start = max_end - 1;
80        }
81        if bin_end > max_end {
82            bin_end = max_end;
83        }
84        bins.push((bin_start, bin_end));
85    }
86    bins
87}
88
89/// get the maximum and median rle of m6a in a window
90fn get_m6a_rle_data(rec: &FiberseqData, start: i64, end: i64) -> (f32, f32) {
91    let mut m6a_rles = vec![];
92    let mut max = 0;
93    let mut _max_pos = 0;
94    // if you are a position, on average you will be in an rle length of weighted_rle
95    let mut weighted_rle = 0.0;
96    for (m6a_1, m6a_2) in rec.m6a.starts.iter().flatten().tuple_windows() {
97        // we only want m6a in the window
98        if *m6a_1 < start || *m6a_1 > end || *m6a_2 < start || *m6a_2 > end {
99            continue;
100        }
101        // distance between m6a sites
102        let rle = (m6a_2 - m6a_1).abs();
103        m6a_rles.push(rle);
104        // update max
105        if rle > max {
106            max = rle;
107            _max_pos = ((*m6a_1 + *m6a_2) / 2 - (end + start) / 2).abs();
108        }
109        // update weighted rle
110        weighted_rle += (rle * rle) as f32;
111    }
112    weighted_rle /= (end - start) as f32;
113
114    if m6a_rles.is_empty() {
115        return (-1.0, -1.0);
116    }
117    let mid_length = m6a_rles.len() / 2;
118    let (_, median, _) = m6a_rles.select_nth_unstable(mid_length);
119    (weighted_rle, *median as f32)
120}
121
122const FEATS_IN_USE: [&str; 3] = [
123    "m6a_count",
124    //"at_count",
125    //"count_5mc",
126    "frac_m6a",
127    "m6a_fc",
128    //"max_m6a_rle",
129    //"max_m6a_rle_pos",
130    // "weighted_m6a_rle",
131    //"median_m6a_rle",
132];
133#[derive(Debug, Clone, Builder)]
134pub struct FireFeatsInRange {
135    pub m6a_count: f32,
136    pub at_count: f32,
137    #[allow(unused)]
138    pub count_5mc: f32,
139    pub frac_m6a: f32,
140    pub m6a_fc: f32,
141    pub weighted_m6a_rle: f32,
142    pub median_m6a_rle: f32,
143}
144
145impl FireFeatsInRange {
146    pub fn header(tag: &str) -> String {
147        let mut out = "".to_string();
148        for col in FEATS_IN_USE.iter() {
149            out += &format!("\t{}_{}", tag, col);
150        }
151        out
152    }
153}
154
155#[derive(Debug)]
156pub struct FireFeats<'a> {
157    rec: &'a FiberseqData,
158    #[allow(unused)]
159    at_count: usize,
160    m6a_count: usize,
161    frac_m6a: f32,
162    //frac_m6a_in_msps: f32,
163    fire_opts: &'a FireOptions,
164    seq: Vec<u8>,
165    fire_feats: Vec<(i64, i64, Vec<f32>)>,
166}
167
168impl<'a> FireFeats<'a> {
169    pub fn new(rec: &'a FiberseqData, fire_opts: &'a FireOptions) -> Self {
170        let seq_len = rec.record.seq_len();
171        let seq = rec.record.seq().as_bytes();
172
173        let mut rtn = Self {
174            rec,
175            at_count: 0,
176            m6a_count: 0,
177            frac_m6a: 0.0,
178            //frac_m6a_in_msps,
179            fire_opts,
180            seq,
181            fire_feats: vec![],
182        };
183
184        // add in the m6a and AT counts
185        rtn.at_count = rtn.get_at_count(0, seq_len as i64);
186        rtn.m6a_count = rtn.get_m6a_count(0, seq_len as i64);
187        rtn.frac_m6a = if rtn.at_count > 0 {
188            rtn.m6a_count as f32 / rtn.at_count as f32
189        } else {
190            0.0
191        };
192
193        rtn.get_fire_features();
194        if rtn.fire_opts.ont {
195            rtn.validate_that_ont_is_single_strand();
196        }
197        rtn
198    }
199
200    fn validate_that_ont_is_single_strand(&self) {
201        let sequenced_bp = if self.rec.record.is_reverse() {
202            b'T'
203        } else {
204            b'A'
205        };
206        for m6a_st in self.rec.m6a.starts.iter().flatten() {
207            let m6a_bp = self.seq[*m6a_st as usize];
208            if m6a_bp != sequenced_bp {
209                log::warn!(
210                    "m6A site at {} is not the same as the sequenced base {}",
211                    m6a_st,
212                    sequenced_bp as char
213                );
214            }
215        }
216    }
217
218    fn get_bp_count(&self, start: i64, end: i64, bp: u8) -> usize {
219        let subseq = &self.seq[start as usize..end as usize];
220        subseq.iter().filter(|&&b| b == bp).count()
221    }
222
223    fn get_at_count(&self, start: i64, end: i64) -> usize {
224        self.get_bp_count(start, end, b'A') + self.get_bp_count(start, end, b'T')
225    }
226
227    fn get_5mc_count(&self, start: i64, end: i64) -> usize {
228        self.rec
229            .cpg
230            .starts
231            .iter()
232            .flatten()
233            .filter(|&&pos| pos >= start && pos < end)
234            .count()
235    }
236
237    fn get_m6a_count(&self, start: i64, end: i64) -> usize {
238        let mut m6a_count = self
239            .rec
240            .m6a
241            .starts
242            .iter()
243            .flatten()
244            .filter(|&&pos| pos >= start && pos < end)
245            .count();
246
247        // estimate what the count would be if we sequenced the other strand
248        if self.fire_opts.ont {
249            let mut sequenced_bp = self.get_bp_count(start, end, b'A');
250            let mut un_sequenced_bp = self.get_bp_count(start, end, b'T');
251            if self.rec.record.is_reverse() {
252                // swap the counts
253                std::mem::swap(&mut sequenced_bp, &mut un_sequenced_bp);
254            }
255            let m6a_frac = if sequenced_bp > 0 {
256                m6a_count as f32 / sequenced_bp as f32
257            } else {
258                0.0
259            };
260            m6a_count += (un_sequenced_bp as f32 * m6a_frac).round() as usize;
261        }
262        m6a_count
263    }
264
265    fn m6a_fc_over_expected(&self, m6a_count: usize, at_count: usize) -> f32 {
266        //let expected = self.frac_m6a_in_msps * at_count as f32;
267        // ^ this didn't work well
268        let expected = self.frac_m6a * at_count as f32;
269        let observed = m6a_count as f32;
270        if expected == 0.0 || observed == 0.0 {
271            return 0.0;
272        }
273        let fc = observed / expected;
274        fc.log2()
275    }
276
277    fn feats_in_range(&self, start: i64, end: i64) -> FireFeatsInRange {
278        let m6a_count = self.get_m6a_count(start, end);
279        let at_count = self.get_at_count(start, end);
280        let count_5mc = self.get_5mc_count(start, end);
281        let frac_m6a = if at_count > 0 {
282            m6a_count as f32 / at_count as f32
283        } else {
284            0.0
285        };
286        let m6a_fc = self.m6a_fc_over_expected(m6a_count, at_count);
287        let (weighted_m6a_rle, median_m6a_rle) = get_m6a_rle_data(self.rec, start, end);
288
289        FireFeatsInRange {
290            m6a_count: m6a_count as f32,
291            at_count: at_count as f32,
292            count_5mc: count_5mc as f32,
293            frac_m6a,
294            m6a_fc,
295            weighted_m6a_rle,
296            median_m6a_rle,
297        }
298    }
299
300    pub fn fire_feats_header(fire_opts: &FireOptions) -> String {
301        let mut out = "#chrom\tstart\tend\tfiber".to_string();
302        out += "\tmsp_len\tmsp_len_times_m6a_fc\tccs_passes";
303        out += "\tfiber_m6a_count\tfiber_m6a_frac";
304        out += &FireFeatsInRange::header("msp");
305        out += &FireFeatsInRange::header("best");
306        out += &FireFeatsInRange::header("worst");
307        for bin_num in 0..fire_opts.bin_num {
308            out += &FireFeatsInRange::header(&format!("bin_{}", bin_num));
309        }
310        out += "\n";
311        out
312    }
313
314    fn msp_get_fire_features(&self, start: i64, end: i64) -> Vec<f32> {
315        let msp_len = end - start;
316        // skip predicting (or outputting) on short windows
317        if msp_len < self.fire_opts.min_msp_length_for_positive_fire_call {
318            return vec![];
319        }
320        let ccs_passes = if self.fire_opts.ont { 4.0 } else { self.rec.ec };
321
322        // find the 100bp window within the range with the most m6a
323        let mut max_m6a_count = 0;
324        let mut max_m6a_start = 0;
325        let mut max_m6a_end = 0;
326        let mut min_m6a_count = usize::MAX;
327        let mut min_m6a_start = 0;
328        let mut min_m6a_end = 0;
329        let mut centering_pos = get_mid_point(start, end);
330        for st_idx in start..end {
331            let en_idx = (st_idx + self.fire_opts.best_window_size).min(end);
332            // this analysis is only interesting if we have a larger msp that could have two ~ distinct windows. Thus we need to check that we have a window larger than 2X the best window size
333            if (end - start) < (2 * self.fire_opts.best_window_size) {
334                log::trace!("MSP window is not large enough for best and worst window analysis");
335                break;
336            }
337            let m6a_count = self.get_m6a_count(st_idx, en_idx);
338            if m6a_count > max_m6a_count {
339                max_m6a_count = m6a_count;
340                max_m6a_start = st_idx;
341                max_m6a_end = en_idx;
342                // center my bins around the highest density m6A region instead of the middle of the MSP
343                centering_pos = get_mid_point(st_idx, en_idx);
344            }
345            if m6a_count < min_m6a_count {
346                min_m6a_count = m6a_count;
347                min_m6a_start = st_idx;
348                min_m6a_end = en_idx;
349            }
350            if en_idx == end {
351                break;
352            }
353        }
354        let best_fire_feats = self.feats_in_range(max_m6a_start, max_m6a_end);
355        let worst_fire_feats = self.feats_in_range(min_m6a_start, min_m6a_end);
356
357        let msp_feats = self.feats_in_range(start, end);
358        let bins = get_bins(
359            centering_pos,
360            self.fire_opts.bin_num,
361            self.fire_opts.width_bin,
362            self.rec.record.seq_len() as i64,
363        );
364        let bin_feats = bins
365            .into_iter()
366            .map(|(start, end)| self.feats_in_range(start, end))
367            .collect::<Vec<FireFeatsInRange>>();
368        let msp_len_times_m6a_fc = msp_feats.m6a_fc * (msp_len as f32);
369        let mut rtn = vec![
370            msp_len as f32,
371            msp_len_times_m6a_fc,
372            ccs_passes,
373            self.m6a_count as f32,
374            self.frac_m6a,
375        ];
376        let feat_sets = vec![&msp_feats, &best_fire_feats, &worst_fire_feats]
377            .into_iter()
378            .chain(bin_feats.iter());
379        for feat_set in feat_sets {
380            rtn.push(feat_set.m6a_count);
381            //rtn.push(feat_set.at_count);
382            //rtn.push(feat_set.count_5mc);
383            rtn.push(feat_set.frac_m6a);
384            rtn.push(feat_set.m6a_fc);
385            //rtn.push(feat_set.weighted_m6a_rle);
386            //rtn.push(feat_set.median_m6a_rle);
387        }
388        rtn
389    }
390
391    pub fn get_fire_features(&mut self) {
392        let msp_data = self.rec.msp.into_iter().collect_vec();
393        self.fire_feats = msp_data
394            .into_iter()
395            .map(|(s, e, _l, _q, refs)| {
396                let (rs, re, _rl) = refs.unwrap_or((0, 0, 0));
397                (rs, re, self.msp_get_fire_features(s, e))
398            })
399            .collect();
400    }
401
402    pub fn dump_fire_feats(&self, out_buffer: &mut Box<dyn Write>) -> Result<(), anyhow::Error> {
403        for (s, e, row) in self.fire_feats.iter() {
404            if row.is_empty() {
405                continue;
406            }
407            let lead_feats = format!(
408                "{}\t{}\t{}\t{}\t",
409                self.rec.target_name,
410                s,
411                e,
412                String::from_utf8_lossy(self.rec.record.qname())
413            );
414            out_buffer.write_all(lead_feats.as_bytes())?;
415            out_buffer.write_all(row.iter().join("\t").as_bytes())?;
416            out_buffer.write_all(b"\n")?;
417        }
418        Ok(())
419    }
420
421    pub fn predict_with_xgb(
422        &self,
423        gbdt_model: &GBDT,
424        precision_converter: &MapPrecisionValues,
425    ) -> Vec<u8> {
426        let count = self.fire_feats.len();
427        if count == 0 {
428            return vec![];
429        }
430        // predict on windows of sufficient length
431        let mut gbdt_data: DataVec = Vec::new();
432        for (_st, _en, window) in self.fire_feats.iter() {
433            if window.is_empty() {
434                continue;
435            }
436            let d = Data::new_test_data(window.to_vec(), None);
437            gbdt_data.push(d);
438        }
439        let predictions_without_short_ones = gbdt_model.predict(&gbdt_data);
440
441        // convert predictions to precision values, restoring empty windows
442        let mut precisions = Vec::with_capacity(count);
443        let mut cur_pos = 0;
444        for (_st, _en, window) in self.fire_feats.iter() {
445            if window.is_empty() {
446                precisions.push(0);
447            } else {
448                let precision = precision_converter
449                    .precision_from_float(predictions_without_short_ones[cur_pos]);
450                precisions.push(precision);
451                cur_pos += 1;
452            }
453        }
454        // check outputs
455        assert_eq!(cur_pos, predictions_without_short_ones.len());
456        assert_eq!(precisions.len(), count);
457        precisions
458    }
459}
460
461#[derive(Debug, Deserialize)]
462pub struct PrecisionTable {
463    pub columns: Vec<String>,
464    /// vec of (mokapot score, mokapot q-value)
465    pub data: Vec<(f32, f32)>,
466}
467
468pub struct MapPrecisionValues {
469    pub map: BTreeMap<OrderedFloat<f32>, u8>,
470}
471
472impl MapPrecisionValues {
473    pub fn new(pt: &PrecisionTable) -> Self {
474        // set up a precision table
475        let mut map = BTreeMap::new();
476
477        for (mokapot_score, mokapot_q_value) in pt.data.iter() {
478            let precision = ((1.0 - mokapot_q_value) * 255.0).round() as u8;
479            map.insert(OrderedFloat(*mokapot_score), precision);
480        }
481        // if we dont have a zero value insert one
482        map.insert(
483            OrderedFloat(0.0),
484            *map.get(&OrderedFloat(0.0)).unwrap_or(&0),
485        );
486        Self { map }
487    }
488
489    /// function to find closest value in a btree based on precision
490    pub fn precision_from_float(&self, value: f32) -> u8 {
491        let key = OrderedFloat(value);
492        // maximum in map less than key
493        let (less_key, less_val) = self
494            .map
495            .range(..key)
496            .next_back()
497            .unwrap_or((&OrderedFloat(0.0), &0));
498        // minimum in map greater than or equal to key
499        let (more_key, more_val) = self
500            .map
501            .range(key..)
502            .next()
503            .unwrap_or((&OrderedFloat(1.0), &255));
504        if (more_key - key).abs() < (less_key - key).abs() {
505            *more_val
506        } else {
507            *less_val
508        }
509    }
510}