1use std::io::Write;
34use std::num::NonZero;
35use std::path::Path;
36
37use noodles::bam;
38use noodles::sam;
39use rsomics_bamio::raw;
40use rsomics_common::{Result, RsomicsError};
41
42const FLAG_PAIRED: u16 = 0x1;
43const FLAG_PROPER_PAIR: u16 = 0x2;
44const FLAG_UNMAP: u16 = 0x4;
45const FLAG_MUNMAP: u16 = 0x8;
46const FLAG_REVERSE: u16 = 0x10;
47const FLAG_MREVERSE: u16 = 0x20;
48const FLAG_SECONDARY: u16 = 0x100;
49const FLAG_DUP: u16 = 0x400;
50const FLAG_SUPPLEMENTARY: u16 = 0x800;
51
52const REF_ID: usize = 0;
54const L_READ_NAME: usize = 8;
55const MAPQ: usize = 9;
56const N_CIGAR: usize = 12;
57const FLAG: usize = 14;
58const L_SEQ: usize = 16;
59const NEXT_REF_ID: usize = 20;
60const TLEN: usize = 28;
61const FIXED_HEAD: usize = 32;
62
63const SEQ_NT16_COMP_STR: &[u8; 16] = b"=TGKCYSBAWRDMHVN";
66
67const SEQ_NT16_STR: &[u8; 16] = b"=ACMGRSVTWYHKDBN";
70
71fn nt16_code(printable: u8) -> u8 {
74 SEQ_NT16_STR
75 .iter()
76 .position(|&c| c == printable)
77 .map(|p| p as u8)
78 .unwrap_or(15)
79}
80
81const DEFAULT_REMOVE_TAGS: [[u8; 2]; 15] = [
83 *b"AS", *b"CC", *b"CG", *b"CP", *b"H1", *b"H2", *b"HI", *b"H0", *b"IH", *b"MC", *b"MD", *b"MQ",
84 *b"NM", *b"SA", *b"TS",
85];
86
87#[derive(Debug, Clone)]
90enum AuxFilter {
91 Remove(Vec<[u8; 2]>),
94 Keep(Vec<[u8; 2]>),
97}
98
99impl AuxFilter {
100 fn survives(&self, tag: [u8; 2]) -> bool {
101 match self {
102 AuxFilter::Remove(set) => !set.contains(&tag),
103 AuxFilter::Keep(set) => set.contains(&tag),
104 }
105 }
106}
107
108#[derive(Debug, Clone, Default)]
110pub struct ResetOpts {
111 pub remove_tags: Vec<[u8; 2]>,
113 pub keep_tags: Vec<[u8; 2]>,
115 pub no_rg: bool,
117 pub no_pg: bool,
119 pub reject_pg: Option<String>,
121 pub keep_dupflag: bool,
123}
124
125fn build_aux_filter(opts: &ResetOpts) -> Result<AuxFilter> {
130 if !opts.keep_tags.is_empty() && !opts.remove_tags.is_empty() {
131 return Err(RsomicsError::InvalidInput(
132 "--keep-tag and -x/--remove-tag are mutually exclusive".to_owned(),
133 ));
134 }
135 if !opts.keep_tags.is_empty() {
136 let mut keep = opts.keep_tags.clone();
137 if opts.no_rg {
138 keep.retain(|t| t != b"RG");
139 }
140 Ok(AuxFilter::Keep(keep))
141 } else {
142 let mut remove = opts.remove_tags.clone();
143 if opts.no_rg && !remove.contains(b"RG") {
144 remove.push(*b"RG");
145 }
146 for t in DEFAULT_REMOVE_TAGS {
147 if !remove.contains(&t) {
148 remove.push(t);
149 }
150 }
151 Ok(AuxFilter::Remove(remove))
152 }
153}
154
155fn rebuild_header(input_text: &str, opts: &ResetOpts, args_cl: &str) -> String {
160 let mut out: Vec<String> = vec!["@HD\tVN:1.6".to_owned()];
161
162 if !opts.no_rg {
163 for line in input_text.lines().filter(|l| l.starts_with("@RG\t")) {
164 out.push(line.to_owned());
165 }
166 }
167
168 let pg_lines: Vec<&str> = input_text
169 .lines()
170 .filter(|l| l.starts_with("@PG\t"))
171 .collect();
172 let kept_pg: Vec<&str> = match &opts.reject_pg {
173 Some(id) => {
174 let needle = format!("ID:{id}");
175 pg_lines
176 .into_iter()
177 .take_while(|l| !l.split('\t').any(|f| f == needle))
178 .collect()
179 }
180 None => pg_lines,
181 };
182 let last_pg_id = kept_pg.last().and_then(|l| {
183 l.split('\t')
184 .find_map(|f| f.strip_prefix("ID:").map(str::to_owned))
185 });
186 for line in &kept_pg {
187 out.push((*line).to_owned());
188 }
189
190 if !opts.no_pg {
191 let id = unique_pg_id("rsomics-bam-reset", &kept_pg);
192 let mut pg = format!("@PG\tID:{id}\tPN:rsomics-bam-reset");
193 if let Some(pp) = &last_pg_id {
194 pg.push_str(&format!("\tPP:{pp}"));
195 }
196 pg.push_str(&format!("\tVN:{}\tCL:{args_cl}", env!("CARGO_PKG_VERSION")));
197 out.push(pg);
198 }
199
200 let mut text = out.join("\n");
201 text.push('\n');
202 text
203}
204
205fn unique_pg_id(base: &str, existing: &[&str]) -> String {
208 let ids: Vec<&str> = existing
209 .iter()
210 .filter_map(|l| l.split('\t').find_map(|f| f.strip_prefix("ID:")))
211 .collect();
212 if !ids.contains(&base) {
213 return base.to_owned();
214 }
215 let mut n = 1;
216 loop {
217 let candidate = format!("{base}.{n}");
218 if !ids.contains(&candidate.as_str()) {
219 return candidate;
220 }
221 n += 1;
222 }
223}
224
225fn header_to_text(header: &sam::Header) -> Result<String> {
226 let mut buf: Vec<u8> = Vec::new();
227 let mut writer = sam::io::Writer::new(&mut buf);
228 writer.write_header(header).map_err(RsomicsError::Io)?;
229 String::from_utf8(buf)
230 .map_err(|e| RsomicsError::InvalidInput(format!("header contains non-UTF-8: {e}")))
231}
232
233fn reparse_header(text: &str) -> Result<sam::Header> {
234 let mut reader = sam::io::Reader::new(text.as_bytes());
235 reader.read_header().map_err(RsomicsError::Io)
236}
237
238fn revert_record(
247 input: &raw::RecordRef<'_>,
248 filter: &AuxFilter,
249 keep_dupflag: bool,
250 out: &mut Vec<u8>,
251) -> bool {
252 let flag = input.flags();
253 if flag & FLAG_SECONDARY != 0 || flag & FLAG_SUPPLEMENTARY != 0 {
254 return false;
255 }
256
257 let mut new_flag = flag & !FLAG_PROPER_PAIR;
258 new_flag |= FLAG_UNMAP;
259 if flag & FLAG_PAIRED != 0 {
260 new_flag |= FLAG_MUNMAP;
261 }
262 new_flag &= !FLAG_MREVERSE;
263 if !keep_dupflag {
264 new_flag &= !FLAG_DUP;
265 }
266 let reverse = flag & FLAG_REVERSE != 0;
267 if reverse {
268 new_flag &= !FLAG_REVERSE;
269 }
270
271 let payload = input.payload();
272 let name = &payload[FIXED_HEAD..FIXED_HEAD + name_len(payload)];
273 let l_seq = base_count(payload);
274 let seq_in_start = FIXED_HEAD + name_len(payload) + cigar_op_count(payload) * 4;
275 let seq_in = &payload[seq_in_start..seq_in_start + l_seq.div_ceil(2)];
276 let qual_in_start = seq_in_start + l_seq.div_ceil(2);
277 let qual_in = &payload[qual_in_start..qual_in_start + l_seq];
278
279 out.clear();
280 out.resize(FIXED_HEAD, 0);
281 write_i32(out, REF_ID, -1);
282 write_i32(out, REF_ID + 4, -1); out[L_READ_NAME] = payload[L_READ_NAME];
284 out[MAPQ] = 0;
285 write_u16(out, N_CIGAR, 0);
286 write_u16(out, FLAG, new_flag);
287 write_u32(out, L_SEQ, l_seq as u32);
288 write_i32(out, NEXT_REF_ID, -1);
289 write_i32(out, NEXT_REF_ID + 4, -1); write_i32(out, TLEN, 0);
291
292 out.extend_from_slice(name);
293
294 if reverse {
295 let mut packed = vec![0u8; l_seq.div_ceil(2)];
299 for (out_i, in_i) in (0..l_seq).rev().enumerate() {
300 let code = seq_nibble(seq_in, in_i);
301 let comp = nt16_code(SEQ_NT16_COMP_STR[code as usize]);
302 set_seq_nibble(&mut packed, out_i, comp);
303 }
304 out.extend_from_slice(&packed);
305 out.extend(qual_in.iter().rev());
306 } else {
307 out.extend_from_slice(seq_in);
308 out.extend_from_slice(qual_in);
309 }
310
311 let aux_in = &payload[qual_in_start + l_seq..];
312 copy_filtered_aux(aux_in, filter, out);
313
314 true
315}
316
317fn write_payload<W: Write>(writer: &mut W, payload: &[u8]) -> Result<()> {
321 let block_size = u32::try_from(payload.len())
322 .map_err(|e| RsomicsError::InvalidInput(format!("record too large: {e}")))?;
323 writer
324 .write_all(&block_size.to_le_bytes())
325 .map_err(RsomicsError::Io)?;
326 writer.write_all(payload).map_err(RsomicsError::Io)?;
327 Ok(())
328}
329
330fn copy_filtered_aux(aux_in: &[u8], filter: &AuxFilter, out: &mut Vec<u8>) {
332 let mut pos = 0;
333 while pos + 3 <= aux_in.len() {
334 let tag = [aux_in[pos], aux_in[pos + 1]];
335 let type_code = aux_in[pos + 2];
336 let value_len = aux_value_len(aux_in, pos + 3, type_code);
337 let field_end = pos + 3 + value_len;
338 if filter.survives(tag) {
339 out.extend_from_slice(&aux_in[pos..field_end]);
340 }
341 pos = field_end;
342 }
343}
344
345fn name_len(p: &[u8]) -> usize {
346 usize::from(p[L_READ_NAME])
347}
348
349fn cigar_op_count(p: &[u8]) -> usize {
350 usize::from(u16::from_le_bytes([p[N_CIGAR], p[N_CIGAR + 1]]))
351}
352
353fn base_count(p: &[u8]) -> usize {
354 u32::from_le_bytes(p[L_SEQ..L_SEQ + 4].try_into().unwrap()) as usize
355}
356
357fn seq_nibble(seq: &[u8], i: usize) -> u8 {
358 let byte = seq[i / 2];
359 if i.is_multiple_of(2) {
360 byte >> 4
361 } else {
362 byte & 0x0f
363 }
364}
365
366fn set_seq_nibble(seq: &mut [u8], i: usize, code: u8) {
367 if i.is_multiple_of(2) {
368 seq[i / 2] = (seq[i / 2] & 0x0f) | (code << 4);
369 } else {
370 seq[i / 2] = (seq[i / 2] & 0xf0) | (code & 0x0f);
371 }
372}
373
374fn aux_value_len(bytes: &[u8], pos: usize, type_code: u8) -> usize {
375 match type_code {
376 b'A' | b'c' | b'C' => 1,
377 b's' | b'S' => 2,
378 b'i' | b'I' | b'f' => 4,
379 b'Z' | b'H' => bytes[pos..].iter().position(|&b| b == 0).unwrap() + 1,
380 b'B' => {
381 let sub = bytes[pos];
382 let count = u32::from_le_bytes(bytes[pos + 1..pos + 5].try_into().unwrap()) as usize;
383 let elem = match sub {
384 b'c' | b'C' => 1,
385 b's' | b'S' => 2,
386 _ => 4,
387 };
388 1 + 4 + count * elem
389 }
390 _ => panic!("malformed aux field: unknown type code {type_code}"),
391 }
392}
393
394fn write_i32(bytes: &mut [u8], off: usize, v: i32) {
395 bytes[off..off + 4].copy_from_slice(&v.to_le_bytes());
396}
397
398fn write_u32(bytes: &mut [u8], off: usize, v: u32) {
399 bytes[off..off + 4].copy_from_slice(&v.to_le_bytes());
400}
401
402fn write_u16(bytes: &mut [u8], off: usize, v: u16) {
403 bytes[off..off + 2].copy_from_slice(&v.to_le_bytes());
404}
405
406pub fn reset(
410 input: &Path,
411 output_path: Option<&Path>,
412 opts: &ResetOpts,
413 args_cl: &str,
414 workers: NonZero<usize>,
415) -> Result<u64> {
416 let filter = build_aux_filter(opts)?;
417
418 let mut reader = rsomics_bamio::open_with_workers(input, workers)?;
419 let header = reader.read_header().map_err(RsomicsError::Io)?;
420 let input_text = header_to_text(&header)?;
421 let out_text = rebuild_header(&input_text, opts, args_cl);
422 let out_header = reparse_header(&out_text)?;
423
424 match output_path {
425 Some(path) => {
426 let mut writer = rsomics_bamio::create_with_workers(path, workers)?;
427 run(
428 &mut reader,
429 &mut writer,
430 &out_header,
431 &filter,
432 opts.keep_dupflag,
433 )
434 }
435 None => {
436 let mut writer = bam::io::Writer::new(std::io::stdout().lock());
437 run(
438 &mut reader,
439 &mut writer,
440 &out_header,
441 &filter,
442 opts.keep_dupflag,
443 )
444 }
445 }
446}
447
448fn run<W: Write>(
449 reader: &mut rsomics_bamio::ParallelBamReader,
450 writer: &mut bam::io::Writer<W>,
451 header: &sam::Header,
452 filter: &AuxFilter,
453 keep_dupflag: bool,
454) -> Result<u64> {
455 writer.write_header(header).map_err(RsomicsError::Io)?;
456
457 let mut out: Vec<u8> = Vec::new();
458 let mut count: u64 = 0;
459 let mut scanner = raw::RecordReader::new(reader.get_mut());
460 while let Some(rec) = scanner.next()? {
461 if revert_record(&rec, filter, keep_dupflag, &mut out) {
462 write_payload(writer.get_mut(), &out)?;
463 count += 1;
464 }
465 }
466 Ok(count)
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn default_filter_includes_defaults_not_rg() {
475 let opts = ResetOpts::default();
476 let f = build_aux_filter(&opts).unwrap();
477 assert!(!f.survives(*b"NM"));
478 assert!(!f.survives(*b"MD"));
479 assert!(!f.survives(*b"AS"));
480 assert!(f.survives(*b"RG"));
481 assert!(f.survives(*b"XS"));
482 }
483
484 #[test]
485 fn no_rg_removes_rg() {
486 let opts = ResetOpts {
487 no_rg: true,
488 ..Default::default()
489 };
490 let f = build_aux_filter(&opts).unwrap();
491 assert!(!f.survives(*b"RG"));
492 }
493
494 #[test]
495 fn remove_tag_adds_to_defaults() {
496 let opts = ResetOpts {
497 remove_tags: vec![*b"XS"],
498 ..Default::default()
499 };
500 let f = build_aux_filter(&opts).unwrap();
501 assert!(!f.survives(*b"XS"));
502 assert!(!f.survives(*b"NM"));
503 assert!(f.survives(*b"RG"));
504 }
505
506 #[test]
507 fn keep_tag_drops_everything_else() {
508 let opts = ResetOpts {
509 keep_tags: vec![*b"RG"],
510 ..Default::default()
511 };
512 let f = build_aux_filter(&opts).unwrap();
513 assert!(f.survives(*b"RG"));
514 assert!(!f.survives(*b"NM"));
515 assert!(!f.survives(*b"XS"));
516 }
517
518 #[test]
519 fn keep_tag_drops_rg_with_no_rg() {
520 let opts = ResetOpts {
521 keep_tags: vec![*b"RG", *b"BC"],
522 no_rg: true,
523 ..Default::default()
524 };
525 let f = build_aux_filter(&opts).unwrap();
526 assert!(!f.survives(*b"RG"));
527 assert!(f.survives(*b"BC"));
528 }
529
530 #[test]
531 fn rebuild_header_drops_sq_keeps_rg_pg() {
532 let input = "@HD\tVN:1.6\tSO:coordinate\n@SQ\tSN:chr1\tLN:100\n@RG\tID:rg1\tSM:s1\n@PG\tID:bwa\tPN:bwa\n";
533 let text = rebuild_header(input, &ResetOpts::default(), "rsomics-bam-reset in.bam");
534 assert!(text.starts_with("@HD\tVN:1.6\n"));
535 assert!(!text.contains("@SQ"));
536 assert!(text.contains("@RG\tID:rg1\tSM:s1"));
537 assert!(text.contains("@PG\tID:bwa\tPN:bwa"));
538 assert!(text.contains("PN:rsomics-bam-reset"));
539 assert!(text.contains("PP:bwa"));
540 }
541
542 #[test]
543 fn reject_pg_truncates_chain() {
544 let input = "@HD\tVN:1.6\n@PG\tID:a\tPN:a\n@PG\tID:b\tPN:b\tPP:a\n@PG\tID:c\tPN:c\tPP:b\n";
545 let opts = ResetOpts {
546 reject_pg: Some("b".to_owned()),
547 no_pg: true,
548 ..Default::default()
549 };
550 let text = rebuild_header(input, &opts, "");
551 assert!(text.contains("ID:a"));
552 assert!(!text.contains("ID:b"));
553 assert!(!text.contains("ID:c"));
554 }
555
556 #[test]
557 fn unique_pg_id_suffixes_on_collision() {
558 let existing = ["@PG\tID:rsomics-bam-reset\tPN:x"];
559 assert_eq!(
560 unique_pg_id("rsomics-bam-reset", &existing),
561 "rsomics-bam-reset.1"
562 );
563 }
564
565 #[test]
566 fn nt16_roundtrip() {
567 for (code, &ch) in SEQ_NT16_STR.iter().enumerate() {
568 assert_eq!(nt16_code(ch), code as u8);
569 }
570 }
571}