Skip to main content

fibertools_rs/
fiber.rs

1use super::subcommands::center::CenterPosition;
2use super::subcommands::center::CenteredFiberData;
3use super::utils::input_bam::FiberFilters;
4use super::*;
5use crate::utils::bamannotations::*;
6use crate::utils::basemods::BaseMods;
7use crate::utils::bio_io::*;
8use crate::utils::ftexpression::apply_filter_fsd;
9use rayon::prelude::*;
10use rust_htslib::bam::Read;
11use rust_htslib::{bam, bam::ext::BamRecordExtensions, bam::record::Aux, bam::HeaderView};
12use std::collections::HashMap;
13use std::fmt::Write;
14
15#[derive(Debug, Clone, PartialEq)]
16pub struct FiberseqData {
17    pub record: bam::Record,
18    pub msp: Ranges,
19    pub nuc: Ranges,
20    pub m6a: Ranges,
21    pub cpg: Ranges,
22    pub base_mods: BaseMods,
23    pub ec: f32,
24    pub target_name: String,
25    pub rg: String,
26    pub center_position: Option<CenterPosition>,
27}
28
29impl FiberseqData {
30    pub fn new(record: bam::Record, target_name: Option<&String>, filters: &FiberFilters) -> Self {
31        // read group
32        let rg = if let Ok(Aux::String(f)) = record.aux(b"RG") {
33            log::trace!("{f}");
34            f
35        } else {
36            "."
37        }
38        .to_string();
39
40        let nuc_starts = get_u32_tag(&record, b"ns");
41        let msp_starts = get_u32_tag(&record, b"as");
42        let nuc_length = get_u32_tag(&record, b"nl");
43        let msp_length = get_u32_tag(&record, b"al");
44        let nuc = Ranges::new(&record, nuc_starts, None, Some(nuc_length));
45        let mut msp = Ranges::new(&record, msp_starts, None, Some(msp_length));
46        let msp_qual = get_u8_tag(&record, b"aq");
47        if !msp_qual.is_empty() {
48            msp.set_qual(msp_qual);
49        }
50
51        // get the number of passes
52        let ec = if let Ok(Aux::Float(f)) = record.aux(b"ec") {
53            log::trace!("{f}");
54            f
55        } else {
56            0.0
57        };
58
59        let target_name = match target_name {
60            Some(t) => t.clone(),
61            None => ".".to_string(),
62        };
63
64        // get fiberseq basemods
65        let mut base_mods = BaseMods::new(&record, filters.min_ml_score);
66        base_mods.filter_at_read_ends(filters.strip_starting_basemods);
67
68        //let (m6a, cpg) = FiberMods::new(&base_mods);
69        let m6a = base_mods.m6a();
70        let cpg = base_mods.cpg();
71
72        let mut fsd = FiberseqData {
73            record,
74            msp,
75            nuc,
76            m6a,
77            base_mods,
78            cpg,
79            ec,
80            target_name,
81            rg,
82            center_position: None,
83        };
84
85        apply_filter_fsd(&mut fsd, filters).expect("Failed to apply filter to FiberseqData");
86        fsd
87    }
88
89    pub fn dict_from_head_view(head_view: &HeaderView) -> HashMap<i32, String> {
90        if head_view.target_count() == 0 {
91            return HashMap::new();
92        }
93        let target_u8s = head_view.target_names();
94        let tids = target_u8s
95            .iter()
96            .map(|t| head_view.tid(t).expect("Unable to get tid"));
97        let target_names = target_u8s
98            .iter()
99            .map(|&a| String::from_utf8_lossy(a).to_string());
100
101        tids.zip(target_names)
102            .map(|(id, t)| (id as i32, t))
103            .collect()
104    }
105
106    pub fn target_name_from_tid(tid: i32, target_dict: &HashMap<i32, String>) -> Option<&String> {
107        target_dict.get(&tid)
108    }
109
110    pub fn from_records(
111        records: Vec<bam::Record>,
112        head_view: &HeaderView,
113        filters: &FiberFilters,
114    ) -> Vec<Self> {
115        let target_dict = Self::dict_from_head_view(head_view);
116        records
117            .into_par_iter()
118            .map(|r| {
119                let tid = r.tid();
120                (r, Self::target_name_from_tid(tid, &target_dict))
121            })
122            .map(|(r, target_name)| Self::new(r, target_name, filters))
123            .collect::<Vec<_>>()
124    }
125
126    //
127    // GET FUNCTIONS
128    //
129
130    pub fn get_qname(&self) -> String {
131        String::from_utf8_lossy(self.record.qname()).to_string()
132    }
133
134    pub fn get_rq(&self) -> Option<f32> {
135        if let Ok(Aux::Float(f)) = self.record.aux(b"rq") {
136            Some(f)
137        } else {
138            None
139        }
140    }
141
142    pub fn get_hp(&self) -> String {
143        if let Ok(Aux::U8(f)) = self.record.aux(b"HP") {
144            format!("H{f}")
145        } else {
146            "UNK".to_string()
147        }
148    }
149
150    /// Center all coordinates on the read using the offset attribute.
151    pub fn center(&self, center_position: &CenterPosition) -> Option<Self> {
152        // setup new fiberseq data object to return
153        let mut new = self.clone();
154        let (ref_offset, mol_offset) =
155            CenteredFiberData::find_offsets(&self.record, center_position);
156
157        // Apply offsets to all annotations using the new methods
158        new.m6a
159            .apply_offset(mol_offset, ref_offset, center_position.strand);
160        new.cpg
161            .apply_offset(mol_offset, ref_offset, center_position.strand);
162        new.msp
163            .apply_offset(mol_offset, ref_offset, center_position.strand);
164        new.nuc
165            .apply_offset(mol_offset, ref_offset, center_position.strand);
166
167        // Validate that MSPs still start and end on m6A marks after centering
168        new.validate_msp_m6a_alignment();
169
170        Some(new)
171    }
172
173    /// Validate that all MSP boundaries align with m6A positions after centering
174    fn validate_msp_m6a_alignment(&self) {
175        let m6a_positions = self.m6a.starts();
176        let msp_boundaries: Vec<i64> = self
177            .msp
178            .starts()
179            .into_iter()
180            .chain(self.msp.ends().into_iter().map(|x| x - 1))
181            .collect();
182
183        if m6a_positions.is_empty() || msp_boundaries.is_empty() {
184            return; // Skip validation if no data
185        }
186
187        for msp_pos in &msp_boundaries {
188            if !m6a_positions.contains(msp_pos) {
189                log::warn!(
190                    "MSP boundary at position {} does not align with m6A mark after centering in read {}",
191                    msp_pos,
192                    String::from_utf8_lossy(self.record.qname())
193                );
194            }
195        }
196    }
197
198    //
199    //  WRITE BED12 FUNCTIONS
200    //
201    pub fn write_msp(&self, reference: bool) -> String {
202        let (starts, _ends, lengths) = if reference {
203            (
204                self.msp.reference_starts(),
205                self.msp.reference_ends(),
206                self.msp.reference_lengths(),
207            )
208        } else {
209            (
210                self.msp.option_starts(),
211                self.msp.option_ends(),
212                self.msp.option_lengths(),
213            )
214        };
215        self.to_bed12(reference, &starts, &lengths, LINKER_COLOR)
216    }
217
218    pub fn write_nuc(&self, reference: bool) -> String {
219        let (starts, _ends, lengths) = if reference {
220            (
221                self.nuc.reference_starts(),
222                self.nuc.reference_ends(),
223                self.nuc.reference_lengths(),
224            )
225        } else {
226            (
227                self.nuc.option_starts(),
228                self.nuc.option_ends(),
229                self.nuc.option_lengths(),
230            )
231        };
232        self.to_bed12(reference, &starts, &lengths, NUC_COLOR)
233    }
234
235    pub fn write_m6a(&self, reference: bool) -> String {
236        let starts = if reference {
237            self.m6a.reference_starts()
238        } else {
239            self.m6a.option_starts()
240        };
241        let lengths = vec![Some(1); starts.len()];
242        self.to_bed12(reference, &starts, &lengths, M6A_COLOR)
243    }
244
245    pub fn write_cpg(&self, reference: bool) -> String {
246        let starts = if reference {
247            self.cpg.reference_starts()
248        } else {
249            self.cpg.option_starts()
250        };
251        let lengths = vec![Some(1); starts.len()];
252        self.to_bed12(reference, &starts, &lengths, CPG_COLOR)
253    }
254
255    pub fn to_bed12(
256        &self,
257        reference: bool,
258        starts: &[Option<i64>],
259        lengths: &[Option<i64>],
260        color: &str,
261    ) -> String {
262        if starts.is_empty() {
263            return "".to_string();
264        }
265        // skip if no alignments are here
266        if self.record.is_unmapped() && reference {
267            return "".to_string();
268        }
269
270        let ct;
271        let start;
272        let end;
273        let name = String::from_utf8_lossy(self.record.qname()).to_string();
274        let mut rtn: String = String::with_capacity(0);
275        if reference {
276            ct = &self.target_name;
277            start = self.record.reference_start();
278            end = self.record.reference_end();
279        } else {
280            ct = &name;
281            start = 0;
282            end = self.record.seq_len() as i64;
283        }
284        let score = self.ec.round() as i64;
285        let strand = if self.record.is_reverse() { '-' } else { '+' };
286        // filter out positions that do not have an exact liftover
287        let (filtered_starts, filtered_lengths): (Vec<i64>, Vec<i64>) = starts
288            .iter()
289            .flatten()
290            .zip(lengths.iter().flatten())
291            .unzip();
292        // skip empty ones
293        if filtered_lengths.is_empty() || filtered_starts.is_empty() {
294            return "".to_string();
295        }
296        let b_ct = filtered_starts.len() + 2;
297        let b_ln: String = filtered_lengths
298            .iter()
299            .map(|&ln| ln.to_string() + ",")
300            .collect();
301        let b_st: String = filtered_starts
302            .iter()
303            .map(|&st| (st - start).to_string() + ",")
304            .collect();
305        assert_eq!(filtered_lengths.len(), filtered_starts.len());
306
307        rtn.push_str(ct);
308        rtn.push('\t');
309        rtn.push_str(&start.to_string());
310        rtn.push('\t');
311        rtn.push_str(&end.to_string());
312        rtn.push('\t');
313        rtn.push_str(&name);
314        rtn.push('\t');
315        rtn.push_str(&score.to_string());
316        rtn.push('\t');
317        rtn.push(strand);
318        rtn.push('\t');
319        rtn.push_str(&start.to_string());
320        rtn.push('\t');
321        rtn.push_str(&end.to_string());
322        rtn.push('\t');
323        rtn.push_str(color);
324        rtn.push('\t');
325        rtn.push_str(&b_ct.to_string());
326        rtn.push_str("\t0,"); // add a zero length start
327        rtn.push_str(&b_ln);
328        rtn.push_str("1\t0,"); // add a 1 base length and a 0 start point
329        rtn.push_str(&b_st);
330        write!(&mut rtn, "{}", format_args!("{}\n", end - start - 1)).unwrap();
331        rtn
332    }
333
334    //
335    // WRITE ALL FUNCTIONS
336    //
337
338    pub fn all_header(simplify: bool, quality: bool) -> String {
339        let mut x = format!(
340            "#{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
341            "ct", "st", "en", "fiber", "score", "strand", "sam_flag", "HP", "RG", "fiber_length",
342        );
343        if !simplify {
344            x.push_str("fiber_sequence\t")
345        }
346        if quality {
347            x.push_str("fiber_qual\t")
348        }
349        x.push_str(&format!(
350            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
351            "ec",
352            "rq",
353            "total_AT_bp",
354            "total_m6a_bp",
355            "total_nuc_bp",
356            "total_msp_bp",
357            "total_5mC_bp",
358            "nuc_starts",
359            "nuc_lengths",
360            "ref_nuc_starts",
361            "ref_nuc_lengths",
362            "msp_starts",
363            "msp_lengths",
364            "fire",
365            "ref_msp_starts",
366            "ref_msp_lengths",
367            "m6a",
368            "ref_m6a",
369            "m6a_qual",
370            "5mC",
371            "ref_5mC",
372            "5mC_qual"
373        ));
374        x
375    }
376
377    pub fn write_all(&self, simplify: bool, quality: bool) -> String {
378        // PB features
379        let name = std::str::from_utf8(self.record.qname()).unwrap();
380        let score = self.ec.round() as i64;
381        let q_len = self.record.seq_len() as i64;
382        let rq = match self.get_rq() {
383            Some(x) => format!("{x}"),
384            None => ".".to_string(),
385        };
386        // reference features
387        let ct;
388        let start;
389        let end;
390        let strand;
391        if self.record.is_unmapped() {
392            ct = ".";
393            start = 0;
394            end = 0;
395            strand = '.';
396        } else {
397            ct = &self.target_name;
398            start = self.record.reference_start();
399            end = self.record.reference_end();
400            strand = if self.record.is_reverse() { '-' } else { '+' };
401        }
402        let sam_flag = self.record.flags();
403        let hp = self.get_hp();
404
405        let at_count = self
406            .record
407            .seq()
408            .as_bytes()
409            .iter()
410            .filter(|&x| *x == b'A' || *x == b'T')
411            .count() as i64;
412
413        // get the info
414        let m6a_count = self.m6a.annotations.len();
415        let m6a_qual = self.m6a.qual().iter().map(|a| Some(*a as i64)).collect();
416        let cpg_count = self.cpg.annotations.len();
417        let cpg_qual = self.cpg.qual().iter().map(|a| Some(*a as i64)).collect();
418        let fire = self.msp.qual().iter().map(|a| Some(*a as i64)).collect();
419
420        // write the features
421        let mut rtn = String::with_capacity(0);
422        // add first things 7
423        rtn.write_fmt(format_args!(
424            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
425            ct, start, end, name, score, strand, sam_flag, hp, self.rg, q_len
426        ))
427        .unwrap();
428        // add sequence
429        if !simplify {
430            rtn.write_fmt(format_args!(
431                "{}\t",
432                String::from_utf8_lossy(&self.record.seq().as_bytes()),
433            ))
434            .unwrap();
435        }
436        if quality {
437            // TODO add quality offset
438            rtn.write_fmt(format_args!(
439                "{}\t",
440                String::from_utf8_lossy(
441                    &self
442                        .record
443                        .qual()
444                        .iter()
445                        .map(|x| x + 33)
446                        .collect::<Vec<u8>>()
447                ),
448            ))
449            .unwrap();
450        }
451        // add PB features
452        let total_nuc_bp = self.nuc.lengths().iter().sum::<i64>();
453        let total_msp_bp = self.msp.lengths().iter().sum::<i64>();
454        rtn.write_fmt(format_args!(
455            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
456            self.ec, rq, at_count, m6a_count, total_nuc_bp, total_msp_bp, cpg_count
457        ))
458        .unwrap();
459        // add fiber features
460        let vecs = [
461            self.nuc.option_starts(),
462            self.nuc.option_lengths(),
463            self.nuc.reference_starts(),
464            self.nuc.reference_lengths(),
465            self.msp.option_starts(),
466            self.msp.option_lengths(),
467            fire,
468            self.msp.reference_starts(),
469            self.msp.reference_lengths(),
470            self.m6a.option_starts(),
471            self.m6a.reference_starts(),
472            m6a_qual,
473            self.cpg.option_starts(),
474            self.cpg.reference_starts(),
475            cpg_qual,
476        ];
477        for vec in &vecs {
478            if vec.is_empty() {
479                rtn.push('.');
480                rtn.push('\t');
481            } else {
482                let z: String = vec
483                    .iter()
484                    .map(|x| match x {
485                        Some(y) => *y,
486                        None => -1,
487                    })
488                    .map(|x| x.to_string() + ",")
489                    .collect();
490                rtn.write_fmt(format_args!("{z}\t")).unwrap();
491            }
492        }
493        // replace the last tab with a newline
494        let len = rtn.len();
495        rtn.replace_range(len - 1..len, "\n");
496
497        rtn
498    }
499}
500
501pub struct FiberseqRecords<'a> {
502    bam_chunk: BamChunk<'a>,
503    header: HeaderView,
504    filters: FiberFilters,
505    cur_chunk: Vec<FiberseqData>,
506}
507
508impl<'a> FiberseqRecords<'a> {
509    pub fn new(bam: &'a mut bam::Reader, filters: FiberFilters) -> Self {
510        let header = bam.header().clone();
511        let bam_recs = bam.records();
512        let mut bam_chunk = BamChunk::new(bam_recs, None);
513        bam_chunk.set_bit_flag_filter(filters.bit_flag);
514        let cur_chunk: Vec<FiberseqData> = vec![];
515        FiberseqRecords {
516            bam_chunk,
517            header,
518            filters,
519            cur_chunk,
520        }
521    }
522}
523
524impl Iterator for FiberseqRecords<'_> {
525    type Item = FiberseqData;
526
527    fn next(&mut self) -> Option<Self::Item> {
528        // if we are out of data check for another chunk in the bam
529        if self.cur_chunk.is_empty() {
530            match self.bam_chunk.next() {
531                Some(recs) => {
532                    self.cur_chunk = FiberseqData::from_records(recs, &self.header, &self.filters);
533                    // we will be popping from this list so we want to remove the first element first, not the last
534                    self.cur_chunk.reverse();
535                }
536                None => return None,
537            }
538        }
539        self.cur_chunk.pop()
540    }
541}