exon_bed/
batch_reader.rs

1// Copyright 2024 WHERE TRUE Technologies.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{str::FromStr, sync::Arc};
16
17use arrow::{error::ArrowError, record_batch::RecordBatch};
18
19use exon_common::ExonArrayBuilder;
20use futures::Stream;
21use noodles::{
22    bed::feature::{record::Strand, record_buf::OtherFields, RecordBuf},
23    core::Position,
24};
25use tokio::io::{AsyncBufRead, AsyncBufReadExt};
26
27use super::{array_builder::BEDArrayBuilder, bed_record_builder::BEDRecord, config::BEDConfig};
28
29/// A batch reader for BED files.
30pub struct BatchReader<R> {
31    /// The underlying BED reader.
32    reader: R,
33
34    /// The BED configuration.
35    config: Arc<BEDConfig>,
36}
37
38impl<R> BatchReader<R>
39where
40    R: AsyncBufRead + Unpin + Send,
41{
42    pub fn new(inner: R, config: Arc<BEDConfig>) -> Self {
43        Self {
44            reader: inner,
45            config,
46        }
47    }
48
49    pub fn into_stream(self) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
50        futures::stream::unfold(self, |mut reader| async move {
51            match reader.read_batch().await {
52                Ok(Some(batch)) => Some((Ok(batch), reader)),
53                Ok(None) => None,
54                Err(e) => Some((Err(ArrowError::ExternalError(Box::new(e))), reader)),
55            }
56        })
57    }
58
59    async fn read_batch(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
60        let mut array_builder = BEDArrayBuilder::create(
61            self.config.file_schema.clone(),
62            Some(self.config.projection()),
63        );
64
65        for _ in 0..self.config.batch_size {
66            match self.read_record().await? {
67                Some(record) => array_builder.append(record).map_err(|e| {
68                    std::io::Error::new(
69                        std::io::ErrorKind::InvalidData,
70                        format!("invalid record: {e}"),
71                    )
72                })?,
73                None => break,
74            }
75        }
76
77        if array_builder.is_empty() {
78            return Ok(None);
79        }
80
81        let schema = self.config.projected_schema().map_err(|e| {
82            std::io::Error::new(
83                std::io::ErrorKind::InvalidData,
84                format!("invalid schema: {e}"),
85            )
86        })?;
87        let batch = array_builder.try_into_record_batch(schema)?;
88
89        Ok(Some(batch))
90    }
91
92    pub async fn read_record(&mut self) -> std::io::Result<Option<BEDRecord>> {
93        let mut buf = String::new();
94        if self.reader.read_line(&mut buf).await? == 0 {
95            return Ok(None);
96        }
97
98        // Skip commented lines
99        while buf.starts_with('#') {
100            buf.clear();
101            if self.reader.read_line(&mut buf).await? == 0 {
102                return Ok(None);
103            }
104        }
105
106        // Remove the newline
107        buf.pop();
108
109        // Remove the carriage return if present and on windows
110        #[cfg(target_os = "windows")]
111        if buf.ends_with('\r') {
112            buf.pop();
113        }
114
115        // Get the number of tab separated fields
116        let split = buf.split('\t').collect::<Vec<&str>>();
117        let num_fields = split.len();
118
119        let bed_record = match num_fields {
120            12 => {
121                let buf_builder = RecordBuf::<6>::builder();
122
123                let other_fields = OtherFields::default();
124
125                let mut record = buf_builder
126                    .set_reference_sequence_name(split[0].as_bytes().to_vec())
127                    .set_feature_start(Position::from_str(split[1]).unwrap())
128                    .set_feature_end(Position::from_str(split[2]).unwrap())
129                    .set_name(split[3].as_bytes().to_vec())
130                    .set_score(split[4].parse().unwrap());
131
132                match split[5] {
133                    "+" => {
134                        record = record.set_strand(Strand::Forward);
135                    }
136                    "-" => {
137                        record = record.set_strand(Strand::Reverse);
138                    }
139                    "." => {}
140                    _ => {
141                        return Err(std::io::Error::new(
142                            std::io::ErrorKind::InvalidData,
143                            format!("invalid strand: {}", split[5]),
144                        ))
145                    }
146                };
147
148                let record = record.set_other_fields(other_fields).build();
149
150                BEDRecord::from(record)
151            }
152            // 9 => extract_record!(buf, 9),
153            // 8 => extract_record!(buf, 8),
154            // 7 => extract_record!(buf, 7),
155            6 => {
156                // let record = read_record_6(buf.as_bytes())?;
157                let mut buf_builder = RecordBuf::<6>::builder()
158                    .set_reference_sequence_name(split[0].as_bytes().to_vec())
159                    .set_feature_start(Position::from_str(split[1]).unwrap())
160                    .set_feature_end(Position::from_str(split[2]).unwrap())
161                    .set_name(split[3].as_bytes().to_vec())
162                    .set_score(split[4].parse().unwrap());
163
164                match split[5] {
165                    "+" => {
166                        buf_builder = buf_builder.set_strand(Strand::Forward);
167                    }
168                    "-" => {
169                        buf_builder = buf_builder.set_strand(Strand::Reverse);
170                    }
171                    "." => {}
172                    _ => {
173                        return Err(std::io::Error::new(
174                            std::io::ErrorKind::InvalidData,
175                            format!("invalid strand: {}", split[5]),
176                        ))
177                    }
178                };
179
180                let record = buf_builder.build();
181
182                BEDRecord::from(record)
183            }
184            5 => {
185                let buf_builder = RecordBuf::<5>::builder();
186
187                let record = buf_builder
188                    .set_reference_sequence_name(split[0].as_bytes().to_vec())
189                    .set_feature_start(Position::from_str(split[1]).unwrap())
190                    .set_feature_end(Position::from_str(split[2]).unwrap())
191                    .set_name(split[3].as_bytes().to_vec())
192                    .set_score(split[4].parse().unwrap())
193                    .build();
194
195                BEDRecord::from(record)
196            }
197            4 => {
198                let buf_builder = RecordBuf::<4>::builder();
199
200                let record = buf_builder
201                    .set_reference_sequence_name(split[0].as_bytes().to_vec())
202                    .set_feature_start(Position::from_str(split[1]).unwrap())
203                    .set_feature_end(Position::from_str(split[2]).unwrap())
204                    .set_name(split[3].as_bytes().to_vec())
205                    .build();
206
207                BEDRecord::from(record)
208            }
209            3 => {
210                let buf_builder = RecordBuf::<3>::builder();
211
212                let record = buf_builder
213                    .set_reference_sequence_name(split[0].as_bytes().to_vec())
214                    .set_feature_start(Position::from_str(split[1]).unwrap())
215                    .set_feature_end(Position::from_str(split[2]).unwrap())
216                    .build();
217
218                BEDRecord::from(record)
219            }
220            _ => {
221                return Err(std::io::Error::new(
222                    std::io::ErrorKind::InvalidData,
223                    format!("invalid number of fields: {num_fields}"),
224                ));
225            }
226        };
227
228        buf.clear();
229
230        Ok(Some(bed_record))
231    }
232}