1use crate::constants as afconst;
10use crate::eq_class::IndexedEqList;
11use anyhow::{anyhow, Context};
12use bstr::io::BufReadExt;
13use core::fmt;
14use dashmap::DashMap;
15use libradicl::utils::SPLICE_MASK_U32;
16use needletail::bitkmer::*;
17use std::collections::{HashMap, HashSet};
18use std::error::Error;
19use std::fs::File;
20use std::io::{BufReader, BufWriter, Write};
21use std::path::Path;
22use std::path::PathBuf;
23use std::str::FromStr;
24use thiserror::Error;
25
26pub(crate) fn remove_file_if_exists(fname: &Path) -> anyhow::Result<()> {
45 if fname.exists() {
46 std::fs::remove_file(fname)
47 .with_context(|| format!("could not remove {}", fname.display()))?;
48 }
49 Ok(())
50}
51
52pub(super) fn count_diff_2_bit_packed(a: u64, b: u64) -> usize {
55 let bit_diffs = a ^ b;
56 let two_bit_diffs = (bit_diffs | bit_diffs >> 1) & 0x5555555555555555;
57 two_bit_diffs.count_ones() as usize
58}
59
60#[inline(always)]
61fn unspliced_of(gid: u32) -> u32 {
62 gid + 1
63}
64
65#[inline(always)]
67fn spliced_of(gid: u32) -> u32 {
68 gid
69}
70
71#[inline(always)]
74fn spliced_id(gid: u32) -> u32 {
75 gid & SPLICE_MASK_U32
76}
77
78#[inline(always)]
79pub fn same_gene(g1: u32, g2: u32, with_unspliced: bool) -> bool {
80 (g1 == g2) || (with_unspliced && (spliced_id(g1) == spliced_id(g2)))
81}
82
83#[inline(always)]
84pub fn is_spliced(gid: u32) -> bool {
85 (0x1 & gid) == 0
87}
88
89#[inline(always)]
90pub fn is_unspliced(gid: u32) -> bool {
91 !is_spliced(gid)
93}
94
95pub fn write_permit_list_freq(
97 o_path: &std::path::Path,
98 bclen: u16,
99 permit_freq_map: &HashMap<u64, u64, ahash::RandomState>,
100) -> Result<(), Box<dyn std::error::Error>> {
101 let output = std::fs::File::create(o_path)?;
102 let mut writer = BufWriter::new(&output);
103
104 {
105 writer
107 .write_all(&afconst::PERMIT_FILE_VER.to_le_bytes())
108 .unwrap();
109
110 writer.write_all(&(u64::from(bclen)).to_le_bytes()).unwrap();
112
113 bincode::serialize_into(&mut writer, &permit_freq_map)?;
115 }
116 Ok(())
117}
118
119pub fn write_permit_list_freq_dashmap(
121 o_path: &std::path::Path,
122 bclen: u16,
123 permit_freq_map: &DashMap<u64, u64, ahash::RandomState>,
124) -> Result<(), Box<dyn std::error::Error>> {
125 let output = std::fs::File::create(o_path)?;
126 let mut writer = BufWriter::new(&output);
127
128 {
129 writer
131 .write_all(&afconst::PERMIT_FILE_VER.to_le_bytes())
132 .unwrap();
133
134 writer.write_all(&(u64::from(bclen)).to_le_bytes()).unwrap();
136
137 bincode::serialize_into(&mut writer, &permit_freq_map)?;
139 }
140 Ok(())
141}
142
143fn parse_tg_spliced_unspliced(
153 rdr: &mut csv::Reader<File>,
154 ref_count: usize,
155 rname_to_id: &HashMap<String, u32, ahash::RandomState>,
156 gene_names: &mut Vec<String>,
157 gene_name_to_id: &mut HashMap<String, u32, ahash::RandomState>,
158) -> anyhow::Result<(Vec<u32>, bool)> {
159 let mut tid_to_gid = vec![u32::MAX; ref_count];
164
165 type TsvRec = (String, String, String);
167
168 let mut found = 0usize;
170
171 let mut next_gid = 0u32;
175 for result in rdr.deserialize() {
181 let record: TsvRec = result?;
182 let gene_id = *gene_name_to_id.entry(record.1.clone()).or_insert_with(|| {
184 let cur_gid = next_gid;
187 next_gid += 2;
188 gene_names.push(record.1.clone());
191 cur_gid
192 });
193
194 if let Some(transcript_id) = rname_to_id.get(&record.0) {
196 found += 1;
197 if record.2.eq_ignore_ascii_case("U") {
198 tid_to_gid[*transcript_id as usize] = unspliced_of(gene_id);
201 } else if record.2.eq_ignore_ascii_case("S") {
202 tid_to_gid[*transcript_id as usize] = spliced_of(gene_id);
205 } else {
206 return Err(anyhow!(
207 "Third column in 3 column txp-to-gene file must be S or U"
208 ));
209 }
210 }
211 }
212
213 assert_eq!(
214 found, ref_count,
215 "The tg-map must contain a gene mapping for all transcripts in the header"
216 );
217
218 Ok((tid_to_gid, true))
219}
220
221fn parse_tg_spliced(
222 rdr: &mut csv::Reader<File>,
223 ref_count: usize,
224 rname_to_id: &HashMap<String, u32, ahash::RandomState>,
225 gene_names: &mut Vec<String>,
226 gene_name_to_id: &mut HashMap<String, u32, ahash::RandomState>,
227) -> anyhow::Result<(Vec<u32>, bool)> {
228 let mut tid_to_gid = vec![u32::MAX; ref_count];
233 type TsvRec = (String, String);
235 let mut found = 0usize;
237 for result in rdr.deserialize() {
243 match result {
244 Ok(record_in) => {
245 let record: TsvRec = record_in;
246 let next_id = gene_name_to_id.len() as u32;
249 let gene_id = *gene_name_to_id.entry(record.1.clone()).or_insert(next_id);
250 if gene_id == next_id {
253 gene_names.push(record.1.clone());
254 }
255 if let Some(transcript_id) = rname_to_id.get(&record.0) {
257 found += 1;
258 tid_to_gid[*transcript_id as usize] = gene_id;
259 }
260 }
261 Err(e) => {
262 return Err(anyhow!(
270 "failed to parse the transcript-to-gene map : {}.",
271 e
272 ));
273 }
274 }
275 }
276
277 assert_eq!(
278 found, ref_count,
279 "The tg-map must contain a gene mapping for all transcripts in the header"
280 );
281
282 Ok((tid_to_gid, false))
283}
284
285pub fn parse_tg_map(
286 tg_map: &PathBuf,
287 ref_count: usize,
288 rname_to_id: &HashMap<String, u32, ahash::RandomState>,
289 gene_names: &mut Vec<String>,
290 gene_name_to_id: &mut HashMap<String, u32, ahash::RandomState>,
291) -> anyhow::Result<(Vec<u32>, bool)> {
292 let t2g_file = std::fs::File::open(tg_map).context("couldn't open file")?;
293 let mut rdr = csv::ReaderBuilder::new()
294 .has_headers(false)
295 .delimiter(b'\t')
296 .from_reader(t2g_file);
297
298 let headers = rdr.headers()?;
299 match headers.len() {
300 2 => {
301 parse_tg_spliced(
303 &mut rdr,
304 ref_count,
305 rname_to_id,
306 gene_names,
307 gene_name_to_id,
308 )
309 }
310 3 => {
311 parse_tg_spliced_unspliced(
313 &mut rdr,
314 ref_count,
315 rname_to_id,
316 gene_names,
317 gene_name_to_id,
318 )
319 }
320 _ => {
321 Err(anyhow!(
323 "Transcript-gene mapping must have either 2 or 3 columns."
324 ))
325 }
326 }
327}
328
329pub fn extract_counts(
339 gene_eqc: &HashMap<Vec<u32>, u32, ahash::RandomState>,
340 num_counts: usize,
341) -> Vec<f32> {
342 let unspliced_offset = num_counts / 3;
345 let ambig_offset = 2 * unspliced_offset;
346 let mut counts = vec![0_f32; num_counts];
347
348 for (labels, count) in gene_eqc {
349 match labels.len() {
353 1 => {
354 if let Some(gid) = labels.first() {
356 let idx = if is_spliced(*gid) {
357 (*gid >> 1) as usize
358 } else {
359 unspliced_offset + (*gid >> 1) as usize
360 };
361 counts[idx] += *count as f32;
362 }
363 }
364 2 => {
365 if let (Some(g1), Some(g2)) = (labels.first(), labels.last()) {
367 if same_gene(*g1, *g2, true) {
368 let idx = ambig_offset + (*g1 >> 1) as usize;
369 counts[idx] += *count as f32;
371 } else {
372 match (is_spliced(*g1), is_spliced(*g2)) {
374 (true, false) => {
375 counts[(*g1 >> 1) as usize] += *count as f32;
376 }
377 (false, true) => {
378 counts[(*g2 >> 1) as usize] += *count as f32;
379 }
380 _ => { }
381 }
382 }
383 }
384 }
385 3..=10 => {
386 let mut iter = labels.iter();
392 if let Some(sidx) = iter.position(|&x| is_spliced(x)) {
394 if let Some(_sidx2) = iter.position(|&x| is_spliced(x)) {
396 } else {
399 if let Some(sg) = labels.get(sidx) {
404 if let Some(ng) = labels.get(sidx + 1) {
405 if same_gene(*sg, *ng, true) {
406 let idx = ambig_offset + (*sg >> 1) as usize;
407 counts[idx] += *count as f32;
408 continue;
409 }
410 }
411 counts[(*sg >> 1) as usize] += *count as f32;
412 }
413 }
414 }
415 }
416 _ => {}
417 }
418 }
419 counts
420}
421
422pub fn extract_counts_mm_uniform(
427 gene_eqc: &HashMap<Vec<u32>, u32, ahash::RandomState>,
428 num_counts: usize,
429) -> Vec<f32> {
430 let unspliced_offset = num_counts / 3;
433 let ambig_offset = 2 * unspliced_offset;
434 let mut counts = vec![0_f32; num_counts];
435 let mut tvec = Vec::<usize>::with_capacity(16);
436
437 for (labels, count) in gene_eqc {
438 match labels.len() {
442 1 => {
443 if let Some(gid) = labels.first() {
445 let idx = if is_spliced(*gid) {
446 (*gid >> 1) as usize
447 } else {
448 unspliced_offset + (*gid >> 1) as usize
449 };
450 counts[idx] += *count as f32;
451 }
452 }
453 _ => {
454 let mut iter = labels.iter().peekable();
456 tvec.clear();
457 while let Some(gn) = iter.next() {
458 let mut idx = (gn >> 1) as usize;
460 if is_spliced(*gn) {
464 if let Some(ng) = iter.peek() {
465 if same_gene(*gn, **ng, true) {
469 idx += ambig_offset;
470 iter.next();
473 }
474 }
478 } else {
479 idx += unspliced_offset;
484 }
485 tvec.push(idx)
486 }
487 let fcount = (*count as f32) / (tvec.len() as f32);
488 for g in &tvec {
489 counts[*g] += fcount;
490 }
491 }
492 }
493 }
494 counts
495}
496
497pub fn extract_usa_eqmap(
506 gene_eqc: &HashMap<Vec<u32>, u32, ahash::RandomState>,
507 num_counts: usize,
508 idx_eq_list: &mut IndexedEqList,
509 eq_id_count: &mut Vec<(u32, u32)>,
510) {
511 idx_eq_list.clear();
521 eq_id_count.clear();
522
523 let unspliced_offset = num_counts / 3;
525 let ambig_offset = 2 * unspliced_offset;
526 let mut tvec = Vec::<u32>::with_capacity(16);
527
528 for (ctr, (labels, count)) in gene_eqc.iter().enumerate() {
529 match labels.len() {
533 1 => {
534 if let Some(gid) = labels.first() {
536 let idx = if is_spliced(*gid) {
537 (*gid >> 1) as usize
538 } else {
539 unspliced_offset + (*gid >> 1) as usize
540 };
541 idx_eq_list.add_single_label(idx as u32);
542 eq_id_count.push((ctr as u32, *count));
543 }
544 }
545 _ => {
546 let mut iter = labels.iter().peekable();
548 tvec.clear();
549 while let Some(gn) = iter.next() {
550 let mut idx = (gn >> 1) as usize;
552 if is_spliced(*gn) {
556 if let Some(ng) = iter.peek() {
557 if same_gene(*gn, **ng, true) {
561 idx += ambig_offset;
562 iter.next();
565 }
566 }
570 } else {
571 idx += unspliced_offset;
576 }
577 tvec.push(idx as u32);
578 }
579 idx_eq_list.add_label_vec(tvec.as_slice());
584 eq_id_count.push((ctr as u32, *count));
585 }
586 }
587 }
588}
589
590pub fn get_bit_mask(nt_index: usize, fill_with: u64) -> u64 {
591 let mut mask: u64 = fill_with;
592 mask <<= 2 * (nt_index - 1);
593 mask
594}
595
596pub fn get_all_snps(bc: u64, bc_length: usize) -> Vec<u64> {
597 assert!(
598 bc <= 2u64.pow(2 * bc_length as u32),
599 "the barcode id is larger than possible (based on barcode length)"
600 );
601 assert!(
602 bc_length <= 32,
603 "barcode length greater than 32 not supported"
604 );
605
606 let mut snps: Vec<u64> = Vec::with_capacity(3 * bc_length);
607
608 for nt_index in 1..=bc_length {
609 let bit_mask = bc & !get_bit_mask(nt_index, 3);
611
612 for i in 0..=3 {
614 let new_bc = bit_mask | get_bit_mask(nt_index, i);
615 if new_bc != bc {
616 snps.push(new_bc);
617 }
618 }
619 }
620
621 snps
622}
623
624pub fn get_all_indels(bc: u64, bc_length: usize) -> Vec<u64> {
625 assert!(
626 bc <= 2u64.pow(2 * bc_length as u32),
627 "the barcode id is larger than possible (based on barcode length)"
628 );
629 assert!(
630 bc_length <= 32,
631 "barcode length greater than 32 not supported"
632 );
633
634 let mut indels: Vec<u64> = Vec::with_capacity(8 * (bc_length - 1));
635
636 for nt_index in 1..bc_length {
637 let mut bit_mask = 1 << (2 * nt_index);
638 bit_mask -= 1;
639
640 let upper_half = bc & !bit_mask;
641 let lower_half = bc & bit_mask;
642
643 for i in 0..=3 {
645 let new_insertion_bc = upper_half | get_bit_mask(nt_index, i) | (lower_half >> 2);
646 let new_deletion_bc = upper_half
647 | get_bit_mask(1, i)
648 | ((lower_half & !get_bit_mask(nt_index + 1, 3)) << 2);
649
650 if new_insertion_bc != bc {
651 indels.push(new_insertion_bc);
652 }
653 if new_deletion_bc != bc {
654 indels.push(new_deletion_bc);
655 }
656 }
657 }
658
659 indels
660}
661
662pub fn get_all_one_edit_neighbors(
663 bc: u64,
664 bc_length: usize,
665 neighbors: &mut HashSet<u64>,
666) -> Result<(), Box<dyn Error>> {
667 neighbors.clear();
668
669 let snps: Vec<u64> = get_all_snps(bc, bc_length);
670 let indels: Vec<u64> = get_all_indels(bc, bc_length);
671
672 neighbors.extend(&snps);
673 neighbors.extend(&indels);
674
675 Ok(())
676}
677
678pub fn generate_whitelist_set(
679 whitelist_bcs: &[u64],
680 bc_length: usize,
681) -> Result<HashSet<u64>, Box<dyn Error>> {
682 let num_bcs = whitelist_bcs.len();
683
684 let mut one_edit_barcode_hash: HashSet<u64> = HashSet::new();
685 let mut neighbors: HashSet<u64> = HashSet::new();
686 one_edit_barcode_hash.reserve(10 * num_bcs);
687 neighbors.reserve(3 * bc_length + 8 * (bc_length - 1));
691
692 for bc in whitelist_bcs {
693 get_all_one_edit_neighbors(*bc, bc_length, &mut neighbors)?;
694 one_edit_barcode_hash.extend(&neighbors);
695 }
696
697 Ok(one_edit_barcode_hash)
698}
699
700pub fn generate_permitlist_map(
706 permit_bcs: &[u64],
707 bc_length: usize,
708) -> Result<HashMap<u64, u64>, Box<dyn Error>> {
709 let num_bcs = permit_bcs.len();
710
711 let mut one_edit_barcode_map: HashMap<u64, u64> = HashMap::with_capacity(10 * num_bcs);
712 for bc in permit_bcs {
714 one_edit_barcode_map.insert(*bc, *bc);
715 }
716
717 let mut neighbors: HashSet<u64> = HashSet::with_capacity(3 * bc_length + 8 * (bc_length - 1));
721
722 for bc in permit_bcs {
723 get_all_one_edit_neighbors(*bc, bc_length, &mut neighbors)?;
724 for n in &neighbors {
725 one_edit_barcode_map.entry(*n).or_insert(*bc);
726 }
727 }
728
729 Ok(one_edit_barcode_map)
730}
731
732pub fn read_filter_list(
737 flist: &PathBuf,
738 bclen: u16,
739) -> anyhow::Result<HashSet<u64, ahash::RandomState>> {
740 let s = ahash::RandomState::with_seeds(2u64, 7u64, 1u64, 8u64);
741 let mut fset = HashSet::<u64, ahash::RandomState>::with_hasher(s);
742
743 let filt_file = std::fs::File::open(flist).context("couldn't open file")?;
744 let mut reader = BufReader::new(filt_file);
745
746 reader
748 .for_byte_line(|line| {
749 let mut bnk = BitNuclKmer::new(line, bclen as u8, false);
750 let (_, k, _) = bnk.next().unwrap();
751 fset.insert(k.0);
752 Ok(true)
753 })
754 .unwrap();
755
756 Ok(fset)
757}
758
759pub fn is_velo_mode(input_dir: &PathBuf) -> bool {
760 let parent = std::path::Path::new(input_dir);
761 let meta_data_file = File::open(parent.join("generate_permit_list.json"))
763 .expect("could not open the generate_permit_list.json file.");
764 let mdata: serde_json::Value = serde_json::from_reader(meta_data_file)
765 .expect("could not deseralize generate_permit_list.json");
766 let vm = mdata.get("velo_mode");
767 match vm {
768 Some(v) => v.as_bool().unwrap_or(false),
769 None => false,
770 }
771}
772
773#[allow(dead_code)]
774#[derive(Debug, PartialEq, Eq)]
775pub struct InternalVersionInfo {
776 pub major: u32,
777 pub minor: u32,
778 pub patch: u32,
779}
780
781impl InternalVersionInfo {
782 pub fn is_compatible_with(&self, other: &InternalVersionInfo) -> Result<(), String> {
783 if self.major == other.major && self.minor == other.minor {
784 Ok(())
785 } else {
786 let s = format!(
787 "running alevin-fry {} on {} results, please regenerate the results using alevin-fry {} or greater",
788 self, other, self
789 );
790 Err(s)
791 }
792 }
793}
794
795impl fmt::Display for InternalVersionInfo {
796 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
797 write!(f, "v{}.{}.{}", self.major, self.minor, self.patch)
798 }
799}
800
801#[derive(Error, Debug)]
802pub enum VersionParseError {
803 #[error("The version string should be of the format x.y.z; it was `{0}`")]
804 IncorrectFormat(String),
805}
806
807impl FromStr for InternalVersionInfo {
808 type Err = VersionParseError;
809
810 fn from_str(vs: &str) -> Result<Self, Self::Err> {
811 let versions: Vec<u32> = vs.split('.').map(|s| s.parse::<u32>().unwrap()).collect();
812 if versions.len() != 3 {
813 return Err(VersionParseError::IncorrectFormat(vs.to_string()));
814 }
815 Ok(Self {
816 major: versions[0],
817 minor: versions[1],
818 patch: versions[2],
819 })
820 }
821}
822
823#[cfg(test)]
824mod tests {
825 use crate::utils::generate_whitelist_set;
826 use crate::utils::get_all_indels;
827 use crate::utils::get_all_one_edit_neighbors;
828 use crate::utils::get_all_snps;
829 use crate::utils::get_bit_mask;
830 use crate::utils::InternalVersionInfo;
831 use std::collections::HashSet;
832 use std::str::FromStr;
833
834 #[test]
835 fn test_version_info() {
836 let vi = InternalVersionInfo::from_str("1.2.3").unwrap();
837 assert_eq!(
838 vi,
839 InternalVersionInfo {
840 major: 1,
841 minor: 2,
842 patch: 3
843 }
844 );
845 }
846
847 #[test]
848 fn test_get_bit_mask() {
849 let mut output = Vec::new();
850 for i in 0..=3 {
851 let mask = get_bit_mask(2, i);
852 output.push(mask);
853 }
854 assert_eq!(output, vec![0, 4, 8, 12]);
855 }
856
857 #[test]
858 fn test_get_all_snps() {
859 let mut output: Vec<u64> = get_all_snps(7, 3).into_iter().collect();
860 output.sort_unstable();
861
862 assert_eq!(output, vec![3, 4, 5, 6, 11, 15, 23, 39, 55]);
863 }
864
865 #[test]
866 fn test_get_all_indels() {
867 let mut output: Vec<u64> = get_all_indels(7, 3).into_iter().collect();
868 output.sort_unstable();
869 output.dedup();
870
871 assert_eq!(output, vec![1, 4, 5, 6, 9, 12, 13, 14, 15, 28, 29, 30, 31]);
872 }
873
874 #[test]
875 fn test_get_all_one_edit_neighbors() {
876 let mut neighbors: HashSet<u64> = HashSet::new();
877 get_all_one_edit_neighbors(7, 3, &mut neighbors).unwrap();
878
879 let mut output: Vec<u64> = neighbors.into_iter().collect();
880
881 output.sort_unstable();
882 output.dedup();
883
884 assert_eq!(
885 output,
886 vec![1, 3, 4, 5, 6, 9, 11, 12, 13, 14, 15, 23, 28, 29, 30, 31, 39, 55]
887 );
888 }
889
890 #[test]
891 fn test_generate_whitelist_hash() {
892 let neighbors: HashSet<u64> = generate_whitelist_set(&[7], 3).unwrap();
893 let mut output: Vec<u64> = neighbors.into_iter().collect();
894
895 output.sort_unstable();
896 output.dedup();
897
898 assert_eq!(
899 output,
900 vec![1, 3, 4, 5, 6, 9, 11, 12, 13, 14, 15, 23, 28, 29, 30, 31, 39, 55]
901 );
902 }
903}