fibertools_rs/subcommands/
predict_m6a.rs1use crate::cli::PredictM6AOptions;
2use crate::utils::basemods;
3use crate::utils::bio_io;
4use crate::utils::nucleosome;
5use crate::*;
6use bio::alphabets::dna::revcomp;
7use burn::tensor::backend::Backend;
8use fiber::FiberseqData;
9use ordered_float::OrderedFloat;
10use rayon::iter::IndexedParallelIterator;
11use rayon::iter::ParallelIterator;
12use rayon::prelude::IntoParallelRefMutIterator;
13use rust_htslib::{bam, bam::Read};
14use serde::Deserialize;
15use std::collections::BTreeMap;
16use std::sync::Once;
17
18pub const WINDOW: usize = 15;
19pub const LAYERS: usize = 6;
20pub const MIN_F32_PRED: f32 = 1.0e-46;
21static LOG_ONCE: Once = Once::new();
22pub static SEMI_JSON_2_0: &str = include_str!("../../models/2.0_semi_torch.json");
24pub static SEMI_JSON_2_2: &str = include_str!("../../models/2.2_semi_torch.json");
25pub static SEMI_JSON_3_2: &str = include_str!("../../models/3.2_semi_torch.json");
26pub static SEMI_JSON_REVIO: &str = include_str!("../../models/Revio_semi_torch.json");
27
28#[derive(Debug, Deserialize)]
29pub struct PrecisionTable {
30 pub columns: Vec<String>,
31 pub data: Vec<(f32, u8)>,
32}
33
34#[derive(Debug, Clone)]
35pub struct PredictOptions<B>
36where
37 B: Backend<Device = m6a_burn::BurnDevice>,
38{
39 pub keep: bool,
40 pub min_ml_score: Option<u8>,
41 pub all_calls: bool,
42 pub polymerase: PbChem,
43 pub batch_size: usize,
44 map: BTreeMap<OrderedFloat<f32>, u8>,
45 pub model: Vec<u8>,
46 pub min_ml: u8,
47 pub nuc_opts: cli::NucleosomeParameters,
48 pub burn_models: m6a_burn::BurnModels<B>,
49 pub fake: bool,
50}
51
52impl<B> PredictOptions<B>
53where
54 B: Backend<Device = m6a_burn::BurnDevice>,
55{
56 #[allow(clippy::too_many_arguments)]
57 pub fn new(
58 keep: bool,
59 min_ml_score: Option<u8>,
60 all_calls: bool,
61 polymerase: PbChem,
62 batch_size: usize,
63 nuc_opts: cli::NucleosomeParameters,
64 fake: bool,
65 ) -> Self {
66 let mut map = BTreeMap::new();
68 map.insert(OrderedFloat(0.0), 0);
69
70 let mut options = PredictOptions {
72 keep,
73 min_ml_score,
74 all_calls,
75 polymerase: polymerase.clone(),
76 batch_size,
77 map,
78 model: vec![],
79 min_ml: 0,
80 nuc_opts,
81 burn_models: m6a_burn::BurnModels::new(&polymerase),
82 fake,
83 };
84 options.add_model().expect("Error loading model");
85 options
86 }
87
88 fn get_precision_table_and_ml(&self) -> Result<(Option<PrecisionTable>, u8)> {
89 let mut precision_json = "".to_string();
90 let min_ml = if let Ok(_file) = std::env::var("FT_MODEL") {
91 244
92 } else {
93 LOG_ONCE.call_once(|| {
94 log::info!("Using semi-supervised CNN m6A model.");
95 });
96 match self.polymerase {
97 PbChem::Two => {
98 precision_json = SEMI_JSON_2_0.to_string();
99 230
100 }
101 PbChem::TwoPointTwo => {
102 precision_json = SEMI_JSON_2_2.to_string();
103 244
104 }
105 PbChem::ThreePointTwo => {
106 precision_json = SEMI_JSON_3_2.to_string();
107 244
108 }
109 PbChem::Revio => {
110 precision_json = SEMI_JSON_REVIO.to_string();
111 254
112 }
113 }
114 };
115
116 if let Ok(json) = std::env::var("FT_JSON") {
118 log::info!("Loading precision table from environment variable.");
119 precision_json =
120 std::fs::read_to_string(json).expect("Unable to read file specified by FT_JSON");
121 }
122
123 let precision_table: Option<PrecisionTable> = Some(
125 serde_json::from_str(&precision_json)
126 .expect("Precision table JSON was not well-formatted"),
127 );
128
129 let final_min_ml = match self.min_ml_score {
131 Some(x) => {
132 log::info!("Using provided minimum ML tag score: {x}");
133 x
134 }
135 None => min_ml,
136 };
137 Ok((precision_table, final_min_ml))
138 }
139
140 fn add_model(&mut self) -> Result<()> {
141 self.model = vec![];
142
143 let (precision_table, min_ml) = self.get_precision_table_and_ml()?;
144
145 if let Some(precision_table) = precision_table {
147 for (cnn_score, precision) in precision_table.data {
148 self.map.insert(OrderedFloat(cnn_score), precision);
149 }
150 }
151
152 self.min_ml = min_ml;
153 Ok(())
154 }
155
156 pub fn progress_style(&self) -> &str {
157 "[PREDICTING m6A] [Elapsed {elapsed:.yellow} ETA {eta:.yellow}] {bar:30.cyan/blue} {human_pos:>5.cyan}/{human_len:.blue} (batches/s {per_sec:.green})"
159 }
160
161 pub fn precision_from_float(&self, value: f32) -> u8 {
163 let key = OrderedFloat(value);
164 let (less_key, less_val) = self
166 .map
167 .range(..key)
168 .next_back()
169 .unwrap_or((&OrderedFloat(0.0), &0));
170 let (more_key, more_val) = self
172 .map
173 .range(key..)
174 .next()
175 .unwrap_or((&OrderedFloat(1.0), &255));
176 if (more_key - key).abs() < (less_key - key).abs() {
177 *more_val
178 } else {
179 *less_val
180 }
181 }
182
183 pub fn min_ml_value(&self) -> u8 {
184 if self.all_calls {
185 0
186 } else {
187 self.min_ml
188 }
189 }
190
191 pub fn float_to_u8(&self, x: f32) -> u8 {
192 self.precision_from_float(x)
193 }
194
195 pub fn predict_m6a_on_records(
197 opts: &Self,
198 records: Vec<&mut rust_htslib::bam::Record>,
199 ) -> usize {
201 let data: Vec<Option<(DataWidows, DataWidows)>> = records
203 .iter()
204 .map(|rec| get_m6a_data_windows(rec))
205 .collect();
206 let mut all_ml_data = vec![];
208 let mut all_count = 0;
209 data.iter().flatten().for_each(|(a, t)| {
210 all_ml_data.extend(a.windows.clone());
211 all_count += a.count;
212 all_ml_data.extend(t.windows.clone());
213 all_count += t.count;
214 });
215 let predictions = opts.apply_model(&all_ml_data, all_count);
216 assert_eq!(predictions.len(), all_count);
217 assert_eq!(data.len(), records.len());
219 let mut cur_predict_st = 0;
220 for (option_data, record) in data.iter().zip(records) {
221 let mut cur_basemods = basemods::BaseMods::new(record, 0);
223 cur_basemods.drop_m6a();
224 log::trace!("Number of base mod types {}", cur_basemods.base_mods.len());
225 let (a_data, t_data) = match option_data {
227 Some((a_data, t_data)) => (a_data, t_data),
228 None => continue,
229 };
230 for data in &[a_data, t_data] {
232 let cur_predict_en = cur_predict_st + data.count;
233 let cur_predictions = &predictions[cur_predict_st..cur_predict_en];
234
235 cur_predict_st += data.count;
236 cur_basemods.base_mods.push(opts.basemod_from_ml(
237 record,
238 cur_predictions,
239 &data.positions,
240 &data.base_mod,
241 ));
242 }
243 cur_basemods.add_mm_and_ml_tags(record);
245
246 let modified_bases_forward = cur_basemods.m6a().get_forward_starts();
248
249 nucleosome::add_nucleosomes_to_record(record, &modified_bases_forward, &opts.nuc_opts);
251
252 if !opts.keep {
254 record.remove_aux(b"fp").unwrap_or(());
255 record.remove_aux(b"fi").unwrap_or(());
256 record.remove_aux(b"rp").unwrap_or(());
257 record.remove_aux(b"ri").unwrap_or(());
258 }
259 }
260 assert_eq!(cur_predict_st, predictions.len());
261 data.iter().flatten().count()
262 }
263
264 pub fn basemod_from_ml(
266 &self,
267 record: &mut bam::Record,
268 predictions: &[f32],
269 positions: &[usize],
270 base_mod: &str,
271 ) -> basemods::BaseMod {
272 let min_pos = (WINDOW / 2) as i64;
274 let max_pos = (record.seq_len() - WINDOW / 2) as i64;
275 let (modified_probabilities_forward, full_probabilities_forward, modified_bases_forward): (
276 Vec<u8>,
277 Vec<f32>,
278 Vec<i64>,
279 ) = predictions
280 .iter()
281 .zip(positions.iter())
282 .map(|(&x, &pos)| (self.float_to_u8(x), x, pos as i64))
283 .filter(|(ml, _, pos)| *ml >= self.min_ml_value() && *pos >= min_pos && *pos < max_pos)
284 .multiunzip();
285
286 log::debug!(
287 "Low but non zero values: {:?}\tZero values: {:?}\tlength:{:?}",
288 full_probabilities_forward
289 .iter()
290 .filter(|&x| *x <= 1.0 / 255.0)
291 .filter(|&x| *x > 0.0)
292 .count(),
293 full_probabilities_forward
294 .iter()
295 .filter(|&x| *x <= 0.0)
296 .filter(|&x| *x > -0.00000001)
297 .count(),
298 predictions.len()
299 );
300
301 let base_mod = base_mod.as_bytes();
302 let modified_base = base_mod[0];
303 let strand = base_mod[1] as char;
304 let modification_type = base_mod[2] as char;
305
306 basemods::BaseMod::new(
307 record,
308 modified_base,
309 strand,
310 modification_type,
311 modified_bases_forward,
312 modified_probabilities_forward,
313 )
314 }
315
316 pub fn apply_model(&self, windows: &[f32], count: usize) -> Vec<f32> {
317 self.burn_models.forward(self, windows, count)
318 }
319
320 fn _fake_apply_model(&self, _: &[f32], count: usize) -> Vec<f32> {
321 vec![0.0; count]
322 }
323}
324
325pub fn hot_one_dna(seq: &[u8]) -> Vec<f32> {
338 let len = seq.len() * 4;
339 let mut out = vec![0.0; len];
340 for (row, base) in [b'A', b'C', b'G', b'T'].into_iter().enumerate() {
341 let already_done = seq.len() * row;
342 for i in 0..seq.len() {
343 if seq[i] == base {
344 out[already_done + i] = 1.0;
345 }
346 }
347 }
348 out
349}
350
351struct DataWidows {
352 pub windows: Vec<f32>,
353 pub positions: Vec<usize>,
354 pub count: usize,
355 pub base_mod: String,
356}
357
358fn get_m6a_data_windows(record: &bam::Record) -> Option<(DataWidows, DataWidows)> {
359 if record.is_secondary() {
361 log::warn!(
362 "Skipping secondary alignment of {}",
363 String::from_utf8_lossy(record.qname())
364 );
365 return None;
366 }
367
368 let extend = WINDOW / 2;
369 let mut f_ip = bio_io::get_u8_tag(record, b"fi");
370 let r_ip;
371 let f_pw;
372 let r_pw;
373 if f_ip.is_empty() {
375 f_ip = bio_io::get_pb_u16_tag_as_u8(record, b"fi");
376 if f_ip.is_empty() {
377 r_ip = vec![];
379 f_pw = vec![];
380 r_pw = vec![];
381 } else {
382 r_ip = bio_io::get_pb_u16_tag_as_u8(record, b"ri");
383 f_pw = bio_io::get_pb_u16_tag_as_u8(record, b"fp");
384 r_pw = bio_io::get_pb_u16_tag_as_u8(record, b"rp");
385 }
386 } else {
387 r_ip = bio_io::get_u8_tag(record, b"ri");
388 f_pw = bio_io::get_u8_tag(record, b"fp");
389 r_pw = bio_io::get_u8_tag(record, b"rp");
390 }
391 if f_ip.is_empty() || r_ip.is_empty() || f_pw.is_empty() || r_pw.is_empty() {
393 log::debug!(
394 "Hifi kinetics are missing for: {}",
395 String::from_utf8_lossy(record.qname())
396 );
397 return None;
398 }
399 let r_ip = r_ip.into_iter().rev().collect::<Vec<_>>();
401 let r_pw = r_pw.into_iter().rev().collect::<Vec<_>>();
402
403 let mut seq = record.seq().as_bytes();
404 if record.is_reverse() {
405 seq = revcomp(seq);
406 }
407
408 assert_eq!(f_ip.len(), seq.len());
409 let mut a_count = 0;
410 let mut t_count = 0;
411 let mut a_windows = vec![];
412 let mut t_windows = vec![];
413 let mut a_positions = vec![];
414 let mut t_positions = vec![];
415 for (pos, base) in seq.iter().enumerate() {
416 if !((*base == b'A') || (*base == b'T')) {
417 continue;
418 }
419 let data_window = if (pos < extend) || (pos + extend + 1 > record.seq_len()) {
421 vec![0.0; WINDOW * LAYERS]
423 } else {
424 let start = pos - extend;
425 let end = pos + extend + 1;
426 let ip: Vec<f32>;
427 let pw: Vec<f32>;
428 let hot_one;
429 if *base == b'A' {
430 let w_seq = &revcomp(&seq[start..end]);
431 hot_one = hot_one_dna(w_seq);
432 ip = (r_ip[start..end])
433 .iter()
434 .copied()
435 .rev()
436 .map(|x| x as f32 / 255.0)
437 .collect();
438 pw = (r_pw[start..end])
439 .iter()
440 .copied()
441 .rev()
442 .map(|x| x as f32 / 255.0)
443 .collect();
444 } else {
445 let w_seq = &seq[start..end];
446 hot_one = hot_one_dna(w_seq);
447 ip = (f_ip[start..end])
448 .iter()
449 .copied()
450 .map(|x| x as f32 / 255.0)
451 .collect();
452 pw = (f_pw[start..end])
453 .iter()
454 .copied()
455 .map(|x| x as f32 / 255.0)
456 .collect();
457 }
458 let mut data_window = vec![];
459 data_window.extend(hot_one);
460 data_window.extend(ip);
461 data_window.extend(pw);
462 data_window
463 };
464
465 if *base == b'A' {
467 a_windows.extend(data_window);
468 a_count += 1;
469 a_positions.push(pos);
470 } else {
471 t_windows.extend(data_window);
472 t_count += 1;
473 t_positions.push(pos);
474 }
475 }
476 let a_data = DataWidows {
477 windows: a_windows,
478 positions: a_positions,
479 count: a_count,
480 base_mod: "A+a".to_string(),
481 };
482 let t_data = DataWidows {
483 windows: t_windows,
484 positions: t_positions,
485 count: t_count,
486 base_mod: "T-a".to_string(),
487 };
488 Some((a_data, t_data))
489}
490
491pub fn read_bam_into_fiberdata(opts: &mut PredictM6AOptions) {
492 let mut bam = opts.input.bam_reader();
493 let mut out = opts.input.bam_writer(&opts.out);
494 let header = bam::Header::from_template(bam.header());
495 log::info!(
497 "{} reads included at once in batch prediction.",
498 opts.batch_size
499 );
500
501 #[cfg(feature = "tch")]
502 type MlBackend = burn::backend::LibTorch;
503 #[cfg(feature = "tch")]
504 log::info!("Using LibTorch for ML backend.");
505
506 #[cfg(not(feature = "tch"))]
507 type MlBackend = burn::backend::Candle;
508 #[cfg(not(feature = "tch"))]
509 log::info!("Using Candle for ML backend.");
510
511 let predict_options: PredictOptions<MlBackend> = PredictOptions::new(
513 opts.keep,
514 opts.force_min_ml_score,
515 opts.all_calls,
516 find_pb_polymerase(&header),
517 opts.batch_size,
518 opts.nuc.clone(),
519 opts.fake,
520 );
521 let fire_opts = crate::cli::FireOptions::default();
523 let (model, precision_table) = crate::utils::fire::get_model(&fire_opts);
524
525 let bam_chunk_iter = BamChunk::new(bam.records(), None);
527 for mut chunk in bam_chunk_iter {
529 let number_of_reads_with_predictions = chunk
532 .par_iter_mut()
533 .chunks(predict_options.batch_size)
534 .map(|records| {
535 let thread_opts = PredictOptions::<MlBackend>::new(
537 predict_options.all_calls,
538 predict_options.min_ml_score,
539 predict_options.all_calls,
540 predict_options.polymerase.clone(),
541 predict_options.batch_size,
542 predict_options.nuc_opts.clone(),
543 predict_options.fake,
544 );
545 PredictOptions::predict_m6a_on_records(&thread_opts, records)
546 })
547 .sum::<usize>() as f32;
548
549 let frac_called = number_of_reads_with_predictions / chunk.len() as f32;
550 if frac_called < 0.05 {
551 log::warn!("More than 5% ({:.2}%) of reads were not predicted on. Are HiFi kinetics missing from this file? Enable Debug logging level to show which reads lack kinetics.", 100.0-100.0*frac_called);
552 }
553
554 let mut fd_recs =
556 FiberseqData::from_records(chunk, &opts.input.header_view(), &opts.input.filters);
557 fd_recs.par_iter_mut().for_each(|fd| {
558 crate::subcommands::fire::add_fire_to_rec(fd, &fire_opts, &model, &precision_table);
559 });
560
561 fd_recs.iter().for_each(|fd| out.write(&fd.record).unwrap());
563 }
564}
565
566#[cfg(test)]
568mod tests {
569 use super::*;
570 #[test]
571 fn test_precision_json_validity() {
572 for file in [SEMI_JSON_2_0, SEMI_JSON_2_2, SEMI_JSON_3_2, SEMI_JSON_REVIO] {
573 let _p: PrecisionTable =
574 serde_json::from_str(file).expect("Precision table JSON was not well-formatted");
575 }
576 }
577}