1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
use std::cmp::max;
use std::ops::Range;
use std::sync::Mutex;
use crate::kmer::KmerIterator;
use crate::kmer::LongKmer;
use crate::util::is_dna;
use crossbeam::channel::Receiver;
use crossbeam::channel::Sender;
use rayon::iter::IntoParallelIterator;
use rayon::iter::IntoParallelRefMutIterator;
use rayon::iter::ParallelIterator;
// There is one bin for each 3-mer, so there are 4^3 = 64 bins. This is a
// hardcoded assumption in few places in this file, so don't touch these constants
// if you're not willing to go through the whole file and find those places.
const BIN_PREFIX_LEN: usize = 3;
const N_BINS: usize = 64;
struct SeqBatch {
pub concat: Vec<u8>,
pub starts: Vec<usize>, // Also has concat.len() at the end
}
impl SeqBatch {
fn iter<'a>(&'a self) -> SeqBatchIterator<'a> {
SeqBatchIterator{batch: self, index: 0}
}
fn get_seq(&self, idx: usize) -> &[u8] {
&self.concat[self.starts[idx]..self.starts[idx+1]]
}
fn len(&self) -> usize {
self.starts.len() - 1 // Has concat.len() at the end
}
fn reverse_all(&mut self) {
// There should be at least 1 start (in that case the batch is empty).
// If starts is empty, then the user forgot to add the end sentinel.
assert!(!self.starts.is_empty());
// Compute starts in reverse concat
let mut rev_starts = Vec::<usize>::with_capacity(self.starts.len());
rev_starts.push(0);
for i in (0..self.len()).rev() {
rev_starts.push(rev_starts.last().unwrap() + self.starts[i+1] - self.starts[i]);
}
self.concat.reverse();
self.starts = rev_starts;
}
}
struct SeqBatchIterator<'a> {
batch: &'a SeqBatch,
index: usize, // Current iteration index
}
impl<'a> Iterator for SeqBatchIterator<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
if self.index == self.batch.len() {
None
} else {
let seq = self.batch.get_seq(self.index);
self.index += 1;
Some(seq)
}
}
}
fn input_parsing_thread<IN: crate::SeqStream + Send>(mut seqs: IN, buf_cap: usize, out: Sender<SeqBatch>){
let mut cur_concat = Vec::<u8>::with_capacity(buf_cap);
let mut cur_starts = Vec::<usize>::new();
while let Some(seq) = seqs.stream_next(){
// Add to concatenation
cur_starts.push(cur_concat.len());
cur_concat.extend(seq);
if cur_concat.len() >= buf_cap {
cur_starts.push(cur_concat.len()); // End sentinel, as required
let batch = SeqBatch{concat: cur_concat, starts: cur_starts};
out.send(batch).unwrap();
// Start a new batch
cur_concat = Vec::<u8>::with_capacity(buf_cap);
cur_starts = Vec::<usize>::new();
}
}
if !cur_concat.is_empty() {
// Send remaining batch
cur_starts.push(cur_concat.len()); // End sentinel, as required
let batch = SeqBatch{concat: cur_concat, starts: cur_starts};
out.send(batch).unwrap();
}
log::info!("Producer thread: all work pushed to work queue");
drop(out);
}
// The return value is Some if store_first_mers is true
fn kmer_encoder_thread<const B: usize>(input: Receiver<SeqBatch>, output: Sender<Vec<LongKmer<B>>>, shared_bin_buffers: &[Mutex::<Vec::<LongKmer::<B>>>], k: usize, thread_local_buf_caps: usize, shared_buf_caps: usize, dedup_batches: bool, store_first_mers: bool) -> Option<Vec<(LongKmer<B>, u8)>> {
assert!(shared_bin_buffers.len() == N_BINS);
let mut this_thread_bin_buffers = vec![Vec::<LongKmer::<B>>::new(); N_BINS];
let mut first_mers = if store_first_mers {
Some(Vec::<(LongKmer::<B>, u8)>::new())
} else {
None
};
let first_mers_ref = &mut first_mers; // To capture in a closure
while let Ok(mut batch) = input.recv(){
// Reverse to get colex sorting
batch.reverse_all();
for seq in batch.iter(){
if store_first_mers {
crate::util::for_each_run_with_key(seq, |c| is_dna(*c), |mut run_range: Range<usize>| {
if !run_range.is_empty() && is_dna(seq[run_range.start]) {
// Take the last up to k characters
if run_range.len() > k {
run_range = run_range.end-k..run_range.end;
}
let mer = LongKmer::<B>::from_ascii(&seq[run_range.clone()]).unwrap();
first_mers_ref.as_mut().unwrap().push((mer, run_range.len() as u8));
}
});
}
for kmer in KmerIterator::<B>::new(seq, k) {
let bin_id = kmer.get_from_left(0) as usize * 16 + kmer.get_from_left(1) as usize * 4 + kmer.get_from_left(2) as usize; // Interpret nucleotides in base-4
this_thread_bin_buffers[bin_id].push(kmer);
if this_thread_bin_buffers[bin_id].len() >= thread_local_buf_caps {
// Move this local bin buffer to a shared buffer
let mut shared_bin = shared_bin_buffers[bin_id].lock().unwrap();
shared_bin.extend(&this_thread_bin_buffers[bin_id]);
this_thread_bin_buffers[bin_id].clear();
if shared_bin.len() >= shared_buf_caps {
// Flush shared bin to the collector thread
if dedup_batches {
let mut shared_bin_owned: Vec<LongKmer<B>> = std::mem::take(shared_bin.as_mut());
drop(shared_bin); // Release the mutex and proceed to sort
let len_before = shared_bin_owned.len();
shared_bin_owned.sort_unstable();
shared_bin_owned.dedup();
shared_bin_owned.shrink_to_fit(); // This will probably lock the heap memory manager but that's ok
let len_after = shared_bin_owned.len();
log::debug!("Deduplicated batch of {} kmers ({:.2}% kept)", len_before, len_after as f64 / len_before as f64 * 100.0);
output.send(shared_bin_owned).unwrap();
} else {
output.send(shared_bin.clone()).unwrap();
shared_bin.clear();
}
}
}
}
}
}
// Send remaining internal buffers of this thread
for mut b in this_thread_bin_buffers.into_iter() {
if dedup_batches{
log::debug!("Sorting remaining thread-local batch of {} kmers", b.len());
b.sort_unstable();
b.dedup();
}
output.send(b).unwrap();
}
first_mers
}
fn collector_thread<const B: usize>(input: Receiver<Vec<LongKmer<B>>>) -> Vec<Vec<Vec<LongKmer<B>>>> {
// Vec of shared bin batches for each of the 64 3-mers (concatenated at the end)
let mut collected_shared_bins = vec![Vec::<Vec::<LongKmer::<B>>>::new(); N_BINS];
while let Ok(batch) = input.recv(){
if !batch.is_empty() {
// This part assumes that BIN_PREFIX_LEN = 3 and N_BINS = 64
let bin_id = batch[0].get_from_left(0) as usize * 16 + batch[0].get_from_left(1) as usize * 4 + batch[0].get_from_left(2) as usize; // Intepret nucleotides in base-4
collected_shared_bins[bin_id].push(batch);
}
}
collected_shared_bins
}
fn determine_buf_capacities<const B: usize>(approx_mem_gb: usize, n_threads: usize) -> (usize, usize, usize) {
// Calculating suitable buffer sizes. Assuming there are 64 bins.
// * producer_buf_capacity: This determines how many k-mers are pushed to the work queue
// at a time. This needs to be large enough to avoid parallel contention, but small enough
// to distribute work quickly and evenly. A good default is 2^20.
// * thread_local_bin_buf_capacity. There is one thread-local buffer for each of the 64
// distinct 3-mers. Each of these buffers up to thread_local_bin_buf_capacity k-mers.
// So if thread_local_bin_buf_capacity is C_t, then the total space is:
// C_t * n_threads * 64 * sizeof(LongKmer<B>>
// = C_t * n_threads * 64 * 8B
// = 512 * C_t * n_threads * B bytes
// The bigger the C_t, the less parallel contention there will be. If for example you have
// 1GB memory available for this over 48 threads and k = 31 (B = 1), then you'll want to set
// C_t to about 40,000.
// * shared_bin_buf_capacity: There is one shared buffer for each of the 64 3-mer bins. Each
// of these contains up to shared_bin_buf_capacity k-mers. If `dedup_batches` is enabled, then
// the shared bin buffers are sorted and deduplicated before storing to memory. The larger the
// buffer size, the more duplicates we will find. If the capacity is C_s, then the total space
// for the shared buffers will be:
// 64 * C_s * sizeof(LongKmer<B>) = 512 * C_s * B
// If dedup_batches is not enabled, or your data has little to no duplicates, this buffer does
// not matter so much. Otherwise, it's a good idea to put all your available extra memory
// here to deduplicate as much as possible to decrease the peak memory.
// For example, if you have 512GB available for this with k = 31 (B = 1), set C_s to 512GB / 512 = 1 GB.
let kmer_bytes = std::mem::size_of::<crate::kmer::LongKmer<B>>();
let producer_buf_cap = (1_usize << 23) / kmer_bytes; // 8 MB of k-mers
let thread_local_buf_caps = 1_usize << 16;
let bytes_remaining = approx_mem_gb as isize * (1 << 30) - (producer_buf_cap * kmer_bytes) as isize - (thread_local_buf_caps * n_threads * N_BINS * kmer_bytes) as isize;
let bytes_remaining = max(bytes_remaining, 1_isize << 30) as usize; // Use at least 1 GB
// The total memory in the shared buffers will be:
// N_BINS * buf_cap * kmer_bytes
// Set this equal to bytes_remaining and solve for buf_cap:
let shared_buf_caps = bytes_remaining / N_BINS / kmer_bytes;
assert!(shared_buf_caps > 0); // Should be because bytes_remaining >= 1GB
// One bin per thread per 3-mer
let thread_local_total_bytes = thread_local_buf_caps * n_threads * N_BINS * kmer_bytes;
let shared_total_bytes = shared_buf_caps * N_BINS * kmer_bytes;
log::info!("Producer buffer capacity: {}", human_bytes::human_bytes((producer_buf_cap * kmer_bytes) as f64));
log::info!("Thread-local buffer capacity: {} ({} total)", human_bytes::human_bytes((thread_local_buf_caps * kmer_bytes) as f64), human_bytes::human_bytes(thread_local_total_bytes as f64));
log::info!("Shared bin buffer capacity: {} ({} total)", human_bytes::human_bytes((shared_buf_caps * kmer_bytes) as f64), human_bytes::human_bytes(shared_total_bytes as f64));
if producer_buf_cap*kmer_bytes + thread_local_total_bytes + shared_total_bytes > approx_mem_gb * (1_usize << 30) {
log::warn!("Exceeding memory budget");
}
(producer_buf_cap, thread_local_buf_caps, shared_buf_caps)
}
// Rayon thread pool must be initialized before calling
// Returns the bitpacked k-mers and if requested, the *reverse* of the start of
// every run of ACGT characters, up to length k. For example, if the sequence
// is ACGTNNNNNACNN and k = 3, returns (GCA, 3) and (CA, 2)
pub fn get_bitpacked_sorted_distinct_kmers<const B: usize, IN: crate::SeqStream + Send>(
seqs: IN,
k: usize,
n_threads: usize,
dedup_batches: bool,
store_first_mers: bool,
approx_mem_gb: usize
) -> (Vec<LongKmer<B>>, Option<Vec<(LongKmer<B>, u8)>>) {
assert!(k >= BIN_PREFIX_LEN);
let (producer_buf_cap, thread_local_buf_caps, shared_buf_caps) = determine_buf_capacities::<B>(approx_mem_gb, n_threads);
log::info!("Allocating shared buffers");
let mut shared_bin_buffers_vec = Vec::<Mutex::<Vec::<LongKmer::<B>>>>::new();
for _ in 0..N_BINS {
let buf = Vec::<LongKmer::<B>>::with_capacity(shared_buf_caps);
let b = Mutex::new(buf);
shared_bin_buffers_vec.push(b);
};
let shared_bin_buffers = &shared_bin_buffers_vec; // This is shared with threads
log::info!("Bitpacking and binning k-mers");
// Wrap to scope to be able to borrow seqs for the producer thread even when it's not 'static.
let (mut bins, first_mers) = std::thread::scope(|scope| {
use crossbeam::crossbeam_channel::unbounded;
let (parser_out, encoder_in) = unbounded();
let (encoder_out, collector_in) = unbounded();
// Create producer
let producer_handle = scope.spawn(move || {
input_parsing_thread(seqs, producer_buf_cap, parser_out);
});
// Create encoders
let mut encoder_handles = Vec::<std::thread::ScopedJoinHandle::<_>>::new();
for _ in 0..n_threads {
let encoder_in_clone = encoder_in.clone();
let encoder_out_clone = encoder_out.clone();
encoder_handles.push(scope.spawn(move || {
kmer_encoder_thread(encoder_in_clone, encoder_out_clone, shared_bin_buffers, k, thread_local_buf_caps, shared_buf_caps, dedup_batches, store_first_mers)
}));
}
// Spawn a collector that reads from the encoders and pushes to final bins
let collector_handle = std::thread::spawn( move || {
collector_thread(collector_in)
});
producer_handle.join().unwrap(); // Wait for the producer to finish
drop(encoder_in); // Close the channel
let mut first_mers = if store_first_mers { Some(Vec::<(LongKmer<B>, u8)>::new()) } else { None };
for h in encoder_handles { // Wait for the encoders to finish
if let Some(mers) = h.join().unwrap() {
first_mers.as_mut().unwrap().extend(mers); // Unwrap is okay because we end up here only if store_first_mers is set
}
}
// Send remaining shared batches
shared_bin_buffers.into_par_iter().for_each(|sb| {
let mut sb = sb.lock().unwrap();
if dedup_batches && !sb.is_empty(){
log::debug!("Sorting remaining shared batch of {} kmers", sb.len());
sb.sort_unstable();
sb.dedup();
}
encoder_out.send(sb.clone()).unwrap(); // TODO: no clone.
sb.clear();
sb.shrink_to_fit();
});
drop(encoder_out); // Close the channel
// Wait for the collector to finish and concatenate each sequence of shared bins
let collected_shared_bin_seqs = collector_handle.join().unwrap();
let bins: Vec<Vec<LongKmer<B>>> = collected_shared_bin_seqs.into_par_iter().map(|pieces| {
// Concatenate all to the first piece (TODO: here if the
// pieces are sorted, we can just merge the runs and the final sort
// is done also).
if pieces.is_empty() {
vec![]
} else {
let mut piece_iter = pieces.into_iter();
let mut first = piece_iter.next().unwrap();
for next_piece in piece_iter {
first.extend(next_piece);
}
first.shrink_to_fit();
first
}
}).collect();
(bins, first_mers)
});
// Sort bins in parallel: todo: largest first
log::info!("Sorting k-mer bins");
bins.par_iter_mut().for_each(|bin| {
if !bin.is_empty() {
let label = &bin.first().unwrap().to_string()[0..BIN_PREFIX_LEN];
log::info!("Sorting bin {} of size {}", label, bin.len());
// Sort the bin. Here if dedup_batches is enabled, we could instead
// merge sorted runs using a min-heap data structure, but then we need
// extra auxiliary space, up to 2x the total space. We could also use
// the standard stable sort, which is said to perform well on partially
// sorted inputs, but that also needs auxiliary space. At this point space
// is critical because this is likely near the space peak of the whole
// algorithm. So we use an unstable sort, which is fast in practice and
// does not require any extra space.
bin.sort_unstable();
bin.dedup();
bin.shrink_to_fit();
} else {
log::info!("Empty bin -> not sorting.");
}
});
let mut bin_concat = Vec::<LongKmer::<B>>::new();
log::info!("Concatenating k-mer bins");
// Concat bins. Each of them is sorted, and they are bucketed by the first 3-mer
// in order, so the concatenation is sorted. This is single-threaded so this
// takes a while. We could multithread this part, but then we need to allocate
// space for the concatenation up front, which up to doubles the space if we don't
// do any extra trick to save space.
for bin in bins {
bin_concat.extend(bin.iter());
}
(bin_concat, first_mers)
}
#[allow(dead_code)] // Might be useful later
fn merge_sorted_unique_in_place<const B: usize>(a: &mut Vec<LongKmer<B>>, b: Vec<LongKmer<B>>) {
// Step 1: Extend `a` with enough space
let original_len = a.len();
let total_capacity = original_len + b.len();
a.resize(total_capacity, LongKmer::from_u64_data([0; B])); // Fill with placeholders
// Step 2: Pointers for merge
let mut i = original_len as isize - 1; // Last element of original `a`
let mut j = b.len() as isize - 1; // Last element of `b`
let mut k = total_capacity as isize - 1; // End of resized `a`
// Step 3: Merge from back to front
while i >= 0 && j >= 0 {
let va = a[i as usize];
let vb = b[j as usize];
#[allow(clippy::comparison_chain)]
if va > vb {
a[k as usize] = va;
i -= 1;
} else if va < vb {
a[k as usize] = vb;
j -= 1;
} else {
// Equal values, keep one
a[k as usize] = va;
i -= 1;
j -= 1;
}
k -= 1;
}
// Copy remaining elements
while i >= 0 {
a[k as usize] = a[i as usize];
i -= 1;
k -= 1;
}
while j >= 0 {
a[k as usize] = b[j as usize];
j -= 1;
k -= 1;
}
// Step 4: Remove duplicates and shift left. Todo: no Option
let mut last: Option<LongKmer<B>> = None;
let mut write = 0;
for read in (k + 1) as usize..total_capacity {
if Some(a[read]) != last {
a[write] = a[read];
last = Some(a[read]);
write += 1;
}
}
// Step 5: Truncate the vector
a.truncate(write);
a.shrink_to_fit();
}