exon_cram/
indexed_async_batch_stream.rs

1// Copyright 2024 WHERE TRUE Technologies.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow::{
18    array::RecordBatch,
19    error::{ArrowError, Result as ArrowResult},
20};
21use coitrees::{BasicCOITree, Interval, IntervalTree};
22use exon_common::{ExonArrayBuilder, DEFAULT_BATCH_SIZE};
23use futures::Stream;
24use noodles::cram::{
25    crai::{self, Record},
26    AsyncReader,
27};
28use tokio::io::{AsyncBufRead, AsyncSeek};
29
30use crate::{array_builder::CRAMArrayBuilder, CRAMConfig, ObjectStoreFastaRepositoryAdapter};
31
32pub struct IndexedAsyncBatchStream<R>
33where
34    R: AsyncBufRead + AsyncSeek + Unpin,
35{
36    /// The underlying stream of CRAM records.
37    reader: AsyncReader<R>,
38
39    /// The header.
40    header: noodles::sam::Header,
41
42    /// The CRAM config.
43    config: Arc<CRAMConfig>,
44
45    /// The reference repository.
46    reference_sequence_repository: noodles::fasta::Repository,
47
48    /// The CRAM index record.
49    // index_records: Vec<Record>,
50    ranges: BasicCOITree<crai::Record, u32>,
51}
52
53impl<R> IndexedAsyncBatchStream<R>
54where
55    R: AsyncBufRead + AsyncSeek + Unpin,
56{
57    pub async fn try_new(
58        reader: AsyncReader<R>,
59        header: noodles::sam::Header,
60        config: Arc<CRAMConfig>,
61        index_records: Vec<Record>,
62    ) -> ArrowResult<Self> {
63        let reference_sequence_repository = match &config.fasta_reference {
64            Some(reference) => {
65                let object_store_repo = ObjectStoreFastaRepositoryAdapter::try_new(
66                    config.object_store.clone(),
67                    reference.to_string(),
68                )
69                .await?;
70
71                noodles::fasta::Repository::new(object_store_repo)
72            }
73            None => noodles::fasta::Repository::default(),
74        };
75
76        let ranges = index_records
77            .iter()
78            .map(|r| {
79                let start = r.alignment_start().unwrap().get();
80                let end = start + r.alignment_span();
81
82                Interval::new(start as i32, end as i32, r.clone())
83            })
84            .collect::<Vec<_>>();
85
86        let trees = BasicCOITree::new(&ranges);
87
88        Ok(Self {
89            reader,
90            header,
91            config,
92            reference_sequence_repository,
93            ranges: trees,
94        })
95    }
96
97    async fn read_batch(&mut self) -> ArrowResult<Option<RecordBatch>> {
98        let mut array_builder =
99            CRAMArrayBuilder::new(self.header.clone(), DEFAULT_BATCH_SIZE, &self.config);
100
101        let container = if let Some(container) = self.reader.read_data_container().await? {
102            container
103        } else {
104            return Ok(None);
105        };
106
107        let records = container
108            .slices()
109            .iter()
110            .map(|slice| {
111                let compression_header = container.compression_header();
112
113                slice.records(compression_header).and_then(|mut records| {
114                    slice.resolve_records(
115                        &self.reference_sequence_repository,
116                        &self.header,
117                        compression_header,
118                        &mut records,
119                    )?;
120
121                    Ok(records)
122                })
123            })
124            .collect::<Result<Vec<_>, _>>()?
125            .into_iter()
126            .flatten()
127            .filter(|record| {
128                let start = record.alignment_start().unwrap().get();
129                let end = start + record.alignment_end().unwrap().get();
130
131                self.ranges.query_count(start as i32, end as i32) > 0
132            });
133
134        for record in records {
135            array_builder.append(record)?;
136        }
137
138        let schema = self.config.projected_schema();
139        let batch = array_builder.try_into_record_batch(schema)?;
140
141        Ok(Some(batch))
142    }
143
144    pub fn into_stream(self) -> impl Stream<Item = ArrowResult<RecordBatch>> {
145        futures::stream::unfold(self, |mut reader| async move {
146            match reader.read_batch().await {
147                Ok(Some(batch)) => Some((Ok(batch), reader)),
148                Ok(None) => None,
149                Err(e) => Some((Err(ArrowError::ExternalError(Box::new(e))), reader)),
150            }
151        })
152    }
153}