1use std::collections::HashMap;
16use std::sync::Arc;
17
18use arrow::datatypes::{DataType, Field, Fields, Schema};
19use arrow::error::{ArrowError, Result};
20use exon_common::TableSchema;
21use noodles::sam::alignment::record_buf::data::field::{value::Array, Value};
22use noodles::sam::alignment::record_buf::Data;
23
24macro_rules! arrow_error {
25 ($tag:expr, $field_type:expr, $expected_type:expr) => {
26 Err(arrow::error::ArrowError::InvalidArgumentError(
27 format!(
28 "tag {} has conflicting types: {:?} and {:?}",
29 $tag, $field_type, $expected_type,
30 )
31 .into(),
32 ))
33 };
34}
35
36pub struct SAMSchemaBuilder {
38 file_fields: Vec<Field>,
39 partition_fields: Vec<Field>,
40 tags_data_type: Option<DataType>,
41}
42
43impl SAMSchemaBuilder {
44 pub fn new(file_fields: Vec<Field>, partition_fields: Vec<Field>) -> Self {
46 Self {
47 file_fields,
48 partition_fields,
49 tags_data_type: None,
50 }
51 }
52
53 pub fn with_partition_fields(self, partition_fields: Vec<Field>) -> Self {
55 Self {
56 partition_fields,
57 ..self
58 }
59 }
60
61 pub fn with_tags_data_type(self, tags_data_type: DataType) -> Self {
63 Self {
64 tags_data_type: Some(tags_data_type),
65 ..self
66 }
67 }
68
69 pub fn with_tags_data_type_from_data(self, data: &Data) -> Result<Self> {
71 let mut fields = HashMap::new();
72
73 for (tag, value) in data.iter() {
74 let tag_name = std::str::from_utf8(tag.as_ref())?;
75
76 match value {
77 Value::Character(_) | Value::String(_) | Value::Hex(_) => {
78 let field = fields.entry(tag).or_insert_with(|| {
79 Field::new(tag_name, arrow::datatypes::DataType::Utf8, true)
80 });
81 if field.data_type() != &arrow::datatypes::DataType::Utf8 {
82 return arrow_error!(
83 tag_name,
84 field.data_type(),
85 arrow::datatypes::DataType::Utf8
86 );
87 }
88 }
89 Value::Int8(_) => {
90 let field = fields.entry(tag).or_insert_with(|| {
91 Field::new(tag_name, arrow::datatypes::DataType::Int8, true)
92 });
93 if field.data_type() != &arrow::datatypes::DataType::Int8 {
94 return arrow_error!(
95 tag_name,
96 field.data_type(),
97 arrow::datatypes::DataType::Int8
98 );
99 }
100 }
101 Value::Int16(_) => {
102 let field = fields.entry(tag).or_insert_with(|| {
103 Field::new(tag_name, arrow::datatypes::DataType::Int16, true)
104 });
105 if field.data_type() != &arrow::datatypes::DataType::Int16 {
106 return arrow_error!(
107 tag_name,
108 field.data_type(),
109 arrow::datatypes::DataType::Int16
110 );
111 }
112 }
113 Value::Int32(_) => {
114 let field = fields.entry(tag).or_insert_with(|| {
115 Field::new(tag_name, arrow::datatypes::DataType::Int32, true)
116 });
117 if field.data_type() != &arrow::datatypes::DataType::Int32 {
118 return arrow_error!(
119 tag_name,
120 field.data_type(),
121 arrow::datatypes::DataType::Int32
122 );
123 }
124 }
125 Value::UInt8(_) => {
126 let field = fields.entry(tag).or_insert_with(|| {
127 Field::new(tag_name, arrow::datatypes::DataType::UInt8, true)
128 });
129 if field.data_type() != &arrow::datatypes::DataType::UInt8 {
130 return arrow_error!(
131 tag_name,
132 field.data_type(),
133 arrow::datatypes::DataType::UInt8
134 );
135 }
136 }
137 Value::UInt16(_) => {
138 let field = fields.entry(tag).or_insert_with(|| {
139 Field::new(tag_name, arrow::datatypes::DataType::UInt16, true)
140 });
141 if field.data_type() != &arrow::datatypes::DataType::UInt16 {
142 return arrow_error!(
143 tag_name,
144 field.data_type(),
145 arrow::datatypes::DataType::UInt16
146 );
147 }
148 }
149 Value::UInt32(_) => {
150 let field = fields.entry(tag).or_insert_with(|| {
151 Field::new(tag_name, arrow::datatypes::DataType::UInt32, true)
152 });
153 if field.data_type() != &arrow::datatypes::DataType::UInt32 {
154 return arrow_error!(
155 tag_name,
156 field.data_type(),
157 arrow::datatypes::DataType::UInt16
158 );
159 }
160 }
161 Value::Float(_) => {
162 let field = fields.entry(tag).or_insert_with(|| {
163 Field::new(tag_name, arrow::datatypes::DataType::Float32, true)
164 });
165
166 if field.data_type() != &arrow::datatypes::DataType::Float32 {
167 return arrow_error!(
168 tag_name,
169 field.data_type(),
170 arrow::datatypes::DataType::Float32
171 );
172 }
173 }
174 Value::Array(array) => {
175 match array {
176 Array::Int32(_) => {
177 let field = fields.entry(tag).or_insert_with(|| {
178 Field::new(
179 tag_name,
180 arrow::datatypes::DataType::List(Arc::new(Field::new(
181 "item",
182 arrow::datatypes::DataType::Int32,
183 true,
184 ))),
185 false,
186 )
187 });
188
189 let expected_type = arrow::datatypes::DataType::List(Arc::new(
190 Field::new("item", arrow::datatypes::DataType::Int32, true),
191 ));
192
193 if field.data_type() != &expected_type {
194 return arrow_error!(tag_name, field.data_type(), expected_type);
195 }
196 }
197 Array::Int16(_) => {
198 let field = fields.entry(tag).or_insert_with(|| {
199 Field::new(
200 tag_name,
201 arrow::datatypes::DataType::List(Arc::new(Field::new(
202 "item",
203 arrow::datatypes::DataType::Int16,
204 true,
205 ))),
206 true,
207 )
208 });
209
210 let expected_type = arrow::datatypes::DataType::List(Arc::new(
211 Field::new("item", arrow::datatypes::DataType::Int16, true),
212 ));
213
214 if field.data_type() != &expected_type {
215 return arrow_error!(tag_name, field.data_type(), expected_type);
216 }
217 }
218 Array::Int8(_) => {
219 let field = fields.entry(tag).or_insert_with(|| {
220 Field::new(
221 tag_name,
222 arrow::datatypes::DataType::List(Arc::new(Field::new(
223 "item",
224 arrow::datatypes::DataType::Int8,
225 true,
226 ))),
227 true,
228 )
229 });
230
231 let expected_type = arrow::datatypes::DataType::List(Arc::new(
232 Field::new("item", arrow::datatypes::DataType::Int8, true),
233 ));
234
235 if field.data_type() != &expected_type {
236 return arrow_error!(tag_name, field.data_type(), expected_type);
237 }
238 }
239 Array::UInt8(_) => {
240 let field = fields.entry(tag).or_insert_with(|| {
241 Field::new(
242 tag_name,
243 arrow::datatypes::DataType::List(Arc::new(Field::new(
244 "item",
245 arrow::datatypes::DataType::UInt8,
246 true,
247 ))),
248 true,
249 )
250 });
251
252 let expected_type = arrow::datatypes::DataType::List(Arc::new(
253 Field::new("item", arrow::datatypes::DataType::UInt8, true),
254 ));
255
256 if field.data_type() != &expected_type {
257 return arrow_error!(tag_name, field.data_type(), expected_type);
258 }
259 }
260 Array::UInt16(_) => {
261 let field = fields.entry(tag).or_insert_with(|| {
262 Field::new(
263 tag_name,
264 arrow::datatypes::DataType::List(Arc::new(Field::new(
265 "item",
266 arrow::datatypes::DataType::UInt16,
267 true,
268 ))),
269 true,
270 )
271 });
272
273 let expected_type = arrow::datatypes::DataType::List(Arc::new(
274 Field::new("item", arrow::datatypes::DataType::UInt16, true),
275 ));
276
277 if field.data_type() != &expected_type {
278 return arrow_error!(tag_name, field.data_type(), expected_type);
279 }
280 }
281 Array::UInt32(_) => {
282 let field = fields.entry(tag).or_insert_with(|| {
283 Field::new(
284 tag_name,
285 arrow::datatypes::DataType::List(Arc::new(Field::new(
286 "item",
287 arrow::datatypes::DataType::UInt32,
288 true,
289 ))),
290 true,
291 )
292 });
293
294 let expected_type = arrow::datatypes::DataType::List(Arc::new(
295 Field::new("item", arrow::datatypes::DataType::UInt32, true),
296 ));
297
298 if field.data_type() != &expected_type {
299 return arrow_error!(tag_name, field.data_type(), expected_type);
300 }
301 }
302 Array::Float(_) => {
303 let field = fields.entry(tag).or_insert_with(|| {
304 Field::new(
305 tag_name,
306 arrow::datatypes::DataType::List(Arc::new(Field::new(
307 "item",
308 arrow::datatypes::DataType::Float32,
309 true,
310 ))),
311 true,
312 )
313 });
314
315 let expected_type = arrow::datatypes::DataType::List(Arc::new(
316 Field::new("item", arrow::datatypes::DataType::Float32, true),
317 ));
318
319 if field.data_type() != &expected_type {
320 return arrow_error!(tag_name, field.data_type(), expected_type);
321 }
322 }
323 }
324 }
325 }
326 }
327
328 if fields.is_empty() {
330 return Err(ArrowError::InvalidArgumentError(
331 "No fields found in the data".into(),
332 ));
333 }
334
335 let data_type = DataType::Struct(Fields::from(
336 fields
337 .into_iter()
338 .map(|(tag, field)| {
339 let data_type = field.data_type().clone();
340
341 let tag_name = std::str::from_utf8(tag.as_ref()).unwrap();
343
344 Field::new(tag_name, data_type, true)
345 })
346 .collect::<Vec<_>>(),
347 ));
348
349 Ok(self.with_tags_data_type(data_type))
350 }
351
352 pub fn build(self) -> TableSchema {
354 let mut fields = self.file_fields;
355
356 if let Some(tags_data_type) = self.tags_data_type.clone() {
357 let tags_field = Field::new("tags", tags_data_type, true);
358 fields.push(tags_field);
359 }
360
361 let file_projection = (0..fields.len()).collect::<Vec<_>>();
362
363 fields.extend_from_slice(&self.partition_fields);
364
365 let file_schema = Schema::new(fields);
366
367 TableSchema::new(Arc::new(file_schema), file_projection)
368 }
369}
370
371impl Default for SAMSchemaBuilder {
372 fn default() -> Self {
373 let tags_data_type = DataType::List(Arc::new(Field::new(
374 "item",
375 DataType::Struct(Fields::from(vec![
376 Field::new("tag", DataType::Utf8, false),
377 Field::new("value", DataType::Utf8, true),
378 ])),
379 true,
380 )));
381
382 let quality_score_list =
383 DataType::List(Arc::new(Field::new("item", DataType::Int64, true)));
384
385 Self::new(
386 vec![
387 Field::new("name", DataType::Utf8, false),
388 Field::new("flag", DataType::Int32, false),
389 Field::new("reference", DataType::Utf8, true),
390 Field::new("start", DataType::Int64, true),
391 Field::new("end", DataType::Int64, true),
392 Field::new("mapping_quality", DataType::Utf8, true),
393 Field::new("cigar", DataType::Utf8, false),
394 Field::new("mate_reference", DataType::Utf8, true),
395 Field::new("sequence", DataType::Utf8, false),
396 Field::new("quality_score", quality_score_list, false),
397 ],
398 vec![],
399 )
400 .with_tags_data_type(tags_data_type)
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use noodles::sam::alignment::record::data::field::Tag;
407
408 use super::*;
409
410 #[test]
411 fn test_build() -> Result<()> {
412 let schema = SAMSchemaBuilder::default().build();
413
414 assert_eq!(schema.fields().len(), 11);
415
416 Ok(())
417 }
418
419 #[test]
420 fn test_build_from_empty_data_errors() -> Result<()> {
421 let data = Data::default();
422
423 let schema = SAMSchemaBuilder::default().with_tags_data_type_from_data(&data);
424 assert!(schema.is_err());
425
426 Ok(())
427 }
428
429 #[test]
430 fn test_parsing_data() -> Result<()> {
431 let mut data = Data::default();
432
433 data.insert(Tag::ALIGNMENT_HIT_COUNT, Value::from(1));
434 data.insert(Tag::CELL_BARCODE_ID, Value::from("AA"));
435 data.insert(Tag::ORIGINAL_UMI_QUALITY_SCORES, Value::from(vec![1, 2, 3]));
436
437 let schema = SAMSchemaBuilder::default().with_tags_data_type_from_data(&data)?;
438
439 let expected_fields = vec![
440 Field::new(
441 "BZ",
442 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
443 false,
444 ),
445 Field::new("CB", DataType::Utf8, false),
446 Field::new("NH", DataType::UInt8, false),
447 ];
448
449 let tags_type = schema
450 .tags_data_type
451 .ok_or(ArrowError::InvalidArgumentError(
452 "tags_data_type is None".into(),
453 ))?;
454
455 match tags_type {
456 DataType::Struct(fields) => {
457 let mut fields = fields.iter().collect::<Vec<_>>();
458 fields.sort_by(|a, b| a.name().cmp(b.name()));
459
460 assert_eq!(fields.len(), 3);
461
462 fields
463 .iter()
464 .zip(expected_fields.iter())
465 .for_each(|(actual, expected)| {
466 assert_eq!(actual.name(), expected.name());
467 assert_eq!(actual.data_type(), expected.data_type());
468 });
469 }
470 _ => {
471 return Err(ArrowError::InvalidArgumentError(
472 "tags_data_type is not a struct".into(),
473 ))
474 }
475 };
476
477 Ok(())
478 }
479}