Skip to main content

fibertools_rs/subcommands/
center.rs

1use crate::cli::CenterOptions;
2use crate::fiber::FiberseqData;
3use crate::utils::bamlift::*;
4use crate::utils::bio_io;
5use crate::*;
6use bio::alphabets::dna::revcomp;
7use indicatif::{style, ProgressBar};
8use rayon::prelude::*;
9use rust_htslib::bam::Read;
10use rust_htslib::{bam, bam::ext::BamRecordExtensions};
11use std::fmt::Write;
12use std::io::{self, prelude::*};
13
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct CenterPosition {
16    pub chrom: String,
17    pub position: i64,
18    pub strand: char,
19}
20pub struct CenteredFiberData {
21    fiber: FiberseqData,
22    pub dist: Option<i64>,
23    center_position: CenterPosition,
24    pub offset: i64,
25    pub reference: bool,
26    pub simplify: bool,
27}
28
29impl CenteredFiberData {
30    pub fn new(
31        fiber: FiberseqData,
32        center_position: CenterPosition,
33        dist: Option<i64>,
34        reference: bool,
35        simplify: bool,
36    ) -> Option<Self> {
37        let (ref_offset, mol_offset) =
38            CenteredFiberData::find_offsets(&fiber.record, &center_position);
39        let offset = if reference { ref_offset } else { mol_offset };
40
41        let fiber = fiber.center(&center_position)?;
42
43        Some(CenteredFiberData {
44            fiber,
45            dist,
46            center_position,
47            offset,
48            reference,
49            simplify,
50        })
51    }
52    /// find both the ref and mol offsets
53    pub fn find_offsets(record: &bam::Record, center_position: &CenterPosition) -> (i64, i64) {
54        let ref_offset = center_position.position;
55        let mol_offset =
56            CenteredFiberData::find_offset(record, center_position.position).unwrap_or(0);
57        (ref_offset, mol_offset)
58    }
59
60    /// find the query position that corresponds to the central reference position
61    pub fn find_offset(record: &bam::Record, reference_position: i64) -> Option<i64> {
62        let read_center: Vec<i64> = lift_query_positions_exact(record, &[reference_position])
63            .ok()?
64            .into_iter()
65            .flatten()
66            .collect();
67        log::debug!(
68            "{}, {}, {}, {:?}",
69            reference_position,
70            record.reference_start(),
71            record.reference_end(),
72            read_center
73        );
74        if read_center.is_empty() {
75            None
76        } else {
77            Some(read_center[0])
78        }
79    }
80
81    /// Get the sequence
82    pub fn subset_sequence(&self) -> String {
83        if self.simplify {
84            return "N".to_string();
85        }
86        let dist = self.dist.unwrap_or(0);
87        let seq = self.fiber.record.seq().as_bytes();
88
89        let mut out_seq: Vec<u8> = vec![];
90        let st = self.offset - dist; //(self.offset - dist)
91        for pos in st..(self.offset + dist + 1) {
92            if pos < 0 || pos as usize >= seq.len() {
93                out_seq.push(b'N');
94            } else {
95                out_seq.push(seq[pos as usize]);
96            }
97        }
98        if self.center_position.strand == '-' {
99            out_seq = revcomp(out_seq);
100        }
101        //assert_eq!(out_seq.len() as i64, dist * 2 + 1);
102        String::from_utf8_lossy(&out_seq).to_string()
103    }
104
105    pub fn get_sequence(&self) -> String {
106        let forward_bases = if self.center_position.strand == '+' {
107            self.fiber.record.seq().as_bytes()
108        } else {
109            revcomp(self.fiber.record.seq().as_bytes())
110        };
111        String::from_utf8_lossy(&forward_bases).to_string()
112    }
113
114    pub fn leading_columns(&self) -> String {
115        let (mut c_query_start, mut c_query_end) = if self.reference {
116            (
117                self.fiber.record.reference_start() - self.center_position.position,
118                self.fiber.record.reference_end() - self.center_position.position,
119            )
120        } else {
121            let query_length = self.fiber.record.seq_len() as i64;
122            (-self.offset, query_length - self.offset)
123        };
124
125        if self.center_position.strand == '-' {
126            // Convert end to inclusive before negation to correctly
127            // handle half-open interval [start, end) under negation.
128            // Negating [a, b) should give [-b+1, -a+1), not [-b, -a).
129            // This matches the logic in apply_offset_helper().
130            c_query_end -= 1;
131            c_query_start = -c_query_start;
132            c_query_end = -c_query_end;
133            if c_query_start > c_query_end {
134                std::mem::swap(&mut c_query_start, &mut c_query_end);
135            }
136            // Convert end back to exclusive
137            c_query_end += 1;
138        }
139        format!(
140            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
141            self.center_position.chrom,
142            self.center_position.position,
143            self.center_position.strand,
144            self.subset_sequence(),
145            self.fiber.record.reference_start(),
146            self.fiber.record.reference_end(),
147            std::str::from_utf8(self.fiber.record.qname()).unwrap(),
148            self.fiber.rg,
149            self.fiber.get_hp(),
150            c_query_start,
151            c_query_end,
152            self.fiber.record.seq_len()
153        )
154    }
155
156    #[allow(clippy::type_complexity)]
157    fn grab_data(
158        &self,
159    ) -> (
160        Vec<Option<i64>>,
161        Vec<u8>,
162        Vec<Option<i64>>,
163        Vec<u8>,
164        Vec<Option<i64>>,
165        Vec<Option<i64>>,
166        Vec<Option<i64>>,
167        Vec<Option<i64>>,
168        Vec<u8>,
169    ) {
170        if self.reference {
171            (
172                self.fiber.m6a.reference_starts(),
173                self.fiber.m6a.qual(),
174                self.fiber.cpg.reference_starts(),
175                self.fiber.cpg.qual(),
176                self.fiber.nuc.reference_starts(),
177                self.fiber.nuc.reference_ends(),
178                self.fiber.msp.reference_starts(),
179                self.fiber.msp.reference_ends(),
180                self.fiber.msp.qual(),
181            )
182        } else {
183            (
184                self.fiber.m6a.option_starts(),
185                self.fiber.m6a.qual(),
186                self.fiber.cpg.option_starts(),
187                self.fiber.cpg.qual(),
188                self.fiber.nuc.option_starts(),
189                self.fiber.nuc.option_ends(),
190                self.fiber.msp.option_starts(),
191                self.fiber.msp.option_ends(),
192                self.fiber.msp.qual(),
193            )
194        }
195    }
196
197    pub fn write(&self) -> String {
198        let (m6a, m6a_qual, cpg, cpg_qual, nuc_st, nuc_en, msp_st, msp_en, fire) = self.grab_data();
199        format!(
200            "{}{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
201            self.leading_columns(),
202            join_by_str_option(&m6a, ","),
203            join_by_str(&m6a_qual, ","),
204            join_by_str_option(&cpg, ","),
205            join_by_str(&cpg_qual, ","),
206            join_by_str_option(&nuc_st, ","),
207            join_by_str_option(&nuc_en, ","),
208            join_by_str_option(&msp_st, ","),
209            join_by_str_option(&msp_en, ","),
210            join_by_str(&fire, ","),
211            self.get_sequence(),
212        )
213    }
214
215    pub fn leading_header() -> String {
216        format!(
217            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
218            "chrom",
219            "centering_position",
220            "strand",
221            "subset_sequence",
222            "reference_start",
223            "reference_end",
224            "query_name",
225            "RG",
226            "HP",
227            "centered_query_start",
228            "centered_query_end",
229            "query_length",
230        )
231    }
232    pub fn header() -> String {
233        format!(
234            "{}{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
235            CenteredFiberData::leading_header(),
236            "centered_m6a_positions",
237            "m6a_qual",
238            "centered_5mC_positions",
239            "5mC_qual",
240            "centered_nuc_starts",
241            "centered_nuc_ends",
242            "centered_msp_starts",
243            "centered_msp_ends",
244            "fire_qual",
245            "query_sequence"
246        )
247    }
248
249    pub fn long_header() -> String {
250        format!(
251            "{}{}\t{}\t{}\t{}\n",
252            CenteredFiberData::leading_header(),
253            "centered_position_type",
254            "centered_start",
255            "centered_end",
256            "centered_qual",
257        )
258    }
259
260    pub fn write_long(&self) -> String {
261        let mut rtn = String::new();
262        let (m6a, m6a_qual, cpg, cpg_qual, nuc_st, nuc_en, msp_st, msp_en, fire) = self.grab_data();
263        for (t, vals) in [
264            ("m6a", (m6a, None, Some(m6a_qual))),
265            ("5mC", (cpg, None, Some(cpg_qual))),
266            ("nuc", (nuc_st, Some(nuc_en), None)),
267            ("msp", (msp_st, Some(msp_en), Some(fire))),
268        ] {
269            let starts = vals.0.iter().collect::<Vec<_>>();
270            let ends: Vec<Option<i64>> = match vals.1 {
271                Some(ends) => ends.to_vec(),
272                None => vec![None; starts.len()],
273            };
274            let quals = match vals.2 {
275                Some(quals) => quals.to_vec(),
276                None => vec![0; starts.len()],
277            };
278            let mut write_count = 0;
279            for ((&st, &en), &qual) in starts.iter().zip(ends.iter()).zip(quals.iter()) {
280                let Some(st) = st else {
281                    continue;
282                };
283                let st = *st;
284                let en = en.unwrap_or(st + 1);
285
286                let mut write = true;
287                if let Some(dist) = self.dist {
288                    // skip writing if we are outside the motif range
289                    if en <= -dist || st > dist {
290                        write = false;
291                    }
292                };
293                if write {
294                    // add the leading data
295                    rtn.push_str(&self.leading_columns());
296                    // add the long form data
297                    write!(
298                        &mut rtn,
299                        "{}",
300                        format_args!("{}\t{}\t{}\t{}\n", t, st, en, qual)
301                    )
302                    .unwrap();
303                    write_count += 1;
304                }
305            }
306            log::debug!("{t}: {write_count}");
307        }
308
309        rtn
310    }
311}
312
313#[allow(clippy::too_many_arguments)]
314pub fn center(
315    records: Vec<bam::Record>,
316    header_view: &rust_htslib::bam::HeaderView,
317    center_position: CenterPosition,
318    opts: &CenterOptions,
319    buffer: &mut Box<dyn std::io::Write>,
320) {
321    let fiber_data = FiberseqData::from_records(records, header_view, &opts.input.filters);
322    let total = fiber_data.len();
323    let mut seen = 0;
324
325    let to_write: Vec<String> = fiber_data
326        .into_par_iter()
327        .map(|fiber| {
328            match CenteredFiberData::new(
329                fiber,
330                center_position.clone(),
331                opts.dist,
332                opts.reference,
333                opts.simplify,
334            ) {
335                Some(centered_fiber) => {
336                    if opts.wide {
337                        centered_fiber.write()
338                    } else {
339                        centered_fiber.write_long()
340                    }
341                }
342                None => "".to_string(),
343            }
344        })
345        .filter(|x| !x.is_empty())
346        .collect::<Vec<_>>();
347
348    for line in to_write {
349        seen += 1;
350        write_to_file(&line, buffer);
351    }
352
353    log::debug!(
354        "centering {} records of {} on {}:{}:{}",
355        seen,
356        total,
357        center_position.chrom,
358        center_position.position,
359        center_position.strand
360    );
361
362    if total - seen > 1 {
363        log::warn!(
364            "Unable to exactly map {}/{} reads at position {}:{}",
365            total - seen,
366            total,
367            center_position.chrom.clone(),
368            center_position.position
369        );
370    }
371}
372
373pub fn center_fiberdata(center_opts: &mut CenterOptions) -> anyhow::Result<()> {
374    let mut bam = center_opts.input.indexed_bam_reader();
375    let center_positions = read_center_positions(&center_opts.bed)?;
376
377    // header needed for the contig name...
378    let header_view = center_opts.input.header_view();
379    // output buffer
380    let mut buffer = bio_io::writer("-").unwrap();
381
382    if center_opts.wide {
383        bio_io::write_to_file(&CenteredFiberData::header(), &mut buffer);
384    } else {
385        bio_io::write_to_file(&CenteredFiberData::long_header(), &mut buffer);
386    }
387
388    let pb = ProgressBar::new(center_positions.len() as u64);
389    pb.set_style(
390        style::ProgressStyle::with_template(PROGRESS_STYLE)
391            .unwrap()
392            .progress_chars("##-"),
393    );
394
395    for center_position in center_positions {
396        bam.fetch((
397            &center_position.chrom,
398            center_position.position,
399            center_position.position + 1,
400        ))
401        .unwrap_or_else(|_| {
402            panic!(
403                "Failed to fetch region: {}:{}-{}",
404                &center_position.chrom,
405                center_position.position,
406                center_position.position + 1
407            )
408        });
409
410        let records: Vec<bam::Record> = center_opts
411            .input
412            .filters
413            .filter_on_bit_flags(bam.records())
414            .collect();
415
416        center(
417            records,
418            &header_view,
419            center_position,
420            center_opts,
421            &mut buffer,
422        );
423        pb.inc(1);
424    }
425    buffer.flush().unwrap();
426    pb.finish_with_message("\ndone");
427    Ok(())
428}
429
430pub fn read_center_positions(infile: &str) -> io::Result<Vec<CenterPosition>> {
431    let reader = bio_io::buffer_from(infile).expect("Failed to open bed file");
432    let mut rtn = vec![];
433    for line in reader.lines() {
434        let line = line?;
435        if line.starts_with('#') {
436            continue;
437        }
438        let tokens = line.split('\t').collect::<Vec<_>>();
439        assert!(tokens.len() >= 3);
440        let st = tokens[1].parse::<i64>().unwrap();
441        let en = tokens[2].parse::<i64>().unwrap();
442        // get the strand for the 6 or 4th column
443        let strand =
444            if (tokens.len() >= 6 && tokens[5] == "-") || (tokens.len() >= 4 && tokens[3] == "-") {
445                '-'
446            } else {
447                '+'
448            };
449
450        let (strand, position) = if strand == '-' {
451            ('-', en - 1)
452        } else {
453            ('+', st)
454        };
455        rtn.push(CenterPosition {
456            chrom: tokens[0].to_string(),
457            position,
458            strand,
459        });
460    }
461    Ok(rtn)
462}