fibertools_rs/
fiber.rs

1use super::subcommands::center::CenterPosition;
2use super::subcommands::center::CenteredFiberData;
3use super::utils::input_bam::FiberFilters;
4use super::*;
5use crate::utils::bamranges::*;
6use crate::utils::basemods::BaseMods;
7use crate::utils::bio_io::*;
8use crate::utils::ftexpression::apply_filter_fsd;
9use rayon::prelude::*;
10use rust_htslib::bam::Read;
11use rust_htslib::{bam, bam::ext::BamRecordExtensions, bam::record::Aux, bam::HeaderView};
12use std::collections::HashMap;
13use std::fmt::Write;
14
15#[derive(Debug, Clone, PartialEq)]
16pub struct FiberseqData {
17    pub record: bam::Record,
18    pub msp: Ranges,
19    pub nuc: Ranges,
20    pub m6a: Ranges,
21    pub cpg: Ranges,
22    pub base_mods: BaseMods,
23    pub ec: f32,
24    pub target_name: String,
25    pub rg: String,
26    pub center_position: Option<CenterPosition>,
27}
28
29impl FiberseqData {
30    pub fn new(record: bam::Record, target_name: Option<&String>, filters: &FiberFilters) -> Self {
31        // read group
32        let rg = if let Ok(Aux::String(f)) = record.aux(b"RG") {
33            log::trace!("{f}");
34            f
35        } else {
36            "."
37        }
38        .to_string();
39
40        let nuc_starts = get_u32_tag(&record, b"ns");
41        let msp_starts = get_u32_tag(&record, b"as");
42        let nuc_length = get_u32_tag(&record, b"nl");
43        let msp_length = get_u32_tag(&record, b"al");
44        let nuc = Ranges::new(&record, nuc_starts, None, Some(nuc_length));
45        let mut msp = Ranges::new(&record, msp_starts, None, Some(msp_length));
46        let msp_qual = get_u8_tag(&record, b"aq");
47        if !msp_qual.is_empty() {
48            msp.set_qual(msp_qual);
49        }
50
51        // get the number of passes
52        let ec = if let Ok(Aux::Float(f)) = record.aux(b"ec") {
53            log::trace!("{f}");
54            f
55        } else {
56            0.0
57        };
58
59        let target_name = match target_name {
60            Some(t) => t.clone(),
61            None => ".".to_string(),
62        };
63
64        // get fiberseq basemods
65        let mut base_mods = BaseMods::new(&record, filters.min_ml_score);
66        base_mods.filter_at_read_ends(filters.strip_starting_basemods);
67
68        //let (m6a, cpg) = FiberMods::new(&base_mods);
69        let m6a = base_mods.m6a();
70        let cpg = base_mods.cpg();
71
72        let mut fsd = FiberseqData {
73            record,
74            msp,
75            nuc,
76            m6a,
77            base_mods,
78            cpg,
79            ec,
80            target_name,
81            rg,
82            center_position: None,
83        };
84
85        apply_filter_fsd(&mut fsd, filters).expect("Failed to apply filter to FiberseqData");
86        fsd
87    }
88
89    pub fn dict_from_head_view(head_view: &HeaderView) -> HashMap<i32, String> {
90        if head_view.target_count() == 0 {
91            return HashMap::new();
92        }
93        let target_u8s = head_view.target_names();
94        let tids = target_u8s
95            .iter()
96            .map(|t| head_view.tid(t).expect("Unable to get tid"));
97        let target_names = target_u8s
98            .iter()
99            .map(|&a| String::from_utf8_lossy(a).to_string());
100
101        tids.zip(target_names)
102            .map(|(id, t)| (id as i32, t))
103            .collect()
104    }
105
106    pub fn target_name_from_tid(tid: i32, target_dict: &HashMap<i32, String>) -> Option<&String> {
107        target_dict.get(&tid)
108    }
109
110    pub fn from_records(
111        records: Vec<bam::Record>,
112        head_view: &HeaderView,
113        filters: &FiberFilters,
114    ) -> Vec<Self> {
115        let target_dict = Self::dict_from_head_view(head_view);
116        records
117            .into_par_iter()
118            .map(|r| {
119                let tid = r.tid();
120                (r, Self::target_name_from_tid(tid, &target_dict))
121            })
122            .map(|(r, target_name)| Self::new(r, target_name, filters))
123            .collect::<Vec<_>>()
124    }
125
126    //
127    // GET FUNCTIONS
128    //
129
130    pub fn get_qname(&self) -> String {
131        String::from_utf8_lossy(self.record.qname()).to_string()
132    }
133
134    pub fn get_rq(&self) -> Option<f32> {
135        if let Ok(Aux::Float(f)) = self.record.aux(b"rq") {
136            Some(f)
137        } else {
138            None
139        }
140    }
141
142    pub fn get_hp(&self) -> String {
143        if let Ok(Aux::U8(f)) = self.record.aux(b"HP") {
144            format!("H{f}")
145        } else {
146            "UNK".to_string()
147        }
148    }
149
150    //
151    //  CENTERING FUNCTIONS
152    //
153
154    /// Center positions on the read around the reference position.
155    fn apply_offset(positions: &mut [Option<i64>], offset: i64, strand: char) {
156        for pos in positions.iter_mut().flatten() {
157            // bp is unaligned
158            if *pos == -1 {
159                *pos = i64::MIN;
160                continue;
161            }
162            // else
163            *pos -= offset;
164            if strand == '-' {
165                *pos = -*pos;
166            }
167        }
168        if strand == '-' {
169            positions.reverse();
170        }
171    }
172
173    /// Center ranges on the read around the reference position.
174    fn offset_range(
175        starts: &mut [Option<i64>],
176        ends: &mut [Option<i64>],
177        offset: i64,
178        strand: char,
179    ) {
180        FiberseqData::apply_offset(starts, offset, strand);
181        FiberseqData::apply_offset(ends, offset, strand);
182        for (start, end) in starts.iter_mut().zip(ends.iter_mut()) {
183            if start > end {
184                std::mem::swap(start, end);
185            }
186        }
187    }
188
189    /// Center all coordinates on the read using the offset attribute.
190    pub fn center(&self, center_position: &CenterPosition) -> Option<Self> {
191        // setup new fiberseq data object to return
192        let mut new = self.clone();
193        let (ref_offset, mol_offset) =
194            CenteredFiberData::find_offsets(&self.record, center_position);
195
196        // move basemods
197        FiberseqData::apply_offset(&mut new.m6a.starts, mol_offset, center_position.strand);
198        FiberseqData::apply_offset(
199            &mut new.m6a.reference_starts,
200            ref_offset,
201            center_position.strand,
202        );
203        FiberseqData::apply_offset(&mut new.cpg.starts, mol_offset, center_position.strand);
204        FiberseqData::apply_offset(
205            &mut new.cpg.reference_starts,
206            ref_offset,
207            center_position.strand,
208        );
209        // move ranges
210        FiberseqData::offset_range(
211            &mut new.msp.starts,
212            &mut new.msp.ends,
213            mol_offset,
214            center_position.strand,
215        );
216        FiberseqData::offset_range(
217            &mut new.msp.reference_starts,
218            &mut new.msp.reference_ends,
219            ref_offset,
220            center_position.strand,
221        );
222        FiberseqData::offset_range(
223            &mut new.nuc.starts,
224            &mut new.nuc.ends,
225            mol_offset,
226            center_position.strand,
227        );
228        FiberseqData::offset_range(
229            &mut new.nuc.reference_starts,
230            &mut new.nuc.reference_ends,
231            ref_offset,
232            center_position.strand,
233        );
234        // correct orientations
235        if center_position.strand == '-' {
236            new.m6a.qual.reverse();
237            new.cpg.qual.reverse();
238            new.msp.lengths.reverse();
239            new.msp.reference_lengths.reverse();
240            new.msp.qual.reverse();
241            new.nuc.lengths.reverse();
242            new.nuc.reference_lengths.reverse();
243        }
244        // TODO update start and end
245        // TODO update aligned block pairs
246        Some(new)
247    }
248
249    //
250    //  WRITE BED12 FUNCTIONS
251    //
252    pub fn write_msp(&self, reference: bool) -> String {
253        let (starts, _ends, lengths) = if reference {
254            (
255                &self.msp.reference_starts,
256                &self.msp.reference_ends,
257                &self.msp.reference_lengths,
258            )
259        } else {
260            (&self.msp.starts, &self.msp.ends, &self.msp.lengths)
261        };
262        self.to_bed12(reference, starts, lengths, LINKER_COLOR)
263    }
264
265    pub fn write_nuc(&self, reference: bool) -> String {
266        let (starts, _ends, lengths) = if reference {
267            (
268                &self.nuc.reference_starts,
269                &self.nuc.reference_ends,
270                &self.nuc.reference_lengths,
271            )
272        } else {
273            (&self.nuc.starts, &self.nuc.ends, &self.nuc.lengths)
274        };
275        self.to_bed12(reference, starts, lengths, NUC_COLOR)
276    }
277
278    pub fn write_m6a(&self, reference: bool) -> String {
279        let starts = if reference {
280            &self.m6a.reference_starts
281        } else {
282            &self.m6a.starts
283        };
284        let lengths = vec![Some(1); starts.len()];
285        self.to_bed12(reference, starts, &lengths, M6A_COLOR)
286    }
287
288    pub fn write_cpg(&self, reference: bool) -> String {
289        let starts = if reference {
290            &self.cpg.reference_starts
291        } else {
292            &self.cpg.starts
293        };
294        let lengths = vec![Some(1); starts.len()];
295        self.to_bed12(reference, starts, &lengths, CPG_COLOR)
296    }
297
298    pub fn to_bed12(
299        &self,
300        reference: bool,
301        starts: &[Option<i64>],
302        lengths: &[Option<i64>],
303        color: &str,
304    ) -> String {
305        if starts.is_empty() {
306            return "".to_string();
307        }
308        // skip if no alignments are here
309        if self.record.is_unmapped() && reference {
310            return "".to_string();
311        }
312
313        let ct;
314        let start;
315        let end;
316        let name = String::from_utf8_lossy(self.record.qname()).to_string();
317        let mut rtn: String = String::with_capacity(0);
318        if reference {
319            ct = &self.target_name;
320            start = self.record.reference_start();
321            end = self.record.reference_end();
322        } else {
323            ct = &name;
324            start = 0;
325            end = self.record.seq_len() as i64;
326        }
327        let score = self.ec.round() as i64;
328        let strand = if self.record.is_reverse() { '-' } else { '+' };
329        // filter out positions that do not have an exact liftover
330        let (filtered_starts, filtered_lengths): (Vec<i64>, Vec<i64>) = starts
331            .iter()
332            .flatten()
333            .zip(lengths.iter().flatten())
334            .unzip();
335        // skip empty ones
336        if filtered_lengths.is_empty() || filtered_starts.is_empty() {
337            return "".to_string();
338        }
339        let b_ct = filtered_starts.len() + 2;
340        let b_ln: String = filtered_lengths
341            .iter()
342            .map(|&ln| ln.to_string() + ",")
343            .collect();
344        let b_st: String = filtered_starts
345            .iter()
346            .map(|&st| (st - start).to_string() + ",")
347            .collect();
348        assert_eq!(filtered_lengths.len(), filtered_starts.len());
349
350        rtn.push_str(ct);
351        rtn.push('\t');
352        rtn.push_str(&start.to_string());
353        rtn.push('\t');
354        rtn.push_str(&end.to_string());
355        rtn.push('\t');
356        rtn.push_str(&name);
357        rtn.push('\t');
358        rtn.push_str(&score.to_string());
359        rtn.push('\t');
360        rtn.push(strand);
361        rtn.push('\t');
362        rtn.push_str(&start.to_string());
363        rtn.push('\t');
364        rtn.push_str(&end.to_string());
365        rtn.push('\t');
366        rtn.push_str(color);
367        rtn.push('\t');
368        rtn.push_str(&b_ct.to_string());
369        rtn.push_str("\t0,"); // add a zero length start
370        rtn.push_str(&b_ln);
371        rtn.push_str("1\t0,"); // add a 1 base length and a 0 start point
372        rtn.push_str(&b_st);
373        write!(&mut rtn, "{}", format_args!("{}\n", end - start - 1)).unwrap();
374        rtn
375    }
376
377    //
378    // WRITE ALL FUNCTIONS
379    //
380
381    pub fn all_header(simplify: bool, quality: bool) -> String {
382        let mut x = format!(
383            "#{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
384            "ct", "st", "en", "fiber", "score", "strand", "sam_flag", "HP", "RG", "fiber_length",
385        );
386        if !simplify {
387            x.push_str("fiber_sequence\t")
388        }
389        if quality {
390            x.push_str("fiber_qual\t")
391        }
392        x.push_str(&format!(
393            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
394            "ec",
395            "rq",
396            "total_AT_bp",
397            "total_m6a_bp",
398            "total_nuc_bp",
399            "total_msp_bp",
400            "total_5mC_bp",
401            "nuc_starts",
402            "nuc_lengths",
403            "ref_nuc_starts",
404            "ref_nuc_lengths",
405            "msp_starts",
406            "msp_lengths",
407            "fire",
408            "ref_msp_starts",
409            "ref_msp_lengths",
410            "m6a",
411            "ref_m6a",
412            "m6a_qual",
413            "5mC",
414            "ref_5mC",
415            "5mC_qual"
416        ));
417        x
418    }
419
420    pub fn write_all(&self, simplify: bool, quality: bool) -> String {
421        // PB features
422        let name = std::str::from_utf8(self.record.qname()).unwrap();
423        let score = self.ec.round() as i64;
424        let q_len = self.record.seq_len() as i64;
425        let rq = match self.get_rq() {
426            Some(x) => format!("{x}"),
427            None => ".".to_string(),
428        };
429        // reference features
430        let ct;
431        let start;
432        let end;
433        let strand;
434        if self.record.is_unmapped() {
435            ct = ".";
436            start = 0;
437            end = 0;
438            strand = '.';
439        } else {
440            ct = &self.target_name;
441            start = self.record.reference_start();
442            end = self.record.reference_end();
443            strand = if self.record.is_reverse() { '-' } else { '+' };
444        }
445        let sam_flag = self.record.flags();
446        let hp = self.get_hp();
447
448        let at_count = self
449            .record
450            .seq()
451            .as_bytes()
452            .iter()
453            .filter(|&x| *x == b'A' || *x == b'T')
454            .count() as i64;
455
456        // get the info
457        let m6a_count = self.m6a.starts.len();
458        let m6a_qual = self.m6a.qual.iter().map(|a| Some(*a as i64)).collect();
459        let cpg_count = self.cpg.starts.len();
460        let cpg_qual = self.cpg.qual.iter().map(|a| Some(*a as i64)).collect();
461        let fire = self.msp.qual.iter().map(|a| Some(*a as i64)).collect();
462
463        // write the features
464        let mut rtn = String::with_capacity(0);
465        // add first things 7
466        rtn.write_fmt(format_args!(
467            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
468            ct, start, end, name, score, strand, sam_flag, hp, self.rg, q_len
469        ))
470        .unwrap();
471        // add sequence
472        if !simplify {
473            rtn.write_fmt(format_args!(
474                "{}\t",
475                String::from_utf8_lossy(&self.record.seq().as_bytes()),
476            ))
477            .unwrap();
478        }
479        if quality {
480            // TODO add quality offset
481            rtn.write_fmt(format_args!(
482                "{}\t",
483                String::from_utf8_lossy(
484                    &self
485                        .record
486                        .qual()
487                        .iter()
488                        .map(|x| x + 33)
489                        .collect::<Vec<u8>>()
490                ),
491            ))
492            .unwrap();
493        }
494        // add PB features
495        let total_nuc_bp = self.nuc.lengths.iter().flatten().sum::<i64>();
496        let total_msp_bp = self.msp.lengths.iter().flatten().sum::<i64>();
497        rtn.write_fmt(format_args!(
498            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
499            self.ec, rq, at_count, m6a_count, total_nuc_bp, total_msp_bp, cpg_count
500        ))
501        .unwrap();
502        // add fiber features
503        for vec in &[
504            &self.nuc.starts,
505            &self.nuc.lengths,
506            &self.nuc.reference_starts,
507            &self.nuc.reference_lengths,
508            &self.msp.starts,
509            &self.msp.lengths,
510            &fire,
511            &self.msp.reference_starts,
512            &self.msp.reference_lengths,
513            &self.m6a.starts,
514            &self.m6a.reference_starts,
515            &m6a_qual,
516            &self.cpg.starts,
517            &self.cpg.reference_starts,
518            &cpg_qual,
519        ] {
520            if vec.is_empty() {
521                rtn.push('.');
522                rtn.push('\t');
523            } else {
524                let z: String = vec
525                    .iter()
526                    .map(|x| match x {
527                        Some(y) => *y,
528                        None => -1,
529                    })
530                    .map(|x| x.to_string() + ",")
531                    .collect();
532                rtn.write_fmt(format_args!("{z}\t")).unwrap();
533            }
534        }
535        // replace the last tab with a newline
536        let len = rtn.len();
537        rtn.replace_range(len - 1..len, "\n");
538
539        rtn
540    }
541}
542
543pub struct FiberseqRecords<'a> {
544    bam_chunk: BamChunk<'a>,
545    header: HeaderView,
546    filters: FiberFilters,
547    cur_chunk: Vec<FiberseqData>,
548}
549
550impl<'a> FiberseqRecords<'a> {
551    pub fn new(bam: &'a mut bam::Reader, filters: FiberFilters) -> Self {
552        let header = bam.header().clone();
553        let bam_recs = bam.records();
554        let mut bam_chunk = BamChunk::new(bam_recs, None);
555        bam_chunk.set_bit_flag_filter(filters.bit_flag);
556        let cur_chunk: Vec<FiberseqData> = vec![];
557        FiberseqRecords {
558            bam_chunk,
559            header,
560            filters,
561            cur_chunk,
562        }
563    }
564}
565
566impl Iterator for FiberseqRecords<'_> {
567    type Item = FiberseqData;
568
569    fn next(&mut self) -> Option<Self::Item> {
570        // if we are out of data check for another chunk in the bam
571        if self.cur_chunk.is_empty() {
572            match self.bam_chunk.next() {
573                Some(recs) => {
574                    self.cur_chunk = FiberseqData::from_records(recs, &self.header, &self.filters);
575                    // we will be popping from this list so we want to remove the first element first, not the last
576                    self.cur_chunk.reverse();
577                }
578                None => return None,
579            }
580        }
581        self.cur_chunk.pop()
582    }
583}