fibertools_rs/subcommands/
predict_m6a.rs

1use crate::cli::PredictM6AOptions;
2use crate::utils::basemods;
3use crate::utils::bio_io;
4use crate::utils::nucleosome;
5use crate::*;
6use bio::alphabets::dna::revcomp;
7use burn::tensor::backend::Backend;
8use fiber::FiberseqData;
9use ordered_float::OrderedFloat;
10use rayon::iter::IndexedParallelIterator;
11use rayon::iter::ParallelIterator;
12use rayon::prelude::IntoParallelRefMutIterator;
13use rust_htslib::{bam, bam::Read};
14use serde::Deserialize;
15use std::collections::BTreeMap;
16
17pub const WINDOW: usize = 15;
18pub const LAYERS: usize = 6;
19pub const MIN_F32_PRED: f32 = 1.0e-46;
20// json precision tables
21pub static SEMI_JSON_2_0: &str = include_str!("../../models/2.0_semi_torch.json");
22pub static SEMI_JSON_2_2: &str = include_str!("../../models/2.2_semi_torch.json");
23pub static SEMI_JSON_3_2: &str = include_str!("../../models/3.2_semi_torch.json");
24pub static SEMI_JSON_REVIO: &str = include_str!("../../models/Revio_semi_torch.json");
25
26#[derive(Debug, Deserialize)]
27pub struct PrecisionTable {
28    pub columns: Vec<String>,
29    pub data: Vec<(f32, u8)>,
30}
31
32#[derive(Debug, Clone)]
33pub struct PredictOptions<B>
34where
35    B: Backend<Device = m6a_burn::BurnDevice>,
36{
37    pub keep: bool,
38    pub min_ml_score: Option<u8>,
39    pub all_calls: bool,
40    pub polymerase: PbChem,
41    pub batch_size: usize,
42    map: BTreeMap<OrderedFloat<f32>, u8>,
43    pub model: Vec<u8>,
44    pub min_ml: u8,
45    pub nuc_opts: cli::NucleosomeParameters,
46    pub burn_models: m6a_burn::BurnModels<B>,
47    pub fake: bool,
48}
49
50impl<B> PredictOptions<B>
51where
52    B: Backend<Device = m6a_burn::BurnDevice>,
53{
54    #[allow(clippy::too_many_arguments)]
55    pub fn new(
56        keep: bool,
57        min_ml_score: Option<u8>,
58        all_calls: bool,
59        polymerase: PbChem,
60        batch_size: usize,
61        nuc_opts: cli::NucleosomeParameters,
62        fake: bool,
63    ) -> Self {
64        // set up a precision table
65        let mut map = BTreeMap::new();
66        map.insert(OrderedFloat(0.0), 0);
67
68        // return prediction options
69        let mut options = PredictOptions {
70            keep,
71            min_ml_score,
72            all_calls,
73            polymerase: polymerase.clone(),
74            batch_size,
75            map,
76            model: vec![],
77            min_ml: 0,
78            nuc_opts,
79            burn_models: m6a_burn::BurnModels::new(&polymerase),
80            fake,
81        };
82        options.add_model().expect("Error loading model");
83        options
84    }
85
86    fn get_precision_table_and_ml(&self) -> Result<(Option<PrecisionTable>, u8)> {
87        let mut precision_json = "".to_string();
88        let min_ml = if let Ok(_file) = std::env::var("FT_MODEL") {
89            244
90        } else {
91            log::info!("Using semi-supervised CNN m6A model.");
92            match self.polymerase {
93                PbChem::Two => {
94                    precision_json = SEMI_JSON_2_0.to_string();
95                    230
96                }
97                PbChem::TwoPointTwo => {
98                    precision_json = SEMI_JSON_2_2.to_string();
99                    244
100                }
101                PbChem::ThreePointTwo => {
102                    precision_json = SEMI_JSON_3_2.to_string();
103                    244
104                }
105                PbChem::Revio => {
106                    precision_json = SEMI_JSON_REVIO.to_string();
107                    254
108                }
109            }
110        };
111
112        // load precision json from env var if needed
113        if let Ok(json) = std::env::var("FT_JSON") {
114            log::info!("Loading precision table from environment variable.");
115            precision_json =
116                std::fs::read_to_string(json).expect("Unable to read file specified by FT_JSON");
117        }
118
119        // load the precision table
120        let precision_table: Option<PrecisionTable> = Some(
121            serde_json::from_str(&precision_json)
122                .expect("Precision table JSON was not well-formatted"),
123        );
124
125        // set the variables for ML
126        let final_min_ml = match self.min_ml_score {
127            Some(x) => {
128                log::info!("Using provided minimum ML tag score: {}", x);
129                x
130            }
131            None => min_ml,
132        };
133        Ok((precision_table, final_min_ml))
134    }
135
136    fn add_model(&mut self) -> Result<()> {
137        self.model = vec![];
138
139        let (precision_table, min_ml) = self.get_precision_table_and_ml()?;
140
141        // load precision table into map if not None
142        if let Some(precision_table) = precision_table {
143            for (cnn_score, precision) in precision_table.data {
144                self.map.insert(OrderedFloat(cnn_score), precision);
145            }
146        }
147
148        self.min_ml = min_ml;
149        Ok(())
150    }
151
152    pub fn progress_style(&self) -> &str {
153        // {percent:>3.green}%
154        "[PREDICTING m6A] [Elapsed {elapsed:.yellow} ETA {eta:.yellow}] {bar:30.cyan/blue} {human_pos:>5.cyan}/{human_len:.blue} (batches/s {per_sec:.green})"
155    }
156
157    /// function to find closest value in a btree based on precision
158    pub fn precision_from_float(&self, value: f32) -> u8 {
159        let key = OrderedFloat(value);
160        // maximum in map less than key
161        let (less_key, less_val) = self
162            .map
163            .range(..key)
164            .next_back()
165            .unwrap_or((&OrderedFloat(0.0), &0));
166        // minimum in map greater than or equal to key
167        let (more_key, more_val) = self
168            .map
169            .range(key..)
170            .next()
171            .unwrap_or((&OrderedFloat(1.0), &255));
172        if (more_key - key).abs() < (less_key - key).abs() {
173            *more_val
174        } else {
175            *less_val
176        }
177    }
178
179    pub fn min_ml_value(&self) -> u8 {
180        if self.all_calls {
181            0
182        } else {
183            self.min_ml
184        }
185    }
186
187    pub fn float_to_u8(&self, x: f32) -> u8 {
188        self.precision_from_float(x)
189    }
190
191    /// group reads together for predictions so we have to move data to the GPU less often
192    pub fn predict_m6a_on_records(
193        opts: &Self,
194        records: Vec<&mut rust_htslib::bam::Record>,
195        //records: &mut [rust_htslib::bam::Record],
196    ) -> usize {
197        // data windows for all the records in this chunk
198        let data: Vec<Option<(DataWidows, DataWidows)>> = records
199            .iter()
200            .map(|rec| get_m6a_data_windows(rec))
201            .collect();
202        // collect ml windows into one vector
203        let mut all_ml_data = vec![];
204        let mut all_count = 0;
205        data.iter().flatten().for_each(|(a, t)| {
206            all_ml_data.extend(a.windows.clone());
207            all_count += a.count;
208            all_ml_data.extend(t.windows.clone());
209            all_count += t.count;
210        });
211        let predictions = opts.apply_model(&all_ml_data, all_count);
212        assert_eq!(predictions.len(), all_count);
213        // split ml results back to all the records and modify the MM ML tags
214        assert_eq!(data.len(), records.len());
215        let mut cur_predict_st = 0;
216        for (option_data, record) in data.iter().zip(records) {
217            // base mods in the exiting record
218            let mut cur_basemods = basemods::BaseMods::new(record, 0);
219            cur_basemods.drop_m6a();
220            log::trace!("Number of base mod types {}", cur_basemods.base_mods.len());
221            // check if there is any data
222            let (a_data, t_data) = match option_data {
223                Some((a_data, t_data)) => (a_data, t_data),
224                None => continue,
225            };
226            // iterate over A and then T basemods
227            for data in &[a_data, t_data] {
228                let cur_predict_en = cur_predict_st + data.count;
229                let cur_predictions = &predictions[cur_predict_st..cur_predict_en];
230
231                cur_predict_st += data.count;
232                cur_basemods.base_mods.push(opts.basemod_from_ml(
233                    record,
234                    cur_predictions,
235                    &data.positions,
236                    &data.base_mod,
237                ));
238            }
239            // write the ml and mm tags
240            cur_basemods.add_mm_and_ml_tags(record);
241
242            //let modified_bases_forward = cur_basemods.forward_m6a().0;
243            let modified_bases_forward = cur_basemods.m6a().get_forward_starts();
244
245            // adding the nucleosomes
246            nucleosome::add_nucleosomes_to_record(record, &modified_bases_forward, &opts.nuc_opts);
247
248            // clear the existing data
249            if !opts.keep {
250                record.remove_aux(b"fp").unwrap_or(());
251                record.remove_aux(b"fi").unwrap_or(());
252                record.remove_aux(b"rp").unwrap_or(());
253                record.remove_aux(b"ri").unwrap_or(());
254            }
255        }
256        assert_eq!(cur_predict_st, predictions.len());
257        data.iter().flatten().count()
258    }
259
260    /// Create a basemod object form our predictions
261    pub fn basemod_from_ml(
262        &self,
263        record: &mut bam::Record,
264        predictions: &[f32],
265        positions: &[usize],
266        base_mod: &str,
267    ) -> basemods::BaseMod {
268        // do not report predictions for the first and last 7 bases
269        let min_pos = (WINDOW / 2) as i64;
270        let max_pos = (record.seq_len() - WINDOW / 2) as i64;
271        let (modified_probabilities_forward, full_probabilities_forward, modified_bases_forward): (
272            Vec<u8>,
273            Vec<f32>,
274            Vec<i64>,
275        ) = predictions
276            .iter()
277            .zip(positions.iter())
278            .map(|(&x, &pos)| (self.float_to_u8(x), x, pos as i64))
279            .filter(|(ml, _, pos)| *ml >= self.min_ml_value() && *pos >= min_pos && *pos < max_pos)
280            .multiunzip();
281
282        log::debug!(
283            "Low but non zero values: {:?}\tZero values: {:?}\tlength:{:?}",
284            full_probabilities_forward
285                .iter()
286                .filter(|&x| *x <= 1.0 / 255.0)
287                .filter(|&x| *x > 0.0)
288                .count(),
289            full_probabilities_forward
290                .iter()
291                .filter(|&x| *x <= 0.0)
292                .filter(|&x| *x > -0.00000001)
293                .count(),
294            predictions.len()
295        );
296
297        let base_mod = base_mod.as_bytes();
298        let modified_base = base_mod[0];
299        let strand = base_mod[1] as char;
300        let modification_type = base_mod[2] as char;
301
302        basemods::BaseMod::new(
303            record,
304            modified_base,
305            strand,
306            modification_type,
307            modified_bases_forward,
308            modified_probabilities_forward,
309        )
310    }
311
312    pub fn apply_model(&self, windows: &[f32], count: usize) -> Vec<f32> {
313        self.burn_models.forward(self, windows, count)
314    }
315
316    fn _fake_apply_model(&self, _: &[f32], count: usize) -> Vec<f32> {
317        vec![0.0; count]
318    }
319}
320
321/// ```
322/// use fibertools_rs::subcommands::predict_m6a::hot_one_dna;
323/// let x: Vec<u8> = vec![b'A', b'G', b'T', b'C', b'A'];
324/// let ho = hot_one_dna(&x);
325/// let e: Vec<f32> = vec![
326///                          1.0, 0.0, 0.0, 0.0, 1.0,
327///                          0.0, 0.0, 0.0, 1.0, 0.0,
328///                          0.0, 1.0, 0.0, 0.0, 0.0,
329///                          0.0, 0.0, 1.0, 0.0, 0.0
330///                         ];
331/// assert_eq!(ho, e);
332/// ```
333pub fn hot_one_dna(seq: &[u8]) -> Vec<f32> {
334    let len = seq.len() * 4;
335    let mut out = vec![0.0; len];
336    for (row, base) in [b'A', b'C', b'G', b'T'].into_iter().enumerate() {
337        let already_done = seq.len() * row;
338        for i in 0..seq.len() {
339            if seq[i] == base {
340                out[already_done + i] = 1.0;
341            }
342        }
343    }
344    out
345}
346
347struct DataWidows {
348    pub windows: Vec<f32>,
349    pub positions: Vec<usize>,
350    pub count: usize,
351    pub base_mod: String,
352}
353
354fn get_m6a_data_windows(record: &bam::Record) -> Option<(DataWidows, DataWidows)> {
355    // skip invalid or redundant records
356    if record.is_secondary() {
357        log::warn!(
358            "Skipping secondary alignment of {}",
359            String::from_utf8_lossy(record.qname())
360        );
361        return None;
362    }
363
364    let extend = WINDOW / 2;
365    let mut f_ip = bio_io::get_u8_tag(record, b"fi");
366    let r_ip;
367    let f_pw;
368    let r_pw;
369    // check if we maybe are getting u16 input instead of u8
370    if f_ip.is_empty() {
371        f_ip = bio_io::get_pb_u16_tag_as_u8(record, b"fi");
372        if f_ip.is_empty() {
373            // missing u16 as well, set all to empty arrays
374            r_ip = vec![];
375            f_pw = vec![];
376            r_pw = vec![];
377        } else {
378            r_ip = bio_io::get_pb_u16_tag_as_u8(record, b"ri");
379            f_pw = bio_io::get_pb_u16_tag_as_u8(record, b"fp");
380            r_pw = bio_io::get_pb_u16_tag_as_u8(record, b"rp");
381        }
382    } else {
383        r_ip = bio_io::get_u8_tag(record, b"ri");
384        f_pw = bio_io::get_u8_tag(record, b"fp");
385        r_pw = bio_io::get_u8_tag(record, b"rp");
386    }
387    // return if missing kinetics
388    if f_ip.is_empty() || r_ip.is_empty() || f_pw.is_empty() || r_pw.is_empty() {
389        log::debug!(
390            "Hifi kinetics are missing for: {}",
391            String::from_utf8_lossy(record.qname())
392        );
393        return None;
394    }
395    // reverse for reverse strand
396    let r_ip = r_ip.into_iter().rev().collect::<Vec<_>>();
397    let r_pw = r_pw.into_iter().rev().collect::<Vec<_>>();
398
399    let mut seq = record.seq().as_bytes();
400    if record.is_reverse() {
401        seq = revcomp(seq);
402    }
403
404    assert_eq!(f_ip.len(), seq.len());
405    let mut a_count = 0;
406    let mut t_count = 0;
407    let mut a_windows = vec![];
408    let mut t_windows = vec![];
409    let mut a_positions = vec![];
410    let mut t_positions = vec![];
411    for (pos, base) in seq.iter().enumerate() {
412        if !((*base == b'A') || (*base == b'T')) {
413            continue;
414        }
415        // get the data window
416        let data_window = if (pos < extend) || (pos + extend + 1 > record.seq_len()) {
417            // make fake data for leading and trailing As
418            vec![0.0; WINDOW * LAYERS]
419        } else {
420            let start = pos - extend;
421            let end = pos + extend + 1;
422            let ip: Vec<f32>;
423            let pw: Vec<f32>;
424            let hot_one;
425            if *base == b'A' {
426                let w_seq = &revcomp(&seq[start..end]);
427                hot_one = hot_one_dna(w_seq);
428                ip = (r_ip[start..end])
429                    .iter()
430                    .copied()
431                    .rev()
432                    .map(|x| x as f32 / 255.0)
433                    .collect();
434                pw = (r_pw[start..end])
435                    .iter()
436                    .copied()
437                    .rev()
438                    .map(|x| x as f32 / 255.0)
439                    .collect();
440            } else {
441                let w_seq = &seq[start..end];
442                hot_one = hot_one_dna(w_seq);
443                ip = (f_ip[start..end])
444                    .iter()
445                    .copied()
446                    .map(|x| x as f32 / 255.0)
447                    .collect();
448                pw = (f_pw[start..end])
449                    .iter()
450                    .copied()
451                    .map(|x| x as f32 / 255.0)
452                    .collect();
453            }
454            let mut data_window = vec![];
455            data_window.extend(hot_one);
456            data_window.extend(ip);
457            data_window.extend(pw);
458            data_window
459        };
460
461        // add to data windows and record positions
462        if *base == b'A' {
463            a_windows.extend(data_window);
464            a_count += 1;
465            a_positions.push(pos);
466        } else {
467            t_windows.extend(data_window);
468            t_count += 1;
469            t_positions.push(pos);
470        }
471    }
472    let a_data = DataWidows {
473        windows: a_windows,
474        positions: a_positions,
475        count: a_count,
476        base_mod: "A+a".to_string(),
477    };
478    let t_data = DataWidows {
479        windows: t_windows,
480        positions: t_positions,
481        count: t_count,
482        base_mod: "T-a".to_string(),
483    };
484    Some((a_data, t_data))
485}
486
487pub fn read_bam_into_fiberdata(opts: &mut PredictM6AOptions) {
488    let mut bam = opts.input.bam_reader();
489    let mut out = opts.input.bam_writer(&opts.out);
490    let header = bam::Header::from_template(bam.header());
491    // log the options
492    log::info!(
493        "{} reads included at once in batch prediction.",
494        opts.batch_size
495    );
496
497    #[cfg(feature = "tch")]
498    type MlBackend = burn::backend::LibTorch;
499    #[cfg(feature = "tch")]
500    log::info!("Using LibTorch for ML backend.");
501
502    #[cfg(not(feature = "tch"))]
503    type MlBackend = burn::backend::Candle;
504    #[cfg(not(feature = "tch"))]
505    log::info!("Using Candle for ML backend.");
506
507    // switch to the internal predict options
508    let predict_options: PredictOptions<MlBackend> = PredictOptions::new(
509        opts.keep,
510        opts.force_min_ml_score,
511        opts.all_calls,
512        find_pb_polymerase(&header),
513        opts.batch_size,
514        opts.nuc.clone(),
515        opts.fake,
516    );
517    // get default fire options
518    let fire_opts = crate::cli::FireOptions::default();
519    let (model, precision_table) = crate::utils::fire::get_model(&fire_opts);
520
521    // read in bam data
522    let bam_chunk_iter = BamChunk::new(bam.records(), None);
523    // iterate over chunks
524    for mut chunk in bam_chunk_iter {
525        // add m6a calls
526        let number_of_reads_with_predictions = chunk
527            .par_iter_mut()
528            .chunks(predict_options.batch_size)
529            .map(|records| PredictOptions::predict_m6a_on_records(&predict_options, records))
530            .sum::<usize>() as f32;
531
532        let frac_called = number_of_reads_with_predictions / chunk.len() as f32;
533        if frac_called < 0.05 {
534            log::warn!("More than 5% ({:.2}%) of reads were not predicted on. Are HiFi kinetics missing from this file? Enable Debug logging level to show which reads lack kinetics.", 100.0-100.0*frac_called);
535        }
536
537        // covert to FiberData and do FIRE predictions
538        let mut fd_recs =
539            FiberseqData::from_records(chunk, &opts.input.header_view(), &opts.input.filters);
540        fd_recs.par_iter_mut().for_each(|fd| {
541            crate::subcommands::fire::add_fire_to_rec(fd, &fire_opts, &model, &precision_table);
542        });
543
544        // write to output
545        fd_recs.iter().for_each(|fd| out.write(&fd.record).unwrap());
546    }
547}
548
549/// tests
550#[cfg(test)]
551mod tests {
552    use super::*;
553    #[test]
554    fn test_precision_json_validity() {
555        for file in [SEMI_JSON_2_0, SEMI_JSON_2_2, SEMI_JSON_3_2, SEMI_JSON_REVIO] {
556            let _p: PrecisionTable =
557                serde_json::from_str(file).expect("Precision table JSON was not well-formatted");
558        }
559    }
560}