1use std::{str::FromStr, sync::Arc};
16
17use arrow::{error::ArrowError, record_batch::RecordBatch};
18
19use exon_common::ExonArrayBuilder;
20use futures::Stream;
21use noodles::{
22 bed::feature::{record::Strand, record_buf::OtherFields, RecordBuf},
23 core::Position,
24};
25use tokio::io::{AsyncBufRead, AsyncBufReadExt};
26
27use super::{array_builder::BEDArrayBuilder, bed_record_builder::BEDRecord, config::BEDConfig};
28
29pub struct BatchReader<R> {
31 reader: R,
33
34 config: Arc<BEDConfig>,
36}
37
38impl<R> BatchReader<R>
39where
40 R: AsyncBufRead + Unpin + Send,
41{
42 pub fn new(inner: R, config: Arc<BEDConfig>) -> Self {
43 Self {
44 reader: inner,
45 config,
46 }
47 }
48
49 pub fn into_stream(self) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
50 futures::stream::unfold(self, |mut reader| async move {
51 match reader.read_batch().await {
52 Ok(Some(batch)) => Some((Ok(batch), reader)),
53 Ok(None) => None,
54 Err(e) => Some((Err(ArrowError::ExternalError(Box::new(e))), reader)),
55 }
56 })
57 }
58
59 async fn read_batch(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
60 let mut array_builder = BEDArrayBuilder::create(
61 self.config.file_schema.clone(),
62 Some(self.config.projection()),
63 );
64
65 for _ in 0..self.config.batch_size {
66 match self.read_record().await? {
67 Some(record) => array_builder.append(record).map_err(|e| {
68 std::io::Error::new(
69 std::io::ErrorKind::InvalidData,
70 format!("invalid record: {e}"),
71 )
72 })?,
73 None => break,
74 }
75 }
76
77 if array_builder.is_empty() {
78 return Ok(None);
79 }
80
81 let schema = self.config.projected_schema().map_err(|e| {
82 std::io::Error::new(
83 std::io::ErrorKind::InvalidData,
84 format!("invalid schema: {e}"),
85 )
86 })?;
87 let batch = array_builder.try_into_record_batch(schema)?;
88
89 Ok(Some(batch))
90 }
91
92 pub async fn read_record(&mut self) -> std::io::Result<Option<BEDRecord>> {
93 let mut buf = String::new();
94 if self.reader.read_line(&mut buf).await? == 0 {
95 return Ok(None);
96 }
97
98 while buf.starts_with('#') {
100 buf.clear();
101 if self.reader.read_line(&mut buf).await? == 0 {
102 return Ok(None);
103 }
104 }
105
106 buf.pop();
108
109 #[cfg(target_os = "windows")]
111 if buf.ends_with('\r') {
112 buf.pop();
113 }
114
115 let split = buf.split('\t').collect::<Vec<&str>>();
117 let num_fields = split.len();
118
119 let bed_record = match num_fields {
120 12 => {
121 let buf_builder = RecordBuf::<6>::builder();
122
123 let other_fields = OtherFields::default();
124
125 let mut record = buf_builder
126 .set_reference_sequence_name(split[0].as_bytes().to_vec())
127 .set_feature_start(Position::from_str(split[1]).unwrap())
128 .set_feature_end(Position::from_str(split[2]).unwrap())
129 .set_name(split[3].as_bytes().to_vec())
130 .set_score(split[4].parse().unwrap());
131
132 match split[5] {
133 "+" => {
134 record = record.set_strand(Strand::Forward);
135 }
136 "-" => {
137 record = record.set_strand(Strand::Reverse);
138 }
139 "." => {}
140 _ => {
141 return Err(std::io::Error::new(
142 std::io::ErrorKind::InvalidData,
143 format!("invalid strand: {}", split[5]),
144 ))
145 }
146 };
147
148 let record = record.set_other_fields(other_fields).build();
149
150 BEDRecord::from(record)
151 }
152 6 => {
156 let mut buf_builder = RecordBuf::<6>::builder()
158 .set_reference_sequence_name(split[0].as_bytes().to_vec())
159 .set_feature_start(Position::from_str(split[1]).unwrap())
160 .set_feature_end(Position::from_str(split[2]).unwrap())
161 .set_name(split[3].as_bytes().to_vec())
162 .set_score(split[4].parse().unwrap());
163
164 match split[5] {
165 "+" => {
166 buf_builder = buf_builder.set_strand(Strand::Forward);
167 }
168 "-" => {
169 buf_builder = buf_builder.set_strand(Strand::Reverse);
170 }
171 "." => {}
172 _ => {
173 return Err(std::io::Error::new(
174 std::io::ErrorKind::InvalidData,
175 format!("invalid strand: {}", split[5]),
176 ))
177 }
178 };
179
180 let record = buf_builder.build();
181
182 BEDRecord::from(record)
183 }
184 5 => {
185 let buf_builder = RecordBuf::<5>::builder();
186
187 let record = buf_builder
188 .set_reference_sequence_name(split[0].as_bytes().to_vec())
189 .set_feature_start(Position::from_str(split[1]).unwrap())
190 .set_feature_end(Position::from_str(split[2]).unwrap())
191 .set_name(split[3].as_bytes().to_vec())
192 .set_score(split[4].parse().unwrap())
193 .build();
194
195 BEDRecord::from(record)
196 }
197 4 => {
198 let buf_builder = RecordBuf::<4>::builder();
199
200 let record = buf_builder
201 .set_reference_sequence_name(split[0].as_bytes().to_vec())
202 .set_feature_start(Position::from_str(split[1]).unwrap())
203 .set_feature_end(Position::from_str(split[2]).unwrap())
204 .set_name(split[3].as_bytes().to_vec())
205 .build();
206
207 BEDRecord::from(record)
208 }
209 3 => {
210 let buf_builder = RecordBuf::<3>::builder();
211
212 let record = buf_builder
213 .set_reference_sequence_name(split[0].as_bytes().to_vec())
214 .set_feature_start(Position::from_str(split[1]).unwrap())
215 .set_feature_end(Position::from_str(split[2]).unwrap())
216 .build();
217
218 BEDRecord::from(record)
219 }
220 _ => {
221 return Err(std::io::Error::new(
222 std::io::ErrorKind::InvalidData,
223 format!("invalid number of fields: {num_fields}"),
224 ));
225 }
226 };
227
228 buf.clear();
229
230 Ok(Some(bed_record))
231 }
232}