use crate::input_iterator::{InputIterator, PatternRecord};
use sassy::{
RcSearchAble, Searcher, Strand,
profiles::{Iupac, Profile},
};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
use std::{
io::{BufWriter, Write},
path::PathBuf,
sync::Mutex,
};
#[derive(clap::Parser)]
pub struct CrisprArgs {
#[arg(long, short = 'g')]
guide: String,
#[arg(short, long)]
k: usize,
#[arg(short = 'o', long)]
output: Option<PathBuf>,
#[arg(long)]
max_n_frac: f32,
#[arg(short = 'j', long)]
threads: Option<usize>,
#[arg(long, default_value_t = 3)]
pam_length: usize,
#[arg(long)]
allow_pam_edits: bool,
#[arg(long)]
no_rc: bool,
path: PathBuf,
}
fn get_output_writer(args: &CrisprArgs) -> Box<dyn Write + Send> {
if let Some(output_path) = &args.output {
Box::new(BufWriter::new(File::create(output_path).unwrap())) as Box<dyn Write + Send>
} else {
Box::new(std::io::stdout()) as Box<dyn Write + Send>
}
}
pub fn check_n_frac(max_n_frac: f32, match_slice: &[u8]) -> bool {
let n_count = match_slice
.iter()
.filter(|c| (**c & 0xDF) == b'N') .count() as f32;
let n_frac = n_count / match_slice.len() as f32;
n_frac <= max_n_frac
}
fn print_and_check_params(args: &CrisprArgs, guide_sequences: &[Vec<u8>]) -> (String, f32) {
if !(0.0..=1.0).contains(&args.max_n_frac) {
eprintln!("[N-chars] Error: max_n_frac must be between 0 and 1.0");
std::process::exit(1);
}
if guide_sequences.is_empty() {
eprintln!(
"[PAM] Error: No guide sequences provided, please check your input file (one guide sequence per line)"
);
std::process::exit(1);
}
let pam = if !guide_sequences.is_empty() {
let guide = &guide_sequences[0];
let pam = &guide[guide.len() - args.pam_length..];
println!("[PAM] Sequence: [{}]", String::from_utf8_lossy(pam));
println!(
"[PAM] If the above PAM is incorrect, please make sure that the guide sequence ENDs with the PAM-sequence, i.e. XXXXXGGN (not it's reverse complement)"
);
pam
} else {
unreachable!("No guide sequences provided");
};
if guide_sequences.len() > 1 {
for guide_sequence in guide_sequences {
let guide_pam = &guide_sequence[guide_sequence.len() - args.pam_length..];
if pam != guide_pam {
eprintln!(
"[PAM] One of the guide sequences has a PAM different than the provided PAM"
);
eprintln!(
"[PAM] provided PAM {}, detected PAM {}",
String::from_utf8_lossy(pam),
String::from_utf8_lossy(guide_pam)
);
std::process::exit(1);
}
}
}
println!("[PAM] PAM used to filter: {}", String::from_utf8_lossy(pam));
println!("[PAM] Edits in PAM are allowed: {}", args.allow_pam_edits);
println!(
"[N-chars] Allowing up to {}% N characters",
args.max_n_frac * 100.0
);
(String::from_utf8_lossy(pam).into_owned(), args.max_n_frac)
}
pub fn read_guide_sequences(path: &str) -> Vec<Vec<u8>> {
let file = File::open(path).expect("Failed to open guide file");
let reader = BufReader::new(file);
reader
.lines()
.map(|l| l.unwrap().as_bytes().to_vec())
.collect::<Vec<_>>()
.into_iter()
.filter(|seq| !seq.is_empty())
.collect()
}
pub fn matching_seq<P: Profile>(seq1: &[u8], seq2: &[u8]) -> bool {
for (c1, c2) in seq1.iter().zip(seq2.iter()) {
if !P::is_match(*c1, *c2) {
return false;
}
}
true
}
pub fn crispr(args: &CrisprArgs) {
let guide_sequences = read_guide_sequences(&args.guide);
println!("[GUIDES] Found {} guides", guide_sequences.len());
if guide_sequences.is_empty() {
return;
}
let writer = &Mutex::new(get_output_writer(args));
let header = format!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
"guide", "text_id", "cost", "strand", "start", "end", "match_region", "cigar"
);
writer.lock().unwrap().write_all(header.as_bytes()).unwrap();
let (pam, max_n_frac) = print_and_check_params(args, &guide_sequences);
let pam = pam.as_bytes();
let pam_compl = Iupac::complement(pam);
let pam_compl = pam_compl.as_slice();
let total_found = AtomicUsize::new(0);
let num_threads = args.threads.unwrap_or_else(num_cpus::get);
println!("[Threads] Using {num_threads} threads");
let queries: Vec<PatternRecord> = guide_sequences
.iter()
.enumerate()
.map(|(i, seq)| PatternRecord {
id: format!("guide_{}", i + 1),
seq: seq.clone(),
})
.collect();
let paths = vec![args.path.clone()];
let task_iter = InputIterator::new(&paths, &queries, None, None, true);
let start = Instant::now();
std::thread::scope(|scope| {
for _ in 0..num_threads {
scope.spawn(|| {
let mut searcher = if args.no_rc {
Searcher::<Iupac>::new_fwd()
} else {
Searcher::<Iupac>::new_rc()
};
let filter_fn = |_q: &[u8], text_up_to_end: &[u8], strand: Strand| {
let pam_slice = &text_up_to_end[text_up_to_end.len() - args.pam_length..];
if strand == Strand::Fwd {
matching_seq::<Iupac>(pam_slice, pam)
} else {
matching_seq::<Iupac>(pam_slice, pam_compl)
}
};
while let Some((_batch_id, batch)) = task_iter.next_batch() {
for text in &*batch.2 {
for pattern in batch.1 {
let guide_sequence = &pattern.seq;
let guide_string = String::from_utf8_lossy(guide_sequence);
let id_text = &text;
let id = &id_text.id;
let text = &id_text.seq;
let matches = if !args.allow_pam_edits {
searcher.search_with_fn(guide_sequence, text, args.k, true, filter_fn)
} else {
searcher.search_all(guide_sequence, text, args.k)
};
if matches.is_empty() {
continue;
}
total_found.fetch_add(matches.len(), Ordering::Relaxed);
let mut writer_guard = writer.lock().unwrap();
for m in matches {
let start = m.text_start;
let end = m.text_end;
let text = text.text();
let slice = &text.as_ref()[start..end];
let n_ok = if max_n_frac < 100.0 {
check_n_frac(max_n_frac, slice)
} else {
true
};
if !n_ok {
continue;
}
total_found.fetch_add(1, Ordering::Relaxed);
let match_region = if m.strand == Strand::Rc {
let rc = <Iupac as Profile>::reverse_complement(slice);
String::from_utf8_lossy(&rc).into_owned()
} else {
String::from_utf8_lossy(slice).into_owned()
};
let cost = m.cost;
let cigar = m.cigar.to_string();
let strand = match m.strand {
Strand::Fwd => "+",
Strand::Rc => "-",
};
writeln!(
writer_guard,
"{guide_string}\t{id}\t{cost}\t{strand}\t{start}\t{end}\t{match_region}\t{cigar}"
)
.unwrap();
}
drop(writer_guard);
}
}
}
});
}
});
println!("\nSummary");
println!(
" Total targets found: {}",
total_found.load(Ordering::Relaxed)
);
println!(" Time taken: {:?}", start.elapsed());
}