fibertools_rs/subcommands/
predict_m6a.rs

1use crate::cli::PredictM6AOptions;
2use crate::utils::basemods;
3use crate::utils::bio_io;
4use crate::utils::nucleosome;
5use crate::*;
6use bio::alphabets::dna::revcomp;
7use burn::tensor::backend::Backend;
8use fiber::FiberseqData;
9use ordered_float::OrderedFloat;
10use rayon::iter::IndexedParallelIterator;
11use rayon::iter::ParallelIterator;
12use rayon::prelude::IntoParallelRefMutIterator;
13use rust_htslib::{bam, bam::Read};
14use serde::Deserialize;
15use std::collections::BTreeMap;
16use std::sync::Once;
17
18pub const WINDOW: usize = 15;
19pub const LAYERS: usize = 6;
20pub const MIN_F32_PRED: f32 = 1.0e-46;
21static LOG_ONCE: Once = Once::new();
22// json precision tables
23pub static SEMI_JSON_2_0: &str = include_str!("../../models/2.0_semi_torch.json");
24pub static SEMI_JSON_2_2: &str = include_str!("../../models/2.2_semi_torch.json");
25pub static SEMI_JSON_3_2: &str = include_str!("../../models/3.2_semi_torch.json");
26pub static SEMI_JSON_REVIO: &str = include_str!("../../models/Revio_semi_torch.json");
27
28#[derive(Debug, Deserialize)]
29pub struct PrecisionTable {
30    pub columns: Vec<String>,
31    pub data: Vec<(f32, u8)>,
32}
33
34#[derive(Debug, Clone)]
35pub struct PredictOptions<B>
36where
37    B: Backend<Device = m6a_burn::BurnDevice>,
38{
39    pub keep: bool,
40    pub min_ml_score: Option<u8>,
41    pub all_calls: bool,
42    pub polymerase: PbChem,
43    pub batch_size: usize,
44    map: BTreeMap<OrderedFloat<f32>, u8>,
45    pub model: Vec<u8>,
46    pub min_ml: u8,
47    pub nuc_opts: cli::NucleosomeParameters,
48    pub burn_models: m6a_burn::BurnModels<B>,
49    pub fake: bool,
50}
51
52impl<B> PredictOptions<B>
53where
54    B: Backend<Device = m6a_burn::BurnDevice>,
55{
56    #[allow(clippy::too_many_arguments)]
57    pub fn new(
58        keep: bool,
59        min_ml_score: Option<u8>,
60        all_calls: bool,
61        polymerase: PbChem,
62        batch_size: usize,
63        nuc_opts: cli::NucleosomeParameters,
64        fake: bool,
65    ) -> Self {
66        // set up a precision table
67        let mut map = BTreeMap::new();
68        map.insert(OrderedFloat(0.0), 0);
69
70        // return prediction options
71        let mut options = PredictOptions {
72            keep,
73            min_ml_score,
74            all_calls,
75            polymerase: polymerase.clone(),
76            batch_size,
77            map,
78            model: vec![],
79            min_ml: 0,
80            nuc_opts,
81            burn_models: m6a_burn::BurnModels::new(&polymerase),
82            fake,
83        };
84        options.add_model().expect("Error loading model");
85        options
86    }
87
88    fn get_precision_table_and_ml(&self) -> Result<(Option<PrecisionTable>, u8)> {
89        let mut precision_json = "".to_string();
90        let min_ml = if let Ok(_file) = std::env::var("FT_MODEL") {
91            244
92        } else {
93            LOG_ONCE.call_once(|| {
94                log::info!("Using semi-supervised CNN m6A model.");
95            });
96            match self.polymerase {
97                PbChem::Two => {
98                    precision_json = SEMI_JSON_2_0.to_string();
99                    230
100                }
101                PbChem::TwoPointTwo => {
102                    precision_json = SEMI_JSON_2_2.to_string();
103                    244
104                }
105                PbChem::ThreePointTwo => {
106                    precision_json = SEMI_JSON_3_2.to_string();
107                    244
108                }
109                PbChem::Revio => {
110                    precision_json = SEMI_JSON_REVIO.to_string();
111                    254
112                }
113            }
114        };
115
116        // load precision json from env var if needed
117        if let Ok(json) = std::env::var("FT_JSON") {
118            log::info!("Loading precision table from environment variable.");
119            precision_json =
120                std::fs::read_to_string(json).expect("Unable to read file specified by FT_JSON");
121        }
122
123        // load the precision table
124        let precision_table: Option<PrecisionTable> = Some(
125            serde_json::from_str(&precision_json)
126                .expect("Precision table JSON was not well-formatted"),
127        );
128
129        // set the variables for ML
130        let final_min_ml = match self.min_ml_score {
131            Some(x) => {
132                log::info!("Using provided minimum ML tag score: {x}");
133                x
134            }
135            None => min_ml,
136        };
137        Ok((precision_table, final_min_ml))
138    }
139
140    fn add_model(&mut self) -> Result<()> {
141        self.model = vec![];
142
143        let (precision_table, min_ml) = self.get_precision_table_and_ml()?;
144
145        // load precision table into map if not None
146        if let Some(precision_table) = precision_table {
147            for (cnn_score, precision) in precision_table.data {
148                self.map.insert(OrderedFloat(cnn_score), precision);
149            }
150        }
151
152        self.min_ml = min_ml;
153        Ok(())
154    }
155
156    pub fn progress_style(&self) -> &str {
157        // {percent:>3.green}%
158        "[PREDICTING m6A] [Elapsed {elapsed:.yellow} ETA {eta:.yellow}] {bar:30.cyan/blue} {human_pos:>5.cyan}/{human_len:.blue} (batches/s {per_sec:.green})"
159    }
160
161    /// function to find closest value in a btree based on precision
162    pub fn precision_from_float(&self, value: f32) -> u8 {
163        let key = OrderedFloat(value);
164        // maximum in map less than key
165        let (less_key, less_val) = self
166            .map
167            .range(..key)
168            .next_back()
169            .unwrap_or((&OrderedFloat(0.0), &0));
170        // minimum in map greater than or equal to key
171        let (more_key, more_val) = self
172            .map
173            .range(key..)
174            .next()
175            .unwrap_or((&OrderedFloat(1.0), &255));
176        if (more_key - key).abs() < (less_key - key).abs() {
177            *more_val
178        } else {
179            *less_val
180        }
181    }
182
183    pub fn min_ml_value(&self) -> u8 {
184        if self.all_calls {
185            0
186        } else {
187            self.min_ml
188        }
189    }
190
191    pub fn float_to_u8(&self, x: f32) -> u8 {
192        self.precision_from_float(x)
193    }
194
195    /// group reads together for predictions so we have to move data to the GPU less often
196    pub fn predict_m6a_on_records(
197        opts: &Self,
198        records: Vec<&mut rust_htslib::bam::Record>,
199        //records: &mut [rust_htslib::bam::Record],
200    ) -> usize {
201        // data windows for all the records in this chunk
202        let data: Vec<Option<(DataWidows, DataWidows)>> = records
203            .iter()
204            .map(|rec| get_m6a_data_windows(rec))
205            .collect();
206        // collect ml windows into one vector
207        let mut all_ml_data = vec![];
208        let mut all_count = 0;
209        data.iter().flatten().for_each(|(a, t)| {
210            all_ml_data.extend(a.windows.clone());
211            all_count += a.count;
212            all_ml_data.extend(t.windows.clone());
213            all_count += t.count;
214        });
215        let predictions = opts.apply_model(&all_ml_data, all_count);
216        assert_eq!(predictions.len(), all_count);
217        // split ml results back to all the records and modify the MM ML tags
218        assert_eq!(data.len(), records.len());
219        let mut cur_predict_st = 0;
220        for (option_data, record) in data.iter().zip(records) {
221            // base mods in the exiting record
222            let mut cur_basemods = basemods::BaseMods::new(record, 0);
223            cur_basemods.drop_m6a();
224            log::trace!("Number of base mod types {}", cur_basemods.base_mods.len());
225            // check if there is any data
226            let (a_data, t_data) = match option_data {
227                Some((a_data, t_data)) => (a_data, t_data),
228                None => continue,
229            };
230            // iterate over A and then T basemods
231            for data in &[a_data, t_data] {
232                let cur_predict_en = cur_predict_st + data.count;
233                let cur_predictions = &predictions[cur_predict_st..cur_predict_en];
234
235                cur_predict_st += data.count;
236                cur_basemods.base_mods.push(opts.basemod_from_ml(
237                    record,
238                    cur_predictions,
239                    &data.positions,
240                    &data.base_mod,
241                ));
242            }
243            // write the ml and mm tags
244            cur_basemods.add_mm_and_ml_tags(record);
245
246            //let modified_bases_forward = cur_basemods.forward_m6a().0;
247            let modified_bases_forward = cur_basemods.m6a().get_forward_starts();
248
249            // adding the nucleosomes
250            nucleosome::add_nucleosomes_to_record(record, &modified_bases_forward, &opts.nuc_opts);
251
252            // clear the existing data
253            if !opts.keep {
254                record.remove_aux(b"fp").unwrap_or(());
255                record.remove_aux(b"fi").unwrap_or(());
256                record.remove_aux(b"rp").unwrap_or(());
257                record.remove_aux(b"ri").unwrap_or(());
258            }
259        }
260        assert_eq!(cur_predict_st, predictions.len());
261        data.iter().flatten().count()
262    }
263
264    /// Create a basemod object form our predictions
265    pub fn basemod_from_ml(
266        &self,
267        record: &mut bam::Record,
268        predictions: &[f32],
269        positions: &[usize],
270        base_mod: &str,
271    ) -> basemods::BaseMod {
272        // do not report predictions for the first and last 7 bases
273        let min_pos = (WINDOW / 2) as i64;
274        let max_pos = (record.seq_len() - WINDOW / 2) as i64;
275        let (modified_probabilities_forward, full_probabilities_forward, modified_bases_forward): (
276            Vec<u8>,
277            Vec<f32>,
278            Vec<i64>,
279        ) = predictions
280            .iter()
281            .zip(positions.iter())
282            .map(|(&x, &pos)| (self.float_to_u8(x), x, pos as i64))
283            .filter(|(ml, _, pos)| *ml >= self.min_ml_value() && *pos >= min_pos && *pos < max_pos)
284            .multiunzip();
285
286        log::debug!(
287            "Low but non zero values: {:?}\tZero values: {:?}\tlength:{:?}",
288            full_probabilities_forward
289                .iter()
290                .filter(|&x| *x <= 1.0 / 255.0)
291                .filter(|&x| *x > 0.0)
292                .count(),
293            full_probabilities_forward
294                .iter()
295                .filter(|&x| *x <= 0.0)
296                .filter(|&x| *x > -0.00000001)
297                .count(),
298            predictions.len()
299        );
300
301        let base_mod = base_mod.as_bytes();
302        let modified_base = base_mod[0];
303        let strand = base_mod[1] as char;
304        let modification_type = base_mod[2] as char;
305
306        basemods::BaseMod::new(
307            record,
308            modified_base,
309            strand,
310            modification_type,
311            modified_bases_forward,
312            modified_probabilities_forward,
313        )
314    }
315
316    pub fn apply_model(&self, windows: &[f32], count: usize) -> Vec<f32> {
317        self.burn_models.forward(self, windows, count)
318    }
319
320    fn _fake_apply_model(&self, _: &[f32], count: usize) -> Vec<f32> {
321        vec![0.0; count]
322    }
323}
324
325/// ```
326/// use fibertools_rs::subcommands::predict_m6a::hot_one_dna;
327/// let x: Vec<u8> = vec![b'A', b'G', b'T', b'C', b'A'];
328/// let ho = hot_one_dna(&x);
329/// let e: Vec<f32> = vec![
330///                          1.0, 0.0, 0.0, 0.0, 1.0,
331///                          0.0, 0.0, 0.0, 1.0, 0.0,
332///                          0.0, 1.0, 0.0, 0.0, 0.0,
333///                          0.0, 0.0, 1.0, 0.0, 0.0
334///                         ];
335/// assert_eq!(ho, e);
336/// ```
337pub fn hot_one_dna(seq: &[u8]) -> Vec<f32> {
338    let len = seq.len() * 4;
339    let mut out = vec![0.0; len];
340    for (row, base) in [b'A', b'C', b'G', b'T'].into_iter().enumerate() {
341        let already_done = seq.len() * row;
342        for i in 0..seq.len() {
343            if seq[i] == base {
344                out[already_done + i] = 1.0;
345            }
346        }
347    }
348    out
349}
350
351struct DataWidows {
352    pub windows: Vec<f32>,
353    pub positions: Vec<usize>,
354    pub count: usize,
355    pub base_mod: String,
356}
357
358fn get_m6a_data_windows(record: &bam::Record) -> Option<(DataWidows, DataWidows)> {
359    // skip invalid or redundant records
360    if record.is_secondary() {
361        log::warn!(
362            "Skipping secondary alignment of {}",
363            String::from_utf8_lossy(record.qname())
364        );
365        return None;
366    }
367
368    let extend = WINDOW / 2;
369    let mut f_ip = bio_io::get_u8_tag(record, b"fi");
370    let r_ip;
371    let f_pw;
372    let r_pw;
373    // check if we maybe are getting u16 input instead of u8
374    if f_ip.is_empty() {
375        f_ip = bio_io::get_pb_u16_tag_as_u8(record, b"fi");
376        if f_ip.is_empty() {
377            // missing u16 as well, set all to empty arrays
378            r_ip = vec![];
379            f_pw = vec![];
380            r_pw = vec![];
381        } else {
382            r_ip = bio_io::get_pb_u16_tag_as_u8(record, b"ri");
383            f_pw = bio_io::get_pb_u16_tag_as_u8(record, b"fp");
384            r_pw = bio_io::get_pb_u16_tag_as_u8(record, b"rp");
385        }
386    } else {
387        r_ip = bio_io::get_u8_tag(record, b"ri");
388        f_pw = bio_io::get_u8_tag(record, b"fp");
389        r_pw = bio_io::get_u8_tag(record, b"rp");
390    }
391    // return if missing kinetics
392    if f_ip.is_empty() || r_ip.is_empty() || f_pw.is_empty() || r_pw.is_empty() {
393        log::debug!(
394            "Hifi kinetics are missing for: {}",
395            String::from_utf8_lossy(record.qname())
396        );
397        return None;
398    }
399    // reverse for reverse strand
400    let r_ip = r_ip.into_iter().rev().collect::<Vec<_>>();
401    let r_pw = r_pw.into_iter().rev().collect::<Vec<_>>();
402
403    let mut seq = record.seq().as_bytes();
404    if record.is_reverse() {
405        seq = revcomp(seq);
406    }
407
408    assert_eq!(f_ip.len(), seq.len());
409    let mut a_count = 0;
410    let mut t_count = 0;
411    let mut a_windows = vec![];
412    let mut t_windows = vec![];
413    let mut a_positions = vec![];
414    let mut t_positions = vec![];
415    for (pos, base) in seq.iter().enumerate() {
416        if !((*base == b'A') || (*base == b'T')) {
417            continue;
418        }
419        // get the data window
420        let data_window = if (pos < extend) || (pos + extend + 1 > record.seq_len()) {
421            // make fake data for leading and trailing As
422            vec![0.0; WINDOW * LAYERS]
423        } else {
424            let start = pos - extend;
425            let end = pos + extend + 1;
426            let ip: Vec<f32>;
427            let pw: Vec<f32>;
428            let hot_one;
429            if *base == b'A' {
430                let w_seq = &revcomp(&seq[start..end]);
431                hot_one = hot_one_dna(w_seq);
432                ip = (r_ip[start..end])
433                    .iter()
434                    .copied()
435                    .rev()
436                    .map(|x| x as f32 / 255.0)
437                    .collect();
438                pw = (r_pw[start..end])
439                    .iter()
440                    .copied()
441                    .rev()
442                    .map(|x| x as f32 / 255.0)
443                    .collect();
444            } else {
445                let w_seq = &seq[start..end];
446                hot_one = hot_one_dna(w_seq);
447                ip = (f_ip[start..end])
448                    .iter()
449                    .copied()
450                    .map(|x| x as f32 / 255.0)
451                    .collect();
452                pw = (f_pw[start..end])
453                    .iter()
454                    .copied()
455                    .map(|x| x as f32 / 255.0)
456                    .collect();
457            }
458            let mut data_window = vec![];
459            data_window.extend(hot_one);
460            data_window.extend(ip);
461            data_window.extend(pw);
462            data_window
463        };
464
465        // add to data windows and record positions
466        if *base == b'A' {
467            a_windows.extend(data_window);
468            a_count += 1;
469            a_positions.push(pos);
470        } else {
471            t_windows.extend(data_window);
472            t_count += 1;
473            t_positions.push(pos);
474        }
475    }
476    let a_data = DataWidows {
477        windows: a_windows,
478        positions: a_positions,
479        count: a_count,
480        base_mod: "A+a".to_string(),
481    };
482    let t_data = DataWidows {
483        windows: t_windows,
484        positions: t_positions,
485        count: t_count,
486        base_mod: "T-a".to_string(),
487    };
488    Some((a_data, t_data))
489}
490
491pub fn read_bam_into_fiberdata(opts: &mut PredictM6AOptions) {
492    let mut bam = opts.input.bam_reader();
493    let mut out = opts.input.bam_writer(&opts.out);
494    let header = bam::Header::from_template(bam.header());
495    // log the options
496    log::info!(
497        "{} reads included at once in batch prediction.",
498        opts.batch_size
499    );
500
501    #[cfg(feature = "tch")]
502    type MlBackend = burn::backend::LibTorch;
503    #[cfg(feature = "tch")]
504    log::info!("Using LibTorch for ML backend.");
505
506    #[cfg(not(feature = "tch"))]
507    type MlBackend = burn::backend::Candle;
508    #[cfg(not(feature = "tch"))]
509    log::info!("Using Candle for ML backend.");
510
511    // switch to the internal predict options
512    let predict_options: PredictOptions<MlBackend> = PredictOptions::new(
513        opts.keep,
514        opts.force_min_ml_score,
515        opts.all_calls,
516        find_pb_polymerase(&header),
517        opts.batch_size,
518        opts.nuc.clone(),
519        opts.fake,
520    );
521    // get default fire options
522    let fire_opts = crate::cli::FireOptions::default();
523    let (model, precision_table) = crate::utils::fire::get_model(&fire_opts);
524
525    // read in bam data
526    let bam_chunk_iter = BamChunk::new(bam.records(), None);
527    // iterate over chunks
528    for mut chunk in bam_chunk_iter {
529        // add m6a calls
530
531        let number_of_reads_with_predictions = chunk
532            .par_iter_mut()
533            .chunks(predict_options.batch_size)
534            .map(|records| {
535                // Create a fresh PredictOptions instance for this thread
536                let thread_opts = PredictOptions::<MlBackend>::new(
537                    predict_options.all_calls,
538                    predict_options.min_ml_score,
539                    predict_options.all_calls,
540                    predict_options.polymerase.clone(),
541                    predict_options.batch_size,
542                    predict_options.nuc_opts.clone(),
543                    predict_options.fake,
544                );
545                PredictOptions::predict_m6a_on_records(&thread_opts, records)
546            })
547            .sum::<usize>() as f32;
548
549        let frac_called = number_of_reads_with_predictions / chunk.len() as f32;
550        if frac_called < 0.05 {
551            log::warn!("More than 5% ({:.2}%) of reads were not predicted on. Are HiFi kinetics missing from this file? Enable Debug logging level to show which reads lack kinetics.", 100.0-100.0*frac_called);
552        }
553
554        // covert to FiberData and do FIRE predictions
555        let mut fd_recs =
556            FiberseqData::from_records(chunk, &opts.input.header_view(), &opts.input.filters);
557        fd_recs.par_iter_mut().for_each(|fd| {
558            crate::subcommands::fire::add_fire_to_rec(fd, &fire_opts, &model, &precision_table);
559        });
560
561        // write to output
562        fd_recs.iter().for_each(|fd| out.write(&fd.record).unwrap());
563    }
564}
565
566/// tests
567#[cfg(test)]
568mod tests {
569    use super::*;
570    #[test]
571    fn test_precision_json_validity() {
572        for file in [SEMI_JSON_2_0, SEMI_JSON_2_2, SEMI_JSON_3_2, SEMI_JSON_REVIO] {
573            let _p: PrecisionTable =
574                serde_json::from_str(file).expect("Precision table JSON was not well-formatted");
575        }
576    }
577}