1use crate::cli::FireOptions;
2use crate::fiber::FiberseqData;
3use crate::*;
4use anyhow;
5use derive_builder::Builder;
6use gbdt::decision_tree::{Data, DataVec};
7use gbdt::gradient_boost::GBDT;
8use itertools::Itertools;
9use ordered_float::OrderedFloat;
10use serde::Deserialize;
11use std::collections::BTreeMap;
12use std::fs;
13use tempfile::NamedTempFile;
14
15pub static FIRE_MODEL: &str = include_str!("../../models/FIRE.gbdt.json");
16pub static FIRE_CONF_JSON: &str = include_str!("../../models/FIRE.conf.json");
17
18pub fn get_model(fire_opts: &FireOptions) -> (GBDT, MapPrecisionValues) {
19 let mut remove_temp_file = false;
20 let (model_file, fdr_table_file) = match (&fire_opts.model, &fire_opts.fdr_table) {
22 (Some(model_file), Some(b)) => {
23 let fdr_table = fs::read_to_string(b).expect("Unable to read file");
24 (model_file.clone(), fdr_table)
25 }
26 _ => {
27 let temp_file = NamedTempFile::new().expect("Unable to make a temp file");
28 let (mut temp_file, path) = temp_file.keep().expect("Unable to keep temp file");
29 let temp_file_name = path
30 .as_os_str()
31 .to_str()
32 .expect("Unable to convert the path of the named temp file to an &str.");
33 temp_file
34 .write_all(FIRE_MODEL.as_bytes())
35 .expect("Unable to write file");
36 remove_temp_file = true;
38 (temp_file_name.to_string(), FIRE_CONF_JSON.to_string())
39 }
40 };
41 log::info!("Using model: {model_file}");
42 let model =
44 GBDT::from_xgboost_dump(&model_file, "binary:logistic").expect("failed to load FIRE model");
45 if remove_temp_file {
46 fs::remove_file(model_file).expect("Unable to remove temp file");
47 }
48
49 let precision_table: PrecisionTable =
51 serde_json::from_str(&fdr_table_file).expect("Precision table JSON was not well-formatted");
52 let precision_converter = MapPrecisionValues::new(&precision_table);
53
54 (model, precision_converter)
56}
57
58fn get_mid_point(start: i64, end: i64) -> i64 {
59 (start + end) / 2
60}
61
62pub fn get_bins(mid_point: i64, bin_num: i64, bin_width: i64, max_end: i64) -> Vec<(i64, i64)> {
68 let mut bins = Vec::new();
69 for i in 0..bin_num {
70 let mut bin_start = mid_point - (bin_num / 2 - i) * bin_width - bin_width / 2;
71 let mut bin_end = bin_start + bin_width;
72 if bin_start < 0 {
73 bin_start = 0;
74 }
75 if bin_end < 0 {
76 bin_end = 0;
77 }
78 if bin_start > max_end {
79 bin_start = max_end - 1;
80 }
81 if bin_end > max_end {
82 bin_end = max_end;
83 }
84 bins.push((bin_start, bin_end));
85 }
86 bins
87}
88
89fn get_m6a_rle_data(rec: &FiberseqData, start: i64, end: i64) -> (f32, f32) {
91 let mut m6a_rles = vec![];
92 let mut max = 0;
93 let mut _max_pos = 0;
94 let mut weighted_rle = 0.0;
96 for (m6a_1, m6a_2) in rec.m6a.starts().iter().tuple_windows() {
97 if *m6a_1 < start || *m6a_1 > end || *m6a_2 < start || *m6a_2 > end {
99 continue;
100 }
101 let rle = (m6a_2 - m6a_1).abs();
103 m6a_rles.push(rle);
104 if rle > max {
106 max = rle;
107 _max_pos = ((*m6a_1 + *m6a_2) / 2 - (end + start) / 2).abs();
108 }
109 weighted_rle += (rle * rle) as f32;
111 }
112 weighted_rle /= (end - start) as f32;
113
114 if m6a_rles.is_empty() {
115 return (-1.0, -1.0);
116 }
117 let mid_length = m6a_rles.len() / 2;
118 let (_, median, _) = m6a_rles.select_nth_unstable(mid_length);
119 (weighted_rle, *median as f32)
120}
121
122const FEATS_IN_USE: [&str; 3] = [
123 "m6a_count",
124 "frac_m6a",
127 "m6a_fc",
128 ];
133#[derive(Debug, Clone, Builder)]
134pub struct FireFeatsInRange {
135 pub m6a_count: f32,
136 pub at_count: f32,
137 #[allow(unused)]
138 pub count_5mc: f32,
139 pub frac_m6a: f32,
140 pub m6a_fc: f32,
141 pub weighted_m6a_rle: f32,
142 pub median_m6a_rle: f32,
143}
144
145impl FireFeatsInRange {
146 pub fn header(tag: &str) -> String {
147 let mut out = "".to_string();
148 for col in FEATS_IN_USE.iter() {
149 out += &format!("\t{tag}_{col}");
150 }
151 out
152 }
153}
154
155#[derive(Debug)]
156pub struct FireFeats<'a> {
157 rec: &'a FiberseqData,
158 #[allow(unused)]
159 at_count: usize,
160 m6a_count: usize,
161 frac_m6a: f32,
162 fire_opts: &'a FireOptions,
164 seq: Vec<u8>,
165 fire_feats: Vec<(i64, i64, Vec<f32>)>,
166}
167
168impl<'a> FireFeats<'a> {
169 pub fn new(rec: &'a FiberseqData, fire_opts: &'a FireOptions) -> Self {
170 let seq_len = rec.record.seq_len();
171 let seq = rec.record.seq().as_bytes();
172
173 let mut rtn = Self {
174 rec,
175 at_count: 0,
176 m6a_count: 0,
177 frac_m6a: 0.0,
178 fire_opts,
180 seq,
181 fire_feats: vec![],
182 };
183
184 rtn.at_count = rtn.get_at_count(0, seq_len as i64);
186 rtn.m6a_count = rtn.get_m6a_count(0, seq_len as i64);
187 rtn.frac_m6a = if rtn.at_count > 0 {
188 rtn.m6a_count as f32 / rtn.at_count as f32
189 } else {
190 0.0
191 };
192
193 rtn.get_fire_features();
194 if rtn.fire_opts.ont {
195 rtn.validate_that_ont_is_single_strand();
196 }
197 rtn
198 }
199
200 fn validate_that_ont_is_single_strand(&self) {
201 let sequenced_bp = if self.rec.record.is_reverse() {
202 b'T'
203 } else {
204 b'A'
205 };
206 for m6a_st in self.rec.m6a.starts().iter() {
207 let m6a_bp = self.seq[*m6a_st as usize];
208 if m6a_bp != sequenced_bp {
209 log::warn!(
210 "m6A site at {} is not the same as the sequenced base {}",
211 m6a_st,
212 sequenced_bp as char
213 );
214 }
215 }
216 }
217
218 fn get_bp_count(&self, start: i64, end: i64, bp: u8) -> usize {
219 let subseq = &self.seq[start as usize..end as usize];
220 subseq.iter().filter(|&&b| b == bp).count()
221 }
222
223 fn get_at_count(&self, start: i64, end: i64) -> usize {
224 self.get_bp_count(start, end, b'A') + self.get_bp_count(start, end, b'T')
225 }
226
227 fn get_5mc_count(&self, start: i64, end: i64) -> usize {
228 self.rec
229 .cpg
230 .starts()
231 .iter()
232 .filter(|&&pos| pos >= start && pos < end)
233 .count()
234 }
235
236 fn get_m6a_count(&self, start: i64, end: i64) -> usize {
237 let mut m6a_count = self
238 .rec
239 .m6a
240 .starts()
241 .iter()
242 .filter(|&&pos| pos >= start && pos < end)
243 .count();
244
245 if self.fire_opts.ont {
247 let mut sequenced_bp = self.get_bp_count(start, end, b'A');
248 let mut un_sequenced_bp = self.get_bp_count(start, end, b'T');
249 if self.rec.record.is_reverse() {
250 std::mem::swap(&mut sequenced_bp, &mut un_sequenced_bp);
252 }
253 let m6a_frac = if sequenced_bp > 0 {
254 m6a_count as f32 / sequenced_bp as f32
255 } else {
256 0.0
257 };
258 m6a_count += (un_sequenced_bp as f32 * m6a_frac).round() as usize;
259 }
260 m6a_count
261 }
262
263 fn m6a_fc_over_expected(&self, m6a_count: usize, at_count: usize) -> f32 {
264 let expected = self.frac_m6a * at_count as f32;
267 let observed = m6a_count as f32;
268 if expected == 0.0 || observed == 0.0 {
269 return 0.0;
270 }
271 let fc = observed / expected;
272 fc.log2()
273 }
274
275 fn feats_in_range(&self, start: i64, end: i64) -> FireFeatsInRange {
276 let m6a_count = self.get_m6a_count(start, end);
277 let at_count = self.get_at_count(start, end);
278 let count_5mc = self.get_5mc_count(start, end);
279 let frac_m6a = if at_count > 0 {
280 m6a_count as f32 / at_count as f32
281 } else {
282 0.0
283 };
284 let m6a_fc = self.m6a_fc_over_expected(m6a_count, at_count);
285 let (weighted_m6a_rle, median_m6a_rle) = get_m6a_rle_data(self.rec, start, end);
286
287 FireFeatsInRange {
288 m6a_count: m6a_count as f32,
289 at_count: at_count as f32,
290 count_5mc: count_5mc as f32,
291 frac_m6a,
292 m6a_fc,
293 weighted_m6a_rle,
294 median_m6a_rle,
295 }
296 }
297
298 pub fn fire_feats_header(fire_opts: &FireOptions) -> String {
299 let mut out = "#chrom\tstart\tend\tfiber".to_string();
300 out += "\tmsp_len\tmsp_len_times_m6a_fc\tccs_passes";
301 out += "\tfiber_m6a_count\tfiber_m6a_frac";
302 out += &FireFeatsInRange::header("msp");
303 out += &FireFeatsInRange::header("best");
304 out += &FireFeatsInRange::header("worst");
305 for bin_num in 0..fire_opts.bin_num {
306 out += &FireFeatsInRange::header(&format!("bin_{bin_num}"));
307 }
308 out += "\n";
309 out
310 }
311
312 fn msp_get_fire_features(&self, start: i64, end: i64) -> Vec<f32> {
313 let msp_len = end - start;
314 if msp_len < self.fire_opts.min_msp_length_for_positive_fire_call {
316 return vec![];
317 }
318 let ccs_passes = if self.fire_opts.ont { 4.0 } else { self.rec.ec };
319
320 let mut max_m6a_count = 0;
322 let mut max_m6a_start = 0;
323 let mut max_m6a_end = 0;
324 let mut min_m6a_count = usize::MAX;
325 let mut min_m6a_start = 0;
326 let mut min_m6a_end = 0;
327 let mut centering_pos = get_mid_point(start, end);
328 for st_idx in start..end {
329 let en_idx = (st_idx + self.fire_opts.best_window_size).min(end);
330 if (end - start) < (2 * self.fire_opts.best_window_size) {
332 log::trace!("MSP window is not large enough for best and worst window analysis");
333 break;
334 }
335 let m6a_count = self.get_m6a_count(st_idx, en_idx);
336 if m6a_count > max_m6a_count {
337 max_m6a_count = m6a_count;
338 max_m6a_start = st_idx;
339 max_m6a_end = en_idx;
340 centering_pos = get_mid_point(st_idx, en_idx);
342 }
343 if m6a_count < min_m6a_count {
344 min_m6a_count = m6a_count;
345 min_m6a_start = st_idx;
346 min_m6a_end = en_idx;
347 }
348 if en_idx == end {
349 break;
350 }
351 }
352 let best_fire_feats = self.feats_in_range(max_m6a_start, max_m6a_end);
353 let worst_fire_feats = self.feats_in_range(min_m6a_start, min_m6a_end);
354
355 let msp_feats = self.feats_in_range(start, end);
356 let bins = get_bins(
357 centering_pos,
358 self.fire_opts.bin_num,
359 self.fire_opts.width_bin,
360 self.rec.record.seq_len() as i64,
361 );
362 let bin_feats = bins
363 .into_iter()
364 .map(|(start, end)| self.feats_in_range(start, end))
365 .collect::<Vec<FireFeatsInRange>>();
366 let msp_len_times_m6a_fc = msp_feats.m6a_fc * (msp_len as f32);
367 let mut rtn = vec![
368 msp_len as f32,
369 msp_len_times_m6a_fc,
370 ccs_passes,
371 self.m6a_count as f32,
372 self.frac_m6a,
373 ];
374 let feat_sets = vec![&msp_feats, &best_fire_feats, &worst_fire_feats]
375 .into_iter()
376 .chain(bin_feats.iter());
377 for feat_set in feat_sets {
378 rtn.push(feat_set.m6a_count);
379 rtn.push(feat_set.frac_m6a);
382 rtn.push(feat_set.m6a_fc);
383 }
386 rtn
387 }
388
389 pub fn get_fire_features(&mut self) {
390 let msp_data = self.rec.msp.into_iter().collect_vec();
391 self.fire_feats = msp_data
392 .into_iter()
393 .map(|annotation| {
394 let s = annotation.start;
395 let e = annotation.end;
396 let (rs, re, _rl) = match (
397 annotation.reference_start,
398 annotation.reference_end,
399 annotation.reference_length,
400 ) {
401 (Some(rs), Some(re), Some(rl)) => (rs, re, rl),
402 _ => (0, 0, 0),
403 };
404 (rs, re, self.msp_get_fire_features(s, e))
405 })
406 .collect();
407 }
408
409 pub fn dump_fire_feats(&self, out_buffer: &mut Box<dyn Write>) -> Result<(), anyhow::Error> {
410 for (s, e, row) in self.fire_feats.iter() {
411 if row.is_empty() {
412 continue;
413 }
414 let lead_feats = format!(
415 "{}\t{}\t{}\t{}\t",
416 self.rec.target_name,
417 s,
418 e,
419 String::from_utf8_lossy(self.rec.record.qname())
420 );
421 out_buffer.write_all(lead_feats.as_bytes())?;
422 out_buffer.write_all(row.iter().join("\t").as_bytes())?;
423 out_buffer.write_all(b"\n")?;
424 }
425 Ok(())
426 }
427
428 pub fn predict_with_xgb(
429 &self,
430 gbdt_model: &GBDT,
431 precision_converter: &MapPrecisionValues,
432 ) -> Vec<u8> {
433 let count = self.fire_feats.len();
434 if count == 0 {
435 return vec![];
436 }
437 let mut gbdt_data: DataVec = Vec::new();
439 for (_st, _en, window) in self.fire_feats.iter() {
440 if window.is_empty() {
441 continue;
442 }
443 let d = Data::new_test_data(window.to_vec(), None);
444 gbdt_data.push(d);
445 }
446 let predictions_without_short_ones = gbdt_model.predict(&gbdt_data);
447
448 let mut precisions = Vec::with_capacity(count);
450 let mut cur_pos = 0;
451 for (_st, _en, window) in self.fire_feats.iter() {
452 if window.is_empty() {
453 precisions.push(0);
454 } else {
455 let precision = precision_converter
456 .precision_from_float(predictions_without_short_ones[cur_pos]);
457 precisions.push(precision);
458 cur_pos += 1;
459 }
460 }
461 assert_eq!(cur_pos, predictions_without_short_ones.len());
463 assert_eq!(precisions.len(), count);
464 precisions
465 }
466}
467
468#[derive(Debug, Deserialize)]
469pub struct PrecisionTable {
470 pub columns: Vec<String>,
471 pub data: Vec<(f32, f32)>,
473}
474
475pub struct MapPrecisionValues {
476 pub map: BTreeMap<OrderedFloat<f32>, u8>,
477}
478
479impl MapPrecisionValues {
480 pub fn new(pt: &PrecisionTable) -> Self {
481 let mut map = BTreeMap::new();
483
484 for (mokapot_score, mokapot_q_value) in pt.data.iter() {
485 let precision = ((1.0 - mokapot_q_value) * 255.0).round() as u8;
486 map.insert(OrderedFloat(*mokapot_score), precision);
487 }
488 map.insert(
490 OrderedFloat(0.0),
491 *map.get(&OrderedFloat(0.0)).unwrap_or(&0),
492 );
493 Self { map }
494 }
495
496 pub fn precision_from_float(&self, value: f32) -> u8 {
498 let key = OrderedFloat(value);
499 let (less_key, less_val) = self
501 .map
502 .range(..key)
503 .next_back()
504 .unwrap_or((&OrderedFloat(0.0), &0));
505 let (more_key, more_val) = self
507 .map
508 .range(key..)
509 .next()
510 .unwrap_or((&OrderedFloat(1.0), &255));
511 if (more_key - key).abs() < (less_key - key).abs() {
512 *more_val
513 } else {
514 *less_val
515 }
516 }
517}