use num_traits::Float;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::iter::Sum;
use crate::prelude::*;
pub struct SearchState<T> {
pub visited: Vec<usize>,
pub visit_id: usize,
pub candidates: BinaryHeap<Reverse<(OrderedFloat<T>, usize)>>,
pub working_sorted: SortedBuffer<(OrderedFloat<T>, usize)>,
pub scratch_working: Vec<(OrderedFloat<T>, usize)>,
pub scratch_discarded: Vec<(OrderedFloat<T>, usize)>,
}
impl<T> SearchState<T>
where
T: Float + Sum,
{
pub fn new(capacity: usize) -> Self {
Self {
visited: vec![0; capacity],
visit_id: 1,
candidates: BinaryHeap::with_capacity(capacity),
working_sorted: SortedBuffer::with_capacity(capacity),
scratch_working: Vec::with_capacity(capacity),
scratch_discarded: Vec::with_capacity(capacity),
}
}
pub fn reset(&mut self, n: usize) {
if self.visited.len() < n {
self.visited.resize(n, 0);
}
self.visit_id = self.visit_id.wrapping_add(1);
if self.visit_id == 0 {
self.visited.fill(0);
self.visit_id = 1;
}
self.candidates.clear();
self.working_sorted.clear();
self.scratch_working.clear();
self.scratch_discarded.clear();
}
#[inline(always)]
pub fn is_visited(&self, node: usize) -> bool {
self.visited[node] == self.visit_id
}
#[inline(always)]
pub fn mark_visited(&mut self, node: usize) {
self.visited[node] = self.visit_id;
}
}