exon_bam/
indexed_async_batch_stream.rs

1// Copyright 2023 WHERE TRUE Technologies.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow::{
18    error::{ArrowError, Result as ArrowResult},
19    record_batch::RecordBatch,
20};
21use exon_common::ExonArrayBuilder;
22use futures::Stream;
23use noodles::{
24    core::{region::Interval, Position, Region},
25    sam::{alignment::RecordBuf, header::ReferenceSequences, Header},
26};
27use tokio::io::AsyncBufRead;
28
29use crate::ExonBAMError;
30
31use super::{array_builder::BAMArrayBuilder, BAMConfig};
32
33/// This is a semi-lazy record that can be used to filter on the region without
34/// having to decode the entire record or re-decode the cigar.
35pub(crate) struct SemiLazyRecord {
36    inner: RecordBuf,
37    alignment_end: Option<Position>,
38}
39
40impl TryFrom<RecordBuf> for SemiLazyRecord {
41    type Error = ExonBAMError;
42
43    fn try_from(record: RecordBuf) -> Result<Self, Self::Error> {
44        let alignment_end = record.alignment_end();
45
46        Ok(Self {
47            inner: record,
48            alignment_end,
49        })
50    }
51}
52
53impl SemiLazyRecord {
54    pub fn alignment_start(&self) -> Option<Position> {
55        self.inner.alignment_start()
56    }
57
58    pub fn alignment_end(&self) -> Option<Position> {
59        self.alignment_end
60    }
61
62    pub fn record(&self) -> &RecordBuf {
63        &self.inner
64    }
65
66    pub fn intersects(
67        &self,
68        region_sequence_id: usize,
69        region_interval: &Interval,
70    ) -> std::io::Result<bool> {
71        let reference_sequence_id = self.inner.reference_sequence_id();
72
73        let alignment_start = self.alignment_start();
74        let alignment_end = self.alignment_end();
75
76        match (reference_sequence_id, alignment_start, alignment_end) {
77            (Some(id), Some(start), Some(end)) => {
78                let alignment_interval = (start..=end).into();
79                let intersects = region_interval.intersects(alignment_interval);
80
81                let same_sequence = id == region_sequence_id;
82
83                Ok(intersects && same_sequence)
84            }
85            _ => Ok(false),
86        }
87    }
88}
89
90pub struct IndexedAsyncBatchStream<R>
91where
92    R: AsyncBufRead + Unpin,
93{
94    /// The underlying reader.
95    reader: noodles::bam::AsyncReader<noodles::bgzf::AsyncReader<R>>,
96
97    /// The BAM configuration.
98    config: Arc<BAMConfig>,
99
100    /// The reference sequences.
101    header: Arc<Header>,
102
103    /// The region reference sequence.
104    region_reference: usize,
105
106    /// The region interval.
107    region_interval: Interval,
108
109    /// The max uncompressed bytes read.
110    max_bytes: Option<u16>,
111}
112
113fn get_reference_sequence_for_region(
114    reference_sequences: &ReferenceSequences,
115    region: &Region,
116) -> std::io::Result<usize> {
117    reference_sequences
118        .get_index_of(region.name())
119        .ok_or_else(|| {
120            std::io::Error::new(
121                std::io::ErrorKind::InvalidData,
122                "invalid reference sequence", // format!("invalid reference sequence: {}", region.name()),
123            )
124        })
125}
126
127impl<R> IndexedAsyncBatchStream<R>
128where
129    R: AsyncBufRead + Unpin,
130{
131    pub fn try_new(
132        reader: noodles::bam::AsyncReader<noodles::bgzf::AsyncReader<R>>,
133        config: Arc<BAMConfig>,
134        header: Arc<Header>,
135        region: Arc<Region>,
136    ) -> std::io::Result<Self> {
137        let region_reference =
138            get_reference_sequence_for_region(header.reference_sequences(), &region)?;
139        let region_interval = region.interval();
140
141        Ok(Self {
142            reader,
143            config,
144            header,
145            region_reference,
146            region_interval,
147            max_bytes: None,
148        })
149    }
150
151    pub fn set_max_bytes(&mut self, max_bytes: u16) {
152        self.max_bytes = Some(max_bytes);
153    }
154
155    /// Stream the record batches from the VCF file.
156    pub fn into_stream(self) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
157        futures::stream::unfold(self, |mut reader| async move {
158            match reader.read_record_batch().await {
159                Ok(Some(batch)) => Some((Ok(batch), reader)),
160                Ok(None) => None,
161                Err(e) => Some((Err(ArrowError::ExternalError(Box::new(e))), reader)),
162            }
163        })
164    }
165
166    async fn read_record(&mut self, record: &mut RecordBuf) -> std::io::Result<Option<()>> {
167        if let Some(max_bytes) = self.max_bytes {
168            if self.reader.get_ref().virtual_position().uncompressed() >= max_bytes {
169                return Ok(None);
170            }
171        }
172
173        let bytes_read = self.reader.read_record_buf(&self.header, record).await?;
174
175        if bytes_read == 0 {
176            Ok(None)
177        } else {
178            Ok(Some(()))
179        }
180    }
181
182    async fn read_record_batch(&mut self) -> ArrowResult<Option<arrow::record_batch::RecordBatch>> {
183        let mut builder = BAMArrayBuilder::create(self.header.clone(), self.config.clone());
184        let mut record = RecordBuf::default();
185
186        for i in 0..self.config.batch_size {
187            if self.read_record(&mut record).await?.is_some() {
188                let semi_lazy_record = SemiLazyRecord::try_from(record.clone())?;
189
190                if semi_lazy_record.intersects(self.region_reference, &self.region_interval)? {
191                    builder.append(&semi_lazy_record)?;
192                }
193            } else if i == 0 {
194                return Ok(None);
195            } else {
196                break;
197            }
198        }
199
200        let schema = self.config.projected_schema()?;
201        let batch = builder.try_into_record_batch(schema)?;
202
203        Ok(Some(batch))
204    }
205}