use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::fs::{self, File};
use std::io::{self, BufRead, BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Instant;
use gbz::{support, Orientation};
use simple_sds::serialize;
use crate::formats;
use crate::utils;
#[cfg(test)]
mod tests;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum KeyType {
NodeInterval,
Hash,
}
#[derive(Clone, Debug)]
struct GAFRecord {
key: u64,
value: Vec<u8>,
}
impl GAFRecord {
const MISSING_KEY: u64 = u64::MAX;
const PATH_FIELD: usize = 5;
fn new(value: Vec<u8>, key_type: KeyType) -> Self {
let mut record = Self {
key: Self::MISSING_KEY,
value,
};
record.set_key(key_type);
record
}
fn set_key(&mut self, key_type: KeyType) {
self.key = match key_type {
KeyType::NodeInterval => self.extract_node_interval_key(),
KeyType::Hash => self.extract_hash_key(),
};
}
fn extract_node_interval_key(&self) -> u64 {
let path = match self.get_field(Self::PATH_FIELD) {
Some(p) => p,
None => return Self::MISSING_KEY,
};
let mut min_handle: u32 = u32::MAX;
let mut max_handle: u32 = 0;
let mut i = 0;
while i < path.len() {
let orientation = if path[i] == b'>' {
i += 1;
Orientation::Forward
} else if path[i] == b'<' {
i += 1;
Orientation::Reverse
} else {
return Self::MISSING_KEY;
};
let start = i;
while i < path.len() && path[i].is_ascii_digit() {
i += 1;
}
if start < i {
if let Ok(id_str) = std::str::from_utf8(&path[start..i]) {
if let Ok(id) = id_str.parse::<usize>() {
let handle = support::encode_node(id, orientation) as u32;
min_handle = min_handle.min(handle);
max_handle = max_handle.max(handle);
} else {
return Self::MISSING_KEY;
}
} else {
return Self::MISSING_KEY;
}
}
}
if min_handle == u32::MAX {
Self::MISSING_KEY
} else {
((min_handle as u64) << 32) | (max_handle as u64)
}
}
fn extract_hash_key(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.value.hash(&mut hasher);
hasher.finish()
}
fn get_field(&self, field_index: usize) -> Option<&[u8]> {
self.value.split(|&b| b == b'\t').nth(field_index)
}
fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.key.to_le_bytes())?;
let len = self.value.len() as u64;
writer.write_all(&len.to_le_bytes())?;
writer.write_all(&self.value)?;
Ok(())
}
fn write_gaf_line<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.value)?; Ok(())
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut key_bytes = [0u8; 8];
reader.read_exact(&mut key_bytes)?;
let key = u64::from_le_bytes(key_bytes);
let mut len_bytes = [0u8; 8];
reader.read_exact(&mut len_bytes)?;
let len = u64::from_le_bytes(len_bytes) as usize;
let mut value = vec![0u8; len];
reader.read_exact(&mut value)?;
Ok(Self { key, value })
}
fn flip_key(&mut self) {
self.key = u64::MAX - self.key;
}
}
impl PartialEq for GAFRecord {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl Eq for GAFRecord {}
impl PartialOrd for GAFRecord {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for GAFRecord {
fn cmp(&self, other: &Self) -> Ordering {
self.key.cmp(&other.key)
}
}
struct TempFile {
path: PathBuf,
records: usize,
}
impl TempFile {
fn create() -> io::Result<Self> {
let path = serialize::temp_file_name("gaf-sort");
Ok(Self { path, records: 0 })
}
fn writer(&self) -> io::Result<BufWriter<zstd::Encoder<'static, File>>> {
let file = File::create(&self.path)?;
let encoder = zstd::Encoder::new(file, 3)?; Ok(BufWriter::new(encoder))
}
fn reader(&self) -> io::Result<zstd::Decoder<'static, BufReader<std::fs::File>>> {
let file = std::fs::File::open(&self.path)?;
let decoder = zstd::Decoder::new(file)?;
Ok(decoder)
}
}
impl Drop for TempFile {
fn drop(&mut self) {
let _ = fs::remove_file(&self.path);
}
}
enum InitialSortResult {
TempFiles(Vec<Arc<TempFile>>),
SingleBatch(Vec<Vec<u8>>),
}
#[derive(Clone, Debug)]
pub struct SortParameters {
pub key_type: KeyType,
pub records_per_file: usize,
pub files_per_merge: usize,
pub buffer_size: usize,
pub threads: usize,
pub stable: bool,
pub progress: bool,
}
impl SortParameters {
pub const DEFAULT_RECORDS_PER_FILE: usize = 1_000_000;
pub const DEFAULT_FILES_PER_MERGE: usize = 32;
pub const DEFAULT_BUFFER_SIZE: usize = 1000;
pub fn new() -> Self {
Self::default()
}
pub fn validate(&self) -> Result<(), String> {
if self.records_per_file == 0 {
return Err(String::from("SortParameters: records_per_file must be greater than 0"));
}
if self.files_per_merge < 2 {
return Err(String::from("SortParameters: files_per_merge must be at least 2"));
}
if self.buffer_size == 0 {
return Err(String::from("SortParameters: buffer_size must be greater than 0"));
}
if self.threads == 0 {
return Err(String::from("SortParameters: threads must be greater than 0"));
}
Ok(())
}
}
impl Default for SortParameters {
fn default() -> Self {
Self {
key_type: KeyType::NodeInterval,
records_per_file: Self::DEFAULT_RECORDS_PER_FILE,
files_per_merge: Self::DEFAULT_FILES_PER_MERGE,
buffer_size: Self::DEFAULT_BUFFER_SIZE,
threads: 1,
stable: false,
progress: false,
}
}
}
pub fn sort_gaf<P: AsRef<Path>, Q: AsRef<Path>>(
input_file: P,
output_file: Q,
params: &SortParameters,
) -> Result<usize, String> {
params.validate()?;
let start_time = Instant::now();
if params.progress {
eprintln!("Sorting GAF records with {} worker thread(s)", params.threads);
}
let mut reader = utils::open_file(input_file.as_ref())?;
let header_lines = formats::read_gaf_header_lines(&mut reader)
.map_err(|e| format!("Failed to read GAF header: {}", e))?;
let sort_result = initial_sort(reader, params)?;
let mut temp_files = match sort_result {
InitialSortResult::TempFiles(files) => files,
InitialSortResult::SingleBatch(lines) => {
let total_records = lines.len();
sort_to_output(lines, &header_lines, output_file.as_ref(), params)?;
if params.progress {
let elapsed = start_time.elapsed().as_secs_f64();
eprintln!("Sorted {} records in {:.2} seconds", total_records, elapsed);
}
return Ok(total_records);
}
};
let mut round = 0;
while temp_files.len() > params.files_per_merge {
temp_files = merge_round(temp_files, round, params)?;
round += 1;
}
if params.progress {
eprintln!("Starting the final merge");
}
let total_records = merge_to_output(&temp_files, &header_lines, output_file.as_ref(), params)?;
if params.progress {
let elapsed = start_time.elapsed().as_secs_f64();
eprintln!("Sorted {} records in {:.2} seconds", total_records, elapsed);
}
Ok(total_records)
}
fn initial_sort(reader: Box<dyn BufRead>, params: &SortParameters) -> Result<InitialSortResult, String> {
let mut reader = reader;
if params.progress {
eprintln!("Initial sort: {} records per file", params.records_per_file);
}
let mut workers: Vec<Option<JoinHandle<Result<Arc<TempFile>, String>>>> = Vec::with_capacity(params.threads);
for _ in 0..params.threads {
workers.push(None);
}
let mut outputs = Vec::new();
let mut join = |worker: Option<JoinHandle<Result<Arc<TempFile>, String>>>, thread: usize| -> bool {
if let Some(worker) = worker {
match worker.join() {
Ok(Ok(sorted)) => outputs.push(sorted),
Ok(Err(e)) => {
eprintln!("Worker thread {} failed: {}", thread, e);
return false;
},
Err(_) => {
eprintln!("Worker thread {} panicked", thread);
return false;
},
};
}
true
};
let mut batch = 0;
let mut total_records = 0;
let mut ok = true;
loop {
let mut lines = Vec::new();
for _ in 0..params.records_per_file {
let mut line = Vec::new();
match reader.read_until(b'\n', &mut line) {
Ok(0) => break,
Ok(_) => {
if !line.is_empty() && line[0] != b'@' {
lines.push(line.clone());
}
}
Err(e) => {
eprintln!("Failed to read input: {}", e);
ok = false;
break;
},
}
}
if !ok {
break;
}
total_records += lines.len();
if lines.is_empty() {
break;
}
if batch == 0 {
let buffer = reader.fill_buf().map_err(|e| format!("Failed to read input: {}", e))?;
if buffer.is_empty() {
return Ok(InitialSortResult::SingleBatch(lines));
}
}
let thread = batch % params.threads;
if workers[thread].is_some() {
if !join(workers[thread].take(), thread) {
ok = false;
break;
}
}
let key_type = params.key_type;
let stable = params.stable;
workers[thread] = Some(std::thread::spawn(move || sort_to_temp(lines, key_type, stable)));
batch += 1;
}
for (thread, worker) in workers.into_iter().enumerate() {
if !join(worker, thread) {
ok = false;
}
}
if !ok {
return Err(String::from("Initial sort failed"));
}
if params.progress {
eprintln!(
"Initial sort finished with {} records in {} files",
total_records,
outputs.len()
);
}
Ok(InitialSortResult::TempFiles(outputs))
}
fn merge_round(inputs: Vec<Arc<TempFile>>, round: usize, params: &SortParameters) -> Result<Vec<Arc<TempFile>>, String> {
if params.progress {
eprintln!("Round {}: {} files per batch", round, params.files_per_merge);
}
let mut workers: Vec<Option<JoinHandle<Result<Arc<TempFile>, String>>>> = Vec::with_capacity(params.threads);
for _ in 0..params.threads {
workers.push(None);
}
let mut outputs = Vec::new();
let mut join = |worker: Option<JoinHandle<Result<Arc<TempFile>, String>>>, thread: usize| -> bool {
if let Some(worker) = worker {
match worker.join() {
Ok(Ok(merged)) => outputs.push(merged),
Ok(Err(e)) => {
eprintln!("Worker thread {} failed: {}", thread, e);
return false;
},
Err(_) => {
eprintln!("Worker thread {} panicked", thread);
return false;
},
};
}
true
};
let mut i = 0;
let mut batch = 0;
let mut ok = true;
while i + 1 < inputs.len() {
let end = (i + params.files_per_merge).min(inputs.len());
let thread = batch % params.threads;
if workers[thread].is_some() {
if !join(workers[thread].take(), thread) {
ok = false;
break;
}
}
let batch_files = inputs[i..end].to_vec();
let buffer_size = params.buffer_size;
workers[thread] = Some(std::thread::spawn(move || merge_files(batch_files, buffer_size)));
i = end;
batch += 1;
}
for (thread, worker) in workers.into_iter().enumerate() {
if !join(worker, thread) {
ok = false;
}
}
if i + 1 == inputs.len() {
outputs.push(inputs[i].clone());
}
if !ok {
let msg = format!("Merge round {} failed", round);
return Err(msg);
}
if params.progress {
eprintln!("Round {} finished with {} files", round, outputs.len());
}
Ok(outputs)
}
fn create_output_writer(output_file: &Path) -> Result<Box<dyn Write>, String> {
if output_file == Path::new("-") {
Ok(Box::new(BufWriter::new(io::stdout())))
} else {
let file = File::create(output_file)
.map_err(|e| format!("Failed to create output file: {}", e))?;
Ok(Box::new(BufWriter::new(file)))
}
}
fn sort_to_output(
lines: Vec<Vec<u8>>,
header_lines: &[String],
output_file: &Path,
params: &SortParameters,
) -> Result<(), String> {
let mut records: Vec<GAFRecord> = lines
.into_iter()
.map(|line| GAFRecord::new(line, params.key_type))
.collect();
if params.stable {
records.sort();
} else {
records.sort_unstable();
}
let mut writer = create_output_writer(output_file)?;
for line in header_lines {
writeln!(writer, "{}", line)
.map_err(|e| format!("Failed to write header: {}", e))?;
}
for record in records {
record.write_gaf_line(&mut writer)
.map_err(|e| format!("Failed to write output: {}", e))?;
}
writer.flush()
.map_err(|e| format!("Failed to flush output: {}", e))?;
Ok(())
}
fn sort_to_temp(lines: Vec<Vec<u8>>, key_type: KeyType, stable: bool) -> Result<Arc<TempFile>, String> {
let mut records: Vec<GAFRecord> = lines
.into_iter()
.map(|line| GAFRecord::new(line, key_type))
.collect();
if stable {
records.sort();
} else {
records.sort_unstable();
}
let mut temp = TempFile::create()
.map_err(|e| format!("Failed to create temporary file: {}", e))?;
temp.records = records.len();
let mut writer = temp.writer()
.map_err(|e| format!("Failed to open temporary file for writing: {}", e))?;
for record in records {
record.serialize(&mut writer)
.map_err(|e| format!("Failed to write to temporary file: {}", e))?;
}
writer.into_inner()
.map_err(|e| format!("Failed to finish compression: {}", e))?
.finish()
.map_err(|e| format!("Failed to finish compression: {}", e))?;
Ok(Arc::new(temp))
}
fn flip_source_index(index: usize, total: usize) -> usize {
total - 1 - index
}
fn merge_files(inputs: Vec<Arc<TempFile>>, buffer_size: usize) -> Result<Arc<TempFile>, String> {
let mut output = TempFile::create()
.map_err(|e| format!("Failed to create temporary file: {}", e))?;
let mut readers: Vec<_> = inputs
.iter()
.map(|temp| temp.reader())
.collect::<Result<_, _>>()
.map_err(|e| format!("Failed to open temporary file for reading: {}", e))?;
let mut buffers: Vec<VecDeque<GAFRecord>> = vec![VecDeque::new(); readers.len()];
let mut remaining: Vec<usize> = inputs.iter().map(|t| t.records).collect();
let read_buffer = |reader_idx: usize,
readers: &mut Vec<_>,
buffers: &mut Vec<VecDeque<GAFRecord>>,
remaining: &mut Vec<usize>|
-> Result<(), String> {
let count = remaining[reader_idx].min(buffer_size);
if count > 0 {
buffers[reader_idx].clear();
for _ in 0..count {
let mut record = GAFRecord::deserialize(&mut readers[reader_idx])
.map_err(|e| format!("Failed to read from temporary file: {}", e))?;
record.flip_key(); buffers[reader_idx].push_back(record);
}
remaining[reader_idx] -= count;
}
Ok(())
};
for i in 0..readers.len() {
read_buffer(i, &mut readers, &mut buffers, &mut remaining)?;
}
let mut heap = BinaryHeap::new();
for (i, buffer) in buffers.iter_mut().enumerate() {
if let Some(record) = buffer.pop_front() {
heap.push((record, flip_source_index(i, inputs.len())));
}
}
let mut writer = output.writer()
.map_err(|e| format!("Failed to open output temporary file for writing: {}", e))?;
let mut out_buffer = Vec::new();
while let Some((mut record, mut source)) = heap.pop() {
record.flip_key();
source = flip_source_index(source, inputs.len());
out_buffer.push(record);
if out_buffer.len() >= buffer_size {
for rec in out_buffer.drain(..) {
rec.serialize(&mut writer)
.map_err(|e| format!("Failed to write to output: {}", e))?;
output.records += 1;
}
}
if buffers[source].is_empty() && remaining[source] > 0 {
read_buffer(source, &mut readers, &mut buffers, &mut remaining)?;
}
if let Some(next) = buffers[source].pop_front() {
heap.push((next, flip_source_index(source, inputs.len())));
}
}
for rec in out_buffer {
rec.serialize(&mut writer)
.map_err(|e| format!("Failed to write to output: {}", e))?;
output.records += 1;
}
writer.into_inner()
.map_err(|e| format!("Failed to finish compression: {}", e))?
.finish()
.map_err(|e| format!("Failed to finish compression: {}", e))?;
Ok(Arc::new(output))
}
fn merge_to_output(
inputs: &[Arc<TempFile>],
header_lines: &[String],
output_file: &Path,
params: &SortParameters,
) -> Result<usize, String> {
let mut writer = create_output_writer(output_file)?;
for line in header_lines {
writeln!(writer, "{}", line)
.map_err(|e| format!("Failed to write header: {}", e))?;
}
let mut readers: Vec<_> = inputs
.iter()
.map(|temp| temp.reader())
.collect::<Result<_, _>>()
.map_err(|e| format!("Failed to open temporary file for reading: {}", e))?;
let total_records: usize = inputs.iter().map(|t| t.records).sum();
let mut buffers: Vec<VecDeque<GAFRecord>> = vec![VecDeque::new(); readers.len()];
let mut remaining: Vec<usize> = inputs.iter().map(|t| t.records).collect();
let read_buffer = |reader_idx: usize,
readers: &mut Vec<_>,
buffers: &mut Vec<VecDeque<GAFRecord>>,
remaining: &mut Vec<usize>|
-> Result<(), String> {
let count = remaining[reader_idx].min(params.buffer_size);
if count > 0 {
buffers[reader_idx].clear();
for _ in 0..count {
let mut record = GAFRecord::deserialize(&mut readers[reader_idx])
.map_err(|e| format!("Failed to read from temporary file: {}", e))?;
record.flip_key();
buffers[reader_idx].push_back(record);
}
remaining[reader_idx] -= count;
}
Ok(())
};
for i in 0..readers.len() {
read_buffer(i, &mut readers, &mut buffers, &mut remaining)?;
}
let mut heap = BinaryHeap::new();
for (i, buffer) in buffers.iter_mut().enumerate() {
if let Some(record) = buffer.pop_front() {
heap.push((record, flip_source_index(i, inputs.len())));
}
}
let mut out_buffer = Vec::new();
while let Some((mut record, mut source)) = heap.pop() {
record.flip_key();
source = flip_source_index(source, inputs.len());
out_buffer.push(record);
if out_buffer.len() >= params.buffer_size {
for rec in out_buffer.drain(..) {
rec.write_gaf_line(&mut writer)
.map_err(|e| format!("Failed to write to output: {}", e))?;
}
}
if buffers[source].is_empty() && remaining[source] > 0 {
read_buffer(source, &mut readers, &mut buffers, &mut remaining)?;
}
if let Some(next) = buffers[source].pop_front() {
heap.push((next, flip_source_index(source, inputs.len())));
}
}
for rec in out_buffer {
rec.write_gaf_line(&mut writer)
.map_err(|e| format!("Failed to write to output: {}", e))?;
}
writer.flush()
.map_err(|e| format!("Failed to flush output: {}", e))?;
Ok(total_records)
}