use anyhow::Result;
use fgumi_raw_bam::RawRecord;
use std::io;
use crate::unified_pipeline::{BatchWeight, DecodedRecord, MemoryEstimate};
use crate::unified_pipeline::Grouper;
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct MiGroup {
pub mi: String,
pub records: Vec<RawRecord>,
}
impl MiGroup {
#[must_use]
pub fn new(mi: String, records: Vec<RawRecord>) -> Self {
Self { mi, records }
}
}
impl BatchWeight for MiGroup {
fn batch_weight(&self) -> usize {
self.records.len()
}
}
impl MemoryEstimate for MiGroup {
fn estimate_heap_size(&self) -> usize {
let mi_size = self.mi.capacity();
let records_size: usize = self.records.iter().map(RawRecord::capacity).sum();
let records_vec_overhead = self.records.capacity() * std::mem::size_of::<RawRecord>();
mi_size + records_size + records_vec_overhead
}
}
#[derive(Default)]
pub struct MiGroupBatch {
pub groups: Vec<MiGroup>,
}
impl MiGroupBatch {
#[must_use]
pub fn new() -> Self {
Self { groups: Vec::new() }
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self { groups: Vec::with_capacity(capacity) }
}
#[must_use]
pub fn len(&self) -> usize {
self.groups.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.groups.is_empty()
}
pub fn clear(&mut self) {
self.groups.clear();
}
}
impl BatchWeight for MiGroupBatch {
fn batch_weight(&self) -> usize {
self.groups.iter().map(|g| g.records.len()).sum()
}
}
impl MemoryEstimate for MiGroupBatch {
fn estimate_heap_size(&self) -> usize {
let groups_size: usize = self.groups.iter().map(MemoryEstimate::estimate_heap_size).sum();
let groups_vec_overhead = self.groups.capacity() * std::mem::size_of::<MiGroup>();
groups_size + groups_vec_overhead
}
}
type MiTransformFn = Box<dyn Fn(&[u8]) -> String + Send + Sync>;
type RecordFilterFn = Box<dyn Fn(&[u8]) -> bool + Send + Sync>;
pub struct MiGrouper {
tag: [u8; 2],
batch_size: usize,
current_mi: Option<String>,
current_records: Vec<RawRecord>,
pending_groups: VecDeque<MiGroup>,
finished: bool,
mi_transform: Option<MiTransformFn>,
record_filter: Option<RecordFilterFn>,
cell_tag: Option<[u8; 2]>,
}
impl MiGrouper {
#[must_use]
pub fn new(tag_name: &str, batch_size: usize) -> Self {
assert!(tag_name.len() == 2, "Tag name must be exactly 2 characters");
let tag_bytes = tag_name.as_bytes();
Self {
tag: [tag_bytes[0], tag_bytes[1]],
batch_size: batch_size.max(1),
current_mi: None,
current_records: Vec::new(),
pending_groups: VecDeque::new(),
finished: false,
mi_transform: None,
record_filter: None,
cell_tag: None,
}
}
pub fn with_filter_and_transform<F, T>(
tag_name: &str,
batch_size: usize,
record_filter: F,
mi_transform: T,
) -> Self
where
F: Fn(&[u8]) -> bool + Send + Sync + 'static,
T: Fn(&[u8]) -> String + Send + Sync + 'static,
{
assert!(tag_name.len() == 2, "Tag name must be exactly 2 characters");
let tag_bytes = tag_name.as_bytes();
Self {
tag: [tag_bytes[0], tag_bytes[1]],
batch_size: batch_size.max(1),
current_mi: None,
current_records: Vec::new(),
pending_groups: VecDeque::new(),
finished: false,
mi_transform: Some(Box::new(mi_transform)),
record_filter: Some(Box::new(record_filter)),
cell_tag: None,
}
}
#[must_use]
pub fn with_cell_tag(mut self, cell_tag: Option<[u8; 2]>) -> Self {
self.cell_tag = cell_tag;
self
}
fn get_mi_tag(&self, bam: &[u8]) -> Option<String> {
use crate::sort::bam_fields;
let value = bam_fields::find_string_tag_in_record(bam, &self.tag)?;
let mut key = if let Some(ref transform) = self.mi_transform {
transform(value)
} else {
String::from_utf8_lossy(value).into_owned()
};
if let Some(ct) = &self.cell_tag {
key.push('\t');
if let Some(cell_value) = bam_fields::find_string_tag_in_record(bam, ct) {
key.push_str(&String::from_utf8_lossy(cell_value));
}
}
Some(key)
}
fn should_keep(&self, bam: &[u8]) -> bool {
match &self.record_filter {
Some(filter) => filter(bam),
None => true,
}
}
fn flush_current_group(&mut self) {
if let Some(mi) = self.current_mi.take() {
if !self.current_records.is_empty() {
let records = std::mem::take(&mut self.current_records);
self.pending_groups.push_back(MiGroup::new(mi, records));
}
}
}
fn drain_batches(&mut self) -> Vec<MiGroupBatch> {
let mut batches = Vec::new();
while self.pending_groups.len() >= self.batch_size {
let groups: Vec<MiGroup> = self.pending_groups.drain(..self.batch_size).collect();
batches.push(MiGroupBatch { groups });
}
batches
}
}
impl Grouper for MiGrouper {
type Group = MiGroupBatch;
fn add_records(&mut self, records: Vec<DecodedRecord>) -> io::Result<Vec<Self::Group>> {
for decoded in records {
let raw = decoded.into_raw_bytes();
if !self.should_keep(&raw) {
continue;
}
let Some(mi) = self.get_mi_tag(&raw) else {
continue;
};
match &self.current_mi {
Some(current) if current == &mi => {
self.current_records.push(raw);
}
Some(_) => {
self.flush_current_group();
self.current_mi = Some(mi);
self.current_records.push(raw);
}
None => {
self.current_mi = Some(mi);
self.current_records.push(raw);
}
}
}
Ok(self.drain_batches())
}
fn finish(&mut self) -> io::Result<Option<Self::Group>> {
if self.finished {
return Ok(None);
}
self.finished = true;
self.flush_current_group();
if self.pending_groups.is_empty() {
Ok(None)
} else {
let groups: Vec<MiGroup> = self.pending_groups.drain(..).collect();
Ok(Some(MiGroupBatch { groups }))
}
}
fn has_pending(&self) -> bool {
!self.pending_groups.is_empty() || self.current_mi.is_some()
}
}
#[allow(clippy::type_complexity)]
pub struct MiGroupIterator<I>
where
I: Iterator<Item = Result<RawRecord>>,
{
record_iter: I,
tag: [u8; 2],
cell_tag: Option<[u8; 2]>,
current_mi: Option<String>,
current_group: Vec<RawRecord>,
done: bool,
pending_error: Option<anyhow::Error>,
mi_transform: Option<Box<dyn Fn(&[u8]) -> String>>,
}
impl<I> MiGroupIterator<I>
where
I: Iterator<Item = Result<RawRecord>>,
{
pub fn new(record_iter: I, tag_name: &str) -> Self {
assert!(tag_name.len() == 2, "Tag name must be exactly 2 characters");
let tag_bytes = tag_name.as_bytes();
Self {
record_iter,
tag: [tag_bytes[0], tag_bytes[1]],
cell_tag: None,
current_mi: None,
current_group: Vec::new(),
done: false,
pending_error: None,
mi_transform: None,
}
}
pub fn with_transform<F>(record_iter: I, tag_name: &str, mi_transform: F) -> Self
where
F: Fn(&[u8]) -> String + 'static,
{
assert!(tag_name.len() == 2, "Tag name must be exactly 2 characters");
let tag_bytes = tag_name.as_bytes();
Self {
record_iter,
tag: [tag_bytes[0], tag_bytes[1]],
cell_tag: None,
current_mi: None,
current_group: Vec::new(),
done: false,
pending_error: None,
mi_transform: Some(Box::new(mi_transform)),
}
}
#[must_use]
pub fn with_cell_tag(mut self, cell_tag: Option<[u8; 2]>) -> Self {
self.cell_tag = cell_tag;
self
}
fn get_mi(&self, bam: &[u8]) -> Option<String> {
use crate::sort::bam_fields;
let value = bam_fields::find_string_tag_in_record(bam, &self.tag)?;
let mut key = if let Some(ref transform) = self.mi_transform {
transform(value)
} else {
String::from_utf8_lossy(value).into_owned()
};
if let Some(ct) = &self.cell_tag {
key.push('\t');
if let Some(cell_value) = bam_fields::find_string_tag_in_record(bam, ct) {
key.push_str(&String::from_utf8_lossy(cell_value));
}
}
Some(key)
}
}
impl<I> Iterator for MiGroupIterator<I>
where
I: Iterator<Item = Result<RawRecord>>,
{
type Item = Result<(String, Vec<RawRecord>)>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if let Some(e) = self.pending_error.take() {
self.done = true;
return Some(Err(e));
}
loop {
match self.record_iter.next() {
None => {
self.done = true;
if self.current_group.is_empty() {
return None;
}
let mi = self.current_mi.take().unwrap_or_default();
let group = std::mem::take(&mut self.current_group);
return Some(Ok((mi, group)));
}
Some(Err(e)) => {
if !self.current_group.is_empty() {
self.pending_error = Some(e);
let mi = self.current_mi.take().unwrap_or_default();
let group = std::mem::take(&mut self.current_group);
return Some(Ok((mi, group)));
}
self.done = true;
return Some(Err(e));
}
Some(Ok(raw)) => {
let Some(mi) = self.get_mi(&raw) else {
continue;
};
if self.current_group.is_empty() {
self.current_mi = Some(mi);
self.current_group.push(raw);
} else if self.current_mi.as_ref() == Some(&mi) {
self.current_group.push(raw);
} else {
let old_mi = self.current_mi.take().unwrap_or_default();
let group = std::mem::take(&mut self.current_group);
self.current_mi = Some(mi);
self.current_group.push(raw);
return Some(Ok((old_mi, group)));
}
}
}
}
}
}
#[cfg(test)]
#[allow(clippy::similar_names)]
mod tests {
use super::*;
use crate::sam::SamTag;
use crate::umi::extract_mi_base;
#[allow(clippy::cast_possible_truncation)]
fn make_raw_bam_with_tag(tag_name: &str, tag_value: &str) -> RawRecord {
let name = b"read";
let l_read_name: u8 = (name.len() + 1) as u8; let seq_len: u32 = 4; let seq_bytes = seq_len.div_ceil(2) as usize;
let tag_bytes = tag_name.as_bytes();
let aux: Vec<u8> =
[&[tag_bytes[0], tag_bytes[1], b'Z'], tag_value.as_bytes(), &[0u8]].concat();
let total = 32 + l_read_name as usize + seq_bytes + seq_len as usize + aux.len();
let mut buf = vec![0u8; total];
buf[0..4].copy_from_slice(&(-1i32).to_le_bytes());
buf[4..8].copy_from_slice(&(-1i32).to_le_bytes());
buf[8] = l_read_name;
buf[12..14].copy_from_slice(&0u16.to_le_bytes());
buf[16..20].copy_from_slice(&seq_len.to_le_bytes());
buf[20..24].copy_from_slice(&(-1i32).to_le_bytes());
buf[24..28].copy_from_slice(&(-1i32).to_le_bytes());
let name_start = 32;
buf[name_start..name_start + name.len()].copy_from_slice(name);
buf[name_start + name.len()] = 0;
let aux_start = 32 + l_read_name as usize + seq_bytes + seq_len as usize;
buf[aux_start..aux_start + aux.len()].copy_from_slice(&aux);
RawRecord::from(buf)
}
#[allow(clippy::cast_possible_truncation)]
fn make_raw_bam_without_tag() -> RawRecord {
let name = b"read";
let l_read_name: u8 = (name.len() + 1) as u8;
let seq_len: u32 = 4;
let seq_bytes = seq_len.div_ceil(2) as usize;
let total = 32 + l_read_name as usize + seq_bytes + seq_len as usize;
let mut buf = vec![0u8; total];
buf[0..4].copy_from_slice(&(-1i32).to_le_bytes());
buf[4..8].copy_from_slice(&(-1i32).to_le_bytes());
buf[8] = l_read_name;
buf[12..14].copy_from_slice(&0u16.to_le_bytes());
buf[16..20].copy_from_slice(&seq_len.to_le_bytes());
buf[20..24].copy_from_slice(&(-1i32).to_le_bytes());
buf[24..28].copy_from_slice(&(-1i32).to_le_bytes());
let name_start = 32;
buf[name_start..name_start + name.len()].copy_from_slice(name);
buf[name_start + name.len()] = 0;
RawRecord::from(buf)
}
#[test]
fn test_empty_iterator() {
let records: Vec<Result<RawRecord>> = vec![];
let mut iter = MiGroupIterator::new(records.into_iter(), "MI");
assert!(iter.next().is_none());
}
#[test]
fn test_single_group() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "0")),
Ok(make_raw_bam_with_tag("MI", "0")),
Ok(make_raw_bam_with_tag("MI", "0")),
];
let mut iter = MiGroupIterator::new(records.into_iter(), "MI");
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "0");
assert_eq!(result.1.len(), 3);
assert!(iter.next().is_none());
}
#[test]
fn test_multiple_groups() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "0")),
Ok(make_raw_bam_with_tag("MI", "0")),
Ok(make_raw_bam_with_tag("MI", "1")),
Ok(make_raw_bam_with_tag("MI", "1")),
Ok(make_raw_bam_with_tag("MI", "1")),
Ok(make_raw_bam_with_tag("MI", "2")),
];
let mut iter = MiGroupIterator::new(records.into_iter(), "MI");
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "0");
assert_eq!(result.1.len(), 2);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1");
assert_eq!(result.1.len(), 3);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "2");
assert_eq!(result.1.len(), 1);
assert!(iter.next().is_none());
}
#[test]
fn test_skips_records_without_mi_tag() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "0")),
Ok(make_raw_bam_without_tag()),
Ok(make_raw_bam_with_tag("MI", "0")),
Ok(make_raw_bam_without_tag()),
Ok(make_raw_bam_with_tag("MI", "1")),
];
let mut iter = MiGroupIterator::new(records.into_iter(), "MI");
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "0");
assert_eq!(result.1.len(), 2);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1");
assert_eq!(result.1.len(), 1);
assert!(iter.next().is_none());
}
#[test]
fn test_error_propagation() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "0")),
Err(anyhow::anyhow!("test error")),
Ok(make_raw_bam_with_tag("MI", "1")),
];
let mut iter = MiGroupIterator::new(records.into_iter(), "MI");
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "0");
assert_eq!(result.1.len(), 1);
let err = iter.next().expect("iterator should yield item");
assert!(err.is_err());
assert!(iter.next().is_none());
}
#[test]
fn test_error_with_no_pending_group() {
let records: Vec<Result<RawRecord>> =
vec![Err(anyhow::anyhow!("immediate error")), Ok(make_raw_bam_with_tag("MI", "0"))];
let mut iter = MiGroupIterator::new(records.into_iter(), "MI");
let result = iter.next().expect("iterator should yield item");
assert!(result.is_err());
assert!(iter.next().is_none());
}
#[test]
fn test_custom_tag() {
let records: Vec<Result<RawRecord>> =
vec![Ok(make_raw_bam_with_tag("RX", "ACGT")), Ok(make_raw_bam_with_tag("RX", "ACGT"))];
let mut iter = MiGroupIterator::new(records.into_iter(), "RX");
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "ACGT");
assert_eq!(result.1.len(), 2);
assert!(iter.next().is_none());
}
#[test]
#[should_panic(expected = "Tag name must be exactly 2 characters")]
fn test_invalid_tag_length() {
let records: Vec<Result<RawRecord>> = vec![];
let _ = MiGroupIterator::new(records.into_iter(), "M");
}
#[test]
fn test_with_transform() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "1/A")),
Ok(make_raw_bam_with_tag("MI", "1/A")),
Ok(make_raw_bam_with_tag("MI", "1/B")),
Ok(make_raw_bam_with_tag("MI", "1/B")),
Ok(make_raw_bam_with_tag("MI", "2/A")),
Ok(make_raw_bam_with_tag("MI", "2/B")),
];
let mut iter = MiGroupIterator::with_transform(records.into_iter(), "MI", |raw| {
let s = String::from_utf8_lossy(raw);
extract_mi_base(&s).to_string()
});
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1");
assert_eq!(result.1.len(), 4);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "2");
assert_eq!(result.1.len(), 2);
assert!(iter.next().is_none());
}
#[test]
fn test_with_transform_empty() {
let records: Vec<Result<RawRecord>> = vec![];
let mut iter = MiGroupIterator::with_transform(records.into_iter(), "MI", |raw| {
String::from_utf8_lossy(raw).to_uppercase()
});
assert!(iter.next().is_none());
}
#[test]
#[should_panic(expected = "Tag name must be exactly 2 characters")]
fn test_with_transform_invalid_tag_length() {
let records: Vec<Result<RawRecord>> = vec![];
let _ = MiGroupIterator::with_transform(records.into_iter(), "ABC", |raw| {
String::from_utf8_lossy(raw).into_owned()
});
}
#[test]
fn test_get_mi_without_transform() {
let iter = MiGroupIterator::new(std::iter::empty::<Result<RawRecord>>(), "MI");
let bam = make_raw_bam_with_tag("MI", "42");
assert_eq!(iter.get_mi(&bam), Some("42".to_string()));
}
#[test]
fn test_get_mi_with_transform() {
let iter =
MiGroupIterator::with_transform(std::iter::empty::<Result<RawRecord>>(), "MI", |raw| {
let s = String::from_utf8_lossy(raw);
s.to_uppercase()
});
let bam = make_raw_bam_with_tag("MI", "abc");
assert_eq!(iter.get_mi(&bam), Some("ABC".to_string()));
}
#[test]
fn test_get_mi_missing_tag() {
let iter = MiGroupIterator::new(std::iter::empty::<Result<RawRecord>>(), "MI");
let bam = make_raw_bam_without_tag();
assert_eq!(iter.get_mi(&bam), None);
}
#[test]
fn test_get_mi_wrong_tag() {
let iter = MiGroupIterator::new(std::iter::empty::<Result<RawRecord>>(), "MI");
let bam = make_raw_bam_with_tag("RX", "ACGT");
assert_eq!(iter.get_mi(&bam), None);
}
#[allow(clippy::cast_possible_truncation)]
fn make_raw_bam_with_two_tags(tag1: &str, val1: &str, tag2: &str, val2: &str) -> RawRecord {
let name = b"read";
let l_read_name: u8 = (name.len() + 1) as u8;
let seq_len: u32 = 4;
let seq_bytes = seq_len.div_ceil(2) as usize;
let t1 = tag1.as_bytes();
let t2 = tag2.as_bytes();
let aux: Vec<u8> = [
&[t1[0], t1[1], b'Z'],
val1.as_bytes(),
&[0u8],
&[t2[0], t2[1], b'Z'],
val2.as_bytes(),
&[0u8],
]
.concat();
let total = 32 + l_read_name as usize + seq_bytes + seq_len as usize + aux.len();
let mut buf = vec![0u8; total];
buf[0..4].copy_from_slice(&(-1i32).to_le_bytes());
buf[4..8].copy_from_slice(&(-1i32).to_le_bytes());
buf[8] = l_read_name;
buf[12..14].copy_from_slice(&0u16.to_le_bytes());
buf[16..20].copy_from_slice(&seq_len.to_le_bytes());
buf[20..24].copy_from_slice(&(-1i32).to_le_bytes());
buf[24..28].copy_from_slice(&(-1i32).to_le_bytes());
let name_start = 32;
buf[name_start..name_start + name.len()].copy_from_slice(name);
buf[name_start + name.len()] = 0;
let aux_start = 32 + l_read_name as usize + seq_bytes + seq_len as usize;
buf[aux_start..aux_start + aux.len()].copy_from_slice(&aux);
RawRecord::from(buf)
}
#[test]
fn test_cell_tag_composite_grouping() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_two_tags("MI", "1", "CB", "ACGT")),
Ok(make_raw_bam_with_two_tags("MI", "1", "CB", "ACGT")),
Ok(make_raw_bam_with_two_tags("MI", "1", "CB", "TGCA")),
Ok(make_raw_bam_with_two_tags("MI", "1", "CB", "TGCA")),
];
let mut iter =
MiGroupIterator::new(records.into_iter(), "MI").with_cell_tag(Some(*SamTag::CB));
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1\tACGT");
assert_eq!(result.1.len(), 2);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1\tTGCA");
assert_eq!(result.1.len(), 2);
assert!(iter.next().is_none());
}
#[test]
fn test_cell_tag_none_groups_by_mi_only() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_two_tags("MI", "1", "CB", "ACGT")),
Ok(make_raw_bam_with_two_tags("MI", "1", "CB", "TGCA")),
];
let mut iter = MiGroupIterator::new(records.into_iter(), "MI").with_cell_tag(None);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1");
assert_eq!(result.1.len(), 2);
assert!(iter.next().is_none());
}
#[test]
fn test_cell_tag_missing_cell_value() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "1")),
Ok(make_raw_bam_with_tag("MI", "1")),
Ok(make_raw_bam_with_two_tags("MI", "1", "CB", "ACGT")),
];
let mut iter =
MiGroupIterator::new(records.into_iter(), "MI").with_cell_tag(Some(*SamTag::CB));
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1\t");
assert_eq!(result.1.len(), 2);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1\tACGT");
assert_eq!(result.1.len(), 1);
assert!(iter.next().is_none());
}
#[test]
fn test_cell_tag_with_transform() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_two_tags("MI", "1/A", "CB", "ACGT")),
Ok(make_raw_bam_with_two_tags("MI", "1/B", "CB", "ACGT")),
Ok(make_raw_bam_with_two_tags("MI", "1/A", "CB", "TGCA")),
];
let mut iter = MiGroupIterator::with_transform(records.into_iter(), "MI", |raw| {
let s = String::from_utf8_lossy(raw);
extract_mi_base(&s).to_string()
})
.with_cell_tag(Some(*SamTag::CB));
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1\tACGT");
assert_eq!(result.1.len(), 2);
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "1\tTGCA");
assert_eq!(result.1.len(), 1);
assert!(iter.next().is_none());
}
fn make_raw_decoded_record(tag_name: &str, tag_value: &str) -> DecodedRecord {
let raw = make_raw_bam_with_tag(tag_name, tag_value);
let key = crate::unified_pipeline::GroupKey::single(0, 0, 0, 0, 0, 0);
DecodedRecord::from_raw_bytes(raw, key)
}
fn make_raw_decoded_record_no_tag() -> DecodedRecord {
let raw = make_raw_bam_without_tag();
let key = crate::unified_pipeline::GroupKey::single(0, 0, 0, 0, 0, 0);
DecodedRecord::from_raw_bytes(raw, key)
}
#[test]
fn test_grouper_single_mi_group() {
let mut grouper = MiGrouper::new("MI", 10);
let records = vec![
make_raw_decoded_record("MI", "0"),
make_raw_decoded_record("MI", "0"),
make_raw_decoded_record("MI", "0"),
];
let batches = grouper.add_records(records).expect("add_records should succeed");
assert!(batches.is_empty());
assert!(grouper.has_pending());
let final_batch =
grouper.finish().expect("finish should succeed").expect("should return final batch");
assert_eq!(final_batch.groups.len(), 1);
assert_eq!(final_batch.groups[0].mi, "0");
assert_eq!(final_batch.groups[0].records.len(), 3);
}
#[test]
fn test_grouper_multiple_mi_groups() {
let mut grouper = MiGrouper::new("MI", 10);
let records = vec![
make_raw_decoded_record("MI", "0"),
make_raw_decoded_record("MI", "0"),
make_raw_decoded_record("MI", "1"),
make_raw_decoded_record("MI", "1"),
make_raw_decoded_record("MI", "1"),
make_raw_decoded_record("MI", "2"),
];
let batches = grouper.add_records(records).expect("add_records should succeed");
assert!(batches.is_empty());
let final_batch =
grouper.finish().expect("finish should succeed").expect("should return final batch");
assert_eq!(final_batch.groups.len(), 3);
assert_eq!(final_batch.groups[0].mi, "0");
assert_eq!(final_batch.groups[0].records.len(), 2);
assert_eq!(final_batch.groups[1].mi, "1");
assert_eq!(final_batch.groups[1].records.len(), 3);
assert_eq!(final_batch.groups[2].mi, "2");
assert_eq!(final_batch.groups[2].records.len(), 1);
}
#[test]
fn test_grouper_batch_size_triggers() {
let mut grouper = MiGrouper::new("MI", 2);
let records = vec![
make_raw_decoded_record("MI", "0"),
make_raw_decoded_record("MI", "1"),
make_raw_decoded_record("MI", "2"),
make_raw_decoded_record("MI", "3"),
make_raw_decoded_record("MI", "4"),
];
let batches = grouper.add_records(records).expect("add_records should succeed");
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].groups.len(), 2);
assert_eq!(batches[0].groups[0].mi, "0");
assert_eq!(batches[0].groups[1].mi, "1");
assert_eq!(batches[1].groups.len(), 2);
assert_eq!(batches[1].groups[0].mi, "2");
assert_eq!(batches[1].groups[1].mi, "3");
assert!(grouper.has_pending());
let final_batch =
grouper.finish().expect("finish should succeed").expect("should return final batch");
assert_eq!(final_batch.groups.len(), 1);
assert_eq!(final_batch.groups[0].mi, "4");
}
#[test]
fn test_grouper_skips_records_without_mi_tag() {
let mut grouper = MiGrouper::new("MI", 10);
let records = vec![
make_raw_decoded_record("MI", "0"),
make_raw_decoded_record_no_tag(),
make_raw_decoded_record("MI", "0"),
];
let batches = grouper.add_records(records).expect("add_records should succeed");
assert!(batches.is_empty());
let final_batch =
grouper.finish().expect("finish should succeed").expect("should return final batch");
assert_eq!(final_batch.groups.len(), 1);
assert_eq!(final_batch.groups[0].mi, "0");
assert_eq!(final_batch.groups[0].records.len(), 2);
}
#[test]
fn test_grouper_finish_empty() {
let mut grouper = MiGrouper::new("MI", 10);
assert!(!grouper.has_pending());
let final_batch = grouper.finish().expect("finish should succeed");
assert!(final_batch.is_none());
}
#[test]
fn test_grouper_finish_idempotent() {
let mut grouper = MiGrouper::new("MI", 10);
let records = vec![make_raw_decoded_record("MI", "0")];
grouper.add_records(records).expect("add_records should succeed");
let batch1 = grouper.finish().expect("finish should succeed");
assert!(batch1.is_some());
let batch2 = grouper.finish().expect("finish should succeed");
assert!(batch2.is_none());
}
#[test]
fn test_grouper_has_pending_states() {
let mut grouper = MiGrouper::new("MI", 10);
assert!(!grouper.has_pending());
let records = vec![make_raw_decoded_record("MI", "0")];
grouper.add_records(records).expect("add_records should succeed");
assert!(grouper.has_pending());
let records = vec![make_raw_decoded_record("MI", "1")];
grouper.add_records(records).expect("add_records should succeed");
assert!(grouper.has_pending());
}
#[test]
fn test_grouper_with_filter_and_transform() {
let mut grouper = MiGrouper::with_filter_and_transform(
"MI",
10,
|bam: &[u8]| {
let flag = fgumi_raw_bam::RawRecordView::new(bam).flags();
flag & fgumi_raw_bam::flags::SECONDARY == 0 },
|raw: &[u8]| {
let s = String::from_utf8_lossy(raw);
extract_mi_base(&s).to_string()
},
);
let rec_primary = make_raw_bam_with_tag("MI", "1/A");
let mut rec_secondary = make_raw_bam_with_tag("MI", "1/A");
rec_secondary[14..16].copy_from_slice(&0x100u16.to_le_bytes());
let rec_b = make_raw_bam_with_tag("MI", "1/B");
let key = crate::unified_pipeline::GroupKey::single(0, 0, 0, 0, 0, 0);
let records = vec![
DecodedRecord::from_raw_bytes(rec_primary, key),
DecodedRecord::from_raw_bytes(rec_secondary, key),
DecodedRecord::from_raw_bytes(rec_b, key),
];
let batches = grouper.add_records(records).expect("add_records should succeed");
assert!(batches.is_empty());
let final_batch =
grouper.finish().expect("finish should succeed").expect("should return final batch");
assert_eq!(final_batch.groups.len(), 1);
assert_eq!(final_batch.groups[0].mi, "1");
assert_eq!(final_batch.groups[0].records.len(), 2); }
#[test]
fn test_grouper_cell_tag_composite_grouping() {
fn make_two_tag_decoded(mi: &str, cb: &str) -> DecodedRecord {
let raw = make_raw_bam_with_two_tags("MI", mi, "CB", cb);
let key = crate::unified_pipeline::GroupKey::single(0, 0, 0, 0, 0, 0);
DecodedRecord::from_raw_bytes(raw, key)
}
let mut grouper = MiGrouper::new("MI", 10).with_cell_tag(Some(*SamTag::CB));
let records = vec![
make_two_tag_decoded("1", "ACGT"),
make_two_tag_decoded("1", "ACGT"),
make_two_tag_decoded("1", "TGCA"),
make_two_tag_decoded("1", "TGCA"),
];
let batches = grouper.add_records(records).expect("add_records should succeed");
assert!(batches.is_empty());
let final_batch =
grouper.finish().expect("finish should succeed").expect("should return final batch");
assert_eq!(final_batch.groups.len(), 2);
assert_eq!(final_batch.groups[0].mi, "1\tACGT");
assert_eq!(final_batch.groups[0].records.len(), 2);
assert_eq!(final_batch.groups[1].mi, "1\tTGCA");
assert_eq!(final_batch.groups[1].records.len(), 2);
for group in &final_batch.groups {
for raw in &group.records {
use crate::sort::bam_fields;
let mi_val = bam_fields::find_string_tag_in_record(raw, b"MI");
assert_eq!(mi_val, Some(b"1".as_ref()), "MI tag must retain original value");
}
}
}
#[test]
fn test_mi_group_new() {
let raw = make_raw_bam_with_tag("MI", "42");
let group = MiGroup::new("42".to_string(), vec![raw]);
assert_eq!(group.mi, "42");
assert_eq!(group.records.len(), 1);
}
#[test]
fn test_mi_group_batch_weight() {
let group = MiGroup::new(
"0".to_string(),
vec![make_raw_bam_with_tag("MI", "0"), make_raw_bam_with_tag("MI", "0")],
);
assert_eq!(group.batch_weight(), 2);
}
#[test]
fn test_mi_group_memory_estimate() {
let group = MiGroup::new("0".to_string(), vec![make_raw_bam_with_tag("MI", "0")]);
let size = group.estimate_heap_size();
assert!(size > 0);
}
#[test]
fn test_mi_group_batch_new() {
let batch = MiGroupBatch::new();
assert!(batch.groups.is_empty());
}
#[test]
fn test_mi_group_batch_default() {
let batch = MiGroupBatch::default();
assert!(batch.groups.is_empty());
}
#[test]
fn test_mi_group_batch_weight_method() {
let mut batch = MiGroupBatch::new();
assert_eq!(batch.batch_weight(), 0);
batch.groups.push(MiGroup::new(
"0".to_string(),
vec![make_raw_bam_with_tag("MI", "0"), make_raw_bam_with_tag("MI", "0")],
));
assert_eq!(batch.batch_weight(), 2);
batch.groups.push(MiGroup::new("1".to_string(), vec![make_raw_bam_with_tag("MI", "1")]));
assert_eq!(batch.batch_weight(), 3);
}
#[test]
fn test_mi_group_batch_memory_estimate() {
let batch = MiGroupBatch::new();
let _ = batch.estimate_heap_size();
let mut batch = MiGroupBatch::new();
batch.groups.push(MiGroup::new("0".to_string(), vec![make_raw_bam_with_tag("MI", "0")]));
let size = batch.estimate_heap_size();
assert!(size > 0);
}
#[test]
fn test_mi_group_batch_with_capacity() {
let batch = MiGroupBatch::with_capacity(16);
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
}
#[test]
fn test_mi_group_batch_len_and_is_empty() {
let mut batch = MiGroupBatch::new();
assert!(batch.is_empty());
batch.groups.push(MiGroup::new("0".to_string(), vec![make_raw_bam_with_tag("MI", "0")]));
assert!(!batch.is_empty());
assert_eq!(batch.len(), 1);
batch.groups.push(MiGroup::new(
"1".to_string(),
vec![make_raw_bam_with_tag("MI", "1"), make_raw_bam_with_tag("MI", "1")],
));
assert_eq!(batch.len(), 2);
}
#[test]
fn test_mi_group_batch_clear() {
let mut batch = MiGroupBatch::new();
batch.groups.push(MiGroup::new("0".to_string(), vec![make_raw_bam_with_tag("MI", "0")]));
assert!(!batch.is_empty());
batch.clear();
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
}
#[test]
fn test_mi_group_batch_weight_single() {
let group = MiGroup::new(
"0".to_string(),
vec![
make_raw_bam_with_tag("MI", "0"),
make_raw_bam_with_tag("MI", "0"),
make_raw_bam_with_tag("MI", "0"),
],
);
assert_eq!(group.batch_weight(), 3);
}
#[test]
fn test_with_transform_single_group() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "0/A")),
Ok(make_raw_bam_with_tag("MI", "0/B")),
Ok(make_raw_bam_with_tag("MI", "0/A")),
];
let mut iter = MiGroupIterator::with_transform(records.into_iter(), "MI", |raw| {
let s = String::from_utf8_lossy(raw);
extract_mi_base(&s).to_string()
});
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "0");
assert_eq!(result.1.len(), 3);
assert!(iter.next().is_none());
}
#[test]
fn test_with_transform_error_propagation() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "0/A")),
Err(anyhow::anyhow!("test error")),
Ok(make_raw_bam_with_tag("MI", "1/B")),
];
let mut iter = MiGroupIterator::with_transform(records.into_iter(), "MI", |raw| {
let s = String::from_utf8_lossy(raw);
extract_mi_base(&s).to_string()
});
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "0");
assert_eq!(result.1.len(), 1);
let result = iter.next().expect("iterator should yield item");
assert!(result.is_err());
assert!(iter.next().is_none());
}
#[test]
fn test_with_transform_custom_function() {
let records: Vec<Result<RawRecord>> = vec![
Ok(make_raw_bam_with_tag("MI", "abc")),
Ok(make_raw_bam_with_tag("MI", "ABC")),
Ok(make_raw_bam_with_tag("MI", "Abc")),
];
let mut iter = MiGroupIterator::with_transform(records.into_iter(), "MI", |raw| {
String::from_utf8_lossy(raw).to_uppercase()
});
let result = iter.next().expect("iterator should yield item").expect("item should be Ok");
assert_eq!(result.0, "ABC");
assert_eq!(result.1.len(), 3);
assert!(iter.next().is_none());
}
}