fibertools_rs/subcommands/
predict_m6a.rs1use crate::cli::PredictM6AOptions;
2use crate::utils::basemods;
3use crate::utils::bio_io;
4use crate::utils::nucleosome;
5use crate::*;
6use bio::alphabets::dna::revcomp;
7use burn::tensor::backend::Backend;
8use fiber::FiberseqData;
9use ordered_float::OrderedFloat;
10use rayon::iter::IndexedParallelIterator;
11use rayon::iter::ParallelIterator;
12use rayon::prelude::IntoParallelRefMutIterator;
13use rust_htslib::{bam, bam::Read};
14use serde::Deserialize;
15use std::collections::BTreeMap;
16
17pub const WINDOW: usize = 15;
18pub const LAYERS: usize = 6;
19pub const MIN_F32_PRED: f32 = 1.0e-46;
20pub static SEMI_JSON_2_0: &str = include_str!("../../models/2.0_semi_torch.json");
22pub static SEMI_JSON_2_2: &str = include_str!("../../models/2.2_semi_torch.json");
23pub static SEMI_JSON_3_2: &str = include_str!("../../models/3.2_semi_torch.json");
24pub static SEMI_JSON_REVIO: &str = include_str!("../../models/Revio_semi_torch.json");
25
26#[derive(Debug, Deserialize)]
27pub struct PrecisionTable {
28 pub columns: Vec<String>,
29 pub data: Vec<(f32, u8)>,
30}
31
32#[derive(Debug, Clone)]
33pub struct PredictOptions<B>
34where
35 B: Backend<Device = m6a_burn::BurnDevice>,
36{
37 pub keep: bool,
38 pub min_ml_score: Option<u8>,
39 pub all_calls: bool,
40 pub polymerase: PbChem,
41 pub batch_size: usize,
42 map: BTreeMap<OrderedFloat<f32>, u8>,
43 pub model: Vec<u8>,
44 pub min_ml: u8,
45 pub nuc_opts: cli::NucleosomeParameters,
46 pub burn_models: m6a_burn::BurnModels<B>,
47 pub fake: bool,
48}
49
50impl<B> PredictOptions<B>
51where
52 B: Backend<Device = m6a_burn::BurnDevice>,
53{
54 #[allow(clippy::too_many_arguments)]
55 pub fn new(
56 keep: bool,
57 min_ml_score: Option<u8>,
58 all_calls: bool,
59 polymerase: PbChem,
60 batch_size: usize,
61 nuc_opts: cli::NucleosomeParameters,
62 fake: bool,
63 ) -> Self {
64 let mut map = BTreeMap::new();
66 map.insert(OrderedFloat(0.0), 0);
67
68 let mut options = PredictOptions {
70 keep,
71 min_ml_score,
72 all_calls,
73 polymerase: polymerase.clone(),
74 batch_size,
75 map,
76 model: vec![],
77 min_ml: 0,
78 nuc_opts,
79 burn_models: m6a_burn::BurnModels::new(&polymerase),
80 fake,
81 };
82 options.add_model().expect("Error loading model");
83 options
84 }
85
86 fn get_precision_table_and_ml(&self) -> Result<(Option<PrecisionTable>, u8)> {
87 let mut precision_json = "".to_string();
88 let min_ml = if let Ok(_file) = std::env::var("FT_MODEL") {
89 244
90 } else {
91 log::info!("Using semi-supervised CNN m6A model.");
92 match self.polymerase {
93 PbChem::Two => {
94 precision_json = SEMI_JSON_2_0.to_string();
95 230
96 }
97 PbChem::TwoPointTwo => {
98 precision_json = SEMI_JSON_2_2.to_string();
99 244
100 }
101 PbChem::ThreePointTwo => {
102 precision_json = SEMI_JSON_3_2.to_string();
103 244
104 }
105 PbChem::Revio => {
106 precision_json = SEMI_JSON_REVIO.to_string();
107 254
108 }
109 }
110 };
111
112 if let Ok(json) = std::env::var("FT_JSON") {
114 log::info!("Loading precision table from environment variable.");
115 precision_json =
116 std::fs::read_to_string(json).expect("Unable to read file specified by FT_JSON");
117 }
118
119 let precision_table: Option<PrecisionTable> = Some(
121 serde_json::from_str(&precision_json)
122 .expect("Precision table JSON was not well-formatted"),
123 );
124
125 let final_min_ml = match self.min_ml_score {
127 Some(x) => {
128 log::info!("Using provided minimum ML tag score: {}", x);
129 x
130 }
131 None => min_ml,
132 };
133 Ok((precision_table, final_min_ml))
134 }
135
136 fn add_model(&mut self) -> Result<()> {
137 self.model = vec![];
138
139 let (precision_table, min_ml) = self.get_precision_table_and_ml()?;
140
141 if let Some(precision_table) = precision_table {
143 for (cnn_score, precision) in precision_table.data {
144 self.map.insert(OrderedFloat(cnn_score), precision);
145 }
146 }
147
148 self.min_ml = min_ml;
149 Ok(())
150 }
151
152 pub fn progress_style(&self) -> &str {
153 "[PREDICTING m6A] [Elapsed {elapsed:.yellow} ETA {eta:.yellow}] {bar:30.cyan/blue} {human_pos:>5.cyan}/{human_len:.blue} (batches/s {per_sec:.green})"
155 }
156
157 pub fn precision_from_float(&self, value: f32) -> u8 {
159 let key = OrderedFloat(value);
160 let (less_key, less_val) = self
162 .map
163 .range(..key)
164 .next_back()
165 .unwrap_or((&OrderedFloat(0.0), &0));
166 let (more_key, more_val) = self
168 .map
169 .range(key..)
170 .next()
171 .unwrap_or((&OrderedFloat(1.0), &255));
172 if (more_key - key).abs() < (less_key - key).abs() {
173 *more_val
174 } else {
175 *less_val
176 }
177 }
178
179 pub fn min_ml_value(&self) -> u8 {
180 if self.all_calls {
181 0
182 } else {
183 self.min_ml
184 }
185 }
186
187 pub fn float_to_u8(&self, x: f32) -> u8 {
188 self.precision_from_float(x)
189 }
190
191 pub fn predict_m6a_on_records(
193 opts: &Self,
194 records: Vec<&mut rust_htslib::bam::Record>,
195 ) -> usize {
197 let data: Vec<Option<(DataWidows, DataWidows)>> = records
199 .iter()
200 .map(|rec| get_m6a_data_windows(rec))
201 .collect();
202 let mut all_ml_data = vec![];
204 let mut all_count = 0;
205 data.iter().flatten().for_each(|(a, t)| {
206 all_ml_data.extend(a.windows.clone());
207 all_count += a.count;
208 all_ml_data.extend(t.windows.clone());
209 all_count += t.count;
210 });
211 let predictions = opts.apply_model(&all_ml_data, all_count);
212 assert_eq!(predictions.len(), all_count);
213 assert_eq!(data.len(), records.len());
215 let mut cur_predict_st = 0;
216 for (option_data, record) in data.iter().zip(records) {
217 let mut cur_basemods = basemods::BaseMods::new(record, 0);
219 cur_basemods.drop_m6a();
220 log::trace!("Number of base mod types {}", cur_basemods.base_mods.len());
221 let (a_data, t_data) = match option_data {
223 Some((a_data, t_data)) => (a_data, t_data),
224 None => continue,
225 };
226 for data in &[a_data, t_data] {
228 let cur_predict_en = cur_predict_st + data.count;
229 let cur_predictions = &predictions[cur_predict_st..cur_predict_en];
230
231 cur_predict_st += data.count;
232 cur_basemods.base_mods.push(opts.basemod_from_ml(
233 record,
234 cur_predictions,
235 &data.positions,
236 &data.base_mod,
237 ));
238 }
239 cur_basemods.add_mm_and_ml_tags(record);
241
242 let modified_bases_forward = cur_basemods.m6a().get_forward_starts();
244
245 nucleosome::add_nucleosomes_to_record(record, &modified_bases_forward, &opts.nuc_opts);
247
248 if !opts.keep {
250 record.remove_aux(b"fp").unwrap_or(());
251 record.remove_aux(b"fi").unwrap_or(());
252 record.remove_aux(b"rp").unwrap_or(());
253 record.remove_aux(b"ri").unwrap_or(());
254 }
255 }
256 assert_eq!(cur_predict_st, predictions.len());
257 data.iter().flatten().count()
258 }
259
260 pub fn basemod_from_ml(
262 &self,
263 record: &mut bam::Record,
264 predictions: &[f32],
265 positions: &[usize],
266 base_mod: &str,
267 ) -> basemods::BaseMod {
268 let min_pos = (WINDOW / 2) as i64;
270 let max_pos = (record.seq_len() - WINDOW / 2) as i64;
271 let (modified_probabilities_forward, full_probabilities_forward, modified_bases_forward): (
272 Vec<u8>,
273 Vec<f32>,
274 Vec<i64>,
275 ) = predictions
276 .iter()
277 .zip(positions.iter())
278 .map(|(&x, &pos)| (self.float_to_u8(x), x, pos as i64))
279 .filter(|(ml, _, pos)| *ml >= self.min_ml_value() && *pos >= min_pos && *pos < max_pos)
280 .multiunzip();
281
282 log::debug!(
283 "Low but non zero values: {:?}\tZero values: {:?}\tlength:{:?}",
284 full_probabilities_forward
285 .iter()
286 .filter(|&x| *x <= 1.0 / 255.0)
287 .filter(|&x| *x > 0.0)
288 .count(),
289 full_probabilities_forward
290 .iter()
291 .filter(|&x| *x <= 0.0)
292 .filter(|&x| *x > -0.00000001)
293 .count(),
294 predictions.len()
295 );
296
297 let base_mod = base_mod.as_bytes();
298 let modified_base = base_mod[0];
299 let strand = base_mod[1] as char;
300 let modification_type = base_mod[2] as char;
301
302 basemods::BaseMod::new(
303 record,
304 modified_base,
305 strand,
306 modification_type,
307 modified_bases_forward,
308 modified_probabilities_forward,
309 )
310 }
311
312 pub fn apply_model(&self, windows: &[f32], count: usize) -> Vec<f32> {
313 self.burn_models.forward(self, windows, count)
314 }
315
316 fn _fake_apply_model(&self, _: &[f32], count: usize) -> Vec<f32> {
317 vec![0.0; count]
318 }
319}
320
321pub fn hot_one_dna(seq: &[u8]) -> Vec<f32> {
334 let len = seq.len() * 4;
335 let mut out = vec![0.0; len];
336 for (row, base) in [b'A', b'C', b'G', b'T'].into_iter().enumerate() {
337 let already_done = seq.len() * row;
338 for i in 0..seq.len() {
339 if seq[i] == base {
340 out[already_done + i] = 1.0;
341 }
342 }
343 }
344 out
345}
346
347struct DataWidows {
348 pub windows: Vec<f32>,
349 pub positions: Vec<usize>,
350 pub count: usize,
351 pub base_mod: String,
352}
353
354fn get_m6a_data_windows(record: &bam::Record) -> Option<(DataWidows, DataWidows)> {
355 if record.is_secondary() {
357 log::warn!(
358 "Skipping secondary alignment of {}",
359 String::from_utf8_lossy(record.qname())
360 );
361 return None;
362 }
363
364 let extend = WINDOW / 2;
365 let mut f_ip = bio_io::get_u8_tag(record, b"fi");
366 let r_ip;
367 let f_pw;
368 let r_pw;
369 if f_ip.is_empty() {
371 f_ip = bio_io::get_pb_u16_tag_as_u8(record, b"fi");
372 if f_ip.is_empty() {
373 r_ip = vec![];
375 f_pw = vec![];
376 r_pw = vec![];
377 } else {
378 r_ip = bio_io::get_pb_u16_tag_as_u8(record, b"ri");
379 f_pw = bio_io::get_pb_u16_tag_as_u8(record, b"fp");
380 r_pw = bio_io::get_pb_u16_tag_as_u8(record, b"rp");
381 }
382 } else {
383 r_ip = bio_io::get_u8_tag(record, b"ri");
384 f_pw = bio_io::get_u8_tag(record, b"fp");
385 r_pw = bio_io::get_u8_tag(record, b"rp");
386 }
387 if f_ip.is_empty() || r_ip.is_empty() || f_pw.is_empty() || r_pw.is_empty() {
389 log::debug!(
390 "Hifi kinetics are missing for: {}",
391 String::from_utf8_lossy(record.qname())
392 );
393 return None;
394 }
395 let r_ip = r_ip.into_iter().rev().collect::<Vec<_>>();
397 let r_pw = r_pw.into_iter().rev().collect::<Vec<_>>();
398
399 let mut seq = record.seq().as_bytes();
400 if record.is_reverse() {
401 seq = revcomp(seq);
402 }
403
404 assert_eq!(f_ip.len(), seq.len());
405 let mut a_count = 0;
406 let mut t_count = 0;
407 let mut a_windows = vec![];
408 let mut t_windows = vec![];
409 let mut a_positions = vec![];
410 let mut t_positions = vec![];
411 for (pos, base) in seq.iter().enumerate() {
412 if !((*base == b'A') || (*base == b'T')) {
413 continue;
414 }
415 let data_window = if (pos < extend) || (pos + extend + 1 > record.seq_len()) {
417 vec![0.0; WINDOW * LAYERS]
419 } else {
420 let start = pos - extend;
421 let end = pos + extend + 1;
422 let ip: Vec<f32>;
423 let pw: Vec<f32>;
424 let hot_one;
425 if *base == b'A' {
426 let w_seq = &revcomp(&seq[start..end]);
427 hot_one = hot_one_dna(w_seq);
428 ip = (r_ip[start..end])
429 .iter()
430 .copied()
431 .rev()
432 .map(|x| x as f32 / 255.0)
433 .collect();
434 pw = (r_pw[start..end])
435 .iter()
436 .copied()
437 .rev()
438 .map(|x| x as f32 / 255.0)
439 .collect();
440 } else {
441 let w_seq = &seq[start..end];
442 hot_one = hot_one_dna(w_seq);
443 ip = (f_ip[start..end])
444 .iter()
445 .copied()
446 .map(|x| x as f32 / 255.0)
447 .collect();
448 pw = (f_pw[start..end])
449 .iter()
450 .copied()
451 .map(|x| x as f32 / 255.0)
452 .collect();
453 }
454 let mut data_window = vec![];
455 data_window.extend(hot_one);
456 data_window.extend(ip);
457 data_window.extend(pw);
458 data_window
459 };
460
461 if *base == b'A' {
463 a_windows.extend(data_window);
464 a_count += 1;
465 a_positions.push(pos);
466 } else {
467 t_windows.extend(data_window);
468 t_count += 1;
469 t_positions.push(pos);
470 }
471 }
472 let a_data = DataWidows {
473 windows: a_windows,
474 positions: a_positions,
475 count: a_count,
476 base_mod: "A+a".to_string(),
477 };
478 let t_data = DataWidows {
479 windows: t_windows,
480 positions: t_positions,
481 count: t_count,
482 base_mod: "T-a".to_string(),
483 };
484 Some((a_data, t_data))
485}
486
487pub fn read_bam_into_fiberdata(opts: &mut PredictM6AOptions) {
488 let mut bam = opts.input.bam_reader();
489 let mut out = opts.input.bam_writer(&opts.out);
490 let header = bam::Header::from_template(bam.header());
491 log::info!(
493 "{} reads included at once in batch prediction.",
494 opts.batch_size
495 );
496
497 #[cfg(feature = "tch")]
498 type MlBackend = burn::backend::LibTorch;
499 #[cfg(feature = "tch")]
500 log::info!("Using LibTorch for ML backend.");
501
502 #[cfg(not(feature = "tch"))]
503 type MlBackend = burn::backend::Candle;
504 #[cfg(not(feature = "tch"))]
505 log::info!("Using Candle for ML backend.");
506
507 let predict_options: PredictOptions<MlBackend> = PredictOptions::new(
509 opts.keep,
510 opts.force_min_ml_score,
511 opts.all_calls,
512 find_pb_polymerase(&header),
513 opts.batch_size,
514 opts.nuc.clone(),
515 opts.fake,
516 );
517 let fire_opts = crate::cli::FireOptions::default();
519 let (model, precision_table) = crate::utils::fire::get_model(&fire_opts);
520
521 let bam_chunk_iter = BamChunk::new(bam.records(), None);
523 for mut chunk in bam_chunk_iter {
525 let number_of_reads_with_predictions = chunk
527 .par_iter_mut()
528 .chunks(predict_options.batch_size)
529 .map(|records| PredictOptions::predict_m6a_on_records(&predict_options, records))
530 .sum::<usize>() as f32;
531
532 let frac_called = number_of_reads_with_predictions / chunk.len() as f32;
533 if frac_called < 0.05 {
534 log::warn!("More than 5% ({:.2}%) of reads were not predicted on. Are HiFi kinetics missing from this file? Enable Debug logging level to show which reads lack kinetics.", 100.0-100.0*frac_called);
535 }
536
537 let mut fd_recs =
539 FiberseqData::from_records(chunk, &opts.input.header_view(), &opts.input.filters);
540 fd_recs.par_iter_mut().for_each(|fd| {
541 crate::subcommands::fire::add_fire_to_rec(fd, &fire_opts, &model, &precision_table);
542 });
543
544 fd_recs.iter().for_each(|fd| out.write(&fd.record).unwrap());
546 }
547}
548
549#[cfg(test)]
551mod tests {
552 use super::*;
553 #[test]
554 fn test_precision_json_validity() {
555 for file in [SEMI_JSON_2_0, SEMI_JSON_2_2, SEMI_JSON_3_2, SEMI_JSON_REVIO] {
556 let _p: PrecisionTable =
557 serde_json::from_str(file).expect("Precision table JSON was not well-formatted");
558 }
559 }
560}