use std::{
fs::File,
io::Seek,
path::PathBuf,
sync::atomic::{AtomicUsize, Ordering::Relaxed},
};
use clap::Parser;
use indicatif::ParallelProgressIterator;
use itertools::Itertools;
use log::info;
use packed_seq::{PackedNSeqVec, PackedSeqVec, Seq, SeqVec};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use simd_sketch::SketchParams;
#[derive(clap::Parser)]
struct Args {
#[command(subcommand)]
command: Command,
}
#[derive(clap::Subcommand)]
enum Command {
Sketch {
#[command(flatten)]
params: SketchParams,
paths: Vec<PathBuf>,
#[arg(long, short = 'j')]
threads: Option<usize>,
},
Dist {
#[command(flatten)]
params: SketchParams,
path_a: PathBuf,
path_b: PathBuf,
#[arg(long, short = 'j')]
threads: Option<usize>,
},
Triangle {
#[command(flatten)]
params: SketchParams,
paths: Vec<PathBuf>,
#[arg(long)]
output: Option<PathBuf>,
#[arg(long)]
save_sketches: bool,
#[arg(long, short = 'j')]
threads: Option<usize>,
},
}
const BINCODE_CONFIG: bincode::config::Configuration<
bincode::config::LittleEndian,
bincode::config::Fixint,
> = bincode::config::standard().with_fixed_int_encoding();
const EXTENSION: &str = "ssketch";
const SKETCH_VERSION: usize = 1;
#[derive(bincode::Encode, bincode::Decode)]
pub struct VersionedSketch {
version: usize,
sketch: simd_sketch::Sketch,
}
fn main() {
env_logger::init();
let args = Args::parse();
let (Command::Sketch { threads, .. }
| Command::Dist { threads, .. }
| Command::Triangle { threads, .. }) = &args.command;
if let Some(threads) = threads {
rayon::ThreadPoolBuilder::new()
.num_threads(*threads)
.build_global()
.unwrap();
}
let (params, paths) = match &args.command {
Command::Dist {
params,
path_a,
path_b,
..
} => (params, vec![path_a.clone(), path_b.clone()]),
Command::Sketch { params, paths, .. } | Command::Triangle { params, paths, .. } => {
(params, collect_paths(&paths))
}
};
let save_sketches = match &args.command {
Command::Sketch { .. } => true,
Command::Dist { .. } => false,
Command::Triangle { save_sketches, .. } => *save_sketches,
};
let q = paths.len();
let sketcher = params.build();
let params = sketcher.params();
let style = indicatif::ProgressStyle::with_template(
"{msg:.bold} [{elapsed_precise:.cyan}] {bar} {pos}/{len} ({percent:>3}%)",
)
.unwrap()
.progress_chars("##-");
let start = std::time::Instant::now();
let num_sketched = AtomicUsize::new(0);
let num_read = AtomicUsize::new(0);
let num_written = AtomicUsize::new(0);
let total_bytes = AtomicUsize::new(0);
let sketches: Vec<_> = paths
.par_iter()
.progress_with_style(style.clone())
.with_message("Sketching")
.with_finish(indicatif::ProgressFinish::AndLeave)
.map(|path| {
let read_sketch = |path| {
num_read.fetch_add(1, Relaxed);
let mut file = File::open(path).unwrap();
let version: usize =
bincode::decode_from_std_read(&mut file, BINCODE_CONFIG).unwrap();
if version != SKETCH_VERSION {
panic!("Unsupported sketch version: {version}. Only version {SKETCH_VERSION} is supported.");
}
file.seek(std::io::SeekFrom::Start(0)).unwrap();
let VersionedSketch {
version,
sketch,
} = bincode::decode_from_std_read(&mut file, BINCODE_CONFIG).unwrap();
assert_eq!(version, SKETCH_VERSION);
let mut sketch_params = sketch.to_params();
sketch_params.filter_empty = params.filter_empty;
if *params != sketch_params {
panic!(
"Sketch parameters do not match:\nCommand line: {params:?}\nOn disk: {sketch_params:?}",
);
}
return sketch;
};
if path.extension().is_some_and(|ext| ext == EXTENSION) {
return read_sketch(path);
}
let ssketch_path = path.with_extension(EXTENSION);
if ssketch_path.exists() {
return read_sketch(&ssketch_path);
}
let mut reader = needletail::parse_fastx_file(&path).unwrap();
let mut sketch;
if params.filter_out_n {
let mut seqs = vec![];
while let Some(r) = reader.next() {
seqs.push(PackedNSeqVec::from_ascii(&r.unwrap().seq()));
}
let slices = seqs.iter().map(|s| s.as_slice()).collect_vec();
let size = slices.iter().map(|s| s.seq.len()).sum::<usize>();
total_bytes.fetch_add(size, Relaxed);
sketch = sketcher.sketch_seqs(&slices);
}else {
let mut seqs = vec![];
while let Some(r) = reader.next() {
seqs.push(PackedSeqVec::from_ascii(&r.unwrap().seq()));
}
let slices = seqs.iter().map(|s| s.as_slice()).collect_vec();
let size = slices.iter().map(|s| s.len()).sum::<usize>();
total_bytes.fetch_add(size, Relaxed);
sketch = sketcher.sketch_seqs(&slices);
}
num_sketched.fetch_add(1, Relaxed);
if save_sketches {
num_written.fetch_add(1, Relaxed);
let versioned_sketch = VersionedSketch { version: SKETCH_VERSION, sketch };
bincode::encode_into_std_write(
&versioned_sketch,
&mut File::create(ssketch_path).unwrap(),
BINCODE_CONFIG,
)
.unwrap();
sketch = versioned_sketch.sketch;
}
sketch
})
.collect();
let t_sketch = start.elapsed();
info!(
"Sketching {q} seqs took {t_sketch:?} ({:?} avg, {} MiB/s)",
t_sketch / q as u32,
total_bytes.into_inner() as f32 / t_sketch.as_secs_f32() / (1 << 20) as f32
);
let num_read = num_read.into_inner();
let num_sketched = num_sketched.into_inner();
let num_written = num_written.into_inner();
info!("Read {num_read} sketches from disk.");
info!("Newly sketched {num_sketched} files.");
info!("Wrote {num_written} sketches to disk.");
if matches!(args.command, Command::Sketch { .. }) {
return;
}
let num_pairs = q * (q - 1) / 2;
let mut pairs = Vec::with_capacity(num_pairs);
for i in 0..q {
for j in 0..i {
pairs.push((i, j));
}
}
let start = std::time::Instant::now();
let dists: Vec<_> = pairs
.into_par_iter()
.progress_with_style(style.clone())
.with_message("Distances")
.with_finish(indicatif::ProgressFinish::AndLeave)
.map(|(i, j)| sketches[i].mash_distance(&sketches[j]))
.collect();
let t_dist = start.elapsed();
let cnt = q * (q - 1) / 2;
info!(
"Computing {cnt} dists took {t_dist:?} ({:?} avg)",
t_dist / cnt.max(1) as u32
);
match &args.command {
Command::Sketch { .. } => {
unreachable!();
}
Command::Dist { .. } => {
println!("Distance: {:.4}", dists[0]);
return;
}
Command::Triangle { output, .. } => {
use std::io::Write;
let mut out = Vec::new();
writeln!(out, "{q}").unwrap();
let mut d = dists.iter();
for i in 0..q {
write!(out, "{}", paths[i].to_str().unwrap()).unwrap();
for _ in 0..i {
write!(out, "\t{:.7}", d.next().unwrap()).unwrap();
}
writeln!(out).unwrap();
}
match output {
Some(output) => std::fs::write(output, out).unwrap(),
None => println!("{}", str::from_utf8(&out).unwrap()),
}
}
}
}
fn collect_paths(paths: &Vec<PathBuf>) -> Vec<PathBuf> {
let mut res = vec![];
for path in paths {
if path.is_dir() {
res.extend(path.read_dir().unwrap().map(|entry| entry.unwrap().path()));
} else {
res.push(path.clone());
}
}
res.sort();
res
}