exon_bam/
indexed_async_batch_stream.rs1use std::sync::Arc;
16
17use arrow::{
18 error::{ArrowError, Result as ArrowResult},
19 record_batch::RecordBatch,
20};
21use exon_common::ExonArrayBuilder;
22use futures::Stream;
23use noodles::{
24 core::{region::Interval, Position, Region},
25 sam::{alignment::RecordBuf, header::ReferenceSequences, Header},
26};
27use tokio::io::AsyncBufRead;
28
29use crate::ExonBAMError;
30
31use super::{array_builder::BAMArrayBuilder, BAMConfig};
32
33pub(crate) struct SemiLazyRecord {
36 inner: RecordBuf,
37 alignment_end: Option<Position>,
38}
39
40impl TryFrom<RecordBuf> for SemiLazyRecord {
41 type Error = ExonBAMError;
42
43 fn try_from(record: RecordBuf) -> Result<Self, Self::Error> {
44 let alignment_end = record.alignment_end();
45
46 Ok(Self {
47 inner: record,
48 alignment_end,
49 })
50 }
51}
52
53impl SemiLazyRecord {
54 pub fn alignment_start(&self) -> Option<Position> {
55 self.inner.alignment_start()
56 }
57
58 pub fn alignment_end(&self) -> Option<Position> {
59 self.alignment_end
60 }
61
62 pub fn record(&self) -> &RecordBuf {
63 &self.inner
64 }
65
66 pub fn intersects(
67 &self,
68 region_sequence_id: usize,
69 region_interval: &Interval,
70 ) -> std::io::Result<bool> {
71 let reference_sequence_id = self.inner.reference_sequence_id();
72
73 let alignment_start = self.alignment_start();
74 let alignment_end = self.alignment_end();
75
76 match (reference_sequence_id, alignment_start, alignment_end) {
77 (Some(id), Some(start), Some(end)) => {
78 let alignment_interval = (start..=end).into();
79 let intersects = region_interval.intersects(alignment_interval);
80
81 let same_sequence = id == region_sequence_id;
82
83 Ok(intersects && same_sequence)
84 }
85 _ => Ok(false),
86 }
87 }
88}
89
90pub struct IndexedAsyncBatchStream<R>
91where
92 R: AsyncBufRead + Unpin,
93{
94 reader: noodles::bam::AsyncReader<noodles::bgzf::AsyncReader<R>>,
96
97 config: Arc<BAMConfig>,
99
100 header: Arc<Header>,
102
103 region_reference: usize,
105
106 region_interval: Interval,
108
109 max_bytes: Option<u16>,
111}
112
113fn get_reference_sequence_for_region(
114 reference_sequences: &ReferenceSequences,
115 region: &Region,
116) -> std::io::Result<usize> {
117 reference_sequences
118 .get_index_of(region.name())
119 .ok_or_else(|| {
120 std::io::Error::new(
121 std::io::ErrorKind::InvalidData,
122 "invalid reference sequence", )
124 })
125}
126
127impl<R> IndexedAsyncBatchStream<R>
128where
129 R: AsyncBufRead + Unpin,
130{
131 pub fn try_new(
132 reader: noodles::bam::AsyncReader<noodles::bgzf::AsyncReader<R>>,
133 config: Arc<BAMConfig>,
134 header: Arc<Header>,
135 region: Arc<Region>,
136 ) -> std::io::Result<Self> {
137 let region_reference =
138 get_reference_sequence_for_region(header.reference_sequences(), ®ion)?;
139 let region_interval = region.interval();
140
141 Ok(Self {
142 reader,
143 config,
144 header,
145 region_reference,
146 region_interval,
147 max_bytes: None,
148 })
149 }
150
151 pub fn set_max_bytes(&mut self, max_bytes: u16) {
152 self.max_bytes = Some(max_bytes);
153 }
154
155 pub fn into_stream(self) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
157 futures::stream::unfold(self, |mut reader| async move {
158 match reader.read_record_batch().await {
159 Ok(Some(batch)) => Some((Ok(batch), reader)),
160 Ok(None) => None,
161 Err(e) => Some((Err(ArrowError::ExternalError(Box::new(e))), reader)),
162 }
163 })
164 }
165
166 async fn read_record(&mut self, record: &mut RecordBuf) -> std::io::Result<Option<()>> {
167 if let Some(max_bytes) = self.max_bytes {
168 if self.reader.get_ref().virtual_position().uncompressed() >= max_bytes {
169 return Ok(None);
170 }
171 }
172
173 let bytes_read = self.reader.read_record_buf(&self.header, record).await?;
174
175 if bytes_read == 0 {
176 Ok(None)
177 } else {
178 Ok(Some(()))
179 }
180 }
181
182 async fn read_record_batch(&mut self) -> ArrowResult<Option<arrow::record_batch::RecordBatch>> {
183 let mut builder = BAMArrayBuilder::create(self.header.clone(), self.config.clone());
184 let mut record = RecordBuf::default();
185
186 for i in 0..self.config.batch_size {
187 if self.read_record(&mut record).await?.is_some() {
188 let semi_lazy_record = SemiLazyRecord::try_from(record.clone())?;
189
190 if semi_lazy_record.intersects(self.region_reference, &self.region_interval)? {
191 builder.append(&semi_lazy_record)?;
192 }
193 } else if i == 0 {
194 return Ok(None);
195 } else {
196 break;
197 }
198 }
199
200 let schema = self.config.projected_schema()?;
201 let batch = builder.try_into_record_batch(schema)?;
202
203 Ok(Some(batch))
204 }
205}