exon_gtf/
batch_reader.rs

1// Copyright 2023 WHERE TRUE Technologies.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{str::FromStr, sync::Arc};
16
17use arrow::{error::ArrowError, error::Result as ArrowResult, record_batch::RecordBatch};
18
19use futures::Stream;
20use tokio::io::{AsyncBufRead, AsyncBufReadExt};
21
22use super::{array_builder::GTFArrayBuilder, GTFConfig};
23
24/// Reads a GTF file into arrow record batches.
25pub struct BatchReader<R> {
26    /// The reader to read from.
27    reader: R,
28
29    /// The configuration for this reader.
30    config: Arc<GTFConfig>,
31}
32
33impl<R> BatchReader<R>
34where
35    R: AsyncBufRead + Unpin + Send,
36{
37    pub fn new(reader: R, config: Arc<GTFConfig>) -> Self {
38        Self { reader, config }
39    }
40
41    pub fn into_stream(self) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
42        futures::stream::unfold(self, |mut reader| async move {
43            match reader.read_batch().await {
44                Ok(Some(batch)) => Some((Ok(batch), reader)),
45                Ok(None) => None,
46                Err(e) => Some((Err(ArrowError::ExternalError(Box::new(e))), reader)),
47            }
48        })
49    }
50
51    async fn read_line(&mut self) -> std::io::Result<Option<noodles::gtf::Line>> {
52        let mut buf = String::new();
53        match self.reader.read_line(&mut buf).await {
54            Ok(0) => Ok(None),
55            Ok(_) => {
56                // remove the new line character, needs to work for both windows and unix
57                buf.pop();
58
59                // remove the carriage return character if it exists on windows
60                #[cfg(target_os = "windows")]
61                if buf.ends_with('\r') {
62                    buf.pop();
63                }
64
65                let line = match noodles::gtf::Line::from_str(&buf) {
66                    Ok(line) => line,
67                    Err(e) => {
68                        return Err(std::io::Error::new(
69                            std::io::ErrorKind::InvalidData,
70                            format!("invalid line: {buf} error: {e}"),
71                        ));
72                    }
73                };
74                buf.clear();
75                Ok(Some(line))
76            }
77            Err(e) => Err(e),
78        }
79    }
80
81    async fn read_batch(&mut self) -> ArrowResult<Option<RecordBatch>> {
82        let mut gtf_array_builder = GTFArrayBuilder::new();
83
84        for _ in 0..self.config.batch_size {
85            match self.read_line().await? {
86                None => break,
87                Some(line) => match line {
88                    noodles::gtf::Line::Comment(_) => {}
89                    noodles::gtf::Line::Record(record) => {
90                        gtf_array_builder.append(&record)?;
91                    }
92                },
93            }
94        }
95
96        if gtf_array_builder.is_empty() {
97            return Ok(None);
98        }
99
100        let batch =
101            RecordBatch::try_new(self.config.file_schema.clone(), gtf_array_builder.finish())?;
102
103        match &self.config.projection {
104            Some(projection) => Ok(Some(batch.project(projection)?)),
105            None => Ok(Some(batch)),
106        }
107    }
108}