use std::{
collections::BinaryHeap,
sync::atomic::{AtomicBool, Ordering},
};
use datasize::DataSize;
use crate::utils::lengths::TotalLength;
use super::{path::Paths, prefix::CompPrefix, Progress, Search, Update};
const ITERS_BETWEEN_ABORT_CHECKS: usize = 10_000;
const ITERS_BETWEEN_PROGRESS_UPDATES: usize = 100_000;
const ITERS_BETWEEN_PATH_GCS: usize = 100_000_000;
pub(crate) fn search(search: &Search, mut update_fn: impl FnMut(Update), abort_flag: &AtomicBool) {
let mut paths = Paths::new();
let mut frontier: BinaryHeap<CompPrefix> = CompPrefix::starts(&search.graph, &mut paths);
let prefix_size = frontier.peek().unwrap().size();
let mut iter_count = 0;
let mut num_comps = 0;
macro_rules! send_progress_update {
(truncating_queue = $truncating_queue: expr) => {
send_progress_update(
&frontier,
&mut update_fn,
iter_count,
num_comps,
$truncating_queue,
);
};
}
while let Some(prefix) = frontier.pop() {
let maybe_comp = prefix.expand(search, &mut paths, &mut frontier);
if let Some(comp) = maybe_comp {
update_fn(Update::Comp(comp));
num_comps += 1;
if num_comps == search.query.num_comps {
break; }
}
let mem_usage = frontier.len() * prefix_size + paths.estimate_heap_size();
if mem_usage >= search.config.mem_limit {
send_progress_update!(truncating_queue = true);
truncate_queue(frontier.len() / 2, &mut frontier);
paths.gc(frontier.iter().map(|prefix| prefix.path_head()));
send_progress_update!(truncating_queue = false);
}
iter_count += 1;
if iter_count % ITERS_BETWEEN_ABORT_CHECKS == 0 && abort_flag.load(Ordering::Relaxed) {
update_fn(Update::Aborting);
break;
}
if iter_count % ITERS_BETWEEN_PROGRESS_UPDATES == 0 {
send_progress_update!(truncating_queue = false);
}
if iter_count % ITERS_BETWEEN_PATH_GCS == 0 {
paths.gc(frontier.iter().map(|prefix| prefix.path_head()));
}
}
send_progress_update!(truncating_queue = false);
if search.config.leak_search_memory {
std::mem::forget(frontier);
}
}
fn send_progress_update(
frontier: &BinaryHeap<CompPrefix>,
update_fn: &mut impl FnMut(Update),
iter_count: usize,
num_comps: usize,
truncating_queue: bool,
) {
let mut total_len = 0u64; let mut max_length = TotalLength::ZERO;
frontier.iter().for_each(|n| {
total_len += n.length().as_usize() as u64;
max_length = max_length.max(n.length());
});
update_fn(Update::Progress(Progress {
iter_count,
num_comps,
queue_len: frontier.len(),
avg_length: if frontier.is_empty() {
0.0 } else {
total_len as f32 / frontier.len() as f32
},
max_length: max_length.as_usize(),
truncating_queue,
}));
}
fn truncate_queue<T: Ord>(len: usize, queue: &mut BinaryHeap<T>) {
let heap = std::mem::take(queue);
let mut chunks = heap.into_vec();
chunks.sort_by(|a, b| b.cmp(a)); if len < chunks.len() {
chunks.drain(len..);
}
*queue = BinaryHeap::from(chunks);
}