exon_sam/
schema_builder.rs

1// Copyright 2024 WHERE TRUE Technologies.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use arrow::datatypes::{DataType, Field, Fields, Schema};
19use arrow::error::{ArrowError, Result};
20use exon_common::TableSchema;
21use noodles::sam::alignment::record_buf::data::field::{value::Array, Value};
22use noodles::sam::alignment::record_buf::Data;
23
24macro_rules! arrow_error {
25    ($tag:expr, $field_type:expr, $expected_type:expr) => {
26        Err(arrow::error::ArrowError::InvalidArgumentError(
27            format!(
28                "tag {} has conflicting types: {:?} and {:?}",
29                $tag, $field_type, $expected_type,
30            )
31            .into(),
32        ))
33    };
34}
35
36/// Builds a schema for the BAM file.
37pub struct SAMSchemaBuilder {
38    file_fields: Vec<Field>,
39    partition_fields: Vec<Field>,
40    tags_data_type: Option<DataType>,
41}
42
43impl SAMSchemaBuilder {
44    /// Creates a new SAM schema builder.
45    pub fn new(file_fields: Vec<Field>, partition_fields: Vec<Field>) -> Self {
46        Self {
47            file_fields,
48            partition_fields,
49            tags_data_type: None,
50        }
51    }
52
53    /// Set the partition fields.
54    pub fn with_partition_fields(self, partition_fields: Vec<Field>) -> Self {
55        Self {
56            partition_fields,
57            ..self
58        }
59    }
60
61    /// Sets the data type for the tags field.
62    pub fn with_tags_data_type(self, tags_data_type: DataType) -> Self {
63        Self {
64            tags_data_type: Some(tags_data_type),
65            ..self
66        }
67    }
68
69    /// Sets the data type for the tags field from the data.
70    pub fn with_tags_data_type_from_data(self, data: &Data) -> Result<Self> {
71        let mut fields = HashMap::new();
72
73        for (tag, value) in data.iter() {
74            let tag_name = std::str::from_utf8(tag.as_ref())?;
75
76            match value {
77                Value::Character(_) | Value::String(_) | Value::Hex(_) => {
78                    let field = fields.entry(tag).or_insert_with(|| {
79                        Field::new(tag_name, arrow::datatypes::DataType::Utf8, true)
80                    });
81                    if field.data_type() != &arrow::datatypes::DataType::Utf8 {
82                        return arrow_error!(
83                            tag_name,
84                            field.data_type(),
85                            arrow::datatypes::DataType::Utf8
86                        );
87                    }
88                }
89                Value::Int8(_) => {
90                    let field = fields.entry(tag).or_insert_with(|| {
91                        Field::new(tag_name, arrow::datatypes::DataType::Int8, true)
92                    });
93                    if field.data_type() != &arrow::datatypes::DataType::Int8 {
94                        return arrow_error!(
95                            tag_name,
96                            field.data_type(),
97                            arrow::datatypes::DataType::Int8
98                        );
99                    }
100                }
101                Value::Int16(_) => {
102                    let field = fields.entry(tag).or_insert_with(|| {
103                        Field::new(tag_name, arrow::datatypes::DataType::Int16, true)
104                    });
105                    if field.data_type() != &arrow::datatypes::DataType::Int16 {
106                        return arrow_error!(
107                            tag_name,
108                            field.data_type(),
109                            arrow::datatypes::DataType::Int16
110                        );
111                    }
112                }
113                Value::Int32(_) => {
114                    let field = fields.entry(tag).or_insert_with(|| {
115                        Field::new(tag_name, arrow::datatypes::DataType::Int32, true)
116                    });
117                    if field.data_type() != &arrow::datatypes::DataType::Int32 {
118                        return arrow_error!(
119                            tag_name,
120                            field.data_type(),
121                            arrow::datatypes::DataType::Int32
122                        );
123                    }
124                }
125                Value::UInt8(_) => {
126                    let field = fields.entry(tag).or_insert_with(|| {
127                        Field::new(tag_name, arrow::datatypes::DataType::UInt8, true)
128                    });
129                    if field.data_type() != &arrow::datatypes::DataType::UInt8 {
130                        return arrow_error!(
131                            tag_name,
132                            field.data_type(),
133                            arrow::datatypes::DataType::UInt8
134                        );
135                    }
136                }
137                Value::UInt16(_) => {
138                    let field = fields.entry(tag).or_insert_with(|| {
139                        Field::new(tag_name, arrow::datatypes::DataType::UInt16, true)
140                    });
141                    if field.data_type() != &arrow::datatypes::DataType::UInt16 {
142                        return arrow_error!(
143                            tag_name,
144                            field.data_type(),
145                            arrow::datatypes::DataType::UInt16
146                        );
147                    }
148                }
149                Value::UInt32(_) => {
150                    let field = fields.entry(tag).or_insert_with(|| {
151                        Field::new(tag_name, arrow::datatypes::DataType::UInt32, true)
152                    });
153                    if field.data_type() != &arrow::datatypes::DataType::UInt32 {
154                        return arrow_error!(
155                            tag_name,
156                            field.data_type(),
157                            arrow::datatypes::DataType::UInt16
158                        );
159                    }
160                }
161                Value::Float(_) => {
162                    let field = fields.entry(tag).or_insert_with(|| {
163                        Field::new(tag_name, arrow::datatypes::DataType::Float32, true)
164                    });
165
166                    if field.data_type() != &arrow::datatypes::DataType::Float32 {
167                        return arrow_error!(
168                            tag_name,
169                            field.data_type(),
170                            arrow::datatypes::DataType::Float32
171                        );
172                    }
173                }
174                Value::Array(array) => {
175                    match array {
176                        Array::Int32(_) => {
177                            let field = fields.entry(tag).or_insert_with(|| {
178                                Field::new(
179                                    tag_name,
180                                    arrow::datatypes::DataType::List(Arc::new(Field::new(
181                                        "item",
182                                        arrow::datatypes::DataType::Int32,
183                                        true,
184                                    ))),
185                                    false,
186                                )
187                            });
188
189                            let expected_type = arrow::datatypes::DataType::List(Arc::new(
190                                Field::new("item", arrow::datatypes::DataType::Int32, true),
191                            ));
192
193                            if field.data_type() != &expected_type {
194                                return arrow_error!(tag_name, field.data_type(), expected_type);
195                            }
196                        }
197                        Array::Int16(_) => {
198                            let field = fields.entry(tag).or_insert_with(|| {
199                                Field::new(
200                                    tag_name,
201                                    arrow::datatypes::DataType::List(Arc::new(Field::new(
202                                        "item",
203                                        arrow::datatypes::DataType::Int16,
204                                        true,
205                                    ))),
206                                    true,
207                                )
208                            });
209
210                            let expected_type = arrow::datatypes::DataType::List(Arc::new(
211                                Field::new("item", arrow::datatypes::DataType::Int16, true),
212                            ));
213
214                            if field.data_type() != &expected_type {
215                                return arrow_error!(tag_name, field.data_type(), expected_type);
216                            }
217                        }
218                        Array::Int8(_) => {
219                            let field = fields.entry(tag).or_insert_with(|| {
220                                Field::new(
221                                    tag_name,
222                                    arrow::datatypes::DataType::List(Arc::new(Field::new(
223                                        "item",
224                                        arrow::datatypes::DataType::Int8,
225                                        true,
226                                    ))),
227                                    true,
228                                )
229                            });
230
231                            let expected_type = arrow::datatypes::DataType::List(Arc::new(
232                                Field::new("item", arrow::datatypes::DataType::Int8, true),
233                            ));
234
235                            if field.data_type() != &expected_type {
236                                return arrow_error!(tag_name, field.data_type(), expected_type);
237                            }
238                        }
239                        Array::UInt8(_) => {
240                            let field = fields.entry(tag).or_insert_with(|| {
241                                Field::new(
242                                    tag_name,
243                                    arrow::datatypes::DataType::List(Arc::new(Field::new(
244                                        "item",
245                                        arrow::datatypes::DataType::UInt8,
246                                        true,
247                                    ))),
248                                    true,
249                                )
250                            });
251
252                            let expected_type = arrow::datatypes::DataType::List(Arc::new(
253                                Field::new("item", arrow::datatypes::DataType::UInt8, true),
254                            ));
255
256                            if field.data_type() != &expected_type {
257                                return arrow_error!(tag_name, field.data_type(), expected_type);
258                            }
259                        }
260                        Array::UInt16(_) => {
261                            let field = fields.entry(tag).or_insert_with(|| {
262                                Field::new(
263                                    tag_name,
264                                    arrow::datatypes::DataType::List(Arc::new(Field::new(
265                                        "item",
266                                        arrow::datatypes::DataType::UInt16,
267                                        true,
268                                    ))),
269                                    true,
270                                )
271                            });
272
273                            let expected_type = arrow::datatypes::DataType::List(Arc::new(
274                                Field::new("item", arrow::datatypes::DataType::UInt16, true),
275                            ));
276
277                            if field.data_type() != &expected_type {
278                                return arrow_error!(tag_name, field.data_type(), expected_type);
279                            }
280                        }
281                        Array::UInt32(_) => {
282                            let field = fields.entry(tag).or_insert_with(|| {
283                                Field::new(
284                                    tag_name,
285                                    arrow::datatypes::DataType::List(Arc::new(Field::new(
286                                        "item",
287                                        arrow::datatypes::DataType::UInt32,
288                                        true,
289                                    ))),
290                                    true,
291                                )
292                            });
293
294                            let expected_type = arrow::datatypes::DataType::List(Arc::new(
295                                Field::new("item", arrow::datatypes::DataType::UInt32, true),
296                            ));
297
298                            if field.data_type() != &expected_type {
299                                return arrow_error!(tag_name, field.data_type(), expected_type);
300                            }
301                        }
302                        Array::Float(_) => {
303                            let field = fields.entry(tag).or_insert_with(|| {
304                                Field::new(
305                                    tag_name,
306                                    arrow::datatypes::DataType::List(Arc::new(Field::new(
307                                        "item",
308                                        arrow::datatypes::DataType::Float32,
309                                        true,
310                                    ))),
311                                    true,
312                                )
313                            });
314
315                            let expected_type = arrow::datatypes::DataType::List(Arc::new(
316                                Field::new("item", arrow::datatypes::DataType::Float32, true),
317                            ));
318
319                            if field.data_type() != &expected_type {
320                                return arrow_error!(tag_name, field.data_type(), expected_type);
321                            }
322                        }
323                    }
324                }
325            }
326        }
327
328        // Make sure there are fields, otherwise return an error
329        if fields.is_empty() {
330            return Err(ArrowError::InvalidArgumentError(
331                "No fields found in the data".into(),
332            ));
333        }
334
335        let data_type = DataType::Struct(Fields::from(
336            fields
337                .into_iter()
338                .map(|(tag, field)| {
339                    let data_type = field.data_type().clone();
340
341                    // TODO: remove unwrap
342                    let tag_name = std::str::from_utf8(tag.as_ref()).unwrap();
343
344                    Field::new(tag_name, data_type, true)
345                })
346                .collect::<Vec<_>>(),
347        ));
348
349        Ok(self.with_tags_data_type(data_type))
350    }
351
352    /// Builds a schema for the BAM file.
353    pub fn build(self) -> TableSchema {
354        let mut fields = self.file_fields;
355
356        if let Some(tags_data_type) = self.tags_data_type.clone() {
357            let tags_field = Field::new("tags", tags_data_type, true);
358            fields.push(tags_field);
359        }
360
361        let file_projection = (0..fields.len()).collect::<Vec<_>>();
362
363        fields.extend_from_slice(&self.partition_fields);
364
365        let file_schema = Schema::new(fields);
366
367        TableSchema::new(Arc::new(file_schema), file_projection)
368    }
369}
370
371impl Default for SAMSchemaBuilder {
372    fn default() -> Self {
373        let tags_data_type = DataType::List(Arc::new(Field::new(
374            "item",
375            DataType::Struct(Fields::from(vec![
376                Field::new("tag", DataType::Utf8, false),
377                Field::new("value", DataType::Utf8, true),
378            ])),
379            true,
380        )));
381
382        let quality_score_list =
383            DataType::List(Arc::new(Field::new("item", DataType::Int64, true)));
384
385        Self::new(
386            vec![
387                Field::new("name", DataType::Utf8, false),
388                Field::new("flag", DataType::Int32, false),
389                Field::new("reference", DataType::Utf8, true),
390                Field::new("start", DataType::Int64, true),
391                Field::new("end", DataType::Int64, true),
392                Field::new("mapping_quality", DataType::Utf8, true),
393                Field::new("cigar", DataType::Utf8, false),
394                Field::new("mate_reference", DataType::Utf8, true),
395                Field::new("sequence", DataType::Utf8, false),
396                Field::new("quality_score", quality_score_list, false),
397            ],
398            vec![],
399        )
400        .with_tags_data_type(tags_data_type)
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use noodles::sam::alignment::record::data::field::Tag;
407
408    use super::*;
409
410    #[test]
411    fn test_build() -> Result<()> {
412        let schema = SAMSchemaBuilder::default().build();
413
414        assert_eq!(schema.fields().len(), 11);
415
416        Ok(())
417    }
418
419    #[test]
420    fn test_build_from_empty_data_errors() -> Result<()> {
421        let data = Data::default();
422
423        let schema = SAMSchemaBuilder::default().with_tags_data_type_from_data(&data);
424        assert!(schema.is_err());
425
426        Ok(())
427    }
428
429    #[test]
430    fn test_parsing_data() -> Result<()> {
431        let mut data = Data::default();
432
433        data.insert(Tag::ALIGNMENT_HIT_COUNT, Value::from(1));
434        data.insert(Tag::CELL_BARCODE_ID, Value::from("AA"));
435        data.insert(Tag::ORIGINAL_UMI_QUALITY_SCORES, Value::from(vec![1, 2, 3]));
436
437        let schema = SAMSchemaBuilder::default().with_tags_data_type_from_data(&data)?;
438
439        let expected_fields = vec![
440            Field::new(
441                "BZ",
442                DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
443                false,
444            ),
445            Field::new("CB", DataType::Utf8, false),
446            Field::new("NH", DataType::UInt8, false),
447        ];
448
449        let tags_type = schema
450            .tags_data_type
451            .ok_or(ArrowError::InvalidArgumentError(
452                "tags_data_type is None".into(),
453            ))?;
454
455        match tags_type {
456            DataType::Struct(fields) => {
457                let mut fields = fields.iter().collect::<Vec<_>>();
458                fields.sort_by(|a, b| a.name().cmp(b.name()));
459
460                assert_eq!(fields.len(), 3);
461
462                fields
463                    .iter()
464                    .zip(expected_fields.iter())
465                    .for_each(|(actual, expected)| {
466                        assert_eq!(actual.name(), expected.name());
467                        assert_eq!(actual.data_type(), expected.data_type());
468                    });
469            }
470            _ => {
471                return Err(ArrowError::InvalidArgumentError(
472                    "tags_data_type is not a struct".into(),
473                ))
474            }
475        };
476
477        Ok(())
478    }
479}